asr_websocket_service.py
11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
# -*- coding: utf-8 -*-
"""
AIfeng/2025-07-15 14:41:21
ASR WebSocket服务实现
从app.py中抽离的ASR相关WebSocket功能
"""
import asyncio
import json
import weakref
from typing import Dict, Any, Optional
from aiohttp import web
from logger import logger
from .websocket_service_base import WebSocketServiceBase
from .unified_websocket_manager import WebSocketSession
class ASRWebSocketService(WebSocketServiceBase):
"""ASR WebSocket服务"""
def __init__(self):
super().__init__("asr_service")
# ASR连接管理
self.asr_connections: Dict[str, Any] = {} # sessionid -> asr_connection
self._heartbeat_task = None
async def _register_message_handlers(self):
"""注册ASR相关消息处理器"""
# 注册消息处理器
self.manager.register_message_handler('login', self._handle_login)
self.manager.register_message_handler('heartbeat', self._handle_heartbeat)
self.manager.register_message_handler('asr_audio_data', self._handle_asr_audio_data)
self.manager.register_message_handler('start_asr_recognition', self._handle_start_asr_recognition)
self.manager.register_message_handler('stop_asr_recognition', self._handle_stop_asr_recognition)
async def _start_background_tasks(self):
"""启动心跳检测任务"""
self._heartbeat_task = self.add_background_task(self._heartbeat_monitor())
async def _cleanup(self):
"""清理ASR连接"""
# 关闭所有ASR连接
for session_id, asr_conn in list(self.asr_connections.items()):
try:
if hasattr(asr_conn, 'close'):
await asr_conn.close()
except Exception as e:
logger.error(f'关闭ASR连接失败 {session_id}: {e}')
self.asr_connections.clear()
async def _on_session_disconnected(self, session: WebSocketSession):
"""会话断开时清理ASR连接"""
await super()._on_session_disconnected(session)
# 清理对应的ASR连接
if session.session_id in self.asr_connections:
asr_conn = self.asr_connections.pop(session.session_id)
try:
if hasattr(asr_conn, 'close'):
await asr_conn.close()
logger.info(f'已清理ASR连接: {session.session_id}')
except Exception as e:
logger.error(f'清理ASR连接失败 {session.session_id}: {e}')
async def _handle_login(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理登录消息"""
session = self.manager.get_session(websocket)
if not session:
return
session_id = data.get('sessionid')
if session_id:
# 更新会话ID
old_session_id = session.session_id
session.session_id = session_id
# 更新管理器中的会话映射
self.manager._update_session_id(websocket, old_session_id, session_id)
# 发送登录成功响应
await session.send_message({
"type": "login_response",
"data": {
"status": "success",
"sessionid": session_id,
"message": "登录成功"
}
})
logger.info(f'用户登录成功: {session_id}')
else:
await session.send_message({
"type": "login_response",
"data": {
"status": "error",
"message": "缺少sessionid"
}
})
async def _handle_heartbeat(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理心跳消息"""
session = self.manager.get_session(websocket)
if session:
session.update_last_heartbeat()
await session.send_message({
"type": "heartbeat_response",
"data": {"status": "ok"}
})
async def _handle_asr_audio_data(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理ASR音频数据"""
session = self.manager.get_session(websocket)
if not session:
return
session_id = session.session_id
audio_data = data.get('audio_data')
if not audio_data:
await session.send_message({
"type": "error",
"data": {"message": "缺少音频数据"}
})
return
# 获取或创建ASR连接
asr_conn = self.asr_connections.get(session_id)
if not asr_conn:
logger.warning(f'ASR连接不存在: {session_id}')
await session.send_message({
"type": "error",
"data": {"message": "ASR连接未建立"}
})
return
try:
# 转发音频数据到ASR服务
await self._forward_audio_to_asr(asr_conn, audio_data)
except Exception as e:
logger.error(f'转发音频数据失败 {session_id}: {e}')
await session.send_message({
"type": "error",
"data": {"message": f"音频处理失败: {str(e)}"}
})
async def _handle_start_asr_recognition(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理开始ASR识别"""
session = self.manager.get_session(websocket)
if not session:
return
session_id = session.session_id
try:
# 创建ASR连接
asr_conn = await self._create_asr_connection(session_id)
if asr_conn:
self.asr_connections[session_id] = asr_conn
await session.send_message({
"type": "asr_recognition_started",
"data": {
"status": "success",
"message": "ASR识别已开始"
}
})
logger.info(f'ASR识别已开始: {session_id}')
else:
await session.send_message({
"type": "error",
"data": {"message": "创建ASR连接失败"}
})
except Exception as e:
logger.error(f'开始ASR识别失败 {session_id}: {e}')
await session.send_message({
"type": "error",
"data": {"message": f"开始识别失败: {str(e)}"}
})
async def _handle_stop_asr_recognition(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理停止ASR识别"""
session = self.manager.get_session(websocket)
if not session:
return
session_id = session.session_id
if session_id in self.asr_connections:
asr_conn = self.asr_connections.pop(session_id)
try:
if hasattr(asr_conn, 'close'):
await asr_conn.close()
await session.send_message({
"type": "asr_recognition_stopped",
"data": {
"status": "success",
"message": "ASR识别已停止"
}
})
logger.info(f'ASR识别已停止: {session_id}')
except Exception as e:
logger.error(f'停止ASR识别失败 {session_id}: {e}')
await session.send_message({
"type": "error",
"data": {"message": f"停止识别失败: {str(e)}"}
})
else:
await session.send_message({
"type": "asr_recognition_stopped",
"data": {
"status": "success",
"message": "ASR识别未在运行"
}
})
async def _create_asr_connection(self, session_id: str):
"""创建ASR连接(需要根据实际ASR服务实现)"""
# TODO: 这里需要根据实际的ASR服务(如FunASR)来实现连接逻辑
# 暂时返回一个模拟连接对象
logger.info(f'创建ASR连接: {session_id}')
# 示例:创建到FunASR的WebSocket连接
try:
# 这里应该是实际的ASR连接逻辑
# 例如:asr_conn = await create_funasr_connection(session_id, self._on_asr_result)
asr_conn = MockASRConnection(session_id, self._on_asr_result)
return asr_conn
except Exception as e:
logger.error(f'创建ASR连接失败 {session_id}: {e}')
return None
async def _forward_audio_to_asr(self, asr_conn, audio_data):
"""转发音频数据到ASR服务"""
if hasattr(asr_conn, 'send_audio'):
await asr_conn.send_audio(audio_data)
else:
logger.warning('ASR连接不支持发送音频数据')
async def _on_asr_result(self, session_id: str, result: Dict[str, Any]):
"""ASR结果回调"""
try:
await self.broadcast_to_session(session_id, 'asr_result', result)
logger.debug(f'ASR结果已发送: {session_id}')
except Exception as e:
logger.error(f'发送ASR结果失败 {session_id}: {e}')
async def _heartbeat_monitor(self):
"""心跳监控任务"""
while True:
try:
await asyncio.sleep(40) # 每40秒检查一次
# 检查会话心跳
expired_sessions = self.manager.get_expired_sessions(timeout=60)
for session in expired_sessions:
logger.info(f'会话心跳超时,断开连接: {session.session_id}')
await session.close()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f'心跳监控异常: {e}')
await asyncio.sleep(5)
def get_asr_stats(self) -> Dict[str, Any]:
"""获取ASR统计信息"""
return {
"active_asr_connections": len(self.asr_connections),
"asr_sessions": list(self.asr_connections.keys())
}
class MockASRConnection:
"""模拟ASR连接(用于测试)"""
def __init__(self, session_id: str, result_callback):
self.session_id = session_id
self.result_callback = result_callback
self.is_closed = False
async def send_audio(self, audio_data):
"""发送音频数据"""
if self.is_closed:
return
# 模拟ASR处理
await asyncio.sleep(0.1)
# 模拟返回识别结果
result = {
"text": "模拟识别结果",
"confidence": 0.95,
"timestamp": asyncio.get_event_loop().time()
}
if self.result_callback:
await self.result_callback(self.session_id, result)
async def close(self):
"""关闭连接"""
self.is_closed = True
logger.info(f'模拟ASR连接已关闭: {self.session_id}')
# 创建ASR服务实例
asr_service = ASRWebSocketService()
def get_asr_service() -> ASRWebSocketService:
"""获取ASR服务实例"""
return asr_service