冯杨

ASR WebSocket服务实现:1.FunASR本地方案 2.豆包

1.已实现音频文件的处理,包括小文件直接转换以及大文件分割识别。
2.豆包接入流式识别,但封装仍需要修改

Too many changes to show.

To preserve performance only 28 of 28+ files are displayed.

@@ -20,3 +20,4 @@ workspace/log_ngp.txt @@ -20,3 +20,4 @@ workspace/log_ngp.txt
20 models/ 20 models/
21 *.log 21 *.log
22 dist 22 dist
  23 +.vscode/launch.json
1 -python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar7 --fullbody_height 1722 --fullbody_width 1080  
  1 +python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar10 --fullbody_height 1920 --fullbody_width 1080
@@ -49,120 +49,42 @@ import gc @@ -49,120 +49,42 @@ import gc
49 import weakref 49 import weakref
50 import time 50 import time
51 51
  52 +# 注意:server_recording_api模块已移除,相关功能已迁移到其他模块
  53 +# 导入新的统一WebSocket管理架构
  54 +from core.app_websocket_migration import (
  55 + get_app_websocket_migration,
  56 + initialize_app_websocket_migration,
  57 + setup_app_websocket_routes,
  58 + broadcast_message_to_session,
  59 + handle_asr_audio_data,
  60 + handle_start_asr_recognition,
  61 + handle_stop_asr_recognition,
  62 + send_asr_result,
  63 + send_normal_asr_result
  64 +)
52 65
53 app = Flask(__name__) 66 app = Flask(__name__)
54 #sockets = Sockets(app) 67 #sockets = Sockets(app)
55 nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal 68 nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal
56 -websocket_connections:Dict[int, weakref.WeakSet] = {} #sessionid:websocket_connections 69 +# WebSocket连接管理已迁移到统一架构
  70 +# websocket_connections和asr_connections现在通过迁移层管理
  71 +# 全局事件循环引用,用于跨线程异步调用
  72 +main_event_loop = None
57 opt = None 73 opt = None
58 model = None 74 model = None
59 avatar = None 75 avatar = None
  76 +# WebSocket迁移实例
  77 +websocket_migration = None
60 78
61 79
62 #####webrtc############################### 80 #####webrtc###############################
63 pcs = set() 81 pcs = set()
64 82
65 -# WebSocket消息推送函数  
66 -async def broadcast_message_to_session(sessionid: int, message_type: str, content: str, source: str = "数字人回复", model_info: str = None, request_source: str = "页面"):  
67 - """向指定会话的所有WebSocket连接推送消息"""  
68 - logger.info(f'[SessionID:{sessionid}] 开始推送消息: {message_type}, source: {source}, content: {content[:50]}...')  
69 - logger.info(f'[SessionID:{sessionid}] 当前websocket_connections keys: {list(websocket_connections.keys())}')  
70 -  
71 - if sessionid not in websocket_connections:  
72 - logger.warning(f'[SessionID:{sessionid}] 会话不存在于websocket_connections中')  
73 - return  
74 -  
75 - logger.info(f'[SessionID:{sessionid}] 找到会话,连接数量: {len(websocket_connections[sessionid])}')  
76 -  
77 - message = {  
78 - "type": "chat_message",  
79 - "data": {  
80 - "sessionid": sessionid,  
81 - "message_type": message_type,  
82 - "content": content,  
83 - "source": source,  
84 - "model_info": model_info,  
85 - "request_source": request_source,  
86 - "timestamp": time.time()  
87 - }  
88 - }  
89 -  
90 - # 获取该会话的所有WebSocket连接  
91 - connections = list(websocket_connections[sessionid])  
92 -  
93 - # 向所有连接发送消息  
94 - logger.info(f'[SessionID:{sessionid}] 准备向{len(connections)}个连接发送消息')  
95 - for i, ws in enumerate(connections):  
96 - try:  
97 - logger.info(f'[SessionID:{sessionid}] 检查连接{i+1}: closed={ws.closed}')  
98 - if not ws.closed:  
99 - logger.info(f'[SessionID:{sessionid}] 向连接{i+1}发送消息: {json.dumps(message)}')  
100 - await ws.send_str(json.dumps(message))  
101 - logger.info(f'[SessionID:{sessionid}] 连接{i+1}消息发送成功: {message_type} from {request_source}')  
102 - else:  
103 - logger.warning(f'[SessionID:{sessionid}] 连接{i+1}已关闭,跳过发送')  
104 - except Exception as e:  
105 - logger.error(f'[SessionID:{sessionid}] 连接{i+1}发送失败: {e}')  
106 -  
107 -# WebSocket处理器  
108 -async def websocket_handler(request):  
109 - """处理WebSocket连接"""  
110 - ws = web.WebSocketResponse()  
111 - await ws.prepare(request)  
112 -  
113 - sessionid = None  
114 - logger.info('New WebSocket connection established')  
115 -  
116 - try:  
117 - async for msg in ws:  
118 - if msg.type == WSMsgType.TEXT:  
119 - try:  
120 - data = json.loads(msg.data)  
121 -  
122 - if data.get('type') == 'login':  
123 - sessionid = data.get('sessionid', 0)  
124 - logger.info(f'[SessionID:{sessionid}] 收到登录请求,当前连接池: {list(websocket_connections.keys())}')  
125 -  
126 - # 初始化该会话的WebSocket连接集合  
127 - if sessionid not in websocket_connections:  
128 - websocket_connections[sessionid] = weakref.WeakSet()  
129 - logger.info(f'[SessionID:{sessionid}] 创建新的连接集合')  
130 -  
131 - # 添加当前连接到会话  
132 - websocket_connections[sessionid].add(ws)  
133 - logger.info(f'[SessionID:{sessionid}] 连接已添加,当前会话连接数: {len(websocket_connections[sessionid])}')  
134 -  
135 - logger.info(f'[SessionID:{sessionid}] WebSocket client logged in')  
136 -  
137 - # 发送登录确认  
138 - await ws.send_str(json.dumps({  
139 - "type": "login_success",  
140 - "sessionid": sessionid,  
141 - "message": "WebSocket连接成功"  
142 - }))  
143 -  
144 - elif data.get('type') == 'ping':  
145 - # 心跳检测  
146 - await ws.send_str(json.dumps({"type": "pong"}))  
147 -  
148 - except json.JSONDecodeError:  
149 - logger.error('Invalid JSON received from WebSocket')  
150 - except Exception as e:  
151 - logger.error(f'Error processing WebSocket message: {e}')  
152 -  
153 - elif msg.type == WSMsgType.ERROR:  
154 - logger.error(f'WebSocket error: {ws.exception()}')  
155 - break  
156 -  
157 - except Exception as e:  
158 - logger.error(f'WebSocket connection error: {e}')  
159 - finally:  
160 - if sessionid is not None:  
161 - logger.info(f'[SessionID:{sessionid}] WebSocket connection closed')  
162 - else:  
163 - logger.info('WebSocket connection closed') 83 +# WebSocket消息推送函数已迁移到统一架构
  84 +# 通过 core.app_websocket_migration 模块提供兼容性接口
164 85
165 - return ws 86 +# WebSocket处理器已迁移到统一架构
  87 +# 通过 core.app_websocket_migration 模块提供
166 88
167 def randN(N)->int: 89 def randN(N)->int:
168 '''生成长度为 N的随机数 ''' 90 '''生成长度为 N的随机数 '''
@@ -456,41 +378,187 @@ async def interrupt_talk(request): @@ -456,41 +378,187 @@ async def interrupt_talk(request):
456 ) 378 )
457 from pydub import AudioSegment 379 from pydub import AudioSegment
458 from io import BytesIO 380 from io import BytesIO
459 -async def humanaudio(request): 381 +
  382 +async def ensure_asr_connection(sessionid: int) -> bool:
  383 + """确保ASR连接可用"""
  384 + # 通过迁移实例获取ASR连接
  385 + migration = get_app_websocket_migration()
  386 + if sessionid not in migration.asr_connections:
  387 + return await create_asr_connection(sessionid)
  388 +
  389 + asr_client = migration.asr_connections[sessionid]
  390 +
  391 + # 检查连接状态
  392 + if not asr_client.is_connected():
  393 + logger.warning(f"[SessionID:{sessionid}] ASR连接已断开,尝试重连")
460 try: 394 try:
461 - params = await request.json()  
462 - sessionid = int(params.get('sessionid', 0))  
463 - fileobj = params.get('file_url') 395 + # 重新连接
  396 + success = await asyncio.get_event_loop().run_in_executor(
  397 + None, asr_client.connect
  398 + )
  399 + if success:
  400 + logger.info(f"[SessionID:{sessionid}] ASR重连成功")
  401 + return True
  402 + else:
  403 + logger.error(f"[SessionID:{sessionid}] ASR重连失败")
  404 + # 清理失效连接
  405 + migration = get_app_websocket_migration()
  406 + if sessionid in migration.asr_connections:
  407 + del migration.asr_connections[sessionid]
  408 + return False
  409 + except Exception as e:
  410 + logger.error(f"[SessionID:{sessionid}] ASR重连异常: {e}")
  411 + del asr_connections[sessionid]
  412 + return False
464 413
465 - # 获取音频文件数据  
466 - if isinstance(fileobj, str) and fileobj.startswith("http"):  
467 - async with aiohttp.ClientSession() as session:  
468 - async with session.get(fileobj) as response:  
469 - if response.status == 200:  
470 - filebytes = await response.read() 414 + return True
  415 +
  416 +async def create_asr_connection(sessionid: int) -> bool:
  417 + """创建新的ASR连接"""
  418 + try:
  419 + from funasr_asr_sync import FunASRSync
  420 + username = f'User_{sessionid}' # 修复大小写不一致:user_ -> User_
  421 + asr_client = FunASRSync(username)
  422 +
  423 + # 设置结果回调
  424 + def on_asr_result(result):
  425 + if isinstance(result, str):
  426 + result_data = {
  427 + 'text': result,
  428 + 'is_final': True,
  429 + 'confidence': 1.0
  430 + }
  431 + else:
  432 + result_data = result
  433 +
  434 + # 线程安全地调度异步任务
  435 + try:
  436 + # 优先使用全局事件循环引用
  437 + if main_event_loop is not None and not main_event_loop.is_closed():
  438 + # 使用全局事件循环进行跨线程调用
  439 + asyncio.run_coroutine_threadsafe(
  440 + # send_asr_result(sessionid, result_data), main_event_loop
  441 + send_normal_asr_result(sessionid, result_data), main_event_loop
  442 + )
  443 + logger.debug(f"[SessionID:{sessionid}] 使用全局事件循环发送ASR结果")
  444 + else:
  445 + # 降级处理:尝试获取当前线程的事件循环
  446 + try:
  447 + loop = asyncio.get_event_loop()
  448 + if loop.is_running():
  449 + loop.call_soon_threadsafe(
  450 + lambda: asyncio.create_task(send_normal_asr_result(sessionid, result_data))
  451 + )
  452 + else:
  453 + asyncio.create_task(send_normal_asr_result(sessionid, result_data))
  454 + except RuntimeError:
  455 + # 最终降级:仅记录日志
  456 + logger.info(f"[SessionID:{sessionid}] ASR识别结果: {result_data.get('text', 'N/A')}")
  457 + logger.warning(f"[SessionID:{sessionid}] 无法发送ASR结果到客户端,事件循环不可用")
  458 + except Exception as e:
  459 + logger.error(f"[SessionID:{sessionid}] ASR结果处理异常: {e}")
  460 + # 至少记录识别结果
  461 + logger.info(f"[SessionID:{sessionid}] ASR识别结果: {result_data.get('text', 'N/A')}")
  462 +
  463 + asr_client.set_result_callback(on_asr_result)
  464 +
  465 + # 异步连接
  466 + success = await asyncio.get_event_loop().run_in_executor(
  467 + None, asr_client.connect
  468 + )
  469 +
  470 + if success:
  471 + # 通过迁移实例存储ASR连接
  472 + migration = get_app_websocket_migration()
  473 + migration.asr_connections[sessionid] = asr_client
  474 + logger.info(f"[SessionID:{sessionid}] ASR连接创建成功")
  475 + return True
471 else: 476 else:
  477 + logger.error(f"[SessionID:{sessionid}] ASR连接创建失败")
  478 + return False
  479 +
  480 + except Exception as e:
  481 + logger.error(f"[SessionID:{sessionid}] 创建ASR连接异常: {e}")
  482 + return False
  483 +
  484 +async def humanaudio(request):
  485 + try:
  486 + # 检查请求内容类型,支持FormData和JSON两种格式
  487 + content_type = request.headers.get('content-type', '')
  488 +
  489 +
  490 + # 处理FormData格式(文件上传)
  491 + reader = await request.multipart()
  492 + sessionid = 0
  493 + fileobj = None
  494 + # 默认启用语音本地服务
  495 + asr_service = "funasr"
  496 +
  497 + # 读取FormData字段
  498 + async for field in reader:
  499 + if field.name == 'sessionid':
  500 + sessionid = int(await field.text())
  501 + logger.info(f'Parsed sessionid: {sessionid}')
  502 + elif field.name == 'audio':
  503 + fileobj = field
  504 + filename = field.filename
  505 + filebytes = await field.read()
  506 + # 输出文件大小信息
  507 + logger.info(f'Audio file content size: {len(filebytes)} bytes')
  508 + if not fileobj:
472 return web.Response( 509 return web.Response(
473 content_type="application/json", 510 content_type="application/json",
474 - text=json.dumps({"code": -1, "msg": "Error downloading file"}) 511 + text=json.dumps({"code": -1, "msg": "No audio file provided"})
475 ) 512 )
476 - # 根据 URL 后缀判断是否为 MP3 文件  
477 - is_mp3 = fileobj.lower().endswith('.mp3')  
478 - else:  
479 - filename = fileobj.filename  
480 - filebytes = fileobj.file.read()  
481 - is_mp3 = filename.lower().endswith('.mp3') 513 + elif field.name == 'asr_service':
  514 + asr_service = (await field.text()).strip().lower()
482 515
  516 + # 根据文件名判断是否为 MP3 文件
  517 + is_mp3 = filename.lower().endswith('.mp3') if filename else False
  518 +
  519 + # 处理MP3转WAV
483 if is_mp3: 520 if is_mp3:
484 - audio = AudioSegment.from_file(BytesIO(filebytes), format="mp3") 521 + try:
  522 + with BytesIO(filebytes) as audio_buffer:
  523 + audio = AudioSegment.from_file(audio_buffer, format="mp3")
485 out_io = BytesIO() 524 out_io = BytesIO()
486 audio.export(out_io, format="wav") 525 audio.export(out_io, format="wav")
487 filebytes = out_io.getvalue() 526 filebytes = out_io.getvalue()
  527 + except Exception as e:
  528 + logger.error(f"[SessionID:{sessionid}] 音频处理失败: {e}")
  529 + raise
  530 +
  531 + # 获取WebSocket迁移实例来访问连接信息
  532 + migration = get_app_websocket_migration()
  533 + active_sessions = migration.get_websocket_connections()
  534 + logger.info(f'[SessionID:{sessionid}] 收到登录请求,当前连接池: {list(active_sessions.keys())}')
  535 + # 验证sessionid是否存在
  536 + if sessionid not in nerfreals:
  537 + return web.Response(
  538 + content_type="application/json",
  539 + text=json.dumps({"code": -1, "msg": f"Session {sessionid} not found. Please establish WebRTC connection first."})
  540 + )
488 541
  542 + # 发送音频数据进行处理 数字人播报
489 nerfreals[sessionid].put_audio_file(filebytes) 543 nerfreals[sessionid].put_audio_file(filebytes)
490 544
  545 +
  546 + # ---------- ASR 分流 ----------
  547 + if asr_service == 'funasr':
  548 + await handle_funasr(sessionid, filebytes)
  549 + elif asr_service == 'doubao':
  550 + await handle_doubao(sessionid, filebytes)
  551 + else:
  552 + logger.warning(f'[SessionID:{sessionid}] 未指定或未知 asr_service,跳过 ASR')
  553 +
  554 +
  555 +
  556 + # 通过迁移实例检查ASR连接状态
  557 + migration = get_app_websocket_migration()
  558 + asr_enabled = sessionid in migration.asr_connections
491 return web.Response( 559 return web.Response(
492 content_type="application/json", 560 content_type="application/json",
493 - text=json.dumps({"code": 0, "msg": "ok"}) 561 + text=json.dumps({"code": 0, "msg": "ok", "asr_enabled": asr_enabled})
494 ) 562 )
495 563
496 except Exception as e: 564 except Exception as e:
@@ -500,6 +568,66 @@ async def humanaudio(request): @@ -500,6 +568,66 @@ async def humanaudio(request):
500 text=json.dumps( {"code": -1, "msg": str(e)}) 568 text=json.dumps( {"code": -1, "msg": str(e)})
501 ) 569 )
502 570
  571 +async def handle_funasr(sessionid: int, audio_bytes: bytes):
  572 + # ASR识别处理 - 使用新的连接管理机制
  573 + try:
  574 + # 确保ASR连接可用
  575 + asr_available = await ensure_asr_connection(sessionid)
  576 +
  577 + if asr_available:
  578 + # 发送音频数据到ASR服务进行识别
  579 + # 通过迁移实例获取ASR连接
  580 + migration = get_app_websocket_migration()
  581 + asr_client = migration.asr_connections[sessionid]
  582 + if hasattr(asr_client, 'send_audio_data'):
  583 + asr_client.send_audio_data(audio_bytes)
  584 + logger.info(f'[SessionID:{sessionid}] 音频数据已发送到ASR服务进行识别')
  585 + else:
  586 + logger.warning(f'[SessionID:{sessionid}] ASR客户端不支持send_audio_data方法')
  587 + else:
  588 + logger.warning(f'[SessionID:{sessionid}] ASR连接不可用,跳过语音识别')
  589 +
  590 + except Exception as asr_error:
  591 + logger.error(f'[SessionID:{sessionid}] ASR处理错误: {asr_error}')
  592 + # ASR错误不影响主要功能,继续返回成功
  593 +
  594 +# 导入 Doubao ASR 服务
  595 +from asr.doubao.service_factory import recognize_audio_data
  596 +import os
  597 +import json
  598 +
  599 +async def handle_doubao(sessionid: int, audio_bytes: bytes):
  600 + """云端 Doubao 调用"""
  601 + try:
  602 + logger.info(f"[SessionID:{sessionid}] 使用云端 Doubao 识别")
  603 +
  604 + # 读取豆包ASR配置文件
  605 + config_path = os.path.join(os.path.dirname(__file__), 'asr', 'doubao', 'config.json')
  606 + with open(config_path, 'r', encoding='utf-8') as f:
  607 + config = json.load(f)
  608 +
  609 + # 获取认证配置
  610 + auth_config = config.get('auth_config', {})
  611 + app_key = auth_config.get('app_key')
  612 + access_key = auth_config.get('access_key')
  613 +
  614 + if not app_key or not access_key:
  615 + raise ValueError("豆包ASR认证配置缺失:app_key 或 access_key 未配置")
  616 +
  617 + result = await recognize_audio_data(
  618 + audio_data=audio_bytes,
  619 + app_key=app_key,
  620 + access_key=access_key,
  621 + streaming=True,
  622 + result_callback=lambda res: logger.info(f"[SessionID:{sessionid}] Doubao 识别结果: {res}")
  623 + )
  624 + return result
  625 + except Exception as e:
  626 + logger.error(f"[SessionID:{sessionid}] Doubao 错误: {e}")
  627 + raise
  628 +
  629 +
  630 +
503 async def set_audiotype(request): 631 async def set_audiotype(request):
504 try: 632 try:
505 params = await request.json() 633 params = await request.json()
@@ -787,6 +915,10 @@ if __name__ == '__main__': @@ -787,6 +915,10 @@ if __name__ == '__main__':
787 rendthrd.start() 915 rendthrd.start()
788 916
789 ############################################################################# 917 #############################################################################
  918 + # ASR处理函数已迁移到统一架构
  919 + # 通过 core.app_websocket_migration 模块提供
  920 +
  921 + #############################################################################
790 appasync = web.Application() 922 appasync = web.Application()
791 appasync.on_shutdown.append(on_shutdown) 923 appasync.on_shutdown.append(on_shutdown)
792 appasync.router.add_post("/offer", offer) 924 appasync.router.add_post("/offer", offer)
@@ -796,9 +928,27 @@ if __name__ == '__main__': @@ -796,9 +928,27 @@ if __name__ == '__main__':
796 appasync.router.add_post("/record", record) 928 appasync.router.add_post("/record", record)
797 appasync.router.add_post("/interrupt_talk", interrupt_talk) 929 appasync.router.add_post("/interrupt_talk", interrupt_talk)
798 appasync.router.add_post("/is_speaking", is_speaking) 930 appasync.router.add_post("/is_speaking", is_speaking)
799 - appasync.router.add_get("/ws", websocket_handler) 931 +
  932 + # 初始化统一WebSocket管理架构
  933 + websocket_migration = get_app_websocket_migration()
  934 +
  935 + # 注册WebSocket接口 - 使用新的统一架构
  936 + setup_app_websocket_routes(appasync)
  937 +
  938 + # 异步初始化将在服务器启动时进行
  939 + async def init_websocket_migration():
  940 + await initialize_app_websocket_migration()
  941 + logger.info("WebSocket迁移架构初始化完成")
  942 +
  943 + # 添加启动时初始化
  944 + appasync.on_startup.append(lambda app: init_websocket_migration())
  945 +
800 appasync.router.add_static('/',path='web') 946 appasync.router.add_static('/',path='web')
801 947
  948 + # 服务端录音WebSocket接口已集成到统一架构中
  949 + # 通过 /ws 路由和消息类型区分访问:wsa_register_web, wsa_register_human 等
  950 + logger.info("主应用路由配置完成,WebSocket接口已统一到 /ws 路由")
  951 +
802 # Configure default CORS settings. 952 # Configure default CORS settings.
803 cors = aiohttp_cors.setup(appasync, defaults={ 953 cors = aiohttp_cors.setup(appasync, defaults={
804 "*": aiohttp_cors.ResourceOptions( 954 "*": aiohttp_cors.ResourceOptions(
@@ -819,8 +969,13 @@ if __name__ == '__main__': @@ -819,8 +969,13 @@ if __name__ == '__main__':
819 logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename) 969 logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename)
820 logger.info('如果使用webrtc,推荐访问webrtc集成前端: http://<serverip>:'+str(opt.listenport)+'/dashboard.html') 970 logger.info('如果使用webrtc,推荐访问webrtc集成前端: http://<serverip>:'+str(opt.listenport)+'/dashboard.html')
821 def run_server(runner): 971 def run_server(runner):
  972 + global main_event_loop
822 loop = asyncio.new_event_loop() 973 loop = asyncio.new_event_loop()
823 asyncio.set_event_loop(loop) 974 asyncio.set_event_loop(loop)
  975 + # 设置全局事件循环引用,用于跨线程异步调用
  976 + main_event_loop = loop
  977 + logger.info("全局事件循环引用已设置")
  978 +
824 loop.run_until_complete(runner.setup()) 979 loop.run_until_complete(runner.setup())
825 site = web.TCPSite(runner, '0.0.0.0', opt.listenport) 980 site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
826 loop.run_until_complete(site.start()) 981 loop.run_until_complete(site.start())
1 # -*- coding: utf-8 -*- 1 # -*- coding: utf-8 -*-
2 """ 2 """
3 -AIfeng/2025-01-27 3 +AIfeng/2025-06-30
4 配置管理工具模块 4 配置管理工具模块
5 统一管理项目配置参数 5 统一管理项目配置参数
6 """ 6 """
@@ -24,6 +24,9 @@ class ConfigManager: @@ -24,6 +24,9 @@ class ConfigManager:
24 'local_asr_ip': '127.0.0.1', 24 'local_asr_ip': '127.0.0.1',
25 'local_asr_port': 10197, 25 'local_asr_port': 10197,
26 26
  27 + # 音频设备配置
  28 + 'local_audio_ip': '127.0.0.1',
  29 +
27 # 阿里云NLS配置 30 # 阿里云NLS配置
28 'key_ali_nls_key_id': '', 31 'key_ali_nls_key_id': '',
29 'key_ali_nls_key_secret': '', 32 'key_ali_nls_key_secret': '',
@@ -82,6 +85,7 @@ _config_manager = ConfigManager() @@ -82,6 +85,7 @@ _config_manager = ConfigManager()
82 # 兼容原有的属性访问方式 85 # 兼容原有的属性访问方式
83 local_asr_ip = _config_manager.local_asr_ip 86 local_asr_ip = _config_manager.local_asr_ip
84 local_asr_port = _config_manager.local_asr_port 87 local_asr_port = _config_manager.local_asr_port
  88 +local_audio_ip = _config_manager.local_audio_ip
85 key_ali_nls_key_id = _config_manager.key_ali_nls_key_id 89 key_ali_nls_key_id = _config_manager.key_ali_nls_key_id
86 key_ali_nls_key_secret = _config_manager.key_ali_nls_key_secret 90 key_ali_nls_key_secret = _config_manager.key_ali_nls_key_secret
87 key_ali_nls_app_key = _config_manager.key_ali_nls_app_key 91 key_ali_nls_app_key = _config_manager.key_ali_nls_app_key
1 # -*- coding: utf-8 -*- 1 # -*- coding: utf-8 -*-
2 """ 2 """
3 -AIfeng/2025-01-27  
4 -Core模块初始化文件 3 +AIfeng/2025-07-17 14:15:27
  4 +Core模块初始化
  5 +已迁移到统一WebSocket架构,移除旧的wsa_server依赖
5 """ 6 """
6 7
7 -from .wsa_server import get_web_instance, get_instance 8 +# 统一WebSocket架构导入
  9 +from .wsa_websocket_service import get_web_instance, get_instance
8 10
9 __all__ = ['get_web_instance', 'get_instance'] 11 __all__ = ['get_web_instance', 'get_instance']
  1 +# -*- coding: utf-8 -*-
  2 +"""
  3 +AIfeng/2025-07-15 14:41:21
  4 +app.py WebSocket功能迁移脚本
  5 +将app.py中的WebSocket功能迁移到统一架构
  6 +"""
  7 +
  8 +import asyncio
  9 +import json
  10 +import weakref
  11 +from typing import Dict, Any, Optional
  12 +from aiohttp import web
  13 +from logger import logger
  14 +from .websocket_router import get_websocket_router, get_websocket_compatibility_api
  15 +from .asr_websocket_service import get_asr_service
  16 +from .digital_human_websocket_service import get_digital_human_service
  17 +
  18 +
  19 +class AppWebSocketMigration:
  20 + """app.py WebSocket功能迁移类"""
  21 +
  22 + def __init__(self):
  23 + self.router = get_websocket_router()
  24 + self.compatibility_api = get_websocket_compatibility_api()
  25 + self.asr_service = get_asr_service()
  26 + self.digital_human_service = get_digital_human_service()
  27 +
  28 + # 兼容性变量(保持与原app.py的接口一致)
  29 + self.websocket_connections = {}
  30 + self.asr_connections = {}
  31 +
  32 + async def initialize(self):
  33 + """初始化迁移组件"""
  34 + await self.router.initialize()
  35 + logger.info('WebSocket迁移组件初始化完成')
  36 +
  37 + async def shutdown(self):
  38 + """关闭迁移组件"""
  39 + await self.router.shutdown()
  40 + logger.info('WebSocket迁移组件已关闭')
  41 +
  42 + def setup_routes(self, app: web.Application):
  43 + """设置路由(替换原app.py中的WebSocket路由)"""
  44 + # 使用新的统一WebSocket处理器
  45 + self.router.setup_routes(app, '/ws')
  46 +
  47 + # 添加兼容性路由(如果需要)
  48 + app.router.add_get('/ws_legacy', self._legacy_websocket_handler)
  49 +
  50 + async def _legacy_websocket_handler(self, request: web.Request):
  51 + """兼容性WebSocket处理器(保持原有接口)"""
  52 + # 直接转发到新的统一处理器
  53 + return await self.router.websocket_handler(request)
  54 +
  55 + # 兼容性接口方法
  56 + async def broadcast_message_to_session(self, sessionid: int, message_type: str,
  57 + content: str, source: str = "数字人回复",
  58 + model_info: str = None, request_source: str = "页面"):
  59 + """兼容原app.py的消息推送接口"""
  60 + message_data = {
  61 + "sessionid": sessionid,
  62 + "message_type": message_type,
  63 + "content": content,
  64 + "source": source,
  65 + "model_info": model_info,
  66 + "request_source": request_source,
  67 + "timestamp": asyncio.get_event_loop().time()
  68 + }
  69 +
  70 + return await self.router.send_to_session(str(sessionid), 'chat_message', message_data)
  71 +
  72 + async def handle_asr_audio_data(self, data: Dict[str, Any], sessionid: int, ws):
  73 + """兼容原app.py的ASR音频数据处理"""
  74 + # 转换为新架构的消息格式
  75 + message_data = {
  76 + 'audio_data': data.get('audio_data'),
  77 + 'sessionid': sessionid
  78 + }
  79 +
  80 + # 通过新的ASR服务处理
  81 + session = self.router.manager.get_session(ws)
  82 + if session:
  83 + await self.asr_service._handle_asr_audio_data(ws, message_data)
  84 +
  85 + async def handle_start_asr_recognition(self, sessionid: int, ws):
  86 + """兼容原app.py的开始ASR识别"""
  87 + session = self.router.manager.get_session(ws)
  88 + if session:
  89 + await self.asr_service._handle_start_asr_recognition(ws, {'sessionid': sessionid})
  90 +
  91 + async def handle_stop_asr_recognition(self, sessionid: int, ws):
  92 + """兼容原app.py的停止ASR识别"""
  93 + session = self.router.manager.get_session(ws)
  94 + if session:
  95 + await self.asr_service._handle_stop_asr_recognition(ws, {'sessionid': sessionid})
  96 +
  97 + async def send_asr_result(self, sessionid: int, result: Dict[str, Any]):
  98 + """兼容原app.py的ASR结果发送"""
  99 + return await self.router.send_to_session(str(sessionid), 'asr_result', {
  100 + "text": result.get('text', ''),
  101 + "is_final": result.get('is_final', False),
  102 + "confidence": result.get('confidence', 0.0)
  103 + })
  104 +
  105 + async def send_normal_asr_result(self, sessionid: int, result: Dict[str, Any]):
  106 + """业务层决定传输内容以及结构"""
  107 + return await self.router.send_raw_to_session(str(sessionid), result)
  108 +
  109 +
  110 + def get_websocket_connections(self):
  111 + """获取WebSocket连接(兼容性接口)"""
  112 + # 返回兼容性字典格式,键为会话ID,值为WebSocket对象
  113 + sessions_dict = self.router.manager._sessions
  114 + result = {}
  115 + for session_id, session_set in sessions_dict.items():
  116 + # 取集合中的第一个WebSocket连接(通常每个session_id只有一个连接)
  117 + if session_set:
  118 + session = next(iter(session_set))
  119 + result[session_id] = session.websocket
  120 + return result
  121 +
  122 + def get_session_count(self):
  123 + """获取会话数量(兼容性接口)"""
  124 + return self.compatibility_api.get_session_count()
  125 +
  126 + async def cleanup_session(self, sessionid: int):
  127 + """清理会话(兼容性接口)"""
  128 + # 清理ASR连接
  129 + if sessionid in self.asr_connections:
  130 + del self.asr_connections[sessionid]
  131 +
  132 + # 通过新架构清理会话
  133 + sessions = self.router.manager._sessions
  134 + session_id_str = str(sessionid)
  135 +
  136 + for ws, session in list(sessions.items()):
  137 + if session.session_id == session_id_str:
  138 + await self.router.manager.remove_session(ws)
  139 + break
  140 +
  141 + def get_migration_stats(self) -> Dict[str, Any]:
  142 + """获取迁移统计信息"""
  143 + return {
  144 + "router_stats": self.router.get_router_stats(),
  145 + "asr_stats": self.asr_service.get_asr_stats(),
  146 + "digital_human_stats": self.digital_human_service.get_digital_human_stats(),
  147 + "compatibility_sessions": len(self.websocket_connections),
  148 + "compatibility_asr_connections": len(self.asr_connections)
  149 + }
  150 +
  151 +
  152 +# 全局迁移实例
  153 +_migration_instance = None
  154 +
  155 +
  156 +def get_app_websocket_migration() -> AppWebSocketMigration:
  157 + """获取app.py WebSocket迁移实例"""
  158 + global _migration_instance
  159 + if _migration_instance is None:
  160 + _migration_instance = AppWebSocketMigration()
  161 + return _migration_instance
  162 +
  163 +
  164 +async def initialize_app_websocket_migration():
  165 + """初始化app.py WebSocket迁移"""
  166 + migration = get_app_websocket_migration()
  167 + await migration.initialize()
  168 + return migration
  169 +
  170 +
  171 +async def shutdown_app_websocket_migration():
  172 + """关闭app.py WebSocket迁移"""
  173 + global _migration_instance
  174 + if _migration_instance:
  175 + await _migration_instance.shutdown()
  176 + _migration_instance = None
  177 +
  178 +
  179 +def setup_app_websocket_routes(app: web.Application):
  180 + """设置app.py WebSocket路由(便捷函数)"""
  181 + migration = get_app_websocket_migration()
  182 + migration.setup_routes(app)
  183 + return migration
  184 +
  185 +
  186 +# 兼容性函数(保持与原app.py的接口一致)
  187 +async def broadcast_message_to_session(sessionid: int, message_type: str, content: str,
  188 + source: str = "数字人回复", model_info: str = None,
  189 + request_source: str = "页面"):
  190 + """兼容原app.py的消息推送函数"""
  191 + migration = get_app_websocket_migration()
  192 + return await migration.broadcast_message_to_session(
  193 + sessionid, message_type, content, source, model_info, request_source
  194 + )
  195 +
  196 +
  197 +async def handle_asr_audio_data(data: Dict[str, Any], sessionid: int, ws):
  198 + """兼容原app.py的ASR音频数据处理函数"""
  199 + migration = get_app_websocket_migration()
  200 + return await migration.handle_asr_audio_data(data, sessionid, ws)
  201 +
  202 +
  203 +async def handle_start_asr_recognition(sessionid: int, ws):
  204 + """兼容原app.py的开始ASR识别函数"""
  205 + migration = get_app_websocket_migration()
  206 + return await migration.handle_start_asr_recognition(sessionid, ws)
  207 +
  208 +
  209 +async def handle_stop_asr_recognition(sessionid: int, ws):
  210 + """兼容原app.py的停止ASR识别函数"""
  211 + migration = get_app_websocket_migration()
  212 + return await migration.handle_stop_asr_recognition(sessionid, ws)
  213 +
  214 +
  215 +async def send_asr_result(sessionid: int, result: Dict[str, Any]):
  216 + """兼容原app.py的ASR结果发送函数"""
  217 + migration = get_app_websocket_migration()
  218 + return await migration.send_asr_result(sessionid, result)
  219 +
  220 +async def send_normal_asr_result(sessionid: int, result: Dict[str, Any]):
  221 + """兼容原app.py的ASR结果发送函数"""
  222 + migration = get_app_websocket_migration()
  223 + return await migration.send_normal_asr_result(sessionid, result)
  224 +
  225 +
  226 +# 全局变量兼容性接口
  227 +def get_websocket_connections():
  228 + """获取WebSocket连接字典(兼容性接口)"""
  229 + migration = get_app_websocket_migration()
  230 + return migration.websocket_connections
  231 +
  232 +
  233 +def get_asr_connections():
  234 + """获取ASR连接字典(兼容性接口)"""
  235 + migration = get_app_websocket_migration()
  236 + return migration.asr_connections
  1 +# -*- coding: utf-8 -*-
  2 +"""
  3 +AIfeng/2025-07-15 14:41:21
  4 +ASR WebSocket服务实现
  5 +从app.py中抽离的ASR相关WebSocket功能
  6 +"""
  7 +
  8 +import asyncio
  9 +import json
  10 +import weakref
  11 +from typing import Dict, Any, Optional
  12 +from aiohttp import web
  13 +from logger import logger
  14 +from .websocket_service_base import WebSocketServiceBase
  15 +from .unified_websocket_manager import WebSocketSession
  16 +
  17 +
  18 +class ASRWebSocketService(WebSocketServiceBase):
  19 + """ASR WebSocket服务"""
  20 +
  21 + def __init__(self):
  22 + super().__init__("asr_service")
  23 + # ASR连接管理
  24 + self.asr_connections: Dict[str, Any] = {} # sessionid -> asr_connection
  25 + self._heartbeat_task = None
  26 +
  27 + async def _register_message_handlers(self):
  28 + """注册ASR相关消息处理器"""
  29 + # 注册消息处理器
  30 + self.manager.register_message_handler('login', self._handle_login)
  31 + self.manager.register_message_handler('heartbeat', self._handle_heartbeat)
  32 + self.manager.register_message_handler('asr_audio_data', self._handle_asr_audio_data)
  33 + self.manager.register_message_handler('start_asr_recognition', self._handle_start_asr_recognition)
  34 + self.manager.register_message_handler('stop_asr_recognition', self._handle_stop_asr_recognition)
  35 +
  36 + async def _start_background_tasks(self):
  37 + """启动心跳检测任务"""
  38 + self._heartbeat_task = self.add_background_task(self._heartbeat_monitor())
  39 +
  40 + async def _cleanup(self):
  41 + """清理ASR连接"""
  42 + # 关闭所有ASR连接
  43 + for session_id, asr_conn in list(self.asr_connections.items()):
  44 + try:
  45 + if hasattr(asr_conn, 'close'):
  46 + await asr_conn.close()
  47 + except Exception as e:
  48 + logger.error(f'关闭ASR连接失败 {session_id}: {e}')
  49 +
  50 + self.asr_connections.clear()
  51 +
  52 + async def _on_session_disconnected(self, session: WebSocketSession):
  53 + """会话断开时清理ASR连接"""
  54 + await super()._on_session_disconnected(session)
  55 +
  56 + # 清理对应的ASR连接
  57 + if session.session_id in self.asr_connections:
  58 + asr_conn = self.asr_connections.pop(session.session_id)
  59 + try:
  60 + if hasattr(asr_conn, 'close'):
  61 + await asr_conn.close()
  62 + logger.info(f'已清理ASR连接: {session.session_id}')
  63 + except Exception as e:
  64 + logger.error(f'清理ASR连接失败 {session.session_id}: {e}')
  65 +
  66 + async def _handle_login(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  67 + """处理登录消息"""
  68 + session = self.manager.get_session(websocket)
  69 + if not session:
  70 + return
  71 +
  72 + session_id = data.get('sessionid')
  73 + if session_id:
  74 + # 更新会话ID
  75 + old_session_id = session.session_id
  76 + session.session_id = session_id
  77 +
  78 + # 更新管理器中的会话映射
  79 + self.manager._update_session_id(websocket, old_session_id, session_id)
  80 +
  81 + # 发送登录成功响应
  82 + await session.send_message({
  83 + "type": "login_response",
  84 + "data": {
  85 + "status": "success",
  86 + "sessionid": session_id,
  87 + "message": "登录成功"
  88 + }
  89 + })
  90 +
  91 + logger.info(f'用户登录成功: {session_id}')
  92 + else:
  93 + await session.send_message({
  94 + "type": "login_response",
  95 + "data": {
  96 + "status": "error",
  97 + "message": "缺少sessionid"
  98 + }
  99 + })
  100 +
  101 + async def _handle_heartbeat(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  102 + """处理心跳消息"""
  103 + session = self.manager.get_session(websocket)
  104 + if session:
  105 + session.update_last_heartbeat()
  106 + await session.send_message({
  107 + "type": "heartbeat_response",
  108 + "data": {"status": "ok"}
  109 + })
  110 +
  111 + async def _handle_asr_audio_data(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  112 + """处理ASR音频数据"""
  113 + session = self.manager.get_session(websocket)
  114 + if not session:
  115 + return
  116 +
  117 + session_id = session.session_id
  118 + audio_data = data.get('audio_data')
  119 +
  120 + if not audio_data:
  121 + await session.send_message({
  122 + "type": "error",
  123 + "data": {"message": "缺少音频数据"}
  124 + })
  125 + return
  126 +
  127 + # 获取或创建ASR连接
  128 + asr_conn = self.asr_connections.get(session_id)
  129 + if not asr_conn:
  130 + logger.warning(f'ASR连接不存在: {session_id}')
  131 + await session.send_message({
  132 + "type": "error",
  133 + "data": {"message": "ASR连接未建立"}
  134 + })
  135 + return
  136 +
  137 + try:
  138 + # 转发音频数据到ASR服务
  139 + await self._forward_audio_to_asr(asr_conn, audio_data)
  140 + except Exception as e:
  141 + logger.error(f'转发音频数据失败 {session_id}: {e}')
  142 + await session.send_message({
  143 + "type": "error",
  144 + "data": {"message": f"音频处理失败: {str(e)}"}
  145 + })
  146 +
  147 + async def _handle_start_asr_recognition(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  148 + """处理开始ASR识别"""
  149 + session = self.manager.get_session(websocket)
  150 + if not session:
  151 + return
  152 +
  153 + session_id = session.session_id
  154 +
  155 + try:
  156 + # 创建ASR连接
  157 + asr_conn = await self._create_asr_connection(session_id)
  158 + if asr_conn:
  159 + self.asr_connections[session_id] = asr_conn
  160 +
  161 + await session.send_message({
  162 + "type": "asr_recognition_started",
  163 + "data": {
  164 + "status": "success",
  165 + "message": "ASR识别已开始"
  166 + }
  167 + })
  168 +
  169 + logger.info(f'ASR识别已开始: {session_id}')
  170 + else:
  171 + await session.send_message({
  172 + "type": "error",
  173 + "data": {"message": "创建ASR连接失败"}
  174 + })
  175 + except Exception as e:
  176 + logger.error(f'开始ASR识别失败 {session_id}: {e}')
  177 + await session.send_message({
  178 + "type": "error",
  179 + "data": {"message": f"开始识别失败: {str(e)}"}
  180 + })
  181 +
  182 + async def _handle_stop_asr_recognition(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  183 + """处理停止ASR识别"""
  184 + session = self.manager.get_session(websocket)
  185 + if not session:
  186 + return
  187 +
  188 + session_id = session.session_id
  189 +
  190 + if session_id in self.asr_connections:
  191 + asr_conn = self.asr_connections.pop(session_id)
  192 + try:
  193 + if hasattr(asr_conn, 'close'):
  194 + await asr_conn.close()
  195 +
  196 + await session.send_message({
  197 + "type": "asr_recognition_stopped",
  198 + "data": {
  199 + "status": "success",
  200 + "message": "ASR识别已停止"
  201 + }
  202 + })
  203 +
  204 + logger.info(f'ASR识别已停止: {session_id}')
  205 + except Exception as e:
  206 + logger.error(f'停止ASR识别失败 {session_id}: {e}')
  207 + await session.send_message({
  208 + "type": "error",
  209 + "data": {"message": f"停止识别失败: {str(e)}"}
  210 + })
  211 + else:
  212 + await session.send_message({
  213 + "type": "asr_recognition_stopped",
  214 + "data": {
  215 + "status": "success",
  216 + "message": "ASR识别未在运行"
  217 + }
  218 + })
  219 +
  220 + async def _create_asr_connection(self, session_id: str):
  221 + """创建ASR连接(需要根据实际ASR服务实现)"""
  222 + # TODO: 这里需要根据实际的ASR服务(如FunASR)来实现连接逻辑
  223 + # 暂时返回一个模拟连接对象
  224 + logger.info(f'创建ASR连接: {session_id}')
  225 +
  226 + # 示例:创建到FunASR的WebSocket连接
  227 + try:
  228 + # 这里应该是实际的ASR连接逻辑
  229 + # 例如:asr_conn = await create_funasr_connection(session_id, self._on_asr_result)
  230 + asr_conn = MockASRConnection(session_id, self._on_asr_result)
  231 + return asr_conn
  232 + except Exception as e:
  233 + logger.error(f'创建ASR连接失败 {session_id}: {e}')
  234 + return None
  235 +
  236 + async def _forward_audio_to_asr(self, asr_conn, audio_data):
  237 + """转发音频数据到ASR服务"""
  238 + if hasattr(asr_conn, 'send_audio'):
  239 + await asr_conn.send_audio(audio_data)
  240 + else:
  241 + logger.warning('ASR连接不支持发送音频数据')
  242 +
  243 + async def _on_asr_result(self, session_id: str, result: Dict[str, Any]):
  244 + """ASR结果回调"""
  245 + try:
  246 + await self.broadcast_to_session(session_id, 'asr_result', result)
  247 + logger.debug(f'ASR结果已发送: {session_id}')
  248 + except Exception as e:
  249 + logger.error(f'发送ASR结果失败 {session_id}: {e}')
  250 +
  251 + async def _heartbeat_monitor(self):
  252 + """心跳监控任务"""
  253 + while True:
  254 + try:
  255 + await asyncio.sleep(40) # 每40秒检查一次
  256 +
  257 + # 检查会话心跳
  258 + expired_sessions = self.manager.get_expired_sessions(timeout=60)
  259 + for session in expired_sessions:
  260 + logger.info(f'会话心跳超时,断开连接: {session.session_id}')
  261 + await session.close()
  262 +
  263 + except asyncio.CancelledError:
  264 + break
  265 + except Exception as e:
  266 + logger.error(f'心跳监控异常: {e}')
  267 + await asyncio.sleep(5)
  268 +
  269 + def get_asr_stats(self) -> Dict[str, Any]:
  270 + """获取ASR统计信息"""
  271 + return {
  272 + "active_asr_connections": len(self.asr_connections),
  273 + "asr_sessions": list(self.asr_connections.keys())
  274 + }
  275 +
  276 +
  277 +class MockASRConnection:
  278 + """模拟ASR连接(用于测试)"""
  279 +
  280 + def __init__(self, session_id: str, result_callback):
  281 + self.session_id = session_id
  282 + self.result_callback = result_callback
  283 + self.is_closed = False
  284 +
  285 + async def send_audio(self, audio_data):
  286 + """发送音频数据"""
  287 + if self.is_closed:
  288 + return
  289 +
  290 + # 模拟ASR处理
  291 + await asyncio.sleep(0.1)
  292 +
  293 + # 模拟返回识别结果
  294 + result = {
  295 + "text": "模拟识别结果",
  296 + "confidence": 0.95,
  297 + "timestamp": asyncio.get_event_loop().time()
  298 + }
  299 +
  300 + if self.result_callback:
  301 + await self.result_callback(self.session_id, result)
  302 +
  303 + async def close(self):
  304 + """关闭连接"""
  305 + self.is_closed = True
  306 + logger.info(f'模拟ASR连接已关闭: {self.session_id}')
  307 +
  308 +
  309 +# 创建ASR服务实例
  310 +asr_service = ASRWebSocketService()
  311 +
  312 +
  313 +def get_asr_service() -> ASRWebSocketService:
  314 + """获取ASR服务实例"""
  315 + return asr_service
  1 +# -*- coding: utf-8 -*-
  2 +"""
  3 +AIfeng/2025-07-15 14:41:21
  4 +数字人WebSocket服务实现
  5 +处理数字人相关的WebSocket通信和状态管理
  6 +"""
  7 +
  8 +import asyncio
  9 +import json
  10 +from typing import Dict, Any, Optional, List
  11 +from aiohttp import web
  12 +from logger import logger
  13 +from .websocket_service_base import WebSocketServiceBase
  14 +from .unified_websocket_manager import WebSocketSession
  15 +
  16 +
  17 +class DigitalHumanWebSocketService(WebSocketServiceBase):
  18 + """数字人WebSocket服务"""
  19 +
  20 + def __init__(self):
  21 + super().__init__("digital_human_service")
  22 + # 数字人状态管理
  23 + self.digital_humans: Dict[str, Dict[str, Any]] = {} # human_id -> human_info
  24 + self.human_sessions: Dict[str, str] = {} # session_id -> human_id
  25 + self.session_humans: Dict[str, List[str]] = {} # session_id -> [human_ids]
  26 +
  27 + async def _register_message_handlers(self):
  28 + """注册数字人相关消息处理器"""
  29 + self.manager.register_message_handler('register_digital_human', self._handle_register_digital_human)
  30 + self.manager.register_message_handler('unregister_digital_human', self._handle_unregister_digital_human)
  31 + self.manager.register_message_handler('digital_human_status', self._handle_digital_human_status)
  32 + self.manager.register_message_handler('digital_human_action', self._handle_digital_human_action)
  33 + self.manager.register_message_handler('digital_human_speak', self._handle_digital_human_speak)
  34 + self.manager.register_message_handler('digital_human_emotion', self._handle_digital_human_emotion)
  35 + self.manager.register_message_handler('get_digital_humans', self._handle_get_digital_humans)
  36 +
  37 + async def _on_session_disconnected(self, session: WebSocketSession):
  38 + """会话断开时清理数字人注册"""
  39 + await super()._on_session_disconnected(session)
  40 +
  41 + session_id = session.session_id
  42 +
  43 + # 清理该会话注册的数字人
  44 + if session_id in self.session_humans:
  45 + human_ids = self.session_humans.pop(session_id)
  46 + for human_id in human_ids:
  47 + if human_id in self.digital_humans:
  48 + human_info = self.digital_humans.pop(human_id)
  49 + logger.info(f'数字人已注销: {human_id} (会话断开)')
  50 +
  51 + # 通知其他会话数字人已离线
  52 + await self.broadcast_to_all('digital_human_offline', {
  53 + 'human_id': human_id,
  54 + 'name': human_info.get('name', ''),
  55 + 'reason': 'session_disconnected'
  56 + })
  57 +
  58 + # 清理会话到数字人的映射
  59 + if session_id in self.human_sessions:
  60 + del self.human_sessions[session_id]
  61 +
  62 + async def _handle_register_digital_human(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  63 + """处理数字人注册"""
  64 + session = self.manager.get_session(websocket)
  65 + if not session:
  66 + return
  67 +
  68 + human_id = data.get('human_id')
  69 + human_name = data.get('name', '')
  70 + human_type = data.get('type', 'default')
  71 + capabilities = data.get('capabilities', [])
  72 +
  73 + if not human_id:
  74 + await session.send_message({
  75 + "type": "error",
  76 + "data": {"message": "缺少human_id"}
  77 + })
  78 + return
  79 +
  80 + # 检查数字人是否已存在
  81 + if human_id in self.digital_humans:
  82 + await session.send_message({
  83 + "type": "register_digital_human_response",
  84 + "data": {
  85 + "status": "error",
  86 + "message": f"数字人已存在: {human_id}"
  87 + }
  88 + })
  89 + return
  90 +
  91 + # 注册数字人
  92 + human_info = {
  93 + 'human_id': human_id,
  94 + 'name': human_name,
  95 + 'type': human_type,
  96 + 'capabilities': capabilities,
  97 + 'session_id': session.session_id,
  98 + 'status': 'online',
  99 + 'registered_at': asyncio.get_event_loop().time(),
  100 + 'last_activity': asyncio.get_event_loop().time()
  101 + }
  102 +
  103 + self.digital_humans[human_id] = human_info
  104 + self.human_sessions[session.session_id] = human_id
  105 +
  106 + # 记录会话注册的数字人
  107 + if session.session_id not in self.session_humans:
  108 + self.session_humans[session.session_id] = []
  109 + self.session_humans[session.session_id].append(human_id)
  110 +
  111 + # 发送注册成功响应
  112 + await session.send_message({
  113 + "type": "register_digital_human_response",
  114 + "data": {
  115 + "status": "success",
  116 + "human_id": human_id,
  117 + "message": "数字人注册成功"
  118 + }
  119 + })
  120 +
  121 + # 通知其他会话新数字人上线
  122 + await self.broadcast_to_all('digital_human_online', {
  123 + 'human_id': human_id,
  124 + 'name': human_name,
  125 + 'type': human_type,
  126 + 'capabilities': capabilities
  127 + }, metadata={'exclude_session': session.session_id})
  128 +
  129 + logger.info(f'数字人已注册: {human_id} ({human_name})')
  130 +
  131 + async def _handle_unregister_digital_human(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  132 + """处理数字人注销"""
  133 + session = self.manager.get_session(websocket)
  134 + if not session:
  135 + return
  136 +
  137 + human_id = data.get('human_id')
  138 +
  139 + if not human_id:
  140 + await session.send_message({
  141 + "type": "error",
  142 + "data": {"message": "缺少human_id"}
  143 + })
  144 + return
  145 +
  146 + # 检查数字人是否存在且属于当前会话
  147 + if human_id not in self.digital_humans:
  148 + await session.send_message({
  149 + "type": "unregister_digital_human_response",
  150 + "data": {
  151 + "status": "error",
  152 + "message": f"数字人不存在: {human_id}"
  153 + }
  154 + })
  155 + return
  156 +
  157 + human_info = self.digital_humans[human_id]
  158 + if human_info['session_id'] != session.session_id:
  159 + await session.send_message({
  160 + "type": "unregister_digital_human_response",
  161 + "data": {
  162 + "status": "error",
  163 + "message": "无权注销该数字人"
  164 + }
  165 + })
  166 + return
  167 +
  168 + # 注销数字人
  169 + del self.digital_humans[human_id]
  170 + if session.session_id in self.human_sessions:
  171 + del self.human_sessions[session.session_id]
  172 +
  173 + if session.session_id in self.session_humans:
  174 + self.session_humans[session.session_id].remove(human_id)
  175 + if not self.session_humans[session.session_id]:
  176 + del self.session_humans[session.session_id]
  177 +
  178 + # 发送注销成功响应
  179 + await session.send_message({
  180 + "type": "unregister_digital_human_response",
  181 + "data": {
  182 + "status": "success",
  183 + "human_id": human_id,
  184 + "message": "数字人注销成功"
  185 + }
  186 + })
  187 +
  188 + # 通知其他会话数字人已离线
  189 + await self.broadcast_to_all('digital_human_offline', {
  190 + 'human_id': human_id,
  191 + 'name': human_info.get('name', ''),
  192 + 'reason': 'manual_unregister'
  193 + }, metadata={'exclude_session': session.session_id})
  194 +
  195 + logger.info(f'数字人已注销: {human_id}')
  196 +
  197 + async def _handle_digital_human_status(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  198 + """处理数字人状态更新"""
  199 + session = self.manager.get_session(websocket)
  200 + if not session:
  201 + return
  202 +
  203 + human_id = data.get('human_id')
  204 + status = data.get('status')
  205 +
  206 + if not human_id or not status:
  207 + await session.send_message({
  208 + "type": "error",
  209 + "data": {"message": "缺少human_id或status"}
  210 + })
  211 + return
  212 +
  213 + if human_id not in self.digital_humans:
  214 + await session.send_message({
  215 + "type": "error",
  216 + "data": {"message": f"数字人不存在: {human_id}"}
  217 + })
  218 + return
  219 +
  220 + human_info = self.digital_humans[human_id]
  221 + if human_info['session_id'] != session.session_id:
  222 + await session.send_message({
  223 + "type": "error",
  224 + "data": {"message": "无权更新该数字人状态"}
  225 + })
  226 + return
  227 +
  228 + # 更新状态
  229 + old_status = human_info['status']
  230 + human_info['status'] = status
  231 + human_info['last_activity'] = asyncio.get_event_loop().time()
  232 +
  233 + # 广播状态变化
  234 + await self.broadcast_to_all('digital_human_status_changed', {
  235 + 'human_id': human_id,
  236 + 'old_status': old_status,
  237 + 'new_status': status,
  238 + 'name': human_info.get('name', '')
  239 + })
  240 +
  241 + logger.info(f'数字人状态更新: {human_id} {old_status} -> {status}')
  242 +
  243 + async def _handle_digital_human_action(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  244 + """处理数字人动作指令"""
  245 + session = self.manager.get_session(websocket)
  246 + if not session:
  247 + return
  248 +
  249 + human_id = data.get('human_id')
  250 + action = data.get('action')
  251 + params = data.get('params', {})
  252 +
  253 + if not human_id or not action:
  254 + await session.send_message({
  255 + "type": "error",
  256 + "data": {"message": "缺少human_id或action"}
  257 + })
  258 + return
  259 +
  260 + if human_id not in self.digital_humans:
  261 + await session.send_message({
  262 + "type": "error",
  263 + "data": {"message": f"数字人不存在: {human_id}"}
  264 + })
  265 + return
  266 +
  267 + human_info = self.digital_humans[human_id]
  268 +
  269 + # 更新活动时间
  270 + human_info['last_activity'] = asyncio.get_event_loop().time()
  271 +
  272 + # 转发动作指令到数字人会话
  273 + target_session_id = human_info['session_id']
  274 + await self.broadcast_to_session(target_session_id, 'digital_human_action_command', {
  275 + 'human_id': human_id,
  276 + 'action': action,
  277 + 'params': params,
  278 + 'from_session': session.session_id
  279 + })
  280 +
  281 + logger.info(f'数字人动作指令: {human_id} -> {action}')
  282 +
  283 + async def _handle_digital_human_speak(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  284 + """处理数字人说话指令"""
  285 + session = self.manager.get_session(websocket)
  286 + if not session:
  287 + return
  288 +
  289 + human_id = data.get('human_id')
  290 + text = data.get('text')
  291 + voice_config = data.get('voice_config', {})
  292 +
  293 + if not human_id or not text:
  294 + await session.send_message({
  295 + "type": "error",
  296 + "data": {"message": "缺少human_id或text"}
  297 + })
  298 + return
  299 +
  300 + if human_id not in self.digital_humans:
  301 + await session.send_message({
  302 + "type": "error",
  303 + "data": {"message": f"数字人不存在: {human_id}"}
  304 + })
  305 + return
  306 +
  307 + human_info = self.digital_humans[human_id]
  308 +
  309 + # 更新活动时间
  310 + human_info['last_activity'] = asyncio.get_event_loop().time()
  311 +
  312 + # 转发说话指令到数字人会话
  313 + target_session_id = human_info['session_id']
  314 + await self.broadcast_to_session(target_session_id, 'digital_human_speak_command', {
  315 + 'human_id': human_id,
  316 + 'text': text,
  317 + 'voice_config': voice_config,
  318 + 'from_session': session.session_id
  319 + })
  320 +
  321 + # 广播数字人说话事件
  322 + await self.broadcast_to_all('digital_human_speaking', {
  323 + 'human_id': human_id,
  324 + 'name': human_info.get('name', ''),
  325 + 'text': text
  326 + }, metadata={'exclude_session': target_session_id})
  327 +
  328 + logger.info(f'数字人说话指令: {human_id} -> "{text}"')
  329 +
  330 + async def _handle_digital_human_emotion(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  331 + """处理数字人情感状态"""
  332 + session = self.manager.get_session(websocket)
  333 + if not session:
  334 + return
  335 +
  336 + human_id = data.get('human_id')
  337 + emotion = data.get('emotion')
  338 + intensity = data.get('intensity', 1.0)
  339 +
  340 + if not human_id or not emotion:
  341 + await session.send_message({
  342 + "type": "error",
  343 + "data": {"message": "缺少human_id或emotion"}
  344 + })
  345 + return
  346 +
  347 + if human_id not in self.digital_humans:
  348 + await session.send_message({
  349 + "type": "error",
  350 + "data": {"message": f"数字人不存在: {human_id}"}
  351 + })
  352 + return
  353 +
  354 + human_info = self.digital_humans[human_id]
  355 +
  356 + # 更新情感状态
  357 + human_info['emotion'] = emotion
  358 + human_info['emotion_intensity'] = intensity
  359 + human_info['last_activity'] = asyncio.get_event_loop().time()
  360 +
  361 + # 广播情感变化
  362 + await self.broadcast_to_all('digital_human_emotion_changed', {
  363 + 'human_id': human_id,
  364 + 'name': human_info.get('name', ''),
  365 + 'emotion': emotion,
  366 + 'intensity': intensity
  367 + })
  368 +
  369 + logger.info(f'数字人情感更新: {human_id} -> {emotion} ({intensity})')
  370 +
  371 + async def _handle_get_digital_humans(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  372 + """处理获取数字人列表请求"""
  373 + session = self.manager.get_session(websocket)
  374 + if not session:
  375 + return
  376 +
  377 + # 返回所有在线数字人信息
  378 + humans_list = []
  379 + for human_id, human_info in self.digital_humans.items():
  380 + humans_list.append({
  381 + 'human_id': human_id,
  382 + 'name': human_info.get('name', ''),
  383 + 'type': human_info.get('type', 'default'),
  384 + 'status': human_info.get('status', 'unknown'),
  385 + 'capabilities': human_info.get('capabilities', []),
  386 + 'emotion': human_info.get('emotion', 'neutral'),
  387 + 'emotion_intensity': human_info.get('emotion_intensity', 1.0)
  388 + })
  389 +
  390 + await session.send_message({
  391 + "type": "digital_humans_list",
  392 + "data": {
  393 + "humans": humans_list,
  394 + "total": len(humans_list)
  395 + }
  396 + })
  397 +
  398 + def get_digital_human_stats(self) -> Dict[str, Any]:
  399 + """获取数字人统计信息"""
  400 + online_count = len([h for h in self.digital_humans.values() if h.get('status') == 'online'])
  401 +
  402 + return {
  403 + "total_digital_humans": len(self.digital_humans),
  404 + "online_digital_humans": online_count,
  405 + "active_sessions": len(self.session_humans),
  406 + "human_types": list(set(h.get('type', 'default') for h in self.digital_humans.values()))
  407 + }
  408 +
  409 + async def send_to_digital_human(self, human_id: str, message_type: str, content: Any):
  410 + """向指定数字人发送消息"""
  411 + if human_id not in self.digital_humans:
  412 + logger.warning(f'数字人不存在: {human_id}')
  413 + return False
  414 +
  415 + human_info = self.digital_humans[human_id]
  416 + target_session_id = human_info['session_id']
  417 +
  418 + return await self.broadcast_to_session(target_session_id, message_type, content)
  419 +
  420 + async def broadcast_to_digital_humans(self, message_type: str, content: Any,
  421 + human_filter: Optional[callable] = None):
  422 + """向数字人广播消息"""
  423 + sent_count = 0
  424 +
  425 + for human_id, human_info in self.digital_humans.items():
  426 + if human_filter and not human_filter(human_info):
  427 + continue
  428 +
  429 + target_session_id = human_info['session_id']
  430 + success = await self.broadcast_to_session(target_session_id, message_type, content)
  431 + if success:
  432 + sent_count += 1
  433 +
  434 + return sent_count
  435 +
  436 +
  437 +# 创建数字人服务实例
  438 +digital_human_service = DigitalHumanWebSocketService()
  439 +
  440 +
  441 +def get_digital_human_service() -> DigitalHumanWebSocketService:
  442 + """获取数字人服务实例"""
  443 + return digital_human_service
  1 +# -*- coding: utf-8 -*-
  2 +"""
  3 +AIfeng/2025-07-15 14:41:21
  4 +统一WebSocket管理模块
  5 +提供统一的WebSocket连接管理、消息推送和事件处理功能
  6 +"""
  7 +
  8 +import json
  9 +import time
  10 +import asyncio
  11 +import weakref
  12 +from typing import Dict, Any, Optional, Callable, List, Set
  13 +from threading import Lock
  14 +from aiohttp import web, WSMsgType
  15 +from logger import logger
  16 +
  17 +
  18 +class WebSocketSession:
  19 + """WebSocket会话管理类"""
  20 +
  21 + def __init__(self, session_id: str, websocket: web.WebSocketResponse):
  22 + self.session_id = session_id
  23 + self.websocket = websocket
  24 + self.created_at = time.time()
  25 + self.last_ping = time.time()
  26 + self.metadata = {}
  27 +
  28 + def __eq__(self, other):
  29 + """基于websocket对象判断会话是否相等"""
  30 + if not isinstance(other, WebSocketSession):
  31 + return False
  32 + return self.websocket is other.websocket
  33 +
  34 + def __hash__(self):
  35 + """基于websocket对象的id生成哈希值"""
  36 + return hash(id(self.websocket))
  37 +
  38 + def is_alive(self) -> bool:
  39 + """检查连接是否存活"""
  40 + return not self.websocket.closed
  41 +
  42 + async def send_message(self, message: Dict[str, Any]) -> bool:
  43 + """发送消息到WebSocket客户端"""
  44 + try:
  45 + if not self.is_alive():
  46 + return False
  47 + await self.websocket.send_str(json.dumps(message))
  48 + return True
  49 + except ConnectionResetError:
  50 + logger.warning(f'[Session:{self.session_id}] 客户端连接已重置')
  51 + return False
  52 + except ConnectionAbortedError:
  53 + logger.warning(f'[Session:{self.session_id}] 客户端连接已中止')
  54 + return False
  55 + except Exception as e:
  56 + logger.error(f'[Session:{self.session_id}] 发送消息失败: {e}')
  57 + return False
  58 +
  59 + def update_ping(self):
  60 + """更新心跳时间"""
  61 + self.last_ping = time.time()
  62 +
  63 + def set_metadata(self, key: str, value: Any):
  64 + """设置会话元数据"""
  65 + self.metadata[key] = value
  66 +
  67 + def get_metadata(self, key: str, default=None):
  68 + """获取会话元数据"""
  69 + return self.metadata.get(key, default)
  70 +
  71 + async def close(self):
  72 + """关闭WebSocket连接"""
  73 + try:
  74 + if not self.websocket.closed:
  75 + await self.websocket.close()
  76 + logger.info(f'[Session:{self.session_id}] WebSocket连接已关闭')
  77 + except ConnectionResetError:
  78 + logger.warning(f'[Session:{self.session_id}] 连接已被远程主机重置,无需关闭')
  79 + except ConnectionAbortedError:
  80 + logger.warning(f'[Session:{self.session_id}] 连接已被中止,无需关闭')
  81 + except Exception as e:
  82 + logger.error(f'[Session:{self.session_id}] 关闭WebSocket连接失败: {e}')
  83 +
  84 +
  85 +class UnifiedWebSocketManager:
  86 + """统一WebSocket管理器"""
  87 +
  88 + def __init__(self):
  89 + self._sessions: Dict[str, Set[WebSocketSession]] = {} # session_id -> WebSocketSession集合
  90 + self._websockets: Dict[web.WebSocketResponse, WebSocketSession] = {} # websocket -> session映射
  91 + self._message_handlers: Dict[str, Callable] = {} # 消息类型处理器
  92 + self._event_handlers: Dict[str, List[Callable]] = {} # 事件处理器
  93 + self._lock = Lock()
  94 +
  95 + # 注册默认消息处理器
  96 + self.register_message_handler('ping', self._handle_ping)
  97 + self.register_message_handler('login', self._handle_login)
  98 +
  99 + def register_message_handler(self, message_type: str, handler: Callable):
  100 + """注册消息处理器"""
  101 + self._message_handlers[message_type] = handler
  102 + logger.info(f'注册消息处理器: {message_type}')
  103 +
  104 + def register_event_handler(self, event_type: str, handler: Callable):
  105 + """注册事件处理器"""
  106 + if event_type not in self._event_handlers:
  107 + self._event_handlers[event_type] = []
  108 + self._event_handlers[event_type].append(handler)
  109 + logger.info(f'注册事件处理器: {event_type}')
  110 +
  111 + async def _emit_event(self, event_type: str, **kwargs):
  112 + """触发事件"""
  113 + if event_type in self._event_handlers:
  114 + for handler in self._event_handlers[event_type]:
  115 + try:
  116 + if asyncio.iscoroutinefunction(handler):
  117 + await handler(**kwargs)
  118 + else:
  119 + handler(**kwargs)
  120 + except Exception as e:
  121 + logger.error(f'事件处理器执行失败 {event_type}: {e}')
  122 +
  123 + def add_session(self, session_id: str, websocket: web.WebSocketResponse) -> WebSocketSession:
  124 + """添加WebSocket会话"""
  125 + with self._lock:
  126 + # 检查是否已存在相同的websocket连接
  127 + if websocket in self._websockets:
  128 + existing_session = self._websockets[websocket]
  129 + logger.warning(f'[Session:{session_id}] WebSocket连接已存在 (WebSocket={id(websocket)}, 原Session={existing_session.session_id})')
  130 + return existing_session
  131 +
  132 + session = WebSocketSession(session_id, websocket)
  133 +
  134 + # 初始化会话集合
  135 + if session_id not in self._sessions:
  136 + self._sessions[session_id] = set()
  137 +
  138 + # 检查Set添加前后的大小变化
  139 + before_count = len(self._sessions[session_id])
  140 + self._sessions[session_id].add(session)
  141 + after_count = len(self._sessions[session_id])
  142 +
  143 + self._websockets[websocket] = session
  144 +
  145 + logger.info(f'[Session:{session_id}] 添加WebSocket会话 (WebSocket={id(websocket)}), 连接数变化: {before_count} -> {after_count}')
  146 +
  147 + # 如果Set大小没有变化,说明可能存在重复
  148 + if before_count == after_count:
  149 + logger.warning(f'[Session:{session_id}] 检测到可能的重复会话添加!Set大小未变化')
  150 +
  151 + return session
  152 +
  153 + def remove_session(self, websocket: web.WebSocketResponse):
  154 + """移除WebSocket会话"""
  155 + with self._lock:
  156 + if websocket in self._websockets:
  157 + session = self._websockets[websocket]
  158 + session_id = session.session_id
  159 +
  160 + # 从会话集合中移除
  161 + if session_id in self._sessions:
  162 + self._sessions[session_id].discard(session)
  163 + if not self._sessions[session_id]: # 如果集合为空,删除键
  164 + del self._sessions[session_id]
  165 +
  166 + # 从websocket映射中移除
  167 + del self._websockets[websocket]
  168 +
  169 + logger.info(f'[Session:{session_id}] 移除WebSocket会话')
  170 + return session
  171 + return None
  172 +
  173 + def get_session(self, websocket: web.WebSocketResponse) -> Optional[WebSocketSession]:
  174 + """获取WebSocket会话"""
  175 + return self._websockets.get(websocket)
  176 +
  177 + def get_sessions_by_id(self, session_id: str) -> Set[WebSocketSession]:
  178 + """根据会话ID获取所有WebSocket会话"""
  179 + with self._lock:
  180 + # 尝试使用原始session_id查找
  181 + sessions = self._sessions.get(session_id, set())
  182 + if sessions:
  183 + return sessions.copy()
  184 +
  185 + # 如果是字符串类型但存储的是整数类型,尝试转换
  186 + if isinstance(session_id, str) and session_id.isdigit():
  187 + int_session_id = int(session_id)
  188 + sessions = self._sessions.get(int_session_id, set())
  189 + if sessions:
  190 + return sessions.copy()
  191 +
  192 + # 如果是整数类型但存储的是字符串类型,尝试转换
  193 + elif isinstance(session_id, int):
  194 + str_session_id = str(session_id)
  195 + sessions = self._sessions.get(str_session_id, set())
  196 + if sessions:
  197 + return sessions.copy()
  198 +
  199 + return set()
  200 +
  201 + def _update_session_id(self, websocket: web.WebSocketResponse, old_session_id: str, new_session_id: str):
  202 + """更新WebSocket会话的session_id"""
  203 + with self._lock:
  204 + if websocket in self._websockets:
  205 + session = self._websockets[websocket]
  206 +
  207 + # 从旧的session_id集合中移除
  208 + if old_session_id in self._sessions:
  209 + self._sessions[old_session_id].discard(session)
  210 + if not self._sessions[old_session_id]: # 如果集合为空,删除键
  211 + del self._sessions[old_session_id]
  212 +
  213 + # 更新session的session_id
  214 + session.session_id = new_session_id
  215 +
  216 + # 添加到新的session_id集合
  217 + if new_session_id not in self._sessions:
  218 + self._sessions[new_session_id] = set()
  219 + self._sessions[new_session_id].add(session)
  220 +
  221 + logger.info(f'[Session] 更新会话ID: {old_session_id} -> {new_session_id}')
  222 + return True
  223 + return False
  224 +
  225 + async def broadcast_raw_message_to_session(self, session_id: str, message: Dict,source: str = "原数据") -> int:
  226 + """直接广播原始消息到指定会话的所有WebSocket连接"""
  227 + # 确保session_id为字符串类型,保持一致性
  228 + # 确保session_id为字符串类型,保持一致性
  229 + if isinstance(session_id, int):
  230 + session_id = str(session_id)
  231 + elif not isinstance(session_id, str):
  232 + session_id = str(session_id)
  233 +
  234 + sessions = self.get_sessions_by_id(session_id)
  235 + if not sessions:
  236 + logger.warning(f'[Session:{session_id}] 没有找到WebSocket连接')
  237 + return 0
  238 +
  239 + # 详细调试日志:显示会话详情
  240 + logger.info(f'[Session:{session_id}] 开始广播消息,找到 {len(sessions)} 个连接')
  241 + for i, session in enumerate(sessions):
  242 + logger.info(f'[Session:{session_id}] 连接{i+1}: WebSocket={id(session.websocket)}, 创建时间={session.created_at}, 存活状态={session.is_alive()}')
  243 +
  244 + success_count = 0
  245 + failed_sessions = []
  246 +
  247 + for i, session in enumerate(sessions):
  248 + logger.info(f'[Session:{session_id}] 正在向连接{i+1}发送消息 (WebSocket={id(session.websocket)})')
  249 + if await session.send_message(message):
  250 + success_count += 1
  251 + logger.info(f'[Session:{session_id}] 连接{i+1}发送成功')
  252 + else:
  253 + failed_sessions.append(session)
  254 + logger.warning(f'[Session:{session_id}] 连接{i+1}发送失败')
  255 +
  256 + # 清理失败的连接
  257 + for session in failed_sessions:
  258 + self.remove_session(session.websocket)
  259 +
  260 + logger.info(f'[Session:{session_id}] 广播原始消息完成: 成功{success_count}/总计{len(sessions)}, 失败{len(failed_sessions)}')
  261 + return success_count
  262 +
  263 + async def broadcast_to_session(self, session_id: str, message_type: str, content: Any,
  264 + source: str = "系统", metadata: Dict = None) -> int:
  265 + """向指定会话的所有WebSocket连接广播消息"""
  266 + # 确保session_id为字符串类型,保持一致性
  267 + if isinstance(session_id, int):
  268 + session_id = str(session_id)
  269 + elif not isinstance(session_id, str):
  270 + session_id = str(session_id)
  271 +
  272 + sessions = self.get_sessions_by_id(session_id)
  273 + if not sessions:
  274 + logger.warning(f'[Session:{session_id}] 没有找到WebSocket连接')
  275 + return 0
  276 +
  277 + message = {
  278 + "type": message_type,
  279 + "session_id": session_id,
  280 + "content": content,
  281 + "source": source,
  282 + "timestamp": time.time(),
  283 + **(metadata or {})
  284 + }
  285 +
  286 + success_count = 0
  287 + failed_sessions = []
  288 +
  289 + for session in sessions:
  290 + if await session.send_message(message):
  291 + success_count += 1
  292 + else:
  293 + failed_sessions.append(session)
  294 +
  295 + # 清理失败的连接
  296 + for session in failed_sessions:
  297 + self.remove_session(session.websocket)
  298 +
  299 + logger.info(f'[Session:{session_id}] 广播消息成功: {success_count}/{len(sessions)}')
  300 + return success_count
  301 +
  302 + async def broadcast_to_all(self, message_type: str, content: Any,
  303 + source: str = "系统", metadata: Dict = None) -> int:
  304 + """向所有WebSocket连接广播消息"""
  305 + total_sent = 0
  306 + with self._lock:
  307 + session_ids = list(self._sessions.keys())
  308 +
  309 + for session_id in session_ids:
  310 + sent = await self.broadcast_to_session(session_id, message_type, content, source, metadata)
  311 + total_sent += sent
  312 +
  313 + logger.info(f'全局广播消息完成,总发送数: {total_sent}')
  314 + return total_sent
  315 +
  316 + def get_session_count(self) -> int:
  317 + """获取会话总数"""
  318 + with self._lock:
  319 + return len(self._sessions)
  320 +
  321 + def get_connection_count(self) -> int:
  322 + """获取连接总数"""
  323 + with self._lock:
  324 + return len(self._websockets)
  325 +
  326 + def get_session_stats(self) -> Dict[str, Any]:
  327 + """获取会话统计信息"""
  328 + with self._lock:
  329 + stats = {
  330 + "total_sessions": len(self._sessions),
  331 + "total_connections": len(self._websockets),
  332 + "session_details": {}
  333 + }
  334 +
  335 + for session_id, sessions in self._sessions.items():
  336 + stats["session_details"][session_id] = {
  337 + "connection_count": len(sessions),
  338 + "connections": [
  339 + {
  340 + "created_at": session.created_at,
  341 + "last_ping": session.last_ping,
  342 + "is_alive": session.is_alive(),
  343 + "metadata": session.metadata
  344 + } for session in sessions
  345 + ]
  346 + }
  347 +
  348 + return stats
  349 +
  350 + async def _handle_ping(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  351 + """处理心跳消息"""
  352 + session = self.get_session(websocket)
  353 + if session:
  354 + session.update_ping()
  355 + await session.send_message({"type": "pong", "timestamp": time.time()})
  356 +
  357 + async def _handle_login(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  358 + """处理登录消息"""
  359 + session_id = data.get('session_id', data.get('sessionid', str(int(time.time()))))
  360 +
  361 + # 确保session_id为字符串类型,避免类型不一致问题
  362 + if isinstance(session_id, int):
  363 + session_id = str(session_id)
  364 + elif not isinstance(session_id, str):
  365 + session_id = str(session_id)
  366 +
  367 + # 添加会话
  368 + session = self.add_session(session_id, websocket)
  369 +
  370 + # 触发连接事件
  371 + await self._emit_event('session_connected', session=session, data=data)
  372 +
  373 + # 发送登录确认
  374 + await session.send_message({
  375 + "type": "login_success",
  376 + "data": {
  377 + "session_id": session_id,
  378 + "message": "WebSocket连接成功",
  379 + "timestamp": time.time()
  380 + }
  381 + })
  382 +
  383 + async def handle_websocket_message(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  384 + """处理WebSocket消息"""
  385 + message_type = data.get('type')
  386 +
  387 + if message_type in self._message_handlers:
  388 + try:
  389 + await self._message_handlers[message_type](websocket, data)
  390 + except Exception as e:
  391 + logger.error(f'消息处理器执行失败 {message_type}: {e}')
  392 + session = self.get_session(websocket)
  393 + if session:
  394 + await session.send_message({
  395 + "type": "error",
  396 + "data": {
  397 + "message": f"消息处理失败: {str(e)}",
  398 + "original_type": message_type
  399 + }
  400 + })
  401 + else:
  402 + logger.warning(f'未知消息类型: {message_type}')
  403 +
  404 + async def websocket_handler(self, request) -> web.WebSocketResponse:
  405 + """WebSocket连接处理器"""
  406 + ws = web.WebSocketResponse()
  407 + await ws.prepare(request)
  408 +
  409 + logger.info('新的WebSocket连接建立')
  410 +
  411 + try:
  412 + async for msg in ws:
  413 + if msg.type == WSMsgType.TEXT:
  414 + try:
  415 + data = json.loads(msg.data)
  416 + await self.handle_websocket_message(ws, data)
  417 + except json.JSONDecodeError:
  418 + logger.error('收到无效的JSON数据')
  419 + except Exception as e:
  420 + logger.error(f'处理WebSocket消息时出错: {e}')
  421 +
  422 + elif msg.type == WSMsgType.ERROR:
  423 + logger.error(f'WebSocket错误: {ws.exception()}')
  424 + break
  425 + elif msg.type == WSMsgType.CLOSE:
  426 + logger.info('WebSocket连接正常关闭')
  427 + break
  428 +
  429 + except ConnectionResetError:
  430 + logger.warning('WebSocket连接被远程主机重置')
  431 + except ConnectionAbortedError:
  432 + logger.warning('WebSocket连接被中止')
  433 + except Exception as e:
  434 + logger.error(f'WebSocket连接错误: {e}')
  435 + finally:
  436 + # 清理会话
  437 + session = self.remove_session(ws)
  438 + if session:
  439 + await self._emit_event('session_disconnected', session=session)
  440 + logger.info('WebSocket连接已关闭')
  441 +
  442 + return ws
  443 +
  444 + def get_expired_sessions(self, timeout: int = 60) -> List[WebSocketSession]:
  445 + """获取过期的会话列表"""
  446 + current_time = time.time()
  447 + expired_sessions = []
  448 +
  449 + with self._lock:
  450 + for session in self._websockets.values():
  451 + if current_time - session.last_ping > timeout:
  452 + expired_sessions.append(session)
  453 +
  454 + return expired_sessions
  455 +
  456 + async def cleanup_dead_connections(self):
  457 + """清理死连接"""
  458 + dead_websockets = []
  459 +
  460 + with self._lock:
  461 + for websocket, session in self._websockets.items():
  462 + if not session.is_alive():
  463 + dead_websockets.append(websocket)
  464 +
  465 + for websocket in dead_websockets:
  466 + self.remove_session(websocket)
  467 +
  468 + if dead_websockets:
  469 + logger.info(f'清理了 {len(dead_websockets)} 个死连接')
  470 +
  471 + return len(dead_websockets)
  472 +
  473 +
  474 +# 全局统一WebSocket管理器实例
  475 +_unified_manager = UnifiedWebSocketManager()
  476 +
  477 +
  478 +def get_unified_manager() -> UnifiedWebSocketManager:
  479 + """获取统一WebSocket管理器实例"""
  480 + return _unified_manager
  481 +
  482 +
  483 +# 兼容性接口,保持与原有代码的兼容
  484 +async def broadcast_message_to_session(session_id: str, message_type: str, content: str,
  485 + source: str = "数字人回复", model_info: str = None,
  486 + request_source: str = "页面"):
  487 + """兼容性接口:向指定会话广播消息"""
  488 + metadata = {}
  489 + if model_info:
  490 + metadata['model_info'] = model_info
  491 + if request_source:
  492 + metadata['request_source'] = request_source
  493 +
  494 + return await _unified_manager.broadcast_to_session(
  495 + session_id, message_type, content, source, metadata
  496 + )
  1 +# -*- coding: utf-8 -*-
  2 +"""
  3 +AIfeng/2025-07-15 14:41:21
  4 +WebSocket路由管理器
  5 +统一管理所有WebSocket服务的路由和初始化
  6 +"""
  7 +
  8 +import asyncio
  9 +import time
  10 +from typing import Dict, Any, Optional
  11 +from aiohttp import web, WSMsgType
  12 +import json
  13 +from logger import logger
  14 +from .unified_websocket_manager import get_unified_manager
  15 +from .websocket_service_base import get_service_registry
  16 +from .asr_websocket_service import get_asr_service
  17 +from .digital_human_websocket_service import get_digital_human_service
  18 +
  19 +
  20 +class WebSocketRouter:
  21 + """WebSocket路由管理器"""
  22 +
  23 + def __init__(self):
  24 + self.manager = get_unified_manager()
  25 + self.service_registry = get_service_registry()
  26 + self.is_initialized = False
  27 +
  28 + async def initialize(self):
  29 + """初始化路由器和所有服务"""
  30 + if self.is_initialized:
  31 + return
  32 +
  33 + logger.info('初始化WebSocket路由器...')
  34 +
  35 + # 注册所有服务
  36 + await self._register_services()
  37 +
  38 + # 初始化所有服务
  39 + await self.service_registry.initialize_all()
  40 +
  41 + self.is_initialized = True
  42 + logger.info('WebSocket路由器初始化完成')
  43 +
  44 + async def shutdown(self):
  45 + """关闭路由器和所有服务"""
  46 + if not self.is_initialized:
  47 + return
  48 +
  49 + logger.info('关闭WebSocket路由器...')
  50 +
  51 + # 关闭所有服务
  52 + await self.service_registry.shutdown_all()
  53 +
  54 + # 关闭管理器
  55 + await self.manager.shutdown()
  56 +
  57 + self.is_initialized = False
  58 + logger.info('WebSocket路由器已关闭')
  59 +
  60 + async def _register_services(self):
  61 + """注册所有WebSocket服务"""
  62 + logger.info('注册WebSocket服务...')
  63 +
  64 + # 注册ASR服务
  65 + asr_service = get_asr_service()
  66 + self.service_registry.register_service(asr_service)
  67 +
  68 + # 注册数字人服务
  69 + digital_human_service = get_digital_human_service()
  70 + self.service_registry.register_service(digital_human_service)
  71 +
  72 + # 注册WSA服务
  73 + from .wsa_websocket_service import WSAWebSocketService, initialize_wsa_service
  74 + wsa_service = WSAWebSocketService(self.manager)
  75 + self.service_registry.register_service(wsa_service)
  76 +
  77 + # 初始化WSA兼容性接口
  78 + initialize_wsa_service(wsa_service)
  79 +
  80 + logger.info(f'已注册 {len(self.service_registry.list_services())} 个WebSocket服务')
  81 +
  82 + async def websocket_handler(self, request: web.Request) -> web.WebSocketResponse:
  83 + """统一的WebSocket处理器"""
  84 + ws = web.WebSocketResponse()
  85 + await ws.prepare(request)
  86 +
  87 + # 创建会话ID
  88 + session_id = request.headers.get('X-Session-ID', str(int(time.time())))
  89 + session = self.manager.add_session(session_id, ws)
  90 + logger.info(f'WebSocket连接建立: {session.session_id}')
  91 +
  92 + try:
  93 + async for msg in ws:
  94 + if msg.type == WSMsgType.TEXT:
  95 + try:
  96 + data = json.loads(msg.data)
  97 + await self._handle_message(ws, data)
  98 + except json.JSONDecodeError as e:
  99 + logger.error(f'JSON解析失败: {e}')
  100 + await session.send_message({
  101 + "type": "error",
  102 + "data": {"message": "消息格式错误"}
  103 + })
  104 + except Exception as e:
  105 + logger.error(f'消息处理失败: {e}')
  106 + await session.send_message({
  107 + "type": "error",
  108 + "data": {"message": f"处理失败: {str(e)}"}
  109 + })
  110 + elif msg.type == WSMsgType.ERROR:
  111 + logger.error(f'WebSocket错误: {ws.exception()}')
  112 + break
  113 + elif msg.type == WSMsgType.CLOSE:
  114 + logger.info(f'WebSocket连接关闭: {session.session_id}')
  115 + break
  116 +
  117 + except ConnectionResetError:
  118 + logger.warning(f'WebSocket连接被远程主机重置: {session.session_id}')
  119 + except ConnectionAbortedError:
  120 + logger.warning(f'WebSocket连接被中止: {session.session_id}')
  121 + except Exception as e:
  122 + logger.error(f'WebSocket处理异常: {e}')
  123 + finally:
  124 + # 清理会话
  125 + self.manager.remove_session(ws)
  126 +
  127 + return ws
  128 +
  129 + async def _handle_message(self, ws: web.WebSocketResponse, data: Dict[str, Any]):
  130 + """处理WebSocket消息"""
  131 + message_type = data.get('type')
  132 + if not message_type:
  133 + session = self.manager.get_session(ws)
  134 + if session:
  135 + await session.send_message({
  136 + "type": "error",
  137 + "data": {"message": "缺少消息类型"}
  138 + })
  139 + return
  140 +
  141 + # 通过管理器处理消息
  142 + await self.manager.handle_websocket_message(ws, data)
  143 +
  144 + def get_router_stats(self) -> Dict[str, Any]:
  145 + """获取路由器统计信息"""
  146 + stats = {
  147 + "initialized": self.is_initialized,
  148 + "manager_stats": self.manager.get_session_stats(),
  149 + "service_stats": self.service_registry.get_all_stats()
  150 + }
  151 +
  152 + # 添加各服务的详细统计
  153 + asr_service = self.service_registry.get_service("asr_service")
  154 + if asr_service:
  155 + stats["asr_stats"] = asr_service.get_asr_stats()
  156 +
  157 + digital_human_service = self.service_registry.get_service("digital_human_service")
  158 + if digital_human_service:
  159 + stats["digital_human_stats"] = digital_human_service.get_digital_human_stats()
  160 +
  161 + return stats
  162 +
  163 + def setup_routes(self, app: web.Application, path: str = '/ws'):
  164 + """设置WebSocket路由"""
  165 + app.router.add_get(path, self.websocket_handler)
  166 + logger.info(f'WebSocket路由已设置: {path}')
  167 +
  168 + async def broadcast_system_message(self, message: str, level: str = 'info'):
  169 + """广播系统消息"""
  170 + await self.manager.broadcast_to_all('system_message', {
  171 + 'message': message,
  172 + 'level': level,
  173 + 'timestamp': asyncio.get_event_loop().time()
  174 + }, source='system')
  175 +
  176 + async def send_to_session(self, session_id: str, message_type: str, content: Any):
  177 + """向指定会话发送消息"""
  178 + return await self.manager.broadcast_to_session(session_id, message_type, content, source='router')
  179 +
  180 + async def send_raw_to_session(self, session_id: str, message: Dict):
  181 + """向指定会话发送消息"""
  182 + return await self.manager.broadcast_raw_message_to_session(str(session_id), message)
  183 +
  184 +
  185 +
  186 + async def send_to_digital_human(self, human_id: str, message_type: str, content: Any):
  187 + """向指定数字人发送消息"""
  188 + digital_human_service = self.service_registry.get_service("digital_human_service")
  189 + if digital_human_service:
  190 + return await digital_human_service.send_to_digital_human(human_id, message_type, content)
  191 + return False
  192 +
  193 + async def get_asr_stats(self) -> Optional[Dict[str, Any]]:
  194 + """获取ASR统计信息"""
  195 + asr_service = self.service_registry.get_service("asr_service")
  196 + if asr_service:
  197 + return asr_service.get_asr_stats()
  198 + return None
  199 +
  200 + async def get_digital_human_stats(self) -> Optional[Dict[str, Any]]:
  201 + """获取数字人统计信息"""
  202 + digital_human_service = self.service_registry.get_service("digital_human_service")
  203 + if digital_human_service:
  204 + return digital_human_service.get_digital_human_stats()
  205 + return None
  206 +
  207 +
  208 +# 全局路由器实例
  209 +_websocket_router = None
  210 +
  211 +
  212 +def get_websocket_router() -> WebSocketRouter:
  213 + """获取WebSocket路由器实例"""
  214 + global _websocket_router
  215 + if _websocket_router is None:
  216 + _websocket_router = WebSocketRouter()
  217 + return _websocket_router
  218 +
  219 +
  220 +async def initialize_websocket_router():
  221 + """初始化WebSocket路由器"""
  222 + router = get_websocket_router()
  223 + await router.initialize()
  224 + return router
  225 +
  226 +
  227 +async def shutdown_websocket_router():
  228 + """关闭WebSocket路由器"""
  229 + global _websocket_router
  230 + if _websocket_router:
  231 + await _websocket_router.shutdown()
  232 + _websocket_router = None
  233 +
  234 +
  235 +def setup_websocket_routes(app: web.Application, path: str = '/ws'):
  236 + """设置WebSocket路由(便捷函数)"""
  237 + router = get_websocket_router()
  238 + router.setup_routes(app, path)
  239 + return router
  240 +
  241 +
  242 +# 兼容性接口
  243 +class WebSocketCompatibilityAPI:
  244 + """WebSocket兼容性API
  245 +
  246 + 为了保持与现有代码的兼容性,提供简化的接口
  247 + """
  248 +
  249 + def __init__(self):
  250 + self.router = get_websocket_router()
  251 +
  252 + async def broadcast_message_to_session(self, session_id: str, message: Dict[str, Any]):
  253 + """向指定会话广播消息(兼容app.py接口)"""
  254 + message_type = message.get('type', 'message')
  255 + content = message.get('data', message)
  256 + return await self.router.send_to_session(session_id, message_type, content)
  257 +
  258 + async def broadcast_to_all_sessions(self, message: Dict[str, Any]):
  259 + """向所有会话广播消息"""
  260 + message_type = message.get('type', 'message')
  261 + content = message.get('data', message)
  262 + return await self.router.manager.broadcast_to_all(message_type, content, source='compatibility')
  263 +
  264 + def get_active_sessions(self):
  265 + """获取活跃会话列表"""
  266 + return list(self.router.manager._sessions.keys())
  267 +
  268 + def get_session_count(self):
  269 + """获取会话数量"""
  270 + return len(self.router.manager._sessions)
  271 +
  272 + async def send_asr_result(self, session_id: str, result: Dict[str, Any]):
  273 + """发送ASR结果(兼容app.py接口)"""
  274 + return await self.router.send_to_session(session_id, 'asr_result', result)
  275 +
  276 +
  277 +# 全局兼容性API实例
  278 +_compatibility_api = None
  279 +
  280 +
  281 +def get_websocket_compatibility_api() -> WebSocketCompatibilityAPI:
  282 + """获取WebSocket兼容性API实例"""
  283 + global _compatibility_api
  284 + if _compatibility_api is None:
  285 + _compatibility_api = WebSocketCompatibilityAPI()
  286 + return _compatibility_api
  1 +# -*- coding: utf-8 -*-
  2 +"""
  3 +AIfeng/2025-07-15 14:41:21
  4 +WebSocket服务抽象基类
  5 +为不同类型的WebSocket服务提供统一接口和生命周期管理
  6 +"""
  7 +
  8 +import asyncio
  9 +from abc import ABC, abstractmethod
  10 +from typing import Dict, Any, Optional, List
  11 +from aiohttp import web
  12 +from logger import logger
  13 +from .unified_websocket_manager import get_unified_manager, WebSocketSession
  14 +
  15 +
  16 +class WebSocketServiceBase(ABC):
  17 + """WebSocket服务抽象基类"""
  18 +
  19 + def __init__(self, service_name: str):
  20 + self.service_name = service_name
  21 + self.manager = get_unified_manager()
  22 + self.is_initialized = False
  23 + self._background_tasks: List[asyncio.Task] = []
  24 +
  25 + async def initialize(self):
  26 + """初始化服务"""
  27 + if self.is_initialized:
  28 + return
  29 +
  30 + logger.info(f'初始化WebSocket服务: {self.service_name}')
  31 +
  32 + # 注册消息处理器
  33 + await self._register_message_handlers()
  34 +
  35 + # 注册事件处理器
  36 + await self._register_event_handlers()
  37 +
  38 + # 启动后台任务
  39 + await self._start_background_tasks()
  40 +
  41 + self.is_initialized = True
  42 + logger.info(f'WebSocket服务初始化完成: {self.service_name}')
  43 +
  44 + async def shutdown(self):
  45 + """关闭服务"""
  46 + if not self.is_initialized:
  47 + return
  48 +
  49 + logger.info(f'关闭WebSocket服务: {self.service_name}')
  50 +
  51 + # 停止后台任务
  52 + for task in self._background_tasks:
  53 + if not task.done():
  54 + task.cancel()
  55 + try:
  56 + await task
  57 + except asyncio.CancelledError:
  58 + pass
  59 +
  60 + self._background_tasks.clear()
  61 +
  62 + # 执行自定义清理
  63 + await self._cleanup()
  64 +
  65 + self.is_initialized = False
  66 + logger.info(f'WebSocket服务已关闭: {self.service_name}')
  67 +
  68 + @abstractmethod
  69 + async def _register_message_handlers(self):
  70 + """注册消息处理器(子类实现)"""
  71 + pass
  72 +
  73 + async def _register_event_handlers(self):
  74 + """注册事件处理器(可选重写)"""
  75 + # 注册通用事件处理器
  76 + self.manager.register_event_handler('session_connected', self._on_session_connected)
  77 + self.manager.register_event_handler('session_disconnected', self._on_session_disconnected)
  78 +
  79 + async def _start_background_tasks(self):
  80 + """启动后台任务(可选重写)"""
  81 + pass
  82 +
  83 + async def _cleanup(self):
  84 + """清理资源(可选重写)"""
  85 + pass
  86 +
  87 + async def _on_session_connected(self, session: WebSocketSession, data: Dict[str, Any]):
  88 + """会话连接事件处理(可选重写)"""
  89 + logger.info(f'[{self.service_name}] 会话连接: {session.session_id}')
  90 +
  91 + async def _on_session_disconnected(self, session: WebSocketSession):
  92 + """会话断开事件处理(可选重写)"""
  93 + logger.info(f'[{self.service_name}] 会话断开: {session.session_id}')
  94 +
  95 + def add_background_task(self, coro):
  96 + """添加后台任务"""
  97 + task = asyncio.create_task(coro)
  98 + self._background_tasks.append(task)
  99 + return task
  100 +
  101 + async def broadcast_to_session(self, session_id: str, message_type: str, content: Any,
  102 + source: str = None, metadata: Dict = None):
  103 + """向指定会话广播消息"""
  104 + if source is None:
  105 + source = self.service_name
  106 + return await self.manager.broadcast_to_session(session_id, message_type, content, source, metadata)
  107 +
  108 + async def broadcast_to_all(self, message_type: str, content: Any,
  109 + source: str = None, metadata: Dict = None):
  110 + """向所有会话广播消息"""
  111 + if source is None:
  112 + source = self.service_name
  113 + return await self.manager.broadcast_to_all(message_type, content, source, metadata)
  114 +
  115 + def get_session_stats(self) -> Dict[str, Any]:
  116 + """获取会话统计信息"""
  117 + return self.manager.get_session_stats()
  118 +
  119 + def create_message_handler(self, message_type: str):
  120 + """装饰器:创建消息处理器"""
  121 + def decorator(func):
  122 + async def wrapper(websocket: web.WebSocketResponse, data: Dict[str, Any]):
  123 + session = self.manager.get_session(websocket)
  124 + if session:
  125 + try:
  126 + await func(session, data)
  127 + except Exception as e:
  128 + logger.error(f'[{self.service_name}] 消息处理器 {message_type} 执行失败: {e}')
  129 + await session.send_message({
  130 + "type": "error",
  131 + "data": {
  132 + "message": f"处理 {message_type} 消息失败: {str(e)}",
  133 + "service": self.service_name
  134 + }
  135 + })
  136 + else:
  137 + logger.warning(f'[{self.service_name}] 未找到会话,无法处理消息: {message_type}')
  138 +
  139 + self.manager.register_message_handler(message_type, wrapper)
  140 + return func
  141 + return decorator
  142 +
  143 + def create_event_handler(self, event_type: str):
  144 + """装饰器:创建事件处理器"""
  145 + def decorator(func):
  146 + self.manager.register_event_handler(event_type, func)
  147 + return func
  148 + return decorator
  149 +
  150 +
  151 +class WebSocketServiceRegistry:
  152 + """WebSocket服务注册表"""
  153 +
  154 + def __init__(self):
  155 + self._services: Dict[str, WebSocketServiceBase] = {}
  156 +
  157 + def register_service(self, service: WebSocketServiceBase):
  158 + """注册服务"""
  159 + if service.service_name in self._services:
  160 + logger.warning(f'服务已存在,将被覆盖: {service.service_name}')
  161 +
  162 + self._services[service.service_name] = service
  163 + logger.info(f'注册WebSocket服务: {service.service_name}')
  164 +
  165 + def get_service(self, service_name: str) -> Optional[WebSocketServiceBase]:
  166 + """获取服务"""
  167 + return self._services.get(service_name)
  168 +
  169 + def list_services(self) -> List[str]:
  170 + """列出所有服务名称"""
  171 + return list(self._services.keys())
  172 +
  173 + async def initialize_all(self):
  174 + """初始化所有服务"""
  175 + logger.info('初始化所有WebSocket服务...')
  176 + for service in self._services.values():
  177 + await service.initialize()
  178 + logger.info('所有WebSocket服务初始化完成')
  179 +
  180 + async def shutdown_all(self):
  181 + """关闭所有服务"""
  182 + logger.info('关闭所有WebSocket服务...')
  183 + for service in self._services.values():
  184 + await service.shutdown()
  185 + logger.info('所有WebSocket服务已关闭')
  186 +
  187 + def get_all_stats(self) -> Dict[str, Any]:
  188 + """获取所有服务的统计信息"""
  189 + stats = {
  190 + "services": {},
  191 + "total_services": len(self._services)
  192 + }
  193 +
  194 + for name, service in self._services.items():
  195 + stats["services"][name] = {
  196 + "initialized": service.is_initialized,
  197 + "background_tasks": len(service._background_tasks)
  198 + }
  199 +
  200 + return stats
  201 +
  202 +
  203 +# 全局服务注册表
  204 +_service_registry = WebSocketServiceRegistry()
  205 +
  206 +
  207 +def get_service_registry() -> WebSocketServiceRegistry:
  208 + """获取服务注册表"""
  209 + return _service_registry
  210 +
  211 +
  212 +def register_websocket_service(service: WebSocketServiceBase):
  213 + """注册WebSocket服务"""
  214 + _service_registry.register_service(service)
  215 +
  216 +
  217 +def get_websocket_service(service_name: str) -> Optional[WebSocketServiceBase]:
  218 + """获取WebSocket服务"""
  219 + return _service_registry.get_service(service_name)
1 -# -*- coding: utf-8 -*-  
2 -"""  
3 -AIfeng/2025-01-27  
4 -WebSocket服务器管理模块  
5 -提供Web和Human连接的管理功能  
6 -"""  
7 -  
8 -import queue  
9 -from typing import Dict, Any, Optional  
10 -from threading import Lock  
11 -  
12 -class WebSocketManager:  
13 - """WebSocket连接管理器"""  
14 -  
15 - def __init__(self):  
16 - self._connections = {}  
17 - self._command_queue = queue.Queue()  
18 - self._lock = Lock()  
19 -  
20 - def is_connected(self, username: str) -> bool:  
21 - """检查用户是否已连接  
22 -  
23 - Args:  
24 - username: 用户名  
25 -  
26 - Returns:  
27 - 是否已连接  
28 - """  
29 - with self._lock:  
30 - return username in self._connections  
31 -  
32 - def is_connected_human(self, username: str) -> bool:  
33 - """检查人类用户是否已连接  
34 -  
35 - Args:  
36 - username: 用户名  
37 -  
38 - Returns:  
39 - 是否已连接  
40 - """  
41 - # 简化实现,与is_connected相同  
42 - return self.is_connected(username)  
43 -  
44 - def add_connection(self, username: str, connection: Any):  
45 - """添加连接  
46 -  
47 - Args:  
48 - username: 用户名  
49 - connection: 连接对象  
50 - """  
51 - with self._lock:  
52 - self._connections[username] = connection  
53 -  
54 - def remove_connection(self, username: str):  
55 - """移除连接  
56 -  
57 - Args:  
58 - username: 用户名  
59 - """  
60 - with self._lock:  
61 - self._connections.pop(username, None)  
62 -  
63 - def add_cmd(self, command: Dict[str, Any]):  
64 - """添加命令到队列  
65 -  
66 - Args:  
67 - command: 命令字典  
68 - """  
69 - try:  
70 - self._command_queue.put(command, timeout=1.0)  
71 - except queue.Full:  
72 - print(f"警告: 命令队列已满,丢弃命令: {command}")  
73 -  
74 - def get_cmd(self, timeout: float = 1.0) -> Optional[Dict[str, Any]]:  
75 - """从队列获取命令  
76 -  
77 - Args:  
78 - timeout: 超时时间  
79 -  
80 - Returns:  
81 - 命令字典或None  
82 - """  
83 - try:  
84 - return self._command_queue.get(timeout=timeout)  
85 - except queue.Empty:  
86 - return None  
87 -  
88 - def get_connection_count(self) -> int:  
89 - """获取连接数量"""  
90 - with self._lock:  
91 - return len(self._connections)  
92 -  
93 - def get_usernames(self) -> list:  
94 - """获取所有用户名列表"""  
95 - with self._lock:  
96 - return list(self._connections.keys())  
97 -  
98 -# 全局实例  
99 -_web_instance = WebSocketManager()  
100 -_human_instance = WebSocketManager()  
101 -  
102 -def get_web_instance() -> WebSocketManager:  
103 - """获取Web WebSocket管理器实例"""  
104 - return _web_instance  
105 -  
106 -def get_instance() -> WebSocketManager:  
107 - """获取Human WebSocket管理器实例"""  
108 - return _human_instance  
  1 +# -*- coding: utf-8 -*-
  2 +"""
  3 +AIfeng/2025-01-27 16:30:00
  4 +WSA WebSocket服务
  5 +将原有wsa_server功能集成到统一WebSocket架构中
  6 +"""
  7 +
  8 +import asyncio
  9 +import json
  10 +import queue
  11 +from typing import Dict, Any, Optional, Set
  12 +from threading import Lock
  13 +from aiohttp import web
  14 +
  15 +from .websocket_service_base import WebSocketServiceBase
  16 +from .unified_websocket_manager import WebSocketSession
  17 +
  18 +
  19 +class WSAWebSocketService(WebSocketServiceBase):
  20 + """WSA WebSocket服务
  21 +
  22 + 提供与原wsa_server兼容的功能:
  23 + - Web连接管理
  24 + - Human连接管理
  25 + - 命令队列处理
  26 + - 消息转发
  27 + """
  28 +
  29 + def __init__(self, manager):
  30 + super().__init__("wsa")
  31 +
  32 + # 连接管理
  33 + self._web_connections: Dict[str, Set[WebSocketSession]] = {}
  34 + self._human_connections: Dict[str, Set[WebSocketSession]] = {}
  35 + self._connection_lock = Lock()
  36 +
  37 + # 命令队列
  38 + self._web_command_queue = queue.Queue()
  39 + self._human_command_queue = queue.Queue()
  40 +
  41 + # 后台任务
  42 + self._queue_processor_task: Optional[asyncio.Task] = None
  43 +
  44 + async def _register_message_handlers(self):
  45 + """注册消息处理器"""
  46 + self.manager.register_message_handler("wsa_register_web", self._handle_register_web)
  47 + self.manager.register_message_handler("wsa_register_human", self._handle_register_human)
  48 + self.manager.register_message_handler("wsa_unregister", self._handle_unregister)
  49 + self.manager.register_message_handler("wsa_get_status", self._handle_get_status)
  50 +
  51 + async def _start_background_tasks(self):
  52 + """启动后台任务"""
  53 + self._queue_processor_task = asyncio.create_task(self._process_command_queues())
  54 +
  55 + async def _cleanup(self):
  56 + """清理资源"""
  57 + if self._queue_processor_task:
  58 + self._queue_processor_task.cancel()
  59 + try:
  60 + await self._queue_processor_task
  61 + except asyncio.CancelledError:
  62 + pass
  63 +
  64 + async def _on_session_disconnected(self, session: WebSocketSession):
  65 + """会话断开处理"""
  66 + with self._connection_lock:
  67 + # 从web连接中移除
  68 + for username, sessions in list(self._web_connections.items()):
  69 + if session in sessions:
  70 + sessions.discard(session)
  71 + if not sessions:
  72 + del self._web_connections[username]
  73 +
  74 + # 从human连接中移除
  75 + for username, sessions in list(self._human_connections.items()):
  76 + if session in sessions:
  77 + sessions.discard(session)
  78 + if not sessions:
  79 + del self._human_connections[username]
  80 +
  81 + async def _handle_register_web(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  82 + """注册Web连接"""
  83 + username = data.get('username')
  84 + if not username:
  85 + await websocket.send_str(json.dumps({
  86 + "type": "wsa_error",
  87 + "message": "用户名不能为空"
  88 + }))
  89 + return
  90 +
  91 + session = self.manager.get_session(websocket)
  92 + if not session:
  93 + await websocket.send_str(json.dumps({
  94 + "type": "wsa_error",
  95 + "message": "会话未找到"
  96 + }))
  97 + return
  98 +
  99 + with self._connection_lock:
  100 + if username not in self._web_connections:
  101 + self._web_connections[username] = set()
  102 + self._web_connections[username].add(session)
  103 +
  104 + await websocket.send_str(json.dumps({
  105 + "type": "wsa_registered",
  106 + "connection_type": "web",
  107 + "username": username
  108 + }))
  109 +
  110 + async def _handle_register_human(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  111 + """注册Human连接"""
  112 + username = data.get('username')
  113 + if not username:
  114 + await websocket.send_str(json.dumps({
  115 + "type": "wsa_error",
  116 + "message": "用户名不能为空"
  117 + }))
  118 + return
  119 +
  120 + session = self.manager.get_session(websocket)
  121 + if not session:
  122 + await websocket.send_str(json.dumps({
  123 + "type": "wsa_error",
  124 + "message": "会话未找到"
  125 + }))
  126 + return
  127 +
  128 + with self._connection_lock:
  129 + if username not in self._human_connections:
  130 + self._human_connections[username] = set()
  131 + self._human_connections[username].add(session)
  132 +
  133 + await websocket.send_str(json.dumps({
  134 + "type": "wsa_registered",
  135 + "connection_type": "human",
  136 + "username": username
  137 + }))
  138 +
  139 + async def _handle_unregister(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  140 + """注销连接"""
  141 + username = data.get('username')
  142 + connection_type = data.get('connection_type', 'both')
  143 +
  144 + session = self.manager.get_session(websocket)
  145 + if not session:
  146 + return
  147 +
  148 + with self._connection_lock:
  149 + if connection_type in ['web', 'both'] and username in self._web_connections:
  150 + self._web_connections[username].discard(session)
  151 + if not self._web_connections[username]:
  152 + del self._web_connections[username]
  153 +
  154 + if connection_type in ['human', 'both'] and username in self._human_connections:
  155 + self._human_connections[username].discard(session)
  156 + if not self._human_connections[username]:
  157 + del self._human_connections[username]
  158 +
  159 + await websocket.send_str(json.dumps({
  160 + "type": "wsa_unregistered",
  161 + "username": username,
  162 + "connection_type": connection_type
  163 + }))
  164 +
  165 + async def _handle_get_status(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
  166 + """获取连接状态"""
  167 + with self._connection_lock:
  168 + web_users = list(self._web_connections.keys())
  169 + human_users = list(self._human_connections.keys())
  170 +
  171 + await websocket.send_str(json.dumps({
  172 + "type": "wsa_status",
  173 + "data": {
  174 + "web_connections": len(self._web_connections),
  175 + "human_connections": len(self._human_connections),
  176 + "web_users": web_users,
  177 + "human_users": human_users,
  178 + "web_queue_size": self._web_command_queue.qsize(),
  179 + "human_queue_size": self._human_command_queue.qsize()
  180 + }
  181 + }))
  182 +
  183 + async def _process_command_queues(self):
  184 + """处理命令队列"""
  185 + while True:
  186 + try:
  187 + # 处理Web命令队列
  188 + await self._process_web_commands()
  189 +
  190 + # 处理Human命令队列
  191 + await self._process_human_commands()
  192 +
  193 + # 短暂休眠避免CPU占用过高
  194 + await asyncio.sleep(0.01)
  195 +
  196 + except asyncio.CancelledError:
  197 + break
  198 + except Exception as e:
  199 + self.logger.error(f"命令队列处理错误: {e}")
  200 + await asyncio.sleep(0.1)
  201 +
  202 + async def _process_web_commands(self):
  203 + """处理Web命令队列"""
  204 + try:
  205 + while True:
  206 + try:
  207 + command = self._web_command_queue.get_nowait()
  208 + await self._forward_web_command(command)
  209 + except queue.Empty:
  210 + break
  211 + except Exception as e:
  212 + self.logger.error(f"Web命令处理错误: {e}")
  213 +
  214 + async def _process_human_commands(self):
  215 + """处理Human命令队列"""
  216 + try:
  217 + while True:
  218 + try:
  219 + command = self._human_command_queue.get_nowait()
  220 + await self._forward_human_command(command)
  221 + except queue.Empty:
  222 + break
  223 + except Exception as e:
  224 + self.logger.error(f"Human命令处理错误: {e}")
  225 +
  226 + async def _forward_web_command(self, command: Dict[str, Any]):
  227 + """转发Web命令"""
  228 + username = command.get('Username')
  229 + if not username:
  230 + return
  231 +
  232 + with self._connection_lock:
  233 + sessions = self._web_connections.get(username, set())
  234 +
  235 + if sessions:
  236 + message = {
  237 + "type": "wsa_command",
  238 + "source": "web",
  239 + "data": command
  240 + }
  241 +
  242 + for session in list(sessions):
  243 + try:
  244 + await session.send_message(message)
  245 + except Exception as e:
  246 + self.logger.error(f"发送Web命令失败 [{username}]: {e}")
  247 +
  248 + async def _forward_human_command(self, command: Dict[str, Any]):
  249 + """转发Human命令"""
  250 + username = command.get('Username')
  251 + if not username:
  252 + return
  253 +
  254 + with self._connection_lock:
  255 + sessions = self._human_connections.get(username, set())
  256 +
  257 + if sessions:
  258 + message = {
  259 + "type": "wsa_command",
  260 + "source": "human",
  261 + "data": command
  262 + }
  263 +
  264 + for session in list(sessions):
  265 + try:
  266 + await session.send_message(message)
  267 + except Exception as e:
  268 + self.logger.error(f"发送Human命令失败 [{username}]: {e}")
  269 +
  270 + # 兼容性接口
  271 + def is_connected(self, username: str) -> bool:
  272 + """检查Web用户是否已连接"""
  273 + with self._connection_lock:
  274 + return username in self._web_connections and bool(self._web_connections[username])
  275 +
  276 + def is_connected_human(self, username: str) -> bool:
  277 + """检查Human用户是否已连接"""
  278 + with self._connection_lock:
  279 + return username in self._human_connections and bool(self._human_connections[username])
  280 +
  281 + def add_connection(self, username: str, connection: Any):
  282 + """添加连接(兼容性接口,已废弃)"""
  283 + self.logger.warning("add_connection方法已废弃,请使用消息注册机制")
  284 +
  285 + def remove_connection(self, username: str):
  286 + """移除连接(兼容性接口,已废弃)"""
  287 + self.logger.warning("remove_connection方法已废弃,连接会自动清理")
  288 +
  289 + def add_cmd(self, command: Dict[str, Any], target: str = "web"):
  290 + """添加命令到队列"""
  291 + try:
  292 + if target == "web":
  293 + self._web_command_queue.put(command, timeout=1.0)
  294 + elif target == "human":
  295 + self._human_command_queue.put(command, timeout=1.0)
  296 + else:
  297 + self.logger.warning(f"未知的目标类型: {target}")
  298 + except queue.Full:
  299 + self.logger.warning(f"命令队列已满,丢弃命令: {command}")
  300 +
  301 + async def send_direct_message(self, message: Dict[str, Any], target: str = "web"):
  302 + """直接发送消息(不封装为wsa_command)"""
  303 + username = message.get('Username')
  304 + if not username:
  305 + self.logger.warning("消息缺少Username字段")
  306 + return
  307 +
  308 + with self._connection_lock:
  309 + if target == "web":
  310 + sessions = self._web_connections.get(username, set())
  311 + elif target == "human":
  312 + sessions = self._human_connections.get(username, set())
  313 + else:
  314 + self.logger.warning(f"未知的目标类型: {target}")
  315 + return
  316 +
  317 + if sessions:
  318 + for session in list(sessions):
  319 + try:
  320 + await session.send_message(message)
  321 + except Exception as e:
  322 + self.logger.error(f"直接发送消息失败 [{username}]: {e}")
  323 + else:
  324 + self.logger.debug(f"用户 {username} 未连接,无法发送直接消息")
  325 +
  326 + def get_cmd(self, timeout: float = 1.0, target: str = "web") -> Optional[Dict[str, Any]]:
  327 + """从队列获取命令"""
  328 + try:
  329 + if target == "web":
  330 + return self._web_command_queue.get(timeout=timeout)
  331 + elif target == "human":
  332 + return self._human_command_queue.get(timeout=timeout)
  333 + else:
  334 + self.logger.warning(f"未知的目标类型: {target}")
  335 + return None
  336 + except queue.Empty:
  337 + return None
  338 +
  339 + def get_connection_count(self, target: str = "web") -> int:
  340 + """获取连接数量"""
  341 + with self._connection_lock:
  342 + if target == "web":
  343 + return len(self._web_connections)
  344 + elif target == "human":
  345 + return len(self._human_connections)
  346 + else:
  347 + return len(self._web_connections) + len(self._human_connections)
  348 +
  349 + def get_usernames(self, target: str = "web") -> list:
  350 + """获取用户名列表"""
  351 + with self._connection_lock:
  352 + if target == "web":
  353 + return list(self._web_connections.keys())
  354 + elif target == "human":
  355 + return list(self._human_connections.keys())
  356 + else:
  357 + return list(set(list(self._web_connections.keys()) + list(self._human_connections.keys())))
  358 +
  359 +
  360 +# 兼容性包装器
  361 +class WSAWebSocketManager:
  362 + """WSA WebSocket管理器兼容性包装器"""
  363 +
  364 + def __init__(self, service: WSAWebSocketService):
  365 + self.service = service
  366 +
  367 + def is_connected(self, username: str) -> bool:
  368 + return self.service.is_connected(username)
  369 +
  370 + def is_connected_human(self, username: str) -> bool:
  371 + return self.service.is_connected_human(username)
  372 +
  373 + def add_connection(self, username: str, connection: Any):
  374 + self.service.add_connection(username, connection)
  375 +
  376 + def remove_connection(self, username: str):
  377 + self.service.remove_connection(username)
  378 +
  379 + def add_cmd(self, command: Dict[str, Any]):
  380 + self.service.add_cmd(command, "web")
  381 +
  382 + async def send_direct_message(self, message: Dict[str, Any]):
  383 + """直接发送消息(不封装为wsa_command)"""
  384 + await self.service.send_direct_message(message, "web")
  385 +
  386 + def get_cmd(self, timeout: float = 1.0) -> Optional[Dict[str, Any]]:
  387 + return self.service.get_cmd(timeout, "web")
  388 +
  389 + def get_connection_count(self) -> int:
  390 + return self.service.get_connection_count("web")
  391 +
  392 + def get_usernames(self) -> list:
  393 + return self.service.get_usernames("web")
  394 +
  395 +
  396 +# 全局实例(兼容性)
  397 +_wsa_service: Optional[WSAWebSocketService] = None
  398 +_web_instance: Optional[WSAWebSocketManager] = None
  399 +_human_instance: Optional[WSAWebSocketManager] = None
  400 +
  401 +
  402 +def initialize_wsa_service(service: WSAWebSocketService):
  403 + """初始化WSA服务"""
  404 + global _wsa_service, _web_instance, _human_instance
  405 + _wsa_service = service
  406 + _web_instance = WSAWebSocketManager(service)
  407 + _human_instance = WSAWebSocketManager(service)
  408 +
  409 +
  410 +def get_web_instance() -> WSAWebSocketManager:
  411 + """获取Web WebSocket管理器实例"""
  412 + if _web_instance is None:
  413 + raise RuntimeError("WSA服务未初始化")
  414 + return _web_instance
  415 +
  416 +
  417 +def get_instance() -> WSAWebSocketManager:
  418 + """获取Human WebSocket管理器实例"""
  419 + if _human_instance is None:
  420 + raise RuntimeError("WSA服务未初始化")
  421 + return _human_instance
@@ -29,7 +29,7 @@ def _play_frame(stream, exit_event, queue, chunk): @@ -29,7 +29,7 @@ def _play_frame(stream, exit_event, queue, chunk):
29 print(f'[INFO] play frame thread ends') 29 print(f'[INFO] play frame thread ends')
30 break 30 break
31 frame = queue.get() 31 frame = queue.get()
32 - frame = (frame * 32767).astype(np.int16).tobytes() 32 + frame = bytes((frame * 32767).astype(np.int16).tobytes()) # Fix BufferError: memoryview has 1 exported buffer
33 stream.write(frame, chunk) 33 stream.write(frame, chunk)
34 34
35 class ASR: 35 class ASR:
@@ -71,18 +71,45 @@ class FunASRClient(BaseASR): @@ -71,18 +71,45 @@ class FunASRClient(BaseASR):
71 async def _connect_websocket(self): 71 async def _connect_websocket(self):
72 """连接WebSocket服务器""" 72 """连接WebSocket服务器"""
73 try: 73 try:
74 - self.websocket = await websockets.connect(  
75 - self.server_url,  
76 - timeout=getattr(cfg, 'asr_timeout', 30) 74 + # 修复: websockets新版本不支持timeout参数,使用asyncio.wait_for包装
  75 + timeout_seconds = getattr(cfg, 'asr_timeout', 30)
  76 + self.websocket = await asyncio.wait_for(
  77 + websockets.connect(self.server_url),
  78 + timeout=timeout_seconds
77 ) 79 )
78 self.connected = True 80 self.connected = True
79 util.log(1, f"FunASR WebSocket连接成功: {self.server_url}") 81 util.log(1, f"FunASR WebSocket连接成功: {self.server_url}")
  82 +
  83 + # 发送初始化配置消息(参考funasr_client_api.py)
  84 + await self._send_init_message()
  85 +
80 return True 86 return True
81 except Exception as e: 87 except Exception as e:
82 util.log(3, f"FunASR WebSocket连接失败: {e}") 88 util.log(3, f"FunASR WebSocket连接失败: {e}")
83 self.connected = False 89 self.connected = False
84 return False 90 return False
85 91
  92 + async def _send_init_message(self):
  93 + """发送FunASR初始化配置消息"""
  94 + try:
  95 + # 根据参考项目funasr_client_api.py的格式
  96 + init_message = {
  97 + "mode": "2pass",
  98 + "chunk_size": [0, 10, 5], # [vad_need, chunk_size, chunk_interval]
  99 + "encoder_chunk_look_back": 4,
  100 + "decoder_chunk_look_back": 1,
  101 + "chunk_interval": 10,
  102 + "wav_name": self.username,
  103 + "is_speaking": True
  104 + }
  105 +
  106 + await self.websocket.send(json.dumps(init_message))
  107 + util.log(1, f"发送FunASR初始化消息: {init_message}")
  108 +
  109 + except Exception as e:
  110 + util.log(3, f"发送初始化消息失败: {e}")
  111 + raise e
  112 +
86 async def _disconnect_websocket(self): 113 async def _disconnect_websocket(self):
87 """断开WebSocket连接""" 114 """断开WebSocket连接"""
88 if self.websocket: 115 if self.websocket:
@@ -141,11 +168,12 @@ class FunASRClient(BaseASR): @@ -141,11 +168,12 @@ class FunASRClient(BaseASR):
141 message = self.message_queue.get_nowait() 168 message = self.message_queue.get_nowait()
142 169
143 if isinstance(message, dict): 170 if isinstance(message, dict):
144 - # JSON消息 171 + # JSON消息(配置消息或结束信号)
145 await self.websocket.send(json.dumps(message)) 172 await self.websocket.send(json.dumps(message))
146 util.log(1, f"发送JSON消息: {message}") 173 util.log(1, f"发送JSON消息: {message}")
147 elif isinstance(message, bytes): 174 elif isinstance(message, bytes):
148 - # 二进制音频数据 175 + # 二进制音频数据(参考funasr_client_api.py的feed_chunk方法)
  176 + # 确保音频数据以二进制格式发送
149 await self.websocket.send(message) 177 await self.websocket.send(message)
150 util.log(1, f"发送音频数据: {len(message)} bytes") 178 util.log(1, f"发送音频数据: {len(message)} bytes")
151 else: 179 else:
@@ -203,23 +231,24 @@ class FunASRClient(BaseASR): @@ -203,23 +231,24 @@ class FunASRClient(BaseASR):
203 text: 识别文本 231 text: 识别文本
204 """ 232 """
205 try: 233 try:
206 - from core import wsa_server 234 + from core import get_web_instance, get_instance
207 235
208 # 发送到Web客户端 236 # 发送到Web客户端
209 - if wsa_server.get_web_instance().is_connected(self.username):  
210 - wsa_server.get_web_instance().add_cmd({  
211 - "panelMsg": text,  
212 - "Username": self.username  
213 - })  
214 -  
215 - # 发送到Human客户端  
216 - if wsa_server.get_instance().is_connected_human(self.username):  
217 - content = {  
218 - 'Topic': 'human',  
219 - 'Data': {'Key': 'log', 'Value': text},  
220 - 'Username': self.username 237 + if get_web_instance().is_connected(self.username):
  238 + import asyncio
  239 + # 创建chat_message直接推送
  240 + chat_message = {
  241 + "type": "chat_message",
  242 + "sender": "回音",
  243 + "content": text, # 修复字段名:panelMsg -> content
  244 + "Username": self.username,
  245 + "model_info": "FunASR"
221 } 246 }
222 - wsa_server.get_instance().add_cmd(content) 247 + # 使用直接发送方法,避免wsa_command封装
  248 + asyncio.create_task(get_web_instance().send_direct_message(chat_message))
  249 +
  250 + # Human客户端通知改为日志记录(避免重复通知当前服务)
  251 + util.log(1, f"FunASR识别结果[{self.username}]: {text}")
223 252
224 except Exception as e: 253 except Exception as e:
225 util.log(2, f"发送到Web客户端失败: {e}") 254 util.log(2, f"发送到Web客户端失败: {e}")
@@ -333,6 +362,19 @@ class FunASRClient(BaseASR): @@ -333,6 +362,19 @@ class FunASRClient(BaseASR):
333 self.message_queue.put(audio_data) 362 self.message_queue.put(audio_data)
334 return True 363 return True
335 364
  365 + def send_end_signal(self):
  366 + """发送结束信号"""
  367 + if not self.connected:
  368 + return
  369 +
  370 + try:
  371 + # 发送结束消息(参考funasr_client_api.py的close方法)
  372 + end_message = {"is_speaking": False}
  373 + self.message_queue.put(end_message)
  374 + util.log(1, "发送FunASR结束信号")
  375 + except Exception as e:
  376 + util.log(3, f"发送结束信号失败: {e}")
  377 +
336 def start_recognition(self): 378 def start_recognition(self):
337 """开始语音识别""" 379 """开始语音识别"""
338 if not self.connected: 380 if not self.connected:
@@ -418,6 +460,99 @@ class FunASRClient(BaseASR): @@ -418,6 +460,99 @@ class FunASRClient(BaseASR):
418 # 简化实现,返回空特征 460 # 简化实现,返回空特征
419 return np.zeros((1, 50), dtype=np.float32) 461 return np.zeros((1, 50), dtype=np.float32)
420 462
  463 + async def connect(self):
  464 + """异步连接到FunASR服务器"""
  465 + if self.connected:
  466 + util.log(1, "FunASR客户端已连接")
  467 + return True
  468 +
  469 + try:
  470 + success = await self._connect_websocket()
  471 + if success:
  472 + # 启动消息处理任务
  473 + self.receive_task = asyncio.create_task(self._receive_messages())
  474 + self.send_task = asyncio.create_task(self._send_message_loop())
  475 + util.log(1, "FunASR异步连接建立成功")
  476 + return success
  477 + except Exception as e:
  478 + util.log(3, f"FunASR异步连接失败: {e}")
  479 + return False
  480 +
  481 + async def disconnect(self):
  482 + """异步断开连接"""
  483 + try:
  484 + # 取消任务
  485 + if hasattr(self, 'receive_task'):
  486 + self.receive_task.cancel()
  487 + if hasattr(self, 'send_task'):
  488 + self.send_task.cancel()
  489 +
  490 + # 断开WebSocket连接
  491 + await self._disconnect_websocket()
  492 + util.log(1, "FunASR异步连接已断开")
  493 + except Exception as e:
  494 + util.log(2, f"断开FunASR连接时出错: {e}")
  495 +
  496 + async def send_audio_data(self, audio_data):
  497 + """异步发送音频数据"""
  498 + try:
  499 + if isinstance(audio_data, str):
  500 + # Base64编码的音频数据,需要解码
  501 + import base64
  502 + audio_bytes = base64.b64decode(audio_data)
  503 + util.log(1, f"解码Base64音频数据: {len(audio_bytes)} bytes")
  504 + elif isinstance(audio_data, bytes):
  505 + audio_bytes = audio_data
  506 + util.log(1, f"接收字节音频数据: {len(audio_bytes)} bytes")
  507 + elif isinstance(audio_data, np.ndarray):
  508 + # NumPy数组转换为字节
  509 + if audio_data.dtype != np.int16:
  510 + audio_data = audio_data.astype(np.int16)
  511 + audio_bytes = bytes(audio_data.tobytes()) # Fix BufferError: memoryview has 1 exported buffer
  512 + util.log(1, f"转换NumPy数组为字节: {len(audio_bytes)} bytes")
  513 + else:
  514 + util.log(3, f"不支持的音频数据类型: {type(audio_data)},尝试转换为字节")
  515 + # 尝试强制转换
  516 + try:
  517 + audio_bytes = bytes(audio_data)
  518 + except Exception as convert_error:
  519 + util.log(3, f"音频数据类型转换失败: {convert_error}")
  520 + return False
  521 +
  522 + # 验证音频数据有效性
  523 + if len(audio_bytes) == 0:
  524 + util.log(2, "音频数据为空,跳过发送")
  525 + return False
  526 +
  527 + # 确保音频数据长度为偶数(16位采样)
  528 + if len(audio_bytes) % 2 != 0:
  529 + audio_bytes = audio_bytes[:-1] # 去掉最后一个字节
  530 + util.log(2, f"调整音频数据长度为偶数: {len(audio_bytes)} bytes")
  531 +
  532 + # 参考funasr_client_api.py,音频数据需要按chunk发送
  533 + # 计算stride(参考项目中的计算方式)
  534 + chunk_interval = 10 # ms
  535 + chunk_size = 10 # ms
  536 + stride = int(60 * chunk_size / chunk_interval / 1000 * 16000 * 2)
  537 +
  538 + # 如果音频数据较大,分块发送
  539 + if len(audio_bytes) > stride:
  540 + chunk_num = (len(audio_bytes) - 1) // stride + 1
  541 + for i in range(chunk_num):
  542 + beg = i * stride
  543 + chunk_data = audio_bytes[beg:beg + stride]
  544 + self.message_queue.put(chunk_data)
  545 + util.log(1, f"发送音频块 {i+1}/{chunk_num}: {len(chunk_data)} bytes")
  546 + else:
  547 + # 小数据直接发送
  548 + self.message_queue.put(audio_bytes)
  549 + util.log(1, f"发送音频数据: {len(audio_bytes)} bytes")
  550 +
  551 + return True
  552 + except Exception as e:
  553 + util.log(3, f"发送音频数据失败: {e}")
  554 + return False
  555 +
421 def __del__(self): 556 def __del__(self):
422 """析构函数""" 557 """析构函数"""
423 self.stop() 558 self.stop()
@@ -293,7 +293,10 @@ class LightReal(BaseReal): @@ -293,7 +293,10 @@ class LightReal(BaseReal):
293 frame,type_,eventpoint = audio_frame 293 frame,type_,eventpoint = audio_frame
294 frame = (frame * 32767).astype(np.int16) 294 frame = (frame * 32767).astype(np.int16)
295 new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) 295 new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
296 - new_frame.planes[0].update(frame.tobytes()) 296 + # 修复 BufferError: memoryview has 1 exported buffer
  297 + # 创建数据副本避免内存视图冲突
  298 + frame_bytes = bytes(frame.tobytes())
  299 + new_frame.planes[0].update(frame_bytes)
297 new_frame.sample_rate=16000 300 new_frame.sample_rate=16000
298 # if audio_track._queue.qsize()>10: 301 # if audio_track._queue.qsize()>10:
299 # time.sleep(0.1) 302 # time.sleep(0.1)
@@ -263,7 +263,41 @@ class LipReal(BaseReal): @@ -263,7 +263,41 @@ class LipReal(BaseReal):
263 #print('blending time:',time.perf_counter()-t) 263 #print('blending time:',time.perf_counter()-t)
264 264
265 image = combine_frame #(outputs['image'] * 255).astype(np.uint8) 265 image = combine_frame #(outputs['image'] * 255).astype(np.uint8)
  266 +
  267 + # Fix MemoryError: 优化内存使用和错误处理
  268 + try:
  269 + # 检查图像尺寸,如果过大则压缩
  270 + h, w = image.shape[:2]
  271 + max_dimension = 1920 # 最大尺寸限制
  272 + if h > max_dimension or w > max_dimension:
  273 + scale = max_dimension / max(h, w)
  274 + new_h, new_w = int(h * scale), int(w * scale)
  275 + image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
  276 + logger.warning(f"Image resized from {w}x{h} to {new_w}x{new_h} to prevent MemoryError")
  277 +
  278 + # 确保数据类型正确并创建连续内存布局
  279 + if not image.flags['C_CONTIGUOUS']:
  280 + image = np.ascontiguousarray(image)
  281 + image = image.astype(np.uint8)
  282 +
266 new_frame = VideoFrame.from_ndarray(image, format="bgr24") 283 new_frame = VideoFrame.from_ndarray(image, format="bgr24")
  284 + except MemoryError as e:
  285 + logger.error(f"MemoryError in VideoFrame creation: {e}, image shape: {image.shape}")
  286 + # 进一步压缩图像作为备用方案
  287 + try:
  288 + h, w = image.shape[:2]
  289 + backup_scale = 0.5
  290 + backup_h, backup_w = int(h * backup_scale), int(w * backup_scale)
  291 + image = cv2.resize(image, (backup_w, backup_h), interpolation=cv2.INTER_AREA)
  292 + image = np.ascontiguousarray(image.astype(np.uint8))
  293 + new_frame = VideoFrame.from_ndarray(image, format="bgr24")
  294 + logger.info(f"Backup resize successful: {backup_w}x{backup_h}")
  295 + except Exception as backup_e:
  296 + logger.error(f"Backup resize failed: {backup_e}")
  297 + continue
  298 + except Exception as e:
  299 + logger.error(f"Unexpected error in VideoFrame creation: {e}")
  300 + continue
267 asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop) 301 asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop)
268 self.record_video_data(image) 302 self.record_video_data(image)
269 303
@@ -271,11 +305,19 @@ class LipReal(BaseReal): @@ -271,11 +305,19 @@ class LipReal(BaseReal):
271 frame,type,eventpoint = audio_frame 305 frame,type,eventpoint = audio_frame
272 frame = (frame * 32767).astype(np.int16) 306 frame = (frame * 32767).astype(np.int16)
273 new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) 307 new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
274 - new_frame.planes[0].update(frame.tobytes())  
275 - new_frame.sample_rate=16000  
276 - # if audio_track._queue.qsize()>10:  
277 - # time.sleep(0.1)  
278 - asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) 308 +
  309 + # 修复 BufferError: 强制复制数据避免内存视图冲突
  310 + frame_copy = frame.astype(np.int16).copy()
  311 + frame_bytes = bytes(frame_copy.tobytes()) # Fix BufferError: memoryview has 1 exported buffer
  312 + new_frame.planes[0].update(frame_bytes)
  313 +
  314 + new_frame.sample_rate = 16000
  315 +
  316 + # 使用线程安全的方式提交到队列,避免闭包问题
  317 + def put_audio_frame(frame_obj, event_point):
  318 + audio_track._queue.put_nowait((frame_obj, event_point))
  319 +
  320 + loop.call_soon_threadsafe(put_audio_frame, new_frame, eventpoint)
279 self.record_audio_data(frame) 321 self.record_audio_data(frame)
280 #self.notify(eventpoint) 322 #self.notify(eventpoint)
281 logger.info('lipreal process_frames thread stop') 323 logger.info('lipreal process_frames thread stop')
1 import logging 1 import logging
  2 +import os
  3 +
  4 +# 确保日志目录存在
  5 +log_dir = "logs"
  6 +if not os.path.exists(log_dir):
  7 + os.makedirs(log_dir)
2 8
3 # 配置日志器 9 # 配置日志器
4 logger = logging.getLogger(__name__) 10 logger = logging.getLogger(__name__)
@@ -14,3 +20,46 @@ logger.addHandler(fhandler) @@ -14,3 +20,46 @@ logger.addHandler(fhandler)
14 # sformatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 20 # sformatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
15 # handler.setFormatter(sformatter) 21 # handler.setFormatter(sformatter)
16 # logger.addHandler(handler) 22 # logger.addHandler(handler)
  23 +
  24 +def get_logger(name: str, level: str = "INFO") -> logging.Logger:
  25 + """获取指定名称的日志器
  26 +
  27 + Args:
  28 + name: 日志器名称
  29 + level: 日志级别 (DEBUG, INFO, WARNING, ERROR, CRITICAL)
  30 +
  31 + Returns:
  32 + 配置好的日志器实例
  33 + """
  34 + # 创建日志器
  35 + logger_instance = logging.getLogger(name)
  36 +
  37 + # 避免重复添加处理器
  38 + if logger_instance.handlers:
  39 + return logger_instance
  40 +
  41 + # 设置日志级别
  42 + log_level = getattr(logging, level.upper(), logging.INFO)
  43 + logger_instance.setLevel(log_level)
  44 +
  45 + # 创建格式器
  46 + formatter = logging.Formatter(
  47 + '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
  48 + )
  49 +
  50 + # 文件处理器
  51 + log_file = os.path.join(log_dir, f"{name.lower()}.log")
  52 + file_handler = logging.FileHandler(log_file, encoding='utf-8')
  53 + file_handler.setFormatter(formatter)
  54 + file_handler.setLevel(log_level)
  55 + logger_instance.addHandler(file_handler)
  56 +
  57 + # 控制台处理器(可选)
  58 + console_handler = logging.StreamHandler()
  59 + console_handler.setFormatter(logging.Formatter(
  60 + '%(asctime)s - %(levelname)s - %(message)s'
  61 + ))
  62 + console_handler.setLevel(logging.WARNING) # 控制台只显示警告及以上级别
  63 + logger_instance.addHandler(console_handler)
  64 +
  65 + return logger_instance
@@ -346,7 +346,10 @@ class MuseReal(BaseReal): @@ -346,7 +346,10 @@ class MuseReal(BaseReal):
346 frame,type,eventpoint = audio_frame 346 frame,type,eventpoint = audio_frame
347 frame = (frame * 32767).astype(np.int16) 347 frame = (frame * 32767).astype(np.int16)
348 new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) 348 new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
349 - new_frame.planes[0].update(frame.tobytes()) 349 + # 修复 BufferError: memoryview has 1 exported buffer
  350 + # 创建数据副本避免内存视图冲突
  351 + frame_bytes = bytes(frame.tobytes())
  352 + new_frame.planes[0].update(frame_bytes)
350 new_frame.sample_rate=16000 353 new_frame.sample_rate=16000
351 asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) 354 asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop)
352 self.record_audio_data(frame) 355 self.record_audio_data(frame)
@@ -247,7 +247,10 @@ class NeRFReal(BaseReal): @@ -247,7 +247,10 @@ class NeRFReal(BaseReal):
247 else: #webrtc 247 else: #webrtc
248 frame = (frame * 32767).astype(np.int16) 248 frame = (frame * 32767).astype(np.int16)
249 new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) 249 new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
250 - new_frame.planes[0].update(frame.tobytes()) 250 + # 修复 BufferError: memoryview has 1 exported buffer
  251 + # 创建数据副本避免内存视图冲突
  252 + frame_bytes = bytes(frame.tobytes())
  253 + new_frame.planes[0].update(frame_bytes)
251 new_frame.sample_rate=16000 254 new_frame.sample_rate=16000
252 asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) 255 asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop)
253 256
@@ -18,7 +18,7 @@ face_alignment @@ -18,7 +18,7 @@ face_alignment
18 python_speech_features 18 python_speech_features
19 numba 19 numba
20 resampy 20 resampy
21 -#pyaudio 21 +pyaudio
22 soundfile==0.12.1 22 soundfile==0.12.1
23 einops 23 einops
24 configargparse 24 configargparse
@@ -30,6 +30,9 @@ transformers @@ -30,6 +30,9 @@ transformers
30 edge_tts 30 edge_tts
31 flask 31 flask
32 flask_sockets 32 flask_sockets
  33 +flask-socketio
  34 +websockets
  35 +websocket-client
33 opencv-python-headless 36 opencv-python-headless
34 aiortc 37 aiortc
35 aiohttp_cors 38 aiohttp_cors
@@ -41,5 +44,6 @@ accelerate @@ -41,5 +44,6 @@ accelerate
41 44
42 librosa 45 librosa
43 openai 46 openai
  47 +aiofiles
44 #判断音频类型的支持 48 #判断音频类型的支持
45 AudioSegment 49 AudioSegment
1 -#!/usr/bin/env python3  
2 -# AIfeng/2024-12-19  
3 -# 豆包模型集成测试脚本  
4 -  
5 -import os  
6 -import sys  
7 -import json  
8 -from pathlib import Path  
9 -  
10 -# 添加项目根目录到Python路径  
11 -project_root = Path(__file__).parent  
12 -sys.path.insert(0, str(project_root))  
13 -  
14 -def test_config_files():  
15 - """测试配置文件是否存在和格式正确"""  
16 - print("=== 配置文件测试 ===")  
17 -  
18 - # 测试LLM配置文件  
19 - llm_config_path = project_root / "config" / "llm_config.json"  
20 - if llm_config_path.exists():  
21 - try:  
22 - with open(llm_config_path, 'r', encoding='utf-8') as f:  
23 - llm_config = json.load(f)  
24 - print(f"✓ LLM配置文件加载成功: {llm_config_path}")  
25 - print(f" 当前模型类型: {llm_config.get('model_type', 'unknown')}")  
26 - except Exception as e:  
27 - print(f"✗ LLM配置文件格式错误: {e}")  
28 - else:  
29 - print(f"✗ LLM配置文件不存在: {llm_config_path}")  
30 -  
31 - # 测试豆包配置文件  
32 - doubao_config_path = project_root / "config" / "doubao_config.json"  
33 - if doubao_config_path.exists():  
34 - try:  
35 - with open(doubao_config_path, 'r', encoding='utf-8') as f:  
36 - doubao_config = json.load(f)  
37 - print(f"✓ 豆包配置文件加载成功: {doubao_config_path}")  
38 - print(f" 模型名称: {doubao_config.get('model', 'unknown')}")  
39 - print(f" 人物设定: {doubao_config.get('character', {}).get('name', 'unknown')}")  
40 - except Exception as e:  
41 - print(f"✗ 豆包配置文件格式错误: {e}")  
42 - else:  
43 - print(f"✗ 豆包配置文件不存在: {doubao_config_path}")  
44 -  
45 -def test_module_import():  
46 - """测试模块导入"""  
47 - print("\n=== 模块导入测试 ===")  
48 -  
49 - try:  
50 - from llm.Doubao import Doubao  
51 - print("✓ 豆包模块导入成功")  
52 - except ImportError as e:  
53 - print(f"✗ 豆包模块导入失败: {e}")  
54 - return False  
55 -  
56 - try:  
57 - import llm  
58 - print(f"✓ LLM包导入成功,可用模型: {llm.AVAILABLE_MODELS}")  
59 - except ImportError as e:  
60 - print(f"✗ LLM包导入失败: {e}")  
61 -  
62 - return True  
63 -  
64 -def test_llm_config_loading():  
65 - """测试LLM配置加载函数"""  
66 - print("\n=== LLM配置加载测试 ===")  
67 -  
68 - try:  
69 - # 模拟llm.py中的配置加载函数  
70 - config_path = project_root / "config" / "llm_config.json"  
71 - if config_path.exists():  
72 - with open(config_path, 'r', encoding='utf-8') as f:  
73 - config = json.load(f)  
74 - print(f"✓ 配置加载成功")  
75 - print(f" 模型类型: {config.get('model_type')}")  
76 - print(f" 配置项: {list(config.keys())}")  
77 - return config  
78 - else:  
79 - print("✗ 配置文件不存在,使用默认配置")  
80 - return {"model_type": "qwen"}  
81 - except Exception as e:  
82 - print(f"✗ 配置加载失败: {e}")  
83 - return {"model_type": "qwen"}  
84 -  
85 -def test_doubao_instantiation():  
86 - """测试豆包模型实例化(不需要真实API密钥)"""  
87 - print("\n=== 豆包实例化测试 ===")  
88 -  
89 - try:  
90 - from llm.Doubao import Doubao  
91 -  
92 - # 设置测试API密钥  
93 - os.environ['DOUBAO_API_KEY'] = 'test_key_for_validation'  
94 -  
95 - doubao = Doubao()  
96 - print("✓ 豆包实例化成功")  
97 - print(f" 配置文件路径: {doubao.config_file}")  
98 - print(f" API基础URL: {doubao.base_url}")  
99 - print(f" 模型名称: {doubao.model}")  
100 -  
101 - # 清理测试环境变量  
102 - if 'DOUBAO_API_KEY' in os.environ:  
103 - del os.environ['DOUBAO_API_KEY']  
104 -  
105 - return True  
106 - except Exception as e:  
107 - print(f"✗ 豆包实例化失败: {e}")  
108 - return False  
109 -  
110 -def test_integration_flow():  
111 - """测试完整集成流程"""  
112 - print("\n=== 集成流程测试 ===")  
113 -  
114 - try:  
115 - # 模拟llm.py中的流程  
116 - config = test_llm_config_loading()  
117 - model_type = config.get("model_type", "qwen")  
118 -  
119 - print(f"根据配置选择模型: {model_type}")  
120 -  
121 - if model_type == "doubao":  
122 - print("✓ 将使用豆包模型处理请求")  
123 - elif model_type == "qwen":  
124 - print("✓ 将使用通义千问模型处理请求")  
125 - else:  
126 - print(f"⚠ 未知模型类型: {model_type}")  
127 -  
128 - return True  
129 - except Exception as e:  
130 - print(f"✗ 集成流程测试失败: {e}")  
131 - return False  
132 -  
133 -def main():  
134 - """主测试函数"""  
135 - print("豆包模型集成测试")  
136 - print("=" * 50)  
137 -  
138 - # 运行所有测试  
139 - test_config_files()  
140 -  
141 - if not test_module_import():  
142 - print("\n模块导入失败,停止测试")  
143 - return  
144 -  
145 - test_llm_config_loading()  
146 - test_doubao_instantiation()  
147 - test_integration_flow()  
148 -  
149 - print("\n=== 测试总结 ===")  
150 - print("✓ 豆包模型已成功集成到项目中")  
151 - print("✓ 配置文件结构正确")  
152 - print("✓ 模块导入正常")  
153 - print("\n使用说明:")  
154 - print("1. 设置环境变量 DOUBAO_API_KEY 为您的豆包API密钥")  
155 - print("2. 在 config/llm_config.json 中设置 model_type 为 'doubao'")  
156 - print("3. 根据需要修改 config/doubao_config.json 中的人物设定")  
157 - print("4. 重启应用即可使用豆包模型")  
158 -  
159 -if __name__ == "__main__":  
160 - main()  
1 -# AIfeng/2025-01-27  
2 -"""  
3 -FunASR服务连接测试脚本  
4 -用于验证本地FunASR WebSocket服务是否可以正常连接  
5 -  
6 -使用方法:  
7 -1. 先启动FunASR服务:python -u web/asr/funasr/ASR_server.py --host "127.0.0.1" --port 10197 --ngpu 0  
8 -2. 运行此测试脚本:python test_funasr_connection.py  
9 -"""  
10 -  
11 -import asyncio  
12 -import websockets  
13 -import json  
14 -import os  
15 -import wave  
16 -import numpy as np  
17 -from pathlib import Path  
18 -  
19 -class FunASRConnectionTest:  
20 - def __init__(self, host="127.0.0.1", port=10197):  
21 - self.host = host  
22 - self.port = port  
23 - self.uri = f"ws://{host}:{port}"  
24 -  
25 - async def test_basic_connection(self):  
26 - """测试基本WebSocket连接"""  
27 - print(f"🔍 测试连接到 {self.uri}")  
28 - try:  
29 - async with websockets.connect(self.uri) as websocket:  
30 - print("✅ FunASR WebSocket服务连接成功")  
31 - return True  
32 - except ConnectionRefusedError:  
33 - print("❌ 连接被拒绝,请确认FunASR服务已启动")  
34 - print(" 启动命令: python -u web/asr/funasr/ASR_server.py --host \"127.0.0.1\" --port 10197 --ngpu 0")  
35 - return False  
36 - except Exception as e:  
37 - print(f"❌ 连接失败: {e}")  
38 - return False  
39 -  
40 - def create_test_wav(self, filename="test_audio.wav", duration=2, sample_rate=16000):  
41 - """创建测试用的WAV文件"""  
42 - # 生成简单的正弦波音频  
43 - t = np.linspace(0, duration, int(sample_rate * duration), False)  
44 - frequency = 440 # A4音符  
45 - audio_data = np.sin(2 * np.pi * frequency * t) * 0.3  
46 -  
47 - # 转换为16位整数  
48 - audio_data = (audio_data * 32767).astype(np.int16)  
49 -  
50 - # 保存为WAV文件  
51 - with wave.open(filename, 'wb') as wav_file:  
52 - wav_file.setnchannels(1) # 单声道  
53 - wav_file.setsampwidth(2) # 16位  
54 - wav_file.setframerate(sample_rate)  
55 - wav_file.writeframes(audio_data.tobytes())  
56 -  
57 - print(f"📁 创建测试音频文件: {filename}")  
58 - return filename  
59 -  
60 - async def test_audio_recognition(self):  
61 - """测试音频识别功能"""  
62 - print("\n🎵 测试音频识别功能")  
63 -  
64 - # 创建测试音频文件  
65 - test_file = self.create_test_wav()  
66 - test_file_path = os.path.abspath(test_file)  
67 -  
68 - try:  
69 - async with websockets.connect(self.uri) as websocket:  
70 - print("✅ 连接成功,发送音频文件路径")  
71 -  
72 - # 发送音频文件路径  
73 - message = {"url": test_file_path}  
74 - await websocket.send(json.dumps(message))  
75 - print(f"📤 发送消息: {message}")  
76 -  
77 - # 等待识别结果  
78 - try:  
79 - response = await asyncio.wait_for(websocket.recv(), timeout=10)  
80 - print(f"📥 收到识别结果: {response}")  
81 - return True  
82 - except asyncio.TimeoutError:  
83 - print("⏰ 等待响应超时(10秒)")  
84 - print(" 这可能是正常的,因为测试音频是纯音调,无法识别为文字")  
85 - return True # 超时也算连接成功  
86 -  
87 - except Exception as e:  
88 - print(f"❌ 音频识别测试失败: {e}")  
89 - return False  
90 - finally:  
91 - # 清理测试文件  
92 - if os.path.exists(test_file):  
93 - os.remove(test_file)  
94 - print(f"🗑️ 清理测试文件: {test_file}")  
95 -  
96 - async def test_real_audio_files(self):  
97 - """测试实际音频文件的识别效果"""  
98 - print("\n🎤 测试实际音频文件识别")  
99 -  
100 - # 实际音频文件列表  
101 - audio_files = [  
102 - "yunxi.mp3",  
103 - "yunxia.mp3",  
104 - "yunyang.mp3"  
105 - ]  
106 -  
107 - results = []  
108 -  
109 - for audio_file in audio_files:  
110 - file_path = os.path.abspath(audio_file)  
111 -  
112 - # 检查文件是否存在  
113 - if not os.path.exists(file_path):  
114 - print(f"⚠️ 音频文件不存在: {file_path}")  
115 - continue  
116 -  
117 - print(f"\n🎵 测试音频文件: {audio_file}")  
118 -  
119 - try:  
120 - async with websockets.connect(self.uri) as websocket:  
121 - print(f"✅ 连接成功,发送音频文件: {audio_file}")  
122 -  
123 - # 发送音频文件路径  
124 - message = {"url": file_path}  
125 - await websocket.send(json.dumps(message))  
126 - print(f"📤 发送消息: {message}")  
127 -  
128 - # 等待识别结果  
129 - try:  
130 - response = await asyncio.wait_for(websocket.recv(), timeout=30)  
131 - print(f"📥 识别结果: {response}")  
132 -  
133 - # 解析响应  
134 - try:  
135 - result_data = json.loads(response)  
136 - if isinstance(result_data, dict) and 'text' in result_data:  
137 - recognized_text = result_data['text']  
138 - print(f"🎯 识别文本: {recognized_text}")  
139 - results.append({  
140 - 'file': audio_file,  
141 - 'text': recognized_text,  
142 - 'status': 'success'  
143 - })  
144 - else:  
145 - print(f"📄 原始响应: {response}")  
146 - results.append({  
147 - 'file': audio_file,  
148 - 'response': response,  
149 - 'status': 'received'  
150 - })  
151 - except json.JSONDecodeError:  
152 - print(f"📄 非JSON响应: {response}")  
153 - results.append({  
154 - 'file': audio_file,  
155 - 'response': response,  
156 - 'status': 'received'  
157 - })  
158 -  
159 - except asyncio.TimeoutError:  
160 - print(f"⏰ 等待响应超时(30秒)- {audio_file}")  
161 - results.append({  
162 - 'file': audio_file,  
163 - 'status': 'timeout'  
164 - })  
165 -  
166 - except Exception as e:  
167 - print(f"❌ 测试 {audio_file} 失败: {e}")  
168 - results.append({  
169 - 'file': audio_file,  
170 - 'error': str(e),  
171 - 'status': 'error'  
172 - })  
173 -  
174 - # 文件间等待,避免服务器压力  
175 - await asyncio.sleep(1)  
176 -  
177 - # 输出测试总结  
178 - print("\n" + "="*50)  
179 - print("📊 实际音频文件测试总结:")  
180 - for i, result in enumerate(results, 1):  
181 - print(f"\n{i}. 文件: {result['file']}")  
182 - if result['status'] == 'success':  
183 - print(f" ✅ 识别成功: {result['text']}")  
184 - elif result['status'] == 'received':  
185 - print(f" 📥 收到响应: {result.get('response', 'N/A')}")  
186 - elif result['status'] == 'timeout':  
187 - print(f" ⏰ 响应超时")  
188 - elif result['status'] == 'error':  
189 - print(f" ❌ 测试失败: {result.get('error', 'N/A')}")  
190 -  
191 - return len(results) > 0  
192 -  
193 - async def test_message_format(self):  
194 - """测试消息格式兼容性"""  
195 - print("\n📋 测试消息格式兼容性")  
196 -  
197 - try:  
198 - async with websockets.connect(self.uri) as websocket:  
199 - # 测试不同的消息格式  
200 - test_messages = [  
201 - {"url": "nonexistent.wav"},  
202 - {"test": "message"},  
203 - "invalid_json"  
204 - ]  
205 -  
206 - for i, msg in enumerate(test_messages, 1):  
207 - try:  
208 - if isinstance(msg, dict):  
209 - await websocket.send(json.dumps(msg))  
210 - print(f"✅ 消息 {i} 发送成功: {msg}")  
211 - else:  
212 - await websocket.send(msg)  
213 - print(f"✅ 消息 {i} 发送成功: {msg}")  
214 -  
215 - # 短暂等待,避免消息堆积  
216 - await asyncio.sleep(0.5)  
217 -  
218 - except Exception as e:  
219 - print(f"⚠️ 消息 {i} 发送失败: {e}")  
220 -  
221 - return True  
222 -  
223 - except Exception as e:  
224 - print(f"❌ 消息格式测试失败: {e}")  
225 - return False  
226 -  
227 - def check_dependencies(self):  
228 - """检查依赖项"""  
229 - print("🔍 检查依赖项...")  
230 -  
231 - required_modules = [  
232 - 'websockets',  
233 - 'asyncio',  
234 - 'json',  
235 - 'wave',  
236 - 'numpy'  
237 - ]  
238 -  
239 - missing_modules = []  
240 - for module in required_modules:  
241 - try:  
242 - __import__(module)  
243 - print(f"✅ {module}")  
244 - except ImportError:  
245 - print(f"❌ {module} (缺失)")  
246 - missing_modules.append(module)  
247 -  
248 - if missing_modules:  
249 - print(f"\n⚠️ 缺失依赖项: {', '.join(missing_modules)}")  
250 - print("安装命令: pip install " + ' '.join(missing_modules))  
251 - return False  
252 -  
253 - print("✅ 所有依赖项检查通过")  
254 - return True  
255 -  
256 - def check_funasr_server_file(self):  
257 - """检查FunASR服务器文件是否存在"""  
258 - print("\n📁 检查FunASR服务器文件...")  
259 -  
260 - server_path = Path("web/asr/funasr/ASR_server.py")  
261 - if server_path.exists():  
262 - print(f"✅ 找到服务器文件: {server_path.absolute()}")  
263 - return True  
264 - else:  
265 - print(f"❌ 未找到服务器文件: {server_path.absolute()}")  
266 - print(" 请确认文件路径是否正确")  
267 - return False  
268 -  
269 - async def run_all_tests(self):  
270 - """运行所有测试"""  
271 - print("🚀 开始FunASR连接测试\n")  
272 -  
273 - # 检查依赖  
274 - if not self.check_dependencies():  
275 - return False  
276 -  
277 - # 检查服务器文件  
278 - if not self.check_funasr_server_file():  
279 - return False  
280 -  
281 - # 基本连接测试  
282 - print("\n" + "="*50)  
283 - if not await self.test_basic_connection():  
284 - return False  
285 -  
286 - # 音频识别测试  
287 - print("\n" + "="*50)  
288 - if not await self.test_audio_recognition():  
289 - return False  
290 -  
291 - # 实际音频文件测试  
292 - print("\n" + "="*50)  
293 - await self.test_real_audio_files()  
294 -  
295 - # 消息格式测试  
296 - print("\n" + "="*50)  
297 - if not await self.test_message_format():  
298 - return False  
299 -  
300 - print("\n" + "="*50)  
301 - print("🎉 所有测试完成!FunASR服务连接正常")  
302 - print("\n💡 集成建议:")  
303 - print(" 1. 服务使用WebSocket协议,非gRPC")  
304 - print(" 2. 默认监听端口: 10197")  
305 - print(" 3. 消息格式: JSON字符串,包含'url'字段指向音频文件路径")  
306 - print(" 4. 可以集成到现有项目的ASR模块中")  
307 -  
308 - return True  
309 -  
310 -async def main():  
311 - """主函数"""  
312 - tester = FunASRConnectionTest()  
313 - success = await tester.run_all_tests()  
314 -  
315 - if not success:  
316 - print("\n❌ 测试失败,请检查FunASR服务状态")  
317 - return 1  
318 -  
319 - return 0  
320 -  
321 -if __name__ == "__main__":  
322 - try:  
323 - exit_code = asyncio.run(main())  
324 - exit(exit_code)  
325 - except KeyboardInterrupt:  
326 - print("\n⏹️ 测试被用户中断")  
327 - exit(1)  
328 - except Exception as e:  
329 - print(f"\n💥 测试过程中发生错误: {e}")  
330 - exit(1)  
1 -# -*- coding: utf-8 -*-  
2 -"""  
3 -AIfeng/2025-01-27  
4 -FunASR集成测试脚本  
5 -测试新的FunASRClient与项目的集成效果  
6 -"""  
7 -  
8 -import os  
9 -import sys  
10 -import time  
11 -import threading  
12 -from pathlib import Path  
13 -  
14 -# 添加项目路径  
15 -sys.path.append(os.path.dirname(__file__))  
16 -  
17 -from funasr_asr import FunASRClient  
18 -from web.asr.funasr import FunASR  
19 -import util  
20 -  
21 -class TestFunASRIntegration:  
22 - """FunASR集成测试类"""  
23 -  
24 - def __init__(self):  
25 - self.test_results = []  
26 - self.test_audio_files = [  
27 - "yunxi.mp3",  
28 - "yunxia.mp3",  
29 - "yunyang.mp3"  
30 - ]  
31 -  
32 - def log_test_result(self, test_name: str, success: bool, message: str = ""):  
33 - """记录测试结果"""  
34 - status = "✓ 通过" if success else "✗ 失败"  
35 - result = f"[{status}] {test_name}"  
36 - if message:  
37 - result += f" - {message}"  
38 -  
39 - self.test_results.append((test_name, success, message))  
40 - print(result)  
41 -  
42 - def test_funasr_client_creation(self):  
43 - """测试FunASRClient创建"""  
44 - try:  
45 - class SimpleOpt:  
46 - def __init__(self):  
47 - self.username = "test_user"  
48 -  
49 - opt = SimpleOpt()  
50 - client = FunASRClient(opt)  
51 -  
52 - # 检查基本属性  
53 - assert hasattr(client, 'server_url')  
54 - assert hasattr(client, 'connected')  
55 - assert hasattr(client, 'running')  
56 -  
57 - self.log_test_result("FunASRClient创建", True, "客户端创建成功")  
58 - return client  
59 -  
60 - except Exception as e:  
61 - self.log_test_result("FunASRClient创建", False, f"错误: {e}")  
62 - return None  
63 -  
64 - def test_compatibility_wrapper(self):  
65 - """测试兼容性包装器"""  
66 - try:  
67 - funasr = FunASR("test_user")  
68 -  
69 - # 检查兼容性方法  
70 - assert hasattr(funasr, 'start')  
71 - assert hasattr(funasr, 'end')  
72 - assert hasattr(funasr, 'send')  
73 - assert hasattr(funasr, 'add_frame')  
74 - assert hasattr(funasr, 'set_message_callback')  
75 -  
76 - self.log_test_result("兼容性包装器", True, "所有兼容性方法存在")  
77 - return funasr  
78 -  
79 - except Exception as e:  
80 - self.log_test_result("兼容性包装器", False, f"错误: {e}")  
81 - return None  
82 -  
83 - def test_callback_mechanism(self):  
84 - """测试回调机制"""  
85 - try:  
86 - funasr = FunASR("test_user")  
87 - callback_called = threading.Event()  
88 - received_message = []  
89 -  
90 - def test_callback(message):  
91 - received_message.append(message)  
92 - callback_called.set()  
93 -  
94 - funasr.set_message_callback(test_callback)  
95 -  
96 - # 模拟接收消息  
97 - test_message = "测试识别结果"  
98 - funasr._handle_result(test_message)  
99 -  
100 - # 等待回调  
101 - if callback_called.wait(timeout=1.0):  
102 - if received_message and received_message[0] == test_message:  
103 - self.log_test_result("回调机制", True, "回调函数正常工作")  
104 - else:  
105 - self.log_test_result("回调机制", False, "回调消息不匹配")  
106 - else:  
107 - self.log_test_result("回调机制", False, "回调超时")  
108 -  
109 - except Exception as e:  
110 - self.log_test_result("回调机制", False, f"错误: {e}")  
111 -  
112 - def test_audio_file_existence(self):  
113 - """测试音频文件存在性"""  
114 - existing_files = []  
115 - missing_files = []  
116 -  
117 - for audio_file in self.test_audio_files:  
118 - if os.path.exists(audio_file):  
119 - existing_files.append(audio_file)  
120 - else:  
121 - missing_files.append(audio_file)  
122 -  
123 - if existing_files:  
124 - self.log_test_result(  
125 - "音频文件检查",  
126 - True,  
127 - f"找到 {len(existing_files)} 个文件: {', '.join(existing_files)}"  
128 - )  
129 -  
130 - if missing_files:  
131 - self.log_test_result(  
132 - "音频文件缺失",  
133 - False,  
134 - f"缺少 {len(missing_files)} 个文件: {', '.join(missing_files)}"  
135 - )  
136 -  
137 - return existing_files  
138 -  
139 - def test_connection_simulation(self):  
140 - """测试连接模拟"""  
141 - try:  
142 - client = self.test_funasr_client_creation()  
143 - if not client:  
144 - return  
145 -  
146 - # 测试启动和停止  
147 - client.start()  
148 - time.sleep(0.5) # 给连接一些时间  
149 -  
150 - # 检查运行状态  
151 - if client.running:  
152 - self.log_test_result("客户端启动", True, "客户端成功启动")  
153 - else:  
154 - self.log_test_result("客户端启动", False, "客户端启动失败")  
155 -  
156 - # 停止客户端  
157 - client.stop()  
158 - time.sleep(0.5)  
159 -  
160 - if not client.running:  
161 - self.log_test_result("客户端停止", True, "客户端成功停止")  
162 - else:  
163 - self.log_test_result("客户端停止", False, "客户端停止失败")  
164 -  
165 - except Exception as e:  
166 - self.log_test_result("连接模拟", False, f"错误: {e}")  
167 -  
168 - def test_message_queue(self):  
169 - """测试消息队列"""  
170 - try:  
171 - client = self.test_funasr_client_creation()  
172 - if not client:  
173 - return  
174 -  
175 - # 测试消息入队  
176 - test_message = {"test": "message"}  
177 - client.message_queue.put(test_message)  
178 -  
179 - # 检查队列  
180 - if not client.message_queue.empty():  
181 - retrieved_message = client.message_queue.get_nowait()  
182 - if retrieved_message == test_message:  
183 - self.log_test_result("消息队列", True, "消息队列正常工作")  
184 - else:  
185 - self.log_test_result("消息队列", False, "消息内容不匹配")  
186 - else:  
187 - self.log_test_result("消息队列", False, "消息队列为空")  
188 -  
189 - except Exception as e:  
190 - self.log_test_result("消息队列", False, f"错误: {e}")  
191 -  
192 - def test_config_loading(self):  
193 - """测试配置加载"""  
194 - try:  
195 - import config_util as cfg  
196 -  
197 - # 检查关键配置项  
198 - required_configs = [  
199 - 'local_asr_ip',  
200 - 'local_asr_port',  
201 - 'asr_timeout',  
202 - 'asr_reconnect_delay',  
203 - 'asr_max_reconnect_attempts'  
204 - ]  
205 -  
206 - missing_configs = []  
207 - for config_key in required_configs:  
208 - try:  
209 - if hasattr(cfg, 'config'):  
210 - value = cfg.config.get(config_key)  
211 - else:  
212 - value = getattr(cfg, config_key, None)  
213 - if value is None:  
214 - missing_configs.append(config_key)  
215 - except:  
216 - missing_configs.append(config_key)  
217 -  
218 - if not missing_configs:  
219 - self.log_test_result("配置加载", True, "所有必需配置项存在")  
220 - else:  
221 - self.log_test_result(  
222 - "配置加载",  
223 - False,  
224 - f"缺少配置项: {', '.join(missing_configs)}"  
225 - )  
226 -  
227 - except Exception as e:  
228 - self.log_test_result("配置加载", False, f"错误: {e}")  
229 -  
230 - def run_all_tests(self):  
231 - """运行所有测试"""  
232 - print("\n" + "="*60)  
233 - print("FunASR集成测试开始")  
234 - print("="*60)  
235 -  
236 - # 运行各项测试  
237 - self.test_config_loading()  
238 - self.test_funasr_client_creation()  
239 - self.test_compatibility_wrapper()  
240 - self.test_callback_mechanism()  
241 - self.test_message_queue()  
242 - self.test_audio_file_existence()  
243 - self.test_connection_simulation()  
244 -  
245 - # 输出测试总结  
246 - print("\n" + "="*60)  
247 - print("测试总结")  
248 - print("="*60)  
249 -  
250 - passed_tests = sum(1 for _, success, _ in self.test_results if success)  
251 - total_tests = len(self.test_results)  
252 -  
253 - print(f"总测试数: {total_tests}")  
254 - print(f"通过测试: {passed_tests}")  
255 - print(f"失败测试: {total_tests - passed_tests}")  
256 - print(f"成功率: {passed_tests/total_tests*100:.1f}%")  
257 -  
258 - # 显示失败的测试  
259 - failed_tests = [(name, msg) for name, success, msg in self.test_results if not success]  
260 - if failed_tests:  
261 - print("\n失败的测试:")  
262 - for name, msg in failed_tests:  
263 - print(f" - {name}: {msg}")  
264 -  
265 - print("\n" + "="*60)  
266 -  
267 - return passed_tests == total_tests  
268 -  
269 -def main():  
270 - """主函数"""  
271 - tester = TestFunASRIntegration()  
272 - success = tester.run_all_tests()  
273 -  
274 - if success:  
275 - print("\n🎉 所有测试通过!FunASR集成准备就绪。")  
276 - else:  
277 - print("\n⚠️ 部分测试失败,请检查相关配置和依赖。")  
278 -  
279 - return 0 if success else 1  
280 -  
281 -if __name__ == "__main__":  
282 - exit(main())  
1 -#!/usr/bin/env python3  
2 -# AIfeng/2024-12-19  
3 -# WebSocket通信测试服务器  
4 -  
5 -import asyncio  
6 -import json  
7 -import time  
8 -import weakref  
9 -from aiohttp import web, WSMsgType  
10 -import aiohttp_cors  
11 -from typing import Dict  
12 -  
13 -# 全局变量  
14 -websocket_connections: Dict[int, weakref.WeakSet] = {} # sessionid:websocket_connections  
15 -  
16 -# WebSocket消息推送函数  
17 -async def broadcast_message_to_session(sessionid: int, message_type: str, content: str, source: str = "测试服务器"):  
18 - """向指定会话的所有WebSocket连接推送消息"""  
19 - if sessionid not in websocket_connections:  
20 - print(f'[SessionID:{sessionid}] No WebSocket connections found')  
21 - return  
22 -  
23 - message = {  
24 - "type": "chat_message",  
25 - "data": {  
26 - "sessionid": sessionid,  
27 - "message_type": message_type,  
28 - "content": content,  
29 - "source": source,  
30 - "timestamp": time.time()  
31 - }  
32 - }  
33 -  
34 - # 获取该会话的所有WebSocket连接  
35 - connections = list(websocket_connections[sessionid])  
36 - print(f'[SessionID:{sessionid}] Broadcasting to {len(connections)} connections')  
37 -  
38 - # 向所有连接发送消息  
39 - for ws in connections:  
40 - try:  
41 - if not ws.closed:  
42 - await ws.send_str(json.dumps(message))  
43 - print(f'[SessionID:{sessionid}] Message sent to WebSocket: {message_type}')  
44 - except Exception as e:  
45 - print(f'[SessionID:{sessionid}] Failed to send WebSocket message: {e}')  
46 -  
47 -# WebSocket处理器  
48 -async def websocket_handler(request):  
49 - """处理WebSocket连接"""  
50 - ws = web.WebSocketResponse()  
51 - await ws.prepare(request)  
52 -  
53 - sessionid = None  
54 - print('New WebSocket connection established')  
55 -  
56 - try:  
57 - async for msg in ws:  
58 - if msg.type == WSMsgType.TEXT:  
59 - try:  
60 - data = json.loads(msg.data)  
61 - print(f'Received WebSocket message: {data}')  
62 -  
63 - if data.get('type') == 'login':  
64 - sessionid = data.get('sessionid', 0)  
65 -  
66 - # 初始化该会话的WebSocket连接集合  
67 - if sessionid not in websocket_connections:  
68 - websocket_connections[sessionid] = weakref.WeakSet()  
69 -  
70 - # 添加当前连接到会话  
71 - websocket_connections[sessionid].add(ws)  
72 -  
73 - print(f'[SessionID:{sessionid}] WebSocket client logged in')  
74 -  
75 - # 发送登录确认  
76 - await ws.send_str(json.dumps({  
77 - "type": "login_success",  
78 - "sessionid": sessionid,  
79 - "message": "WebSocket连接成功"  
80 - }))  
81 -  
82 - elif data.get('type') == 'ping':  
83 - # 心跳检测  
84 - await ws.send_str(json.dumps({"type": "pong"}))  
85 - print('Sent pong response')  
86 -  
87 - except json.JSONDecodeError:  
88 - print('Invalid JSON received from WebSocket')  
89 - except Exception as e:  
90 - print(f'Error processing WebSocket message: {e}')  
91 -  
92 - elif msg.type == WSMsgType.ERROR:  
93 - print(f'WebSocket error: {ws.exception()}')  
94 - break  
95 -  
96 - except Exception as e:  
97 - print(f'WebSocket connection error: {e}')  
98 - finally:  
99 - if sessionid is not None:  
100 - print(f'[SessionID:{sessionid}] WebSocket connection closed')  
101 - else:  
102 - print('WebSocket connection closed')  
103 -  
104 - return ws  
105 -  
106 -# 模拟human接口  
107 -async def human(request):  
108 - try:  
109 - params = await request.json()  
110 - sessionid = params.get('sessionid', 0)  
111 - user_message = params.get('text', '')  
112 - message_type = params.get('type', 'echo')  
113 -  
114 - print(f'[SessionID:{sessionid}] Received {message_type} message: {user_message}')  
115 -  
116 - # 推送用户消息到WebSocket  
117 - await broadcast_message_to_session(sessionid, message_type, user_message, "用户")  
118 -  
119 - if message_type == 'echo':  
120 - # 推送回音消息到WebSocket  
121 - await broadcast_message_to_session(sessionid, 'echo', user_message, "回音")  
122 -  
123 - elif message_type == 'chat':  
124 - # 模拟AI回复  
125 - ai_response = f"这是对 '{user_message}' 的AI回复"  
126 - await broadcast_message_to_session(sessionid, 'chat', ai_response, "AI助手")  
127 -  
128 - return web.Response(  
129 - content_type="application/json",  
130 - text=json.dumps(  
131 - {"code": 0, "data": "ok", "message": "消息已处理并推送"}  
132 - ),  
133 - )  
134 - except Exception as e:  
135 - print(f'Error in human endpoint: {e}')  
136 - return web.Response(  
137 - content_type="application/json",  
138 - text=json.dumps(  
139 - {"code": -1, "msg": str(e)}  
140 - ),  
141 - )  
142 -  
143 -# 创建应用  
144 -def create_app():  
145 - app = web.Application()  
146 -  
147 - # 添加路由  
148 - app.router.add_post("/human", human)  
149 - app.router.add_get("/ws", websocket_handler)  
150 - app.router.add_static('/', path='web')  
151 -  
152 - # 配置CORS  
153 - cors = aiohttp_cors.setup(app, defaults={  
154 - "*": aiohttp_cors.ResourceOptions(  
155 - allow_credentials=True,  
156 - expose_headers="*",  
157 - allow_headers="*",  
158 - )  
159 - })  
160 -  
161 - # 为所有路由配置CORS  
162 - for route in list(app.router.routes()):  
163 - cors.add(route)  
164 -  
165 - return app  
166 -  
167 -if __name__ == '__main__':  
168 - app = create_app()  
169 - print('Starting WebSocket test server on http://localhost:8000')  
170 - print('WebSocket endpoint: ws://localhost:8000/ws')  
171 - print('HTTP endpoint: http://localhost:8000/human')  
172 - print('Test page: http://localhost:8000/websocket_test.html')  
173 -  
174 - web.run_app(app, host='0.0.0.0', port=8000)  
@@ -9,7 +9,7 @@ import _thread as thread @@ -9,7 +9,7 @@ import _thread as thread
9 from aliyunsdkcore.client import AcsClient 9 from aliyunsdkcore.client import AcsClient
10 from aliyunsdkcore.request import CommonRequest 10 from aliyunsdkcore.request import CommonRequest
11 11
12 -from core import wsa_server 12 +from core import get_web_instance, get_instance
13 from scheduler.thread_manager import MyThread 13 from scheduler.thread_manager import MyThread
14 from utils import util 14 from utils import util
15 from utils import config_util as cfg 15 from utils import config_util as cfg
@@ -92,19 +92,37 @@ class ALiNls: @@ -92,19 +92,37 @@ class ALiNls:
92 if name == 'SentenceEnd': 92 if name == 'SentenceEnd':
93 self.done = True 93 self.done = True
94 self.finalResults = data['payload']['result'] 94 self.finalResults = data['payload']['result']
95 - if wsa_server.get_web_instance().is_connected(self.username):  
96 - wsa_server.get_web_instance().add_cmd({"panelMsg": self.finalResults, "Username" : self.username})  
97 - if wsa_server.get_instance().is_connected_human(self.username):  
98 - content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': self.finalResults}, 'Username' : self.username}  
99 - wsa_server.get_instance().add_cmd(content) 95 + if get_web_instance().is_connected(self.username):
  96 + import asyncio
  97 + # 创建chat_message直接推送
  98 + chat_message = {
  99 + "type": "chat_message",
  100 + "sender": "回音",
  101 + "content": self.finalResults, # 修复字段名:panelMsg -> content
  102 + "Username": self.username,
  103 + "model_info": "ALiNls"
  104 + }
  105 + # 使用直接发送方法,避免wsa_command封装
  106 + asyncio.create_task(get_web_instance().send_direct_message(chat_message))
  107 + # Human客户端通知改为日志记录(避免重复通知当前服务)
  108 + util.log(1, f"ALiNls识别结果[{self.username}]: {self.finalResults}")
100 ws.close()#TODO 109 ws.close()#TODO
101 elif name == 'TranscriptionResultChanged': 110 elif name == 'TranscriptionResultChanged':
102 self.finalResults = data['payload']['result'] 111 self.finalResults = data['payload']['result']
103 - if wsa_server.get_web_instance().is_connected(self.username):  
104 - wsa_server.get_web_instance().add_cmd({"panelMsg": self.finalResults, "Username" : self.username})  
105 - if wsa_server.get_instance().is_connected_human(self.username):  
106 - content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': self.finalResults}, 'Username' : self.username}  
107 - wsa_server.get_instance().add_cmd(content) 112 + if get_web_instance().is_connected(self.username):
  113 + import asyncio
  114 + # 创建chat_message直接推送
  115 + chat_message = {
  116 + "type": "chat_message",
  117 + "sender": "回音",
  118 + "content": self.finalResults, # 修复字段名:panelMsg -> content
  119 + "Username": self.username,
  120 + "model_info": "ALiNls"
  121 + }
  122 + # 使用直接发送方法,避免wsa_command封装
  123 + asyncio.create_task(get_web_instance().send_direct_message(chat_message))
  124 + # Human客户端通知改为日志记录(避免重复通知当前服务)
  125 + util.log(1, f"ALiNls识别变化[{self.username}]: {self.finalResults}")
108 126
109 except Exception as e: 127 except Exception as e:
110 print(e) 128 print(e)
@@ -14,13 +14,13 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) @@ -14,13 +14,13 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
14 from funasr_asr import FunASRClient 14 from funasr_asr import FunASRClient
15 # 修复导入路径 15 # 修复导入路径
16 try: 16 try:
17 - from core import wsa_server 17 + from core import get_web_instance, get_instance
18 except ImportError: 18 except ImportError:
19 - # 如果core模块不存在,创建一个模拟的wsa_server  
20 - class MockWSAServer:  
21 - def get_web_instance(self): 19 + # 如果core模块不存在,创建一个模拟的函数
  20 + def get_web_instance():
22 return MockWebInstance() 21 return MockWebInstance()
23 - def get_instance(self): 22 +
  23 + def get_instance():
24 return MockInstance() 24 return MockInstance()
25 25
26 class MockWebInstance: 26 class MockWebInstance:
@@ -35,8 +35,6 @@ except ImportError: @@ -35,8 +35,6 @@ except ImportError:
35 def add_cmd(self, cmd): 35 def add_cmd(self, cmd):
36 print(f"Mock Human: {cmd}") 36 print(f"Mock Human: {cmd}")
37 37
38 - wsa_server = MockWSAServer()  
39 -  
40 try: 38 try:
41 from utils import config_util as cfg 39 from utils import config_util as cfg
42 except ImportError: 40 except ImportError:
@@ -92,11 +90,20 @@ class FunASR: @@ -92,11 +90,20 @@ class FunASR:
92 if self.on_message_callback: 90 if self.on_message_callback:
93 self.on_message_callback(message) 91 self.on_message_callback(message)
94 92
95 - if wsa_server.get_web_instance().is_connected(self.username):  
96 - wsa_server.get_web_instance().add_cmd({"panelMsg": self.finalResults, "Username" : self.username})  
97 - if wsa_server.get_instance().is_connected_human(self.username):  
98 - content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': self.finalResults}, 'Username' : self.username}  
99 - wsa_server.get_instance().add_cmd(content) 93 + if get_web_instance().is_connected(self.username):
  94 + import asyncio
  95 + # 创建chat_message直接推送
  96 + chat_message = {
  97 + "type": "chat_message",
  98 + "sender": "回音",
  99 + "content": self.finalResults, # 修复字段名:panelMsg -> content
  100 + "Username": self.username,
  101 + "model_info": "FunASR"
  102 + }
  103 + # 使用直接发送方法,避免wsa_command封装
  104 + asyncio.create_task(get_web_instance().send_direct_message(chat_message))
  105 + # Human客户端通知改为日志记录(避免重复通知当前服务)
  106 + util.log(1, f"FunASR识别结果[{self.username}]: {self.finalResults}")
100 except Exception as e: 107 except Exception as e:
101 print(e) 108 print(e)
102 109
@@ -20,25 +20,42 @@ args = parser.parse_args() @@ -20,25 +20,42 @@ args = parser.parse_args()
20 20
21 # 初始化模型 21 # 初始化模型
22 print("model loading") 22 print("model loading")
23 -asr_model = AutoModel(model="paraformer-zh", model_revision="v2.0.4", 23 +try:
  24 + asr_model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
24 vad_model="fsmn-vad", vad_model_revision="v2.0.4", 25 vad_model="fsmn-vad", vad_model_revision="v2.0.4",
25 punc_model="ct-punc-c", punc_model_revision="v2.0.4", 26 punc_model="ct-punc-c", punc_model_revision="v2.0.4",
26 device=f"cuda:{args.gpu_id}" if args.ngpu else "cpu", disable_update=True) 27 device=f"cuda:{args.gpu_id}" if args.ngpu else "cpu", disable_update=True)
27 # ,disable_update=True 28 # ,disable_update=True
28 -print("model loaded") 29 + print("model loaded")
  30 +except Exception as e:
  31 + print(f"模型加载失败: {e}")
  32 + import traceback
  33 + traceback.print_exc()
  34 + exit(1)
29 websocket_users = {} 35 websocket_users = {}
30 task_queue = asyncio.Queue() 36 task_queue = asyncio.Queue()
  37 +# 分块会话管理
  38 +chunk_sessions = {} # {user_id: {filename, chunks, total_chunks, received_chunks, temp_file}}
31 39
32 async def ws_serve(websocket, path): 40 async def ws_serve(websocket, path):
33 - global websocket_users 41 + global websocket_users, chunk_sessions
34 user_id = id(websocket) 42 user_id = id(websocket)
35 websocket_users[user_id] = websocket 43 websocket_users[user_id] = websocket
36 try: 44 try:
37 async for message in websocket: 45 async for message in websocket:
38 if isinstance(message, str): 46 if isinstance(message, str):
39 data = json.loads(message) 47 data = json.loads(message)
40 - if 'url' in data:  
41 - await task_queue.put((websocket, data['url'])) 48 +
  49 + # 处理分块协议
  50 + if 'type' in data:
  51 + await handle_chunked_protocol(websocket, data, user_id)
  52 + # 处理传统协议
  53 + elif 'url' in data:
  54 + # 处理文件URL
  55 + await task_queue.put((websocket, data['url'], 'url'))
  56 + elif 'audio_data' in data:
  57 + # 处理音频数据
  58 + await task_queue.put((websocket, data, 'audio_data'))
42 except websockets.exceptions.ConnectionClosed as e: 59 except websockets.exceptions.ConnectionClosed as e:
43 logger.info(f"Connection closed: {e.reason}") 60 logger.info(f"Connection closed: {e.reason}")
44 except Exception as e: 61 except Exception as e:
@@ -47,14 +64,28 @@ async def ws_serve(websocket, path): @@ -47,14 +64,28 @@ async def ws_serve(websocket, path):
47 logger.info(f"Cleaning up connection for user {user_id}") 64 logger.info(f"Cleaning up connection for user {user_id}")
48 if user_id in websocket_users: 65 if user_id in websocket_users:
49 del websocket_users[user_id] 66 del websocket_users[user_id]
  67 + # 清理分块会话
  68 + if user_id in chunk_sessions:
  69 + await cleanup_chunk_session(user_id)
50 await websocket.close() 70 await websocket.close()
51 logger.info("WebSocket closed") 71 logger.info("WebSocket closed")
52 72
53 async def worker(): 73 async def worker():
54 while True: 74 while True:
55 - websocket, url = await task_queue.get() 75 + task_data = await task_queue.get()
  76 + websocket = task_data[0]
  77 +
56 if websocket.open: 78 if websocket.open:
57 - await process_wav_file(websocket, url) 79 + if len(task_data) == 3: # 新格式: (websocket, data, type)
  80 + data, data_type = task_data[1], task_data[2]
  81 + if data_type == 'url':
  82 + await process_wav_file(websocket, data)
  83 + elif data_type == 'audio_data':
  84 + await process_audio_data(websocket, data)
  85 + elif data_type == 'chunked_audio':
  86 + await process_chunked_audio(websocket, data)
  87 + else: # 兼容旧格式: (websocket, url)
  88 + await process_wav_file(websocket, task_data[1])
58 else: 89 else:
59 logger.info("WebSocket connection is already closed when trying to process file") 90 logger.info("WebSocket connection is already closed when trying to process file")
60 task_queue.task_done() 91 task_queue.task_done()
@@ -77,8 +108,226 @@ async def process_wav_file(websocket, url): @@ -77,8 +108,226 @@ async def process_wav_file(websocket, url):
77 except Exception as e: 108 except Exception as e:
78 print(f"Error during model.generate: {e}") 109 print(f"Error during model.generate: {e}")
79 finally: 110 finally:
80 - if os.path.exists(wav_path):  
81 - os.remove(wav_path) 111 + # 注释掉文件删除操作,保留缓存文件用于测试
  112 + # if os.path.exists(wav_path):
  113 + # os.remove(wav_path)
  114 + print(f"保留音频文件用于测试: {wav_path}")
  115 +
  116 +async def handle_chunked_protocol(websocket, data, user_id):
  117 + """处理分块协议消息"""
  118 + global chunk_sessions
  119 +
  120 + try:
  121 + msg_type = data.get('type')
  122 + filename = data.get('filename', 'unknown.wav')
  123 +
  124 + if msg_type == 'audio_start':
  125 + # 开始新的分块会话
  126 + total_chunks = data.get('total_chunks', 0)
  127 + total_size = data.get('total_size', 0)
  128 +
  129 + print(f"开始接收分块音频: {filename}, 总分块数: {total_chunks}, 总大小: {total_size} bytes")
  130 +
  131 + # 创建临时文件
  132 + import tempfile
  133 + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
  134 +
  135 + chunk_sessions[user_id] = {
  136 + 'filename': filename,
  137 + 'total_chunks': total_chunks,
  138 + 'total_size': total_size,
  139 + 'received_chunks': 0,
  140 + 'temp_file': temp_file,
  141 + 'temp_path': temp_file.name,
  142 + 'chunks_data': {} # {chunk_index: chunk_data}
  143 + }
  144 +
  145 + await websocket.send(json.dumps({"status": "ready", "message": f"准备接收 {total_chunks} 个分块"}))
  146 +
  147 + elif msg_type == 'audio_chunk':
  148 + # 接收音频分块
  149 + if user_id not in chunk_sessions:
  150 + await websocket.send(json.dumps({"error": "未找到分块会话,请先发送audio_start"}))
  151 + return
  152 +
  153 + session = chunk_sessions[user_id]
  154 + chunk_index = data.get('chunk_index', -1)
  155 + chunk_data = data.get('chunk_data', '')
  156 + is_last = data.get('is_last', False)
  157 +
  158 + if chunk_index >= 0 and chunk_data:
  159 + # 解码并存储分块数据
  160 + import base64
  161 + chunk_bytes = base64.b64decode(chunk_data)
  162 + session['chunks_data'][chunk_index] = chunk_bytes
  163 + session['received_chunks'] += 1
  164 +
  165 + # 进度反馈
  166 + progress = (session['received_chunks'] / session['total_chunks']) * 100
  167 + if session['received_chunks'] % 10 == 0 or is_last:
  168 + print(f"接收进度: {progress:.1f}% ({session['received_chunks']}/{session['total_chunks']})")
  169 +
  170 + elif msg_type == 'audio_end':
  171 + # 完成分块接收,重组音频
  172 + if user_id not in chunk_sessions:
  173 + await websocket.send(json.dumps({"error": "未找到分块会话"}))
  174 + return
  175 +
  176 + session = chunk_sessions[user_id]
  177 +
  178 + # 检查是否接收完整
  179 + if session['received_chunks'] != session['total_chunks']:
  180 + await websocket.send(json.dumps({
  181 + "error": f"分块不完整: 期望{session['total_chunks']}, 实际{session['received_chunks']}"
  182 + }))
  183 + await cleanup_chunk_session(user_id)
  184 + return
  185 +
  186 + # 按顺序重组音频数据
  187 + print(f"重组音频文件: {session['filename']}")
  188 + with open(session['temp_path'], 'wb') as f:
  189 + for i in range(session['total_chunks']):
  190 + if i in session['chunks_data']:
  191 + f.write(session['chunks_data'][i])
  192 + else:
  193 + print(f"警告: 分块 {i} 缺失")
  194 +
  195 + # 提交到处理队列
  196 + reconstructed_data = {
  197 + 'audio_file_path': session['temp_path'],
  198 + 'filename': session['filename']
  199 + }
  200 + await task_queue.put((websocket, reconstructed_data, 'chunked_audio'))
  201 +
  202 + # 清理会话(保留临时文件给处理函数)
  203 + del chunk_sessions[user_id]
  204 + print(f"分块音频重组完成: {session['filename']}")
  205 +
  206 + except Exception as e:
  207 + print(f"处理分块协议时出错: {e}")
  208 + await websocket.send(json.dumps({"error": f"分块处理错误: {str(e)}"}))
  209 + if user_id in chunk_sessions:
  210 + await cleanup_chunk_session(user_id)
  211 +
  212 +async def cleanup_chunk_session(user_id):
  213 + """清理分块会话"""
  214 + global chunk_sessions
  215 +
  216 + if user_id in chunk_sessions:
  217 + session = chunk_sessions[user_id]
  218 + try:
  219 + # 关闭并删除临时文件
  220 + if 'temp_file' in session:
  221 + session['temp_file'].close()
  222 + if 'temp_path' in session and os.path.exists(session['temp_path']):
  223 + os.remove(session['temp_path'])
  224 + print(f"清理临时文件: {session['temp_path']}")
  225 + except Exception as e:
  226 + print(f"清理分块会话时出错: {e}")
  227 + finally:
  228 + del chunk_sessions[user_id]
  229 +
  230 +async def process_chunked_audio(websocket, data):
  231 + """处理分块重组后的音频文件"""
  232 + try:
  233 + audio_file_path = data.get('audio_file_path')
  234 + filename = data.get('filename', 'chunked_audio.wav')
  235 +
  236 + if not audio_file_path or not os.path.exists(audio_file_path):
  237 + await websocket.send(json.dumps({"error": "重组音频文件不存在"}))
  238 + return
  239 +
  240 + print(f"处理分块重组音频: {filename}, 文件路径: {audio_file_path}")
  241 +
  242 + # 热词配置
  243 + param_dict = {"sentence_timestamp": False}
  244 + try:
  245 + with open("data/hotword.txt", "r", encoding="utf-8") as f:
  246 + lines = f.readlines()
  247 + lines = [line.strip() for line in lines]
  248 + hotword = " ".join(lines)
  249 + print(f"热词:{hotword}")
  250 + param_dict["hotword"] = hotword
  251 + except FileNotFoundError:
  252 + print("热词文件不存在,跳过热词配置")
  253 +
  254 + # 进行语音识别
  255 + res = asr_model.generate(input=audio_file_path, is_final=True, **param_dict)
  256 + if res and websocket.open:
  257 + if 'text' in res[0]:
  258 + result_text = res[0]['text']
  259 + print(f"分块音频识别结果: {result_text}")
  260 + await websocket.send(result_text)
  261 + else:
  262 + await websocket.send("识别失败:无法获取文本结果")
  263 +
  264 + except Exception as e:
  265 + print(f"处理分块音频时出错: {e}")
  266 + if websocket.open:
  267 + await websocket.send(f"分块音频识别错误: {str(e)}")
  268 + finally:
  269 + # 注释掉临时文件删除操作,保留用于测试
  270 + # if 'audio_file_path' in locals() and os.path.exists(audio_file_path):
  271 + # os.remove(audio_file_path)
  272 + if 'audio_file_path' in locals():
  273 + print(f"保留分块重组音频文件用于测试: {audio_file_path}")
  274 +
  275 +async def process_audio_data(websocket, data):
  276 + """处理音频数据"""
  277 + import base64
  278 + import tempfile
  279 +
  280 + try:
  281 + # 获取音频数据
  282 + audio_data = data.get('audio_data')
  283 + filename = data.get('filename', 'audio.wav')
  284 +
  285 + if not audio_data:
  286 + await websocket.send(json.dumps({"error": "No audio data provided"}))
  287 + return
  288 +
  289 + # 解码Base64音频数据
  290 + audio_bytes = base64.b64decode(audio_data)
  291 +
  292 + # 创建临时文件
  293 + with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
  294 + temp_file.write(audio_bytes)
  295 + temp_path = temp_file.name
  296 +
  297 + print(f"处理音频文件: {filename}, 临时路径: {temp_path}")
  298 +
  299 + # 热词配置
  300 + param_dict = {"sentence_timestamp": False}
  301 + try:
  302 + with open("data/hotword.txt", "r", encoding="utf-8") as f:
  303 + lines = f.readlines()
  304 + lines = [line.strip() for line in lines]
  305 + hotword = " ".join(lines)
  306 + print(f"热词:{hotword}")
  307 + param_dict["hotword"] = hotword
  308 + except FileNotFoundError:
  309 + print("热词文件不存在,跳过热词配置")
  310 +
  311 + # 进行语音识别
  312 + res = asr_model.generate(input=temp_path, is_final=True, **param_dict)
  313 + if res and websocket.open:
  314 + if 'text' in res[0]:
  315 + result_text = res[0]['text']
  316 + print(f"识别结果: {result_text}")
  317 + await websocket.send(result_text)
  318 + else:
  319 + await websocket.send("识别失败:无法获取文本结果")
  320 +
  321 + except Exception as e:
  322 + print(f"处理音频数据时出错: {e}")
  323 + if websocket.open:
  324 + await websocket.send(f"识别错误: {str(e)}")
  325 + finally:
  326 + # 注释掉临时文件删除操作,保留用于测试
  327 + # if 'temp_path' in locals() and os.path.exists(temp_path):
  328 + # os.remove(temp_path)
  329 + if 'temp_path' in locals():
  330 + print(f"保留临时音频文件用于测试: {temp_path}")
82 331
83 async def main(): 332 async def main():
84 server = await websockets.serve(ws_serve, args.host, args.port, ping_interval=10) 333 server = await websockets.serve(ws_serve, args.host, args.port, ping_interval=10)
@@ -87,6 +336,7 @@ async def main(): @@ -87,6 +336,7 @@ async def main():
87 try: 336 try:
88 # 保持服务器运行,直到被手动中断 337 # 保持服务器运行,直到被手动中断
89 print(f"ASR服务器已启动,监听地址: {args.host}:{args.port}") 338 print(f"ASR服务器已启动,监听地址: {args.host}:{args.port}")
  339 + print("注意:此版本已禁用文件自动删除功能,用于测试分析")
90 await asyncio.Future() # 永久等待,直到程序被中断 340 await asyncio.Future() # 永久等待,直到程序被中断
91 except asyncio.CancelledError: 341 except asyncio.CancelledError:
92 print("服务器正在关闭...") 342 print("服务器正在关闭...")
@@ -101,4 +351,11 @@ async def main(): @@ -101,4 +351,11 @@ async def main():
101 await server.wait_closed() 351 await server.wait_closed()
102 352
103 # 使用 asyncio 运行主函数 353 # 使用 asyncio 运行主函数
104 -asyncio.run(main()) 354 +try:
  355 + asyncio.run(main())
  356 +except KeyboardInterrupt:
  357 + logging.info("服务器已关闭")
  358 +except Exception as e:
  359 + logging.error(f"服务器启动失败: {e}")
  360 + import traceback
  361 + traceback.print_exc()