ASR WebSocket服务实现:1.FunASR本地方案 2.豆包
1.已实现音频文件的处理,包括小文件直接转换以及大文件分割识别。 2.豆包接入流式识别,但封装仍需要修改
Showing
28 changed files
with
3287 additions
and
1242 deletions
Too many changes to show.
To preserve performance only 28 of 28+ files are displayed.
| @@ -49,120 +49,42 @@ import gc | @@ -49,120 +49,42 @@ import gc | ||
| 49 | import weakref | 49 | import weakref |
| 50 | import time | 50 | import time |
| 51 | 51 | ||
| 52 | +# 注意:server_recording_api模块已移除,相关功能已迁移到其他模块 | ||
| 53 | +# 导入新的统一WebSocket管理架构 | ||
| 54 | +from core.app_websocket_migration import ( | ||
| 55 | + get_app_websocket_migration, | ||
| 56 | + initialize_app_websocket_migration, | ||
| 57 | + setup_app_websocket_routes, | ||
| 58 | + broadcast_message_to_session, | ||
| 59 | + handle_asr_audio_data, | ||
| 60 | + handle_start_asr_recognition, | ||
| 61 | + handle_stop_asr_recognition, | ||
| 62 | + send_asr_result, | ||
| 63 | + send_normal_asr_result | ||
| 64 | +) | ||
| 52 | 65 | ||
| 53 | app = Flask(__name__) | 66 | app = Flask(__name__) |
| 54 | #sockets = Sockets(app) | 67 | #sockets = Sockets(app) |
| 55 | nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal | 68 | nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal |
| 56 | -websocket_connections:Dict[int, weakref.WeakSet] = {} #sessionid:websocket_connections | 69 | +# WebSocket连接管理已迁移到统一架构 |
| 70 | +# websocket_connections和asr_connections现在通过迁移层管理 | ||
| 71 | +# 全局事件循环引用,用于跨线程异步调用 | ||
| 72 | +main_event_loop = None | ||
| 57 | opt = None | 73 | opt = None |
| 58 | model = None | 74 | model = None |
| 59 | avatar = None | 75 | avatar = None |
| 76 | +# WebSocket迁移实例 | ||
| 77 | +websocket_migration = None | ||
| 60 | 78 | ||
| 61 | 79 | ||
| 62 | #####webrtc############################### | 80 | #####webrtc############################### |
| 63 | pcs = set() | 81 | pcs = set() |
| 64 | 82 | ||
| 65 | -# WebSocket消息推送函数 | ||
| 66 | -async def broadcast_message_to_session(sessionid: int, message_type: str, content: str, source: str = "数字人回复", model_info: str = None, request_source: str = "页面"): | ||
| 67 | - """向指定会话的所有WebSocket连接推送消息""" | ||
| 68 | - logger.info(f'[SessionID:{sessionid}] 开始推送消息: {message_type}, source: {source}, content: {content[:50]}...') | ||
| 69 | - logger.info(f'[SessionID:{sessionid}] 当前websocket_connections keys: {list(websocket_connections.keys())}') | ||
| 70 | - | ||
| 71 | - if sessionid not in websocket_connections: | ||
| 72 | - logger.warning(f'[SessionID:{sessionid}] 会话不存在于websocket_connections中') | ||
| 73 | - return | ||
| 74 | - | ||
| 75 | - logger.info(f'[SessionID:{sessionid}] 找到会话,连接数量: {len(websocket_connections[sessionid])}') | ||
| 76 | - | ||
| 77 | - message = { | ||
| 78 | - "type": "chat_message", | ||
| 79 | - "data": { | ||
| 80 | - "sessionid": sessionid, | ||
| 81 | - "message_type": message_type, | ||
| 82 | - "content": content, | ||
| 83 | - "source": source, | ||
| 84 | - "model_info": model_info, | ||
| 85 | - "request_source": request_source, | ||
| 86 | - "timestamp": time.time() | ||
| 87 | - } | ||
| 88 | - } | ||
| 89 | - | ||
| 90 | - # 获取该会话的所有WebSocket连接 | ||
| 91 | - connections = list(websocket_connections[sessionid]) | ||
| 92 | - | ||
| 93 | - # 向所有连接发送消息 | ||
| 94 | - logger.info(f'[SessionID:{sessionid}] 准备向{len(connections)}个连接发送消息') | ||
| 95 | - for i, ws in enumerate(connections): | ||
| 96 | - try: | ||
| 97 | - logger.info(f'[SessionID:{sessionid}] 检查连接{i+1}: closed={ws.closed}') | ||
| 98 | - if not ws.closed: | ||
| 99 | - logger.info(f'[SessionID:{sessionid}] 向连接{i+1}发送消息: {json.dumps(message)}') | ||
| 100 | - await ws.send_str(json.dumps(message)) | ||
| 101 | - logger.info(f'[SessionID:{sessionid}] 连接{i+1}消息发送成功: {message_type} from {request_source}') | ||
| 102 | - else: | ||
| 103 | - logger.warning(f'[SessionID:{sessionid}] 连接{i+1}已关闭,跳过发送') | ||
| 104 | - except Exception as e: | ||
| 105 | - logger.error(f'[SessionID:{sessionid}] 连接{i+1}发送失败: {e}') | ||
| 106 | - | ||
| 107 | -# WebSocket处理器 | ||
| 108 | -async def websocket_handler(request): | ||
| 109 | - """处理WebSocket连接""" | ||
| 110 | - ws = web.WebSocketResponse() | ||
| 111 | - await ws.prepare(request) | ||
| 112 | - | ||
| 113 | - sessionid = None | ||
| 114 | - logger.info('New WebSocket connection established') | ||
| 115 | - | ||
| 116 | - try: | ||
| 117 | - async for msg in ws: | ||
| 118 | - if msg.type == WSMsgType.TEXT: | ||
| 119 | - try: | ||
| 120 | - data = json.loads(msg.data) | ||
| 121 | - | ||
| 122 | - if data.get('type') == 'login': | ||
| 123 | - sessionid = data.get('sessionid', 0) | ||
| 124 | - logger.info(f'[SessionID:{sessionid}] 收到登录请求,当前连接池: {list(websocket_connections.keys())}') | ||
| 125 | - | ||
| 126 | - # 初始化该会话的WebSocket连接集合 | ||
| 127 | - if sessionid not in websocket_connections: | ||
| 128 | - websocket_connections[sessionid] = weakref.WeakSet() | ||
| 129 | - logger.info(f'[SessionID:{sessionid}] 创建新的连接集合') | ||
| 130 | - | ||
| 131 | - # 添加当前连接到会话 | ||
| 132 | - websocket_connections[sessionid].add(ws) | ||
| 133 | - logger.info(f'[SessionID:{sessionid}] 连接已添加,当前会话连接数: {len(websocket_connections[sessionid])}') | ||
| 134 | - | ||
| 135 | - logger.info(f'[SessionID:{sessionid}] WebSocket client logged in') | ||
| 136 | - | ||
| 137 | - # 发送登录确认 | ||
| 138 | - await ws.send_str(json.dumps({ | ||
| 139 | - "type": "login_success", | ||
| 140 | - "sessionid": sessionid, | ||
| 141 | - "message": "WebSocket连接成功" | ||
| 142 | - })) | ||
| 143 | - | ||
| 144 | - elif data.get('type') == 'ping': | ||
| 145 | - # 心跳检测 | ||
| 146 | - await ws.send_str(json.dumps({"type": "pong"})) | ||
| 147 | - | ||
| 148 | - except json.JSONDecodeError: | ||
| 149 | - logger.error('Invalid JSON received from WebSocket') | ||
| 150 | - except Exception as e: | ||
| 151 | - logger.error(f'Error processing WebSocket message: {e}') | ||
| 152 | - | ||
| 153 | - elif msg.type == WSMsgType.ERROR: | ||
| 154 | - logger.error(f'WebSocket error: {ws.exception()}') | ||
| 155 | - break | ||
| 156 | - | ||
| 157 | - except Exception as e: | ||
| 158 | - logger.error(f'WebSocket connection error: {e}') | ||
| 159 | - finally: | ||
| 160 | - if sessionid is not None: | ||
| 161 | - logger.info(f'[SessionID:{sessionid}] WebSocket connection closed') | ||
| 162 | - else: | ||
| 163 | - logger.info('WebSocket connection closed') | 83 | +# WebSocket消息推送函数已迁移到统一架构 |
| 84 | +# 通过 core.app_websocket_migration 模块提供兼容性接口 | ||
| 164 | 85 | ||
| 165 | - return ws | 86 | +# WebSocket处理器已迁移到统一架构 |
| 87 | +# 通过 core.app_websocket_migration 模块提供 | ||
| 166 | 88 | ||
| 167 | def randN(N)->int: | 89 | def randN(N)->int: |
| 168 | '''生成长度为 N的随机数 ''' | 90 | '''生成长度为 N的随机数 ''' |
| @@ -456,41 +378,187 @@ async def interrupt_talk(request): | @@ -456,41 +378,187 @@ async def interrupt_talk(request): | ||
| 456 | ) | 378 | ) |
| 457 | from pydub import AudioSegment | 379 | from pydub import AudioSegment |
| 458 | from io import BytesIO | 380 | from io import BytesIO |
| 459 | -async def humanaudio(request): | 381 | + |
| 382 | +async def ensure_asr_connection(sessionid: int) -> bool: | ||
| 383 | + """确保ASR连接可用""" | ||
| 384 | + # 通过迁移实例获取ASR连接 | ||
| 385 | + migration = get_app_websocket_migration() | ||
| 386 | + if sessionid not in migration.asr_connections: | ||
| 387 | + return await create_asr_connection(sessionid) | ||
| 388 | + | ||
| 389 | + asr_client = migration.asr_connections[sessionid] | ||
| 390 | + | ||
| 391 | + # 检查连接状态 | ||
| 392 | + if not asr_client.is_connected(): | ||
| 393 | + logger.warning(f"[SessionID:{sessionid}] ASR连接已断开,尝试重连") | ||
| 460 | try: | 394 | try: |
| 461 | - params = await request.json() | ||
| 462 | - sessionid = int(params.get('sessionid', 0)) | ||
| 463 | - fileobj = params.get('file_url') | 395 | + # 重新连接 |
| 396 | + success = await asyncio.get_event_loop().run_in_executor( | ||
| 397 | + None, asr_client.connect | ||
| 398 | + ) | ||
| 399 | + if success: | ||
| 400 | + logger.info(f"[SessionID:{sessionid}] ASR重连成功") | ||
| 401 | + return True | ||
| 402 | + else: | ||
| 403 | + logger.error(f"[SessionID:{sessionid}] ASR重连失败") | ||
| 404 | + # 清理失效连接 | ||
| 405 | + migration = get_app_websocket_migration() | ||
| 406 | + if sessionid in migration.asr_connections: | ||
| 407 | + del migration.asr_connections[sessionid] | ||
| 408 | + return False | ||
| 409 | + except Exception as e: | ||
| 410 | + logger.error(f"[SessionID:{sessionid}] ASR重连异常: {e}") | ||
| 411 | + del asr_connections[sessionid] | ||
| 412 | + return False | ||
| 464 | 413 | ||
| 465 | - # 获取音频文件数据 | ||
| 466 | - if isinstance(fileobj, str) and fileobj.startswith("http"): | ||
| 467 | - async with aiohttp.ClientSession() as session: | ||
| 468 | - async with session.get(fileobj) as response: | ||
| 469 | - if response.status == 200: | ||
| 470 | - filebytes = await response.read() | 414 | + return True |
| 415 | + | ||
| 416 | +async def create_asr_connection(sessionid: int) -> bool: | ||
| 417 | + """创建新的ASR连接""" | ||
| 418 | + try: | ||
| 419 | + from funasr_asr_sync import FunASRSync | ||
| 420 | + username = f'User_{sessionid}' # 修复大小写不一致:user_ -> User_ | ||
| 421 | + asr_client = FunASRSync(username) | ||
| 422 | + | ||
| 423 | + # 设置结果回调 | ||
| 424 | + def on_asr_result(result): | ||
| 425 | + if isinstance(result, str): | ||
| 426 | + result_data = { | ||
| 427 | + 'text': result, | ||
| 428 | + 'is_final': True, | ||
| 429 | + 'confidence': 1.0 | ||
| 430 | + } | ||
| 431 | + else: | ||
| 432 | + result_data = result | ||
| 433 | + | ||
| 434 | + # 线程安全地调度异步任务 | ||
| 435 | + try: | ||
| 436 | + # 优先使用全局事件循环引用 | ||
| 437 | + if main_event_loop is not None and not main_event_loop.is_closed(): | ||
| 438 | + # 使用全局事件循环进行跨线程调用 | ||
| 439 | + asyncio.run_coroutine_threadsafe( | ||
| 440 | + # send_asr_result(sessionid, result_data), main_event_loop | ||
| 441 | + send_normal_asr_result(sessionid, result_data), main_event_loop | ||
| 442 | + ) | ||
| 443 | + logger.debug(f"[SessionID:{sessionid}] 使用全局事件循环发送ASR结果") | ||
| 444 | + else: | ||
| 445 | + # 降级处理:尝试获取当前线程的事件循环 | ||
| 446 | + try: | ||
| 447 | + loop = asyncio.get_event_loop() | ||
| 448 | + if loop.is_running(): | ||
| 449 | + loop.call_soon_threadsafe( | ||
| 450 | + lambda: asyncio.create_task(send_normal_asr_result(sessionid, result_data)) | ||
| 451 | + ) | ||
| 452 | + else: | ||
| 453 | + asyncio.create_task(send_normal_asr_result(sessionid, result_data)) | ||
| 454 | + except RuntimeError: | ||
| 455 | + # 最终降级:仅记录日志 | ||
| 456 | + logger.info(f"[SessionID:{sessionid}] ASR识别结果: {result_data.get('text', 'N/A')}") | ||
| 457 | + logger.warning(f"[SessionID:{sessionid}] 无法发送ASR结果到客户端,事件循环不可用") | ||
| 458 | + except Exception as e: | ||
| 459 | + logger.error(f"[SessionID:{sessionid}] ASR结果处理异常: {e}") | ||
| 460 | + # 至少记录识别结果 | ||
| 461 | + logger.info(f"[SessionID:{sessionid}] ASR识别结果: {result_data.get('text', 'N/A')}") | ||
| 462 | + | ||
| 463 | + asr_client.set_result_callback(on_asr_result) | ||
| 464 | + | ||
| 465 | + # 异步连接 | ||
| 466 | + success = await asyncio.get_event_loop().run_in_executor( | ||
| 467 | + None, asr_client.connect | ||
| 468 | + ) | ||
| 469 | + | ||
| 470 | + if success: | ||
| 471 | + # 通过迁移实例存储ASR连接 | ||
| 472 | + migration = get_app_websocket_migration() | ||
| 473 | + migration.asr_connections[sessionid] = asr_client | ||
| 474 | + logger.info(f"[SessionID:{sessionid}] ASR连接创建成功") | ||
| 475 | + return True | ||
| 471 | else: | 476 | else: |
| 477 | + logger.error(f"[SessionID:{sessionid}] ASR连接创建失败") | ||
| 478 | + return False | ||
| 479 | + | ||
| 480 | + except Exception as e: | ||
| 481 | + logger.error(f"[SessionID:{sessionid}] 创建ASR连接异常: {e}") | ||
| 482 | + return False | ||
| 483 | + | ||
| 484 | +async def humanaudio(request): | ||
| 485 | + try: | ||
| 486 | + # 检查请求内容类型,支持FormData和JSON两种格式 | ||
| 487 | + content_type = request.headers.get('content-type', '') | ||
| 488 | + | ||
| 489 | + | ||
| 490 | + # 处理FormData格式(文件上传) | ||
| 491 | + reader = await request.multipart() | ||
| 492 | + sessionid = 0 | ||
| 493 | + fileobj = None | ||
| 494 | + # 默认启用语音本地服务 | ||
| 495 | + asr_service = "funasr" | ||
| 496 | + | ||
| 497 | + # 读取FormData字段 | ||
| 498 | + async for field in reader: | ||
| 499 | + if field.name == 'sessionid': | ||
| 500 | + sessionid = int(await field.text()) | ||
| 501 | + logger.info(f'Parsed sessionid: {sessionid}') | ||
| 502 | + elif field.name == 'audio': | ||
| 503 | + fileobj = field | ||
| 504 | + filename = field.filename | ||
| 505 | + filebytes = await field.read() | ||
| 506 | + # 输出文件大小信息 | ||
| 507 | + logger.info(f'Audio file content size: {len(filebytes)} bytes') | ||
| 508 | + if not fileobj: | ||
| 472 | return web.Response( | 509 | return web.Response( |
| 473 | content_type="application/json", | 510 | content_type="application/json", |
| 474 | - text=json.dumps({"code": -1, "msg": "Error downloading file"}) | 511 | + text=json.dumps({"code": -1, "msg": "No audio file provided"}) |
| 475 | ) | 512 | ) |
| 476 | - # 根据 URL 后缀判断是否为 MP3 文件 | ||
| 477 | - is_mp3 = fileobj.lower().endswith('.mp3') | ||
| 478 | - else: | ||
| 479 | - filename = fileobj.filename | ||
| 480 | - filebytes = fileobj.file.read() | ||
| 481 | - is_mp3 = filename.lower().endswith('.mp3') | 513 | + elif field.name == 'asr_service': |
| 514 | + asr_service = (await field.text()).strip().lower() | ||
| 482 | 515 | ||
| 516 | + # 根据文件名判断是否为 MP3 文件 | ||
| 517 | + is_mp3 = filename.lower().endswith('.mp3') if filename else False | ||
| 518 | + | ||
| 519 | + # 处理MP3转WAV | ||
| 483 | if is_mp3: | 520 | if is_mp3: |
| 484 | - audio = AudioSegment.from_file(BytesIO(filebytes), format="mp3") | 521 | + try: |
| 522 | + with BytesIO(filebytes) as audio_buffer: | ||
| 523 | + audio = AudioSegment.from_file(audio_buffer, format="mp3") | ||
| 485 | out_io = BytesIO() | 524 | out_io = BytesIO() |
| 486 | audio.export(out_io, format="wav") | 525 | audio.export(out_io, format="wav") |
| 487 | filebytes = out_io.getvalue() | 526 | filebytes = out_io.getvalue() |
| 527 | + except Exception as e: | ||
| 528 | + logger.error(f"[SessionID:{sessionid}] 音频处理失败: {e}") | ||
| 529 | + raise | ||
| 530 | + | ||
| 531 | + # 获取WebSocket迁移实例来访问连接信息 | ||
| 532 | + migration = get_app_websocket_migration() | ||
| 533 | + active_sessions = migration.get_websocket_connections() | ||
| 534 | + logger.info(f'[SessionID:{sessionid}] 收到登录请求,当前连接池: {list(active_sessions.keys())}') | ||
| 535 | + # 验证sessionid是否存在 | ||
| 536 | + if sessionid not in nerfreals: | ||
| 537 | + return web.Response( | ||
| 538 | + content_type="application/json", | ||
| 539 | + text=json.dumps({"code": -1, "msg": f"Session {sessionid} not found. Please establish WebRTC connection first."}) | ||
| 540 | + ) | ||
| 488 | 541 | ||
| 542 | + # 发送音频数据进行处理 数字人播报 | ||
| 489 | nerfreals[sessionid].put_audio_file(filebytes) | 543 | nerfreals[sessionid].put_audio_file(filebytes) |
| 490 | 544 | ||
| 545 | + | ||
| 546 | + # ---------- ASR 分流 ---------- | ||
| 547 | + if asr_service == 'funasr': | ||
| 548 | + await handle_funasr(sessionid, filebytes) | ||
| 549 | + elif asr_service == 'doubao': | ||
| 550 | + await handle_doubao(sessionid, filebytes) | ||
| 551 | + else: | ||
| 552 | + logger.warning(f'[SessionID:{sessionid}] 未指定或未知 asr_service,跳过 ASR') | ||
| 553 | + | ||
| 554 | + | ||
| 555 | + | ||
| 556 | + # 通过迁移实例检查ASR连接状态 | ||
| 557 | + migration = get_app_websocket_migration() | ||
| 558 | + asr_enabled = sessionid in migration.asr_connections | ||
| 491 | return web.Response( | 559 | return web.Response( |
| 492 | content_type="application/json", | 560 | content_type="application/json", |
| 493 | - text=json.dumps({"code": 0, "msg": "ok"}) | 561 | + text=json.dumps({"code": 0, "msg": "ok", "asr_enabled": asr_enabled}) |
| 494 | ) | 562 | ) |
| 495 | 563 | ||
| 496 | except Exception as e: | 564 | except Exception as e: |
| @@ -500,6 +568,66 @@ async def humanaudio(request): | @@ -500,6 +568,66 @@ async def humanaudio(request): | ||
| 500 | text=json.dumps( {"code": -1, "msg": str(e)}) | 568 | text=json.dumps( {"code": -1, "msg": str(e)}) |
| 501 | ) | 569 | ) |
| 502 | 570 | ||
| 571 | +async def handle_funasr(sessionid: int, audio_bytes: bytes): | ||
| 572 | + # ASR识别处理 - 使用新的连接管理机制 | ||
| 573 | + try: | ||
| 574 | + # 确保ASR连接可用 | ||
| 575 | + asr_available = await ensure_asr_connection(sessionid) | ||
| 576 | + | ||
| 577 | + if asr_available: | ||
| 578 | + # 发送音频数据到ASR服务进行识别 | ||
| 579 | + # 通过迁移实例获取ASR连接 | ||
| 580 | + migration = get_app_websocket_migration() | ||
| 581 | + asr_client = migration.asr_connections[sessionid] | ||
| 582 | + if hasattr(asr_client, 'send_audio_data'): | ||
| 583 | + asr_client.send_audio_data(audio_bytes) | ||
| 584 | + logger.info(f'[SessionID:{sessionid}] 音频数据已发送到ASR服务进行识别') | ||
| 585 | + else: | ||
| 586 | + logger.warning(f'[SessionID:{sessionid}] ASR客户端不支持send_audio_data方法') | ||
| 587 | + else: | ||
| 588 | + logger.warning(f'[SessionID:{sessionid}] ASR连接不可用,跳过语音识别') | ||
| 589 | + | ||
| 590 | + except Exception as asr_error: | ||
| 591 | + logger.error(f'[SessionID:{sessionid}] ASR处理错误: {asr_error}') | ||
| 592 | + # ASR错误不影响主要功能,继续返回成功 | ||
| 593 | + | ||
| 594 | +# 导入 Doubao ASR 服务 | ||
| 595 | +from asr.doubao.service_factory import recognize_audio_data | ||
| 596 | +import os | ||
| 597 | +import json | ||
| 598 | + | ||
| 599 | +async def handle_doubao(sessionid: int, audio_bytes: bytes): | ||
| 600 | + """云端 Doubao 调用""" | ||
| 601 | + try: | ||
| 602 | + logger.info(f"[SessionID:{sessionid}] 使用云端 Doubao 识别") | ||
| 603 | + | ||
| 604 | + # 读取豆包ASR配置文件 | ||
| 605 | + config_path = os.path.join(os.path.dirname(__file__), 'asr', 'doubao', 'config.json') | ||
| 606 | + with open(config_path, 'r', encoding='utf-8') as f: | ||
| 607 | + config = json.load(f) | ||
| 608 | + | ||
| 609 | + # 获取认证配置 | ||
| 610 | + auth_config = config.get('auth_config', {}) | ||
| 611 | + app_key = auth_config.get('app_key') | ||
| 612 | + access_key = auth_config.get('access_key') | ||
| 613 | + | ||
| 614 | + if not app_key or not access_key: | ||
| 615 | + raise ValueError("豆包ASR认证配置缺失:app_key 或 access_key 未配置") | ||
| 616 | + | ||
| 617 | + result = await recognize_audio_data( | ||
| 618 | + audio_data=audio_bytes, | ||
| 619 | + app_key=app_key, | ||
| 620 | + access_key=access_key, | ||
| 621 | + streaming=True, | ||
| 622 | + result_callback=lambda res: logger.info(f"[SessionID:{sessionid}] Doubao 识别结果: {res}") | ||
| 623 | + ) | ||
| 624 | + return result | ||
| 625 | + except Exception as e: | ||
| 626 | + logger.error(f"[SessionID:{sessionid}] Doubao 错误: {e}") | ||
| 627 | + raise | ||
| 628 | + | ||
| 629 | + | ||
| 630 | + | ||
| 503 | async def set_audiotype(request): | 631 | async def set_audiotype(request): |
| 504 | try: | 632 | try: |
| 505 | params = await request.json() | 633 | params = await request.json() |
| @@ -787,6 +915,10 @@ if __name__ == '__main__': | @@ -787,6 +915,10 @@ if __name__ == '__main__': | ||
| 787 | rendthrd.start() | 915 | rendthrd.start() |
| 788 | 916 | ||
| 789 | ############################################################################# | 917 | ############################################################################# |
| 918 | + # ASR处理函数已迁移到统一架构 | ||
| 919 | + # 通过 core.app_websocket_migration 模块提供 | ||
| 920 | + | ||
| 921 | + ############################################################################# | ||
| 790 | appasync = web.Application() | 922 | appasync = web.Application() |
| 791 | appasync.on_shutdown.append(on_shutdown) | 923 | appasync.on_shutdown.append(on_shutdown) |
| 792 | appasync.router.add_post("/offer", offer) | 924 | appasync.router.add_post("/offer", offer) |
| @@ -796,9 +928,27 @@ if __name__ == '__main__': | @@ -796,9 +928,27 @@ if __name__ == '__main__': | ||
| 796 | appasync.router.add_post("/record", record) | 928 | appasync.router.add_post("/record", record) |
| 797 | appasync.router.add_post("/interrupt_talk", interrupt_talk) | 929 | appasync.router.add_post("/interrupt_talk", interrupt_talk) |
| 798 | appasync.router.add_post("/is_speaking", is_speaking) | 930 | appasync.router.add_post("/is_speaking", is_speaking) |
| 799 | - appasync.router.add_get("/ws", websocket_handler) | 931 | + |
| 932 | + # 初始化统一WebSocket管理架构 | ||
| 933 | + websocket_migration = get_app_websocket_migration() | ||
| 934 | + | ||
| 935 | + # 注册WebSocket接口 - 使用新的统一架构 | ||
| 936 | + setup_app_websocket_routes(appasync) | ||
| 937 | + | ||
| 938 | + # 异步初始化将在服务器启动时进行 | ||
| 939 | + async def init_websocket_migration(): | ||
| 940 | + await initialize_app_websocket_migration() | ||
| 941 | + logger.info("WebSocket迁移架构初始化完成") | ||
| 942 | + | ||
| 943 | + # 添加启动时初始化 | ||
| 944 | + appasync.on_startup.append(lambda app: init_websocket_migration()) | ||
| 945 | + | ||
| 800 | appasync.router.add_static('/',path='web') | 946 | appasync.router.add_static('/',path='web') |
| 801 | 947 | ||
| 948 | + # 服务端录音WebSocket接口已集成到统一架构中 | ||
| 949 | + # 通过 /ws 路由和消息类型区分访问:wsa_register_web, wsa_register_human 等 | ||
| 950 | + logger.info("主应用路由配置完成,WebSocket接口已统一到 /ws 路由") | ||
| 951 | + | ||
| 802 | # Configure default CORS settings. | 952 | # Configure default CORS settings. |
| 803 | cors = aiohttp_cors.setup(appasync, defaults={ | 953 | cors = aiohttp_cors.setup(appasync, defaults={ |
| 804 | "*": aiohttp_cors.ResourceOptions( | 954 | "*": aiohttp_cors.ResourceOptions( |
| @@ -819,8 +969,13 @@ if __name__ == '__main__': | @@ -819,8 +969,13 @@ if __name__ == '__main__': | ||
| 819 | logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename) | 969 | logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename) |
| 820 | logger.info('如果使用webrtc,推荐访问webrtc集成前端: http://<serverip>:'+str(opt.listenport)+'/dashboard.html') | 970 | logger.info('如果使用webrtc,推荐访问webrtc集成前端: http://<serverip>:'+str(opt.listenport)+'/dashboard.html') |
| 821 | def run_server(runner): | 971 | def run_server(runner): |
| 972 | + global main_event_loop | ||
| 822 | loop = asyncio.new_event_loop() | 973 | loop = asyncio.new_event_loop() |
| 823 | asyncio.set_event_loop(loop) | 974 | asyncio.set_event_loop(loop) |
| 975 | + # 设置全局事件循环引用,用于跨线程异步调用 | ||
| 976 | + main_event_loop = loop | ||
| 977 | + logger.info("全局事件循环引用已设置") | ||
| 978 | + | ||
| 824 | loop.run_until_complete(runner.setup()) | 979 | loop.run_until_complete(runner.setup()) |
| 825 | site = web.TCPSite(runner, '0.0.0.0', opt.listenport) | 980 | site = web.TCPSite(runner, '0.0.0.0', opt.listenport) |
| 826 | loop.run_until_complete(site.start()) | 981 | loop.run_until_complete(site.start()) |
| 1 | # -*- coding: utf-8 -*- | 1 | # -*- coding: utf-8 -*- |
| 2 | """ | 2 | """ |
| 3 | -AIfeng/2025-01-27 | 3 | +AIfeng/2025-06-30 |
| 4 | 配置管理工具模块 | 4 | 配置管理工具模块 |
| 5 | 统一管理项目配置参数 | 5 | 统一管理项目配置参数 |
| 6 | """ | 6 | """ |
| @@ -24,6 +24,9 @@ class ConfigManager: | @@ -24,6 +24,9 @@ class ConfigManager: | ||
| 24 | 'local_asr_ip': '127.0.0.1', | 24 | 'local_asr_ip': '127.0.0.1', |
| 25 | 'local_asr_port': 10197, | 25 | 'local_asr_port': 10197, |
| 26 | 26 | ||
| 27 | + # 音频设备配置 | ||
| 28 | + 'local_audio_ip': '127.0.0.1', | ||
| 29 | + | ||
| 27 | # 阿里云NLS配置 | 30 | # 阿里云NLS配置 |
| 28 | 'key_ali_nls_key_id': '', | 31 | 'key_ali_nls_key_id': '', |
| 29 | 'key_ali_nls_key_secret': '', | 32 | 'key_ali_nls_key_secret': '', |
| @@ -82,6 +85,7 @@ _config_manager = ConfigManager() | @@ -82,6 +85,7 @@ _config_manager = ConfigManager() | ||
| 82 | # 兼容原有的属性访问方式 | 85 | # 兼容原有的属性访问方式 |
| 83 | local_asr_ip = _config_manager.local_asr_ip | 86 | local_asr_ip = _config_manager.local_asr_ip |
| 84 | local_asr_port = _config_manager.local_asr_port | 87 | local_asr_port = _config_manager.local_asr_port |
| 88 | +local_audio_ip = _config_manager.local_audio_ip | ||
| 85 | key_ali_nls_key_id = _config_manager.key_ali_nls_key_id | 89 | key_ali_nls_key_id = _config_manager.key_ali_nls_key_id |
| 86 | key_ali_nls_key_secret = _config_manager.key_ali_nls_key_secret | 90 | key_ali_nls_key_secret = _config_manager.key_ali_nls_key_secret |
| 87 | key_ali_nls_app_key = _config_manager.key_ali_nls_app_key | 91 | key_ali_nls_app_key = _config_manager.key_ali_nls_app_key |
| 1 | # -*- coding: utf-8 -*- | 1 | # -*- coding: utf-8 -*- |
| 2 | """ | 2 | """ |
| 3 | -AIfeng/2025-01-27 | ||
| 4 | -Core模块初始化文件 | 3 | +AIfeng/2025-07-17 14:15:27 |
| 4 | +Core模块初始化 | ||
| 5 | +已迁移到统一WebSocket架构,移除旧的wsa_server依赖 | ||
| 5 | """ | 6 | """ |
| 6 | 7 | ||
| 7 | -from .wsa_server import get_web_instance, get_instance | 8 | +# 统一WebSocket架构导入 |
| 9 | +from .wsa_websocket_service import get_web_instance, get_instance | ||
| 8 | 10 | ||
| 9 | __all__ = ['get_web_instance', 'get_instance'] | 11 | __all__ = ['get_web_instance', 'get_instance'] |
core/app_websocket_migration.py
0 → 100644
| 1 | +# -*- coding: utf-8 -*- | ||
| 2 | +""" | ||
| 3 | +AIfeng/2025-07-15 14:41:21 | ||
| 4 | +app.py WebSocket功能迁移脚本 | ||
| 5 | +将app.py中的WebSocket功能迁移到统一架构 | ||
| 6 | +""" | ||
| 7 | + | ||
| 8 | +import asyncio | ||
| 9 | +import json | ||
| 10 | +import weakref | ||
| 11 | +from typing import Dict, Any, Optional | ||
| 12 | +from aiohttp import web | ||
| 13 | +from logger import logger | ||
| 14 | +from .websocket_router import get_websocket_router, get_websocket_compatibility_api | ||
| 15 | +from .asr_websocket_service import get_asr_service | ||
| 16 | +from .digital_human_websocket_service import get_digital_human_service | ||
| 17 | + | ||
| 18 | + | ||
| 19 | +class AppWebSocketMigration: | ||
| 20 | + """app.py WebSocket功能迁移类""" | ||
| 21 | + | ||
| 22 | + def __init__(self): | ||
| 23 | + self.router = get_websocket_router() | ||
| 24 | + self.compatibility_api = get_websocket_compatibility_api() | ||
| 25 | + self.asr_service = get_asr_service() | ||
| 26 | + self.digital_human_service = get_digital_human_service() | ||
| 27 | + | ||
| 28 | + # 兼容性变量(保持与原app.py的接口一致) | ||
| 29 | + self.websocket_connections = {} | ||
| 30 | + self.asr_connections = {} | ||
| 31 | + | ||
| 32 | + async def initialize(self): | ||
| 33 | + """初始化迁移组件""" | ||
| 34 | + await self.router.initialize() | ||
| 35 | + logger.info('WebSocket迁移组件初始化完成') | ||
| 36 | + | ||
| 37 | + async def shutdown(self): | ||
| 38 | + """关闭迁移组件""" | ||
| 39 | + await self.router.shutdown() | ||
| 40 | + logger.info('WebSocket迁移组件已关闭') | ||
| 41 | + | ||
| 42 | + def setup_routes(self, app: web.Application): | ||
| 43 | + """设置路由(替换原app.py中的WebSocket路由)""" | ||
| 44 | + # 使用新的统一WebSocket处理器 | ||
| 45 | + self.router.setup_routes(app, '/ws') | ||
| 46 | + | ||
| 47 | + # 添加兼容性路由(如果需要) | ||
| 48 | + app.router.add_get('/ws_legacy', self._legacy_websocket_handler) | ||
| 49 | + | ||
| 50 | + async def _legacy_websocket_handler(self, request: web.Request): | ||
| 51 | + """兼容性WebSocket处理器(保持原有接口)""" | ||
| 52 | + # 直接转发到新的统一处理器 | ||
| 53 | + return await self.router.websocket_handler(request) | ||
| 54 | + | ||
| 55 | + # 兼容性接口方法 | ||
| 56 | + async def broadcast_message_to_session(self, sessionid: int, message_type: str, | ||
| 57 | + content: str, source: str = "数字人回复", | ||
| 58 | + model_info: str = None, request_source: str = "页面"): | ||
| 59 | + """兼容原app.py的消息推送接口""" | ||
| 60 | + message_data = { | ||
| 61 | + "sessionid": sessionid, | ||
| 62 | + "message_type": message_type, | ||
| 63 | + "content": content, | ||
| 64 | + "source": source, | ||
| 65 | + "model_info": model_info, | ||
| 66 | + "request_source": request_source, | ||
| 67 | + "timestamp": asyncio.get_event_loop().time() | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | + return await self.router.send_to_session(str(sessionid), 'chat_message', message_data) | ||
| 71 | + | ||
| 72 | + async def handle_asr_audio_data(self, data: Dict[str, Any], sessionid: int, ws): | ||
| 73 | + """兼容原app.py的ASR音频数据处理""" | ||
| 74 | + # 转换为新架构的消息格式 | ||
| 75 | + message_data = { | ||
| 76 | + 'audio_data': data.get('audio_data'), | ||
| 77 | + 'sessionid': sessionid | ||
| 78 | + } | ||
| 79 | + | ||
| 80 | + # 通过新的ASR服务处理 | ||
| 81 | + session = self.router.manager.get_session(ws) | ||
| 82 | + if session: | ||
| 83 | + await self.asr_service._handle_asr_audio_data(ws, message_data) | ||
| 84 | + | ||
| 85 | + async def handle_start_asr_recognition(self, sessionid: int, ws): | ||
| 86 | + """兼容原app.py的开始ASR识别""" | ||
| 87 | + session = self.router.manager.get_session(ws) | ||
| 88 | + if session: | ||
| 89 | + await self.asr_service._handle_start_asr_recognition(ws, {'sessionid': sessionid}) | ||
| 90 | + | ||
| 91 | + async def handle_stop_asr_recognition(self, sessionid: int, ws): | ||
| 92 | + """兼容原app.py的停止ASR识别""" | ||
| 93 | + session = self.router.manager.get_session(ws) | ||
| 94 | + if session: | ||
| 95 | + await self.asr_service._handle_stop_asr_recognition(ws, {'sessionid': sessionid}) | ||
| 96 | + | ||
| 97 | + async def send_asr_result(self, sessionid: int, result: Dict[str, Any]): | ||
| 98 | + """兼容原app.py的ASR结果发送""" | ||
| 99 | + return await self.router.send_to_session(str(sessionid), 'asr_result', { | ||
| 100 | + "text": result.get('text', ''), | ||
| 101 | + "is_final": result.get('is_final', False), | ||
| 102 | + "confidence": result.get('confidence', 0.0) | ||
| 103 | + }) | ||
| 104 | + | ||
| 105 | + async def send_normal_asr_result(self, sessionid: int, result: Dict[str, Any]): | ||
| 106 | + """业务层决定传输内容以及结构""" | ||
| 107 | + return await self.router.send_raw_to_session(str(sessionid), result) | ||
| 108 | + | ||
| 109 | + | ||
| 110 | + def get_websocket_connections(self): | ||
| 111 | + """获取WebSocket连接(兼容性接口)""" | ||
| 112 | + # 返回兼容性字典格式,键为会话ID,值为WebSocket对象 | ||
| 113 | + sessions_dict = self.router.manager._sessions | ||
| 114 | + result = {} | ||
| 115 | + for session_id, session_set in sessions_dict.items(): | ||
| 116 | + # 取集合中的第一个WebSocket连接(通常每个session_id只有一个连接) | ||
| 117 | + if session_set: | ||
| 118 | + session = next(iter(session_set)) | ||
| 119 | + result[session_id] = session.websocket | ||
| 120 | + return result | ||
| 121 | + | ||
| 122 | + def get_session_count(self): | ||
| 123 | + """获取会话数量(兼容性接口)""" | ||
| 124 | + return self.compatibility_api.get_session_count() | ||
| 125 | + | ||
| 126 | + async def cleanup_session(self, sessionid: int): | ||
| 127 | + """清理会话(兼容性接口)""" | ||
| 128 | + # 清理ASR连接 | ||
| 129 | + if sessionid in self.asr_connections: | ||
| 130 | + del self.asr_connections[sessionid] | ||
| 131 | + | ||
| 132 | + # 通过新架构清理会话 | ||
| 133 | + sessions = self.router.manager._sessions | ||
| 134 | + session_id_str = str(sessionid) | ||
| 135 | + | ||
| 136 | + for ws, session in list(sessions.items()): | ||
| 137 | + if session.session_id == session_id_str: | ||
| 138 | + await self.router.manager.remove_session(ws) | ||
| 139 | + break | ||
| 140 | + | ||
| 141 | + def get_migration_stats(self) -> Dict[str, Any]: | ||
| 142 | + """获取迁移统计信息""" | ||
| 143 | + return { | ||
| 144 | + "router_stats": self.router.get_router_stats(), | ||
| 145 | + "asr_stats": self.asr_service.get_asr_stats(), | ||
| 146 | + "digital_human_stats": self.digital_human_service.get_digital_human_stats(), | ||
| 147 | + "compatibility_sessions": len(self.websocket_connections), | ||
| 148 | + "compatibility_asr_connections": len(self.asr_connections) | ||
| 149 | + } | ||
| 150 | + | ||
| 151 | + | ||
| 152 | +# 全局迁移实例 | ||
| 153 | +_migration_instance = None | ||
| 154 | + | ||
| 155 | + | ||
| 156 | +def get_app_websocket_migration() -> AppWebSocketMigration: | ||
| 157 | + """获取app.py WebSocket迁移实例""" | ||
| 158 | + global _migration_instance | ||
| 159 | + if _migration_instance is None: | ||
| 160 | + _migration_instance = AppWebSocketMigration() | ||
| 161 | + return _migration_instance | ||
| 162 | + | ||
| 163 | + | ||
| 164 | +async def initialize_app_websocket_migration(): | ||
| 165 | + """初始化app.py WebSocket迁移""" | ||
| 166 | + migration = get_app_websocket_migration() | ||
| 167 | + await migration.initialize() | ||
| 168 | + return migration | ||
| 169 | + | ||
| 170 | + | ||
| 171 | +async def shutdown_app_websocket_migration(): | ||
| 172 | + """关闭app.py WebSocket迁移""" | ||
| 173 | + global _migration_instance | ||
| 174 | + if _migration_instance: | ||
| 175 | + await _migration_instance.shutdown() | ||
| 176 | + _migration_instance = None | ||
| 177 | + | ||
| 178 | + | ||
| 179 | +def setup_app_websocket_routes(app: web.Application): | ||
| 180 | + """设置app.py WebSocket路由(便捷函数)""" | ||
| 181 | + migration = get_app_websocket_migration() | ||
| 182 | + migration.setup_routes(app) | ||
| 183 | + return migration | ||
| 184 | + | ||
| 185 | + | ||
| 186 | +# 兼容性函数(保持与原app.py的接口一致) | ||
| 187 | +async def broadcast_message_to_session(sessionid: int, message_type: str, content: str, | ||
| 188 | + source: str = "数字人回复", model_info: str = None, | ||
| 189 | + request_source: str = "页面"): | ||
| 190 | + """兼容原app.py的消息推送函数""" | ||
| 191 | + migration = get_app_websocket_migration() | ||
| 192 | + return await migration.broadcast_message_to_session( | ||
| 193 | + sessionid, message_type, content, source, model_info, request_source | ||
| 194 | + ) | ||
| 195 | + | ||
| 196 | + | ||
| 197 | +async def handle_asr_audio_data(data: Dict[str, Any], sessionid: int, ws): | ||
| 198 | + """兼容原app.py的ASR音频数据处理函数""" | ||
| 199 | + migration = get_app_websocket_migration() | ||
| 200 | + return await migration.handle_asr_audio_data(data, sessionid, ws) | ||
| 201 | + | ||
| 202 | + | ||
| 203 | +async def handle_start_asr_recognition(sessionid: int, ws): | ||
| 204 | + """兼容原app.py的开始ASR识别函数""" | ||
| 205 | + migration = get_app_websocket_migration() | ||
| 206 | + return await migration.handle_start_asr_recognition(sessionid, ws) | ||
| 207 | + | ||
| 208 | + | ||
| 209 | +async def handle_stop_asr_recognition(sessionid: int, ws): | ||
| 210 | + """兼容原app.py的停止ASR识别函数""" | ||
| 211 | + migration = get_app_websocket_migration() | ||
| 212 | + return await migration.handle_stop_asr_recognition(sessionid, ws) | ||
| 213 | + | ||
| 214 | + | ||
| 215 | +async def send_asr_result(sessionid: int, result: Dict[str, Any]): | ||
| 216 | + """兼容原app.py的ASR结果发送函数""" | ||
| 217 | + migration = get_app_websocket_migration() | ||
| 218 | + return await migration.send_asr_result(sessionid, result) | ||
| 219 | + | ||
| 220 | +async def send_normal_asr_result(sessionid: int, result: Dict[str, Any]): | ||
| 221 | + """兼容原app.py的ASR结果发送函数""" | ||
| 222 | + migration = get_app_websocket_migration() | ||
| 223 | + return await migration.send_normal_asr_result(sessionid, result) | ||
| 224 | + | ||
| 225 | + | ||
| 226 | +# 全局变量兼容性接口 | ||
| 227 | +def get_websocket_connections(): | ||
| 228 | + """获取WebSocket连接字典(兼容性接口)""" | ||
| 229 | + migration = get_app_websocket_migration() | ||
| 230 | + return migration.websocket_connections | ||
| 231 | + | ||
| 232 | + | ||
| 233 | +def get_asr_connections(): | ||
| 234 | + """获取ASR连接字典(兼容性接口)""" | ||
| 235 | + migration = get_app_websocket_migration() | ||
| 236 | + return migration.asr_connections |
core/asr_websocket_service.py
0 → 100644
| 1 | +# -*- coding: utf-8 -*- | ||
| 2 | +""" | ||
| 3 | +AIfeng/2025-07-15 14:41:21 | ||
| 4 | +ASR WebSocket服务实现 | ||
| 5 | +从app.py中抽离的ASR相关WebSocket功能 | ||
| 6 | +""" | ||
| 7 | + | ||
| 8 | +import asyncio | ||
| 9 | +import json | ||
| 10 | +import weakref | ||
| 11 | +from typing import Dict, Any, Optional | ||
| 12 | +from aiohttp import web | ||
| 13 | +from logger import logger | ||
| 14 | +from .websocket_service_base import WebSocketServiceBase | ||
| 15 | +from .unified_websocket_manager import WebSocketSession | ||
| 16 | + | ||
| 17 | + | ||
| 18 | +class ASRWebSocketService(WebSocketServiceBase): | ||
| 19 | + """ASR WebSocket服务""" | ||
| 20 | + | ||
| 21 | + def __init__(self): | ||
| 22 | + super().__init__("asr_service") | ||
| 23 | + # ASR连接管理 | ||
| 24 | + self.asr_connections: Dict[str, Any] = {} # sessionid -> asr_connection | ||
| 25 | + self._heartbeat_task = None | ||
| 26 | + | ||
| 27 | + async def _register_message_handlers(self): | ||
| 28 | + """注册ASR相关消息处理器""" | ||
| 29 | + # 注册消息处理器 | ||
| 30 | + self.manager.register_message_handler('login', self._handle_login) | ||
| 31 | + self.manager.register_message_handler('heartbeat', self._handle_heartbeat) | ||
| 32 | + self.manager.register_message_handler('asr_audio_data', self._handle_asr_audio_data) | ||
| 33 | + self.manager.register_message_handler('start_asr_recognition', self._handle_start_asr_recognition) | ||
| 34 | + self.manager.register_message_handler('stop_asr_recognition', self._handle_stop_asr_recognition) | ||
| 35 | + | ||
| 36 | + async def _start_background_tasks(self): | ||
| 37 | + """启动心跳检测任务""" | ||
| 38 | + self._heartbeat_task = self.add_background_task(self._heartbeat_monitor()) | ||
| 39 | + | ||
| 40 | + async def _cleanup(self): | ||
| 41 | + """清理ASR连接""" | ||
| 42 | + # 关闭所有ASR连接 | ||
| 43 | + for session_id, asr_conn in list(self.asr_connections.items()): | ||
| 44 | + try: | ||
| 45 | + if hasattr(asr_conn, 'close'): | ||
| 46 | + await asr_conn.close() | ||
| 47 | + except Exception as e: | ||
| 48 | + logger.error(f'关闭ASR连接失败 {session_id}: {e}') | ||
| 49 | + | ||
| 50 | + self.asr_connections.clear() | ||
| 51 | + | ||
| 52 | + async def _on_session_disconnected(self, session: WebSocketSession): | ||
| 53 | + """会话断开时清理ASR连接""" | ||
| 54 | + await super()._on_session_disconnected(session) | ||
| 55 | + | ||
| 56 | + # 清理对应的ASR连接 | ||
| 57 | + if session.session_id in self.asr_connections: | ||
| 58 | + asr_conn = self.asr_connections.pop(session.session_id) | ||
| 59 | + try: | ||
| 60 | + if hasattr(asr_conn, 'close'): | ||
| 61 | + await asr_conn.close() | ||
| 62 | + logger.info(f'已清理ASR连接: {session.session_id}') | ||
| 63 | + except Exception as e: | ||
| 64 | + logger.error(f'清理ASR连接失败 {session.session_id}: {e}') | ||
| 65 | + | ||
| 66 | + async def _handle_login(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 67 | + """处理登录消息""" | ||
| 68 | + session = self.manager.get_session(websocket) | ||
| 69 | + if not session: | ||
| 70 | + return | ||
| 71 | + | ||
| 72 | + session_id = data.get('sessionid') | ||
| 73 | + if session_id: | ||
| 74 | + # 更新会话ID | ||
| 75 | + old_session_id = session.session_id | ||
| 76 | + session.session_id = session_id | ||
| 77 | + | ||
| 78 | + # 更新管理器中的会话映射 | ||
| 79 | + self.manager._update_session_id(websocket, old_session_id, session_id) | ||
| 80 | + | ||
| 81 | + # 发送登录成功响应 | ||
| 82 | + await session.send_message({ | ||
| 83 | + "type": "login_response", | ||
| 84 | + "data": { | ||
| 85 | + "status": "success", | ||
| 86 | + "sessionid": session_id, | ||
| 87 | + "message": "登录成功" | ||
| 88 | + } | ||
| 89 | + }) | ||
| 90 | + | ||
| 91 | + logger.info(f'用户登录成功: {session_id}') | ||
| 92 | + else: | ||
| 93 | + await session.send_message({ | ||
| 94 | + "type": "login_response", | ||
| 95 | + "data": { | ||
| 96 | + "status": "error", | ||
| 97 | + "message": "缺少sessionid" | ||
| 98 | + } | ||
| 99 | + }) | ||
| 100 | + | ||
| 101 | + async def _handle_heartbeat(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 102 | + """处理心跳消息""" | ||
| 103 | + session = self.manager.get_session(websocket) | ||
| 104 | + if session: | ||
| 105 | + session.update_last_heartbeat() | ||
| 106 | + await session.send_message({ | ||
| 107 | + "type": "heartbeat_response", | ||
| 108 | + "data": {"status": "ok"} | ||
| 109 | + }) | ||
| 110 | + | ||
| 111 | + async def _handle_asr_audio_data(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 112 | + """处理ASR音频数据""" | ||
| 113 | + session = self.manager.get_session(websocket) | ||
| 114 | + if not session: | ||
| 115 | + return | ||
| 116 | + | ||
| 117 | + session_id = session.session_id | ||
| 118 | + audio_data = data.get('audio_data') | ||
| 119 | + | ||
| 120 | + if not audio_data: | ||
| 121 | + await session.send_message({ | ||
| 122 | + "type": "error", | ||
| 123 | + "data": {"message": "缺少音频数据"} | ||
| 124 | + }) | ||
| 125 | + return | ||
| 126 | + | ||
| 127 | + # 获取或创建ASR连接 | ||
| 128 | + asr_conn = self.asr_connections.get(session_id) | ||
| 129 | + if not asr_conn: | ||
| 130 | + logger.warning(f'ASR连接不存在: {session_id}') | ||
| 131 | + await session.send_message({ | ||
| 132 | + "type": "error", | ||
| 133 | + "data": {"message": "ASR连接未建立"} | ||
| 134 | + }) | ||
| 135 | + return | ||
| 136 | + | ||
| 137 | + try: | ||
| 138 | + # 转发音频数据到ASR服务 | ||
| 139 | + await self._forward_audio_to_asr(asr_conn, audio_data) | ||
| 140 | + except Exception as e: | ||
| 141 | + logger.error(f'转发音频数据失败 {session_id}: {e}') | ||
| 142 | + await session.send_message({ | ||
| 143 | + "type": "error", | ||
| 144 | + "data": {"message": f"音频处理失败: {str(e)}"} | ||
| 145 | + }) | ||
| 146 | + | ||
| 147 | + async def _handle_start_asr_recognition(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 148 | + """处理开始ASR识别""" | ||
| 149 | + session = self.manager.get_session(websocket) | ||
| 150 | + if not session: | ||
| 151 | + return | ||
| 152 | + | ||
| 153 | + session_id = session.session_id | ||
| 154 | + | ||
| 155 | + try: | ||
| 156 | + # 创建ASR连接 | ||
| 157 | + asr_conn = await self._create_asr_connection(session_id) | ||
| 158 | + if asr_conn: | ||
| 159 | + self.asr_connections[session_id] = asr_conn | ||
| 160 | + | ||
| 161 | + await session.send_message({ | ||
| 162 | + "type": "asr_recognition_started", | ||
| 163 | + "data": { | ||
| 164 | + "status": "success", | ||
| 165 | + "message": "ASR识别已开始" | ||
| 166 | + } | ||
| 167 | + }) | ||
| 168 | + | ||
| 169 | + logger.info(f'ASR识别已开始: {session_id}') | ||
| 170 | + else: | ||
| 171 | + await session.send_message({ | ||
| 172 | + "type": "error", | ||
| 173 | + "data": {"message": "创建ASR连接失败"} | ||
| 174 | + }) | ||
| 175 | + except Exception as e: | ||
| 176 | + logger.error(f'开始ASR识别失败 {session_id}: {e}') | ||
| 177 | + await session.send_message({ | ||
| 178 | + "type": "error", | ||
| 179 | + "data": {"message": f"开始识别失败: {str(e)}"} | ||
| 180 | + }) | ||
| 181 | + | ||
| 182 | + async def _handle_stop_asr_recognition(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 183 | + """处理停止ASR识别""" | ||
| 184 | + session = self.manager.get_session(websocket) | ||
| 185 | + if not session: | ||
| 186 | + return | ||
| 187 | + | ||
| 188 | + session_id = session.session_id | ||
| 189 | + | ||
| 190 | + if session_id in self.asr_connections: | ||
| 191 | + asr_conn = self.asr_connections.pop(session_id) | ||
| 192 | + try: | ||
| 193 | + if hasattr(asr_conn, 'close'): | ||
| 194 | + await asr_conn.close() | ||
| 195 | + | ||
| 196 | + await session.send_message({ | ||
| 197 | + "type": "asr_recognition_stopped", | ||
| 198 | + "data": { | ||
| 199 | + "status": "success", | ||
| 200 | + "message": "ASR识别已停止" | ||
| 201 | + } | ||
| 202 | + }) | ||
| 203 | + | ||
| 204 | + logger.info(f'ASR识别已停止: {session_id}') | ||
| 205 | + except Exception as e: | ||
| 206 | + logger.error(f'停止ASR识别失败 {session_id}: {e}') | ||
| 207 | + await session.send_message({ | ||
| 208 | + "type": "error", | ||
| 209 | + "data": {"message": f"停止识别失败: {str(e)}"} | ||
| 210 | + }) | ||
| 211 | + else: | ||
| 212 | + await session.send_message({ | ||
| 213 | + "type": "asr_recognition_stopped", | ||
| 214 | + "data": { | ||
| 215 | + "status": "success", | ||
| 216 | + "message": "ASR识别未在运行" | ||
| 217 | + } | ||
| 218 | + }) | ||
| 219 | + | ||
| 220 | + async def _create_asr_connection(self, session_id: str): | ||
| 221 | + """创建ASR连接(需要根据实际ASR服务实现)""" | ||
| 222 | + # TODO: 这里需要根据实际的ASR服务(如FunASR)来实现连接逻辑 | ||
| 223 | + # 暂时返回一个模拟连接对象 | ||
| 224 | + logger.info(f'创建ASR连接: {session_id}') | ||
| 225 | + | ||
| 226 | + # 示例:创建到FunASR的WebSocket连接 | ||
| 227 | + try: | ||
| 228 | + # 这里应该是实际的ASR连接逻辑 | ||
| 229 | + # 例如:asr_conn = await create_funasr_connection(session_id, self._on_asr_result) | ||
| 230 | + asr_conn = MockASRConnection(session_id, self._on_asr_result) | ||
| 231 | + return asr_conn | ||
| 232 | + except Exception as e: | ||
| 233 | + logger.error(f'创建ASR连接失败 {session_id}: {e}') | ||
| 234 | + return None | ||
| 235 | + | ||
| 236 | + async def _forward_audio_to_asr(self, asr_conn, audio_data): | ||
| 237 | + """转发音频数据到ASR服务""" | ||
| 238 | + if hasattr(asr_conn, 'send_audio'): | ||
| 239 | + await asr_conn.send_audio(audio_data) | ||
| 240 | + else: | ||
| 241 | + logger.warning('ASR连接不支持发送音频数据') | ||
| 242 | + | ||
| 243 | + async def _on_asr_result(self, session_id: str, result: Dict[str, Any]): | ||
| 244 | + """ASR结果回调""" | ||
| 245 | + try: | ||
| 246 | + await self.broadcast_to_session(session_id, 'asr_result', result) | ||
| 247 | + logger.debug(f'ASR结果已发送: {session_id}') | ||
| 248 | + except Exception as e: | ||
| 249 | + logger.error(f'发送ASR结果失败 {session_id}: {e}') | ||
| 250 | + | ||
| 251 | + async def _heartbeat_monitor(self): | ||
| 252 | + """心跳监控任务""" | ||
| 253 | + while True: | ||
| 254 | + try: | ||
| 255 | + await asyncio.sleep(40) # 每40秒检查一次 | ||
| 256 | + | ||
| 257 | + # 检查会话心跳 | ||
| 258 | + expired_sessions = self.manager.get_expired_sessions(timeout=60) | ||
| 259 | + for session in expired_sessions: | ||
| 260 | + logger.info(f'会话心跳超时,断开连接: {session.session_id}') | ||
| 261 | + await session.close() | ||
| 262 | + | ||
| 263 | + except asyncio.CancelledError: | ||
| 264 | + break | ||
| 265 | + except Exception as e: | ||
| 266 | + logger.error(f'心跳监控异常: {e}') | ||
| 267 | + await asyncio.sleep(5) | ||
| 268 | + | ||
| 269 | + def get_asr_stats(self) -> Dict[str, Any]: | ||
| 270 | + """获取ASR统计信息""" | ||
| 271 | + return { | ||
| 272 | + "active_asr_connections": len(self.asr_connections), | ||
| 273 | + "asr_sessions": list(self.asr_connections.keys()) | ||
| 274 | + } | ||
| 275 | + | ||
| 276 | + | ||
| 277 | +class MockASRConnection: | ||
| 278 | + """模拟ASR连接(用于测试)""" | ||
| 279 | + | ||
| 280 | + def __init__(self, session_id: str, result_callback): | ||
| 281 | + self.session_id = session_id | ||
| 282 | + self.result_callback = result_callback | ||
| 283 | + self.is_closed = False | ||
| 284 | + | ||
| 285 | + async def send_audio(self, audio_data): | ||
| 286 | + """发送音频数据""" | ||
| 287 | + if self.is_closed: | ||
| 288 | + return | ||
| 289 | + | ||
| 290 | + # 模拟ASR处理 | ||
| 291 | + await asyncio.sleep(0.1) | ||
| 292 | + | ||
| 293 | + # 模拟返回识别结果 | ||
| 294 | + result = { | ||
| 295 | + "text": "模拟识别结果", | ||
| 296 | + "confidence": 0.95, | ||
| 297 | + "timestamp": asyncio.get_event_loop().time() | ||
| 298 | + } | ||
| 299 | + | ||
| 300 | + if self.result_callback: | ||
| 301 | + await self.result_callback(self.session_id, result) | ||
| 302 | + | ||
| 303 | + async def close(self): | ||
| 304 | + """关闭连接""" | ||
| 305 | + self.is_closed = True | ||
| 306 | + logger.info(f'模拟ASR连接已关闭: {self.session_id}') | ||
| 307 | + | ||
| 308 | + | ||
| 309 | +# 创建ASR服务实例 | ||
| 310 | +asr_service = ASRWebSocketService() | ||
| 311 | + | ||
| 312 | + | ||
| 313 | +def get_asr_service() -> ASRWebSocketService: | ||
| 314 | + """获取ASR服务实例""" | ||
| 315 | + return asr_service |
core/digital_human_websocket_service.py
0 → 100644
| 1 | +# -*- coding: utf-8 -*- | ||
| 2 | +""" | ||
| 3 | +AIfeng/2025-07-15 14:41:21 | ||
| 4 | +数字人WebSocket服务实现 | ||
| 5 | +处理数字人相关的WebSocket通信和状态管理 | ||
| 6 | +""" | ||
| 7 | + | ||
| 8 | +import asyncio | ||
| 9 | +import json | ||
| 10 | +from typing import Dict, Any, Optional, List | ||
| 11 | +from aiohttp import web | ||
| 12 | +from logger import logger | ||
| 13 | +from .websocket_service_base import WebSocketServiceBase | ||
| 14 | +from .unified_websocket_manager import WebSocketSession | ||
| 15 | + | ||
| 16 | + | ||
| 17 | +class DigitalHumanWebSocketService(WebSocketServiceBase): | ||
| 18 | + """数字人WebSocket服务""" | ||
| 19 | + | ||
| 20 | + def __init__(self): | ||
| 21 | + super().__init__("digital_human_service") | ||
| 22 | + # 数字人状态管理 | ||
| 23 | + self.digital_humans: Dict[str, Dict[str, Any]] = {} # human_id -> human_info | ||
| 24 | + self.human_sessions: Dict[str, str] = {} # session_id -> human_id | ||
| 25 | + self.session_humans: Dict[str, List[str]] = {} # session_id -> [human_ids] | ||
| 26 | + | ||
| 27 | + async def _register_message_handlers(self): | ||
| 28 | + """注册数字人相关消息处理器""" | ||
| 29 | + self.manager.register_message_handler('register_digital_human', self._handle_register_digital_human) | ||
| 30 | + self.manager.register_message_handler('unregister_digital_human', self._handle_unregister_digital_human) | ||
| 31 | + self.manager.register_message_handler('digital_human_status', self._handle_digital_human_status) | ||
| 32 | + self.manager.register_message_handler('digital_human_action', self._handle_digital_human_action) | ||
| 33 | + self.manager.register_message_handler('digital_human_speak', self._handle_digital_human_speak) | ||
| 34 | + self.manager.register_message_handler('digital_human_emotion', self._handle_digital_human_emotion) | ||
| 35 | + self.manager.register_message_handler('get_digital_humans', self._handle_get_digital_humans) | ||
| 36 | + | ||
| 37 | + async def _on_session_disconnected(self, session: WebSocketSession): | ||
| 38 | + """会话断开时清理数字人注册""" | ||
| 39 | + await super()._on_session_disconnected(session) | ||
| 40 | + | ||
| 41 | + session_id = session.session_id | ||
| 42 | + | ||
| 43 | + # 清理该会话注册的数字人 | ||
| 44 | + if session_id in self.session_humans: | ||
| 45 | + human_ids = self.session_humans.pop(session_id) | ||
| 46 | + for human_id in human_ids: | ||
| 47 | + if human_id in self.digital_humans: | ||
| 48 | + human_info = self.digital_humans.pop(human_id) | ||
| 49 | + logger.info(f'数字人已注销: {human_id} (会话断开)') | ||
| 50 | + | ||
| 51 | + # 通知其他会话数字人已离线 | ||
| 52 | + await self.broadcast_to_all('digital_human_offline', { | ||
| 53 | + 'human_id': human_id, | ||
| 54 | + 'name': human_info.get('name', ''), | ||
| 55 | + 'reason': 'session_disconnected' | ||
| 56 | + }) | ||
| 57 | + | ||
| 58 | + # 清理会话到数字人的映射 | ||
| 59 | + if session_id in self.human_sessions: | ||
| 60 | + del self.human_sessions[session_id] | ||
| 61 | + | ||
| 62 | + async def _handle_register_digital_human(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 63 | + """处理数字人注册""" | ||
| 64 | + session = self.manager.get_session(websocket) | ||
| 65 | + if not session: | ||
| 66 | + return | ||
| 67 | + | ||
| 68 | + human_id = data.get('human_id') | ||
| 69 | + human_name = data.get('name', '') | ||
| 70 | + human_type = data.get('type', 'default') | ||
| 71 | + capabilities = data.get('capabilities', []) | ||
| 72 | + | ||
| 73 | + if not human_id: | ||
| 74 | + await session.send_message({ | ||
| 75 | + "type": "error", | ||
| 76 | + "data": {"message": "缺少human_id"} | ||
| 77 | + }) | ||
| 78 | + return | ||
| 79 | + | ||
| 80 | + # 检查数字人是否已存在 | ||
| 81 | + if human_id in self.digital_humans: | ||
| 82 | + await session.send_message({ | ||
| 83 | + "type": "register_digital_human_response", | ||
| 84 | + "data": { | ||
| 85 | + "status": "error", | ||
| 86 | + "message": f"数字人已存在: {human_id}" | ||
| 87 | + } | ||
| 88 | + }) | ||
| 89 | + return | ||
| 90 | + | ||
| 91 | + # 注册数字人 | ||
| 92 | + human_info = { | ||
| 93 | + 'human_id': human_id, | ||
| 94 | + 'name': human_name, | ||
| 95 | + 'type': human_type, | ||
| 96 | + 'capabilities': capabilities, | ||
| 97 | + 'session_id': session.session_id, | ||
| 98 | + 'status': 'online', | ||
| 99 | + 'registered_at': asyncio.get_event_loop().time(), | ||
| 100 | + 'last_activity': asyncio.get_event_loop().time() | ||
| 101 | + } | ||
| 102 | + | ||
| 103 | + self.digital_humans[human_id] = human_info | ||
| 104 | + self.human_sessions[session.session_id] = human_id | ||
| 105 | + | ||
| 106 | + # 记录会话注册的数字人 | ||
| 107 | + if session.session_id not in self.session_humans: | ||
| 108 | + self.session_humans[session.session_id] = [] | ||
| 109 | + self.session_humans[session.session_id].append(human_id) | ||
| 110 | + | ||
| 111 | + # 发送注册成功响应 | ||
| 112 | + await session.send_message({ | ||
| 113 | + "type": "register_digital_human_response", | ||
| 114 | + "data": { | ||
| 115 | + "status": "success", | ||
| 116 | + "human_id": human_id, | ||
| 117 | + "message": "数字人注册成功" | ||
| 118 | + } | ||
| 119 | + }) | ||
| 120 | + | ||
| 121 | + # 通知其他会话新数字人上线 | ||
| 122 | + await self.broadcast_to_all('digital_human_online', { | ||
| 123 | + 'human_id': human_id, | ||
| 124 | + 'name': human_name, | ||
| 125 | + 'type': human_type, | ||
| 126 | + 'capabilities': capabilities | ||
| 127 | + }, metadata={'exclude_session': session.session_id}) | ||
| 128 | + | ||
| 129 | + logger.info(f'数字人已注册: {human_id} ({human_name})') | ||
| 130 | + | ||
| 131 | + async def _handle_unregister_digital_human(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 132 | + """处理数字人注销""" | ||
| 133 | + session = self.manager.get_session(websocket) | ||
| 134 | + if not session: | ||
| 135 | + return | ||
| 136 | + | ||
| 137 | + human_id = data.get('human_id') | ||
| 138 | + | ||
| 139 | + if not human_id: | ||
| 140 | + await session.send_message({ | ||
| 141 | + "type": "error", | ||
| 142 | + "data": {"message": "缺少human_id"} | ||
| 143 | + }) | ||
| 144 | + return | ||
| 145 | + | ||
| 146 | + # 检查数字人是否存在且属于当前会话 | ||
| 147 | + if human_id not in self.digital_humans: | ||
| 148 | + await session.send_message({ | ||
| 149 | + "type": "unregister_digital_human_response", | ||
| 150 | + "data": { | ||
| 151 | + "status": "error", | ||
| 152 | + "message": f"数字人不存在: {human_id}" | ||
| 153 | + } | ||
| 154 | + }) | ||
| 155 | + return | ||
| 156 | + | ||
| 157 | + human_info = self.digital_humans[human_id] | ||
| 158 | + if human_info['session_id'] != session.session_id: | ||
| 159 | + await session.send_message({ | ||
| 160 | + "type": "unregister_digital_human_response", | ||
| 161 | + "data": { | ||
| 162 | + "status": "error", | ||
| 163 | + "message": "无权注销该数字人" | ||
| 164 | + } | ||
| 165 | + }) | ||
| 166 | + return | ||
| 167 | + | ||
| 168 | + # 注销数字人 | ||
| 169 | + del self.digital_humans[human_id] | ||
| 170 | + if session.session_id in self.human_sessions: | ||
| 171 | + del self.human_sessions[session.session_id] | ||
| 172 | + | ||
| 173 | + if session.session_id in self.session_humans: | ||
| 174 | + self.session_humans[session.session_id].remove(human_id) | ||
| 175 | + if not self.session_humans[session.session_id]: | ||
| 176 | + del self.session_humans[session.session_id] | ||
| 177 | + | ||
| 178 | + # 发送注销成功响应 | ||
| 179 | + await session.send_message({ | ||
| 180 | + "type": "unregister_digital_human_response", | ||
| 181 | + "data": { | ||
| 182 | + "status": "success", | ||
| 183 | + "human_id": human_id, | ||
| 184 | + "message": "数字人注销成功" | ||
| 185 | + } | ||
| 186 | + }) | ||
| 187 | + | ||
| 188 | + # 通知其他会话数字人已离线 | ||
| 189 | + await self.broadcast_to_all('digital_human_offline', { | ||
| 190 | + 'human_id': human_id, | ||
| 191 | + 'name': human_info.get('name', ''), | ||
| 192 | + 'reason': 'manual_unregister' | ||
| 193 | + }, metadata={'exclude_session': session.session_id}) | ||
| 194 | + | ||
| 195 | + logger.info(f'数字人已注销: {human_id}') | ||
| 196 | + | ||
| 197 | + async def _handle_digital_human_status(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 198 | + """处理数字人状态更新""" | ||
| 199 | + session = self.manager.get_session(websocket) | ||
| 200 | + if not session: | ||
| 201 | + return | ||
| 202 | + | ||
| 203 | + human_id = data.get('human_id') | ||
| 204 | + status = data.get('status') | ||
| 205 | + | ||
| 206 | + if not human_id or not status: | ||
| 207 | + await session.send_message({ | ||
| 208 | + "type": "error", | ||
| 209 | + "data": {"message": "缺少human_id或status"} | ||
| 210 | + }) | ||
| 211 | + return | ||
| 212 | + | ||
| 213 | + if human_id not in self.digital_humans: | ||
| 214 | + await session.send_message({ | ||
| 215 | + "type": "error", | ||
| 216 | + "data": {"message": f"数字人不存在: {human_id}"} | ||
| 217 | + }) | ||
| 218 | + return | ||
| 219 | + | ||
| 220 | + human_info = self.digital_humans[human_id] | ||
| 221 | + if human_info['session_id'] != session.session_id: | ||
| 222 | + await session.send_message({ | ||
| 223 | + "type": "error", | ||
| 224 | + "data": {"message": "无权更新该数字人状态"} | ||
| 225 | + }) | ||
| 226 | + return | ||
| 227 | + | ||
| 228 | + # 更新状态 | ||
| 229 | + old_status = human_info['status'] | ||
| 230 | + human_info['status'] = status | ||
| 231 | + human_info['last_activity'] = asyncio.get_event_loop().time() | ||
| 232 | + | ||
| 233 | + # 广播状态变化 | ||
| 234 | + await self.broadcast_to_all('digital_human_status_changed', { | ||
| 235 | + 'human_id': human_id, | ||
| 236 | + 'old_status': old_status, | ||
| 237 | + 'new_status': status, | ||
| 238 | + 'name': human_info.get('name', '') | ||
| 239 | + }) | ||
| 240 | + | ||
| 241 | + logger.info(f'数字人状态更新: {human_id} {old_status} -> {status}') | ||
| 242 | + | ||
| 243 | + async def _handle_digital_human_action(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 244 | + """处理数字人动作指令""" | ||
| 245 | + session = self.manager.get_session(websocket) | ||
| 246 | + if not session: | ||
| 247 | + return | ||
| 248 | + | ||
| 249 | + human_id = data.get('human_id') | ||
| 250 | + action = data.get('action') | ||
| 251 | + params = data.get('params', {}) | ||
| 252 | + | ||
| 253 | + if not human_id or not action: | ||
| 254 | + await session.send_message({ | ||
| 255 | + "type": "error", | ||
| 256 | + "data": {"message": "缺少human_id或action"} | ||
| 257 | + }) | ||
| 258 | + return | ||
| 259 | + | ||
| 260 | + if human_id not in self.digital_humans: | ||
| 261 | + await session.send_message({ | ||
| 262 | + "type": "error", | ||
| 263 | + "data": {"message": f"数字人不存在: {human_id}"} | ||
| 264 | + }) | ||
| 265 | + return | ||
| 266 | + | ||
| 267 | + human_info = self.digital_humans[human_id] | ||
| 268 | + | ||
| 269 | + # 更新活动时间 | ||
| 270 | + human_info['last_activity'] = asyncio.get_event_loop().time() | ||
| 271 | + | ||
| 272 | + # 转发动作指令到数字人会话 | ||
| 273 | + target_session_id = human_info['session_id'] | ||
| 274 | + await self.broadcast_to_session(target_session_id, 'digital_human_action_command', { | ||
| 275 | + 'human_id': human_id, | ||
| 276 | + 'action': action, | ||
| 277 | + 'params': params, | ||
| 278 | + 'from_session': session.session_id | ||
| 279 | + }) | ||
| 280 | + | ||
| 281 | + logger.info(f'数字人动作指令: {human_id} -> {action}') | ||
| 282 | + | ||
| 283 | + async def _handle_digital_human_speak(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 284 | + """处理数字人说话指令""" | ||
| 285 | + session = self.manager.get_session(websocket) | ||
| 286 | + if not session: | ||
| 287 | + return | ||
| 288 | + | ||
| 289 | + human_id = data.get('human_id') | ||
| 290 | + text = data.get('text') | ||
| 291 | + voice_config = data.get('voice_config', {}) | ||
| 292 | + | ||
| 293 | + if not human_id or not text: | ||
| 294 | + await session.send_message({ | ||
| 295 | + "type": "error", | ||
| 296 | + "data": {"message": "缺少human_id或text"} | ||
| 297 | + }) | ||
| 298 | + return | ||
| 299 | + | ||
| 300 | + if human_id not in self.digital_humans: | ||
| 301 | + await session.send_message({ | ||
| 302 | + "type": "error", | ||
| 303 | + "data": {"message": f"数字人不存在: {human_id}"} | ||
| 304 | + }) | ||
| 305 | + return | ||
| 306 | + | ||
| 307 | + human_info = self.digital_humans[human_id] | ||
| 308 | + | ||
| 309 | + # 更新活动时间 | ||
| 310 | + human_info['last_activity'] = asyncio.get_event_loop().time() | ||
| 311 | + | ||
| 312 | + # 转发说话指令到数字人会话 | ||
| 313 | + target_session_id = human_info['session_id'] | ||
| 314 | + await self.broadcast_to_session(target_session_id, 'digital_human_speak_command', { | ||
| 315 | + 'human_id': human_id, | ||
| 316 | + 'text': text, | ||
| 317 | + 'voice_config': voice_config, | ||
| 318 | + 'from_session': session.session_id | ||
| 319 | + }) | ||
| 320 | + | ||
| 321 | + # 广播数字人说话事件 | ||
| 322 | + await self.broadcast_to_all('digital_human_speaking', { | ||
| 323 | + 'human_id': human_id, | ||
| 324 | + 'name': human_info.get('name', ''), | ||
| 325 | + 'text': text | ||
| 326 | + }, metadata={'exclude_session': target_session_id}) | ||
| 327 | + | ||
| 328 | + logger.info(f'数字人说话指令: {human_id} -> "{text}"') | ||
| 329 | + | ||
| 330 | + async def _handle_digital_human_emotion(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 331 | + """处理数字人情感状态""" | ||
| 332 | + session = self.manager.get_session(websocket) | ||
| 333 | + if not session: | ||
| 334 | + return | ||
| 335 | + | ||
| 336 | + human_id = data.get('human_id') | ||
| 337 | + emotion = data.get('emotion') | ||
| 338 | + intensity = data.get('intensity', 1.0) | ||
| 339 | + | ||
| 340 | + if not human_id or not emotion: | ||
| 341 | + await session.send_message({ | ||
| 342 | + "type": "error", | ||
| 343 | + "data": {"message": "缺少human_id或emotion"} | ||
| 344 | + }) | ||
| 345 | + return | ||
| 346 | + | ||
| 347 | + if human_id not in self.digital_humans: | ||
| 348 | + await session.send_message({ | ||
| 349 | + "type": "error", | ||
| 350 | + "data": {"message": f"数字人不存在: {human_id}"} | ||
| 351 | + }) | ||
| 352 | + return | ||
| 353 | + | ||
| 354 | + human_info = self.digital_humans[human_id] | ||
| 355 | + | ||
| 356 | + # 更新情感状态 | ||
| 357 | + human_info['emotion'] = emotion | ||
| 358 | + human_info['emotion_intensity'] = intensity | ||
| 359 | + human_info['last_activity'] = asyncio.get_event_loop().time() | ||
| 360 | + | ||
| 361 | + # 广播情感变化 | ||
| 362 | + await self.broadcast_to_all('digital_human_emotion_changed', { | ||
| 363 | + 'human_id': human_id, | ||
| 364 | + 'name': human_info.get('name', ''), | ||
| 365 | + 'emotion': emotion, | ||
| 366 | + 'intensity': intensity | ||
| 367 | + }) | ||
| 368 | + | ||
| 369 | + logger.info(f'数字人情感更新: {human_id} -> {emotion} ({intensity})') | ||
| 370 | + | ||
| 371 | + async def _handle_get_digital_humans(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 372 | + """处理获取数字人列表请求""" | ||
| 373 | + session = self.manager.get_session(websocket) | ||
| 374 | + if not session: | ||
| 375 | + return | ||
| 376 | + | ||
| 377 | + # 返回所有在线数字人信息 | ||
| 378 | + humans_list = [] | ||
| 379 | + for human_id, human_info in self.digital_humans.items(): | ||
| 380 | + humans_list.append({ | ||
| 381 | + 'human_id': human_id, | ||
| 382 | + 'name': human_info.get('name', ''), | ||
| 383 | + 'type': human_info.get('type', 'default'), | ||
| 384 | + 'status': human_info.get('status', 'unknown'), | ||
| 385 | + 'capabilities': human_info.get('capabilities', []), | ||
| 386 | + 'emotion': human_info.get('emotion', 'neutral'), | ||
| 387 | + 'emotion_intensity': human_info.get('emotion_intensity', 1.0) | ||
| 388 | + }) | ||
| 389 | + | ||
| 390 | + await session.send_message({ | ||
| 391 | + "type": "digital_humans_list", | ||
| 392 | + "data": { | ||
| 393 | + "humans": humans_list, | ||
| 394 | + "total": len(humans_list) | ||
| 395 | + } | ||
| 396 | + }) | ||
| 397 | + | ||
| 398 | + def get_digital_human_stats(self) -> Dict[str, Any]: | ||
| 399 | + """获取数字人统计信息""" | ||
| 400 | + online_count = len([h for h in self.digital_humans.values() if h.get('status') == 'online']) | ||
| 401 | + | ||
| 402 | + return { | ||
| 403 | + "total_digital_humans": len(self.digital_humans), | ||
| 404 | + "online_digital_humans": online_count, | ||
| 405 | + "active_sessions": len(self.session_humans), | ||
| 406 | + "human_types": list(set(h.get('type', 'default') for h in self.digital_humans.values())) | ||
| 407 | + } | ||
| 408 | + | ||
| 409 | + async def send_to_digital_human(self, human_id: str, message_type: str, content: Any): | ||
| 410 | + """向指定数字人发送消息""" | ||
| 411 | + if human_id not in self.digital_humans: | ||
| 412 | + logger.warning(f'数字人不存在: {human_id}') | ||
| 413 | + return False | ||
| 414 | + | ||
| 415 | + human_info = self.digital_humans[human_id] | ||
| 416 | + target_session_id = human_info['session_id'] | ||
| 417 | + | ||
| 418 | + return await self.broadcast_to_session(target_session_id, message_type, content) | ||
| 419 | + | ||
| 420 | + async def broadcast_to_digital_humans(self, message_type: str, content: Any, | ||
| 421 | + human_filter: Optional[callable] = None): | ||
| 422 | + """向数字人广播消息""" | ||
| 423 | + sent_count = 0 | ||
| 424 | + | ||
| 425 | + for human_id, human_info in self.digital_humans.items(): | ||
| 426 | + if human_filter and not human_filter(human_info): | ||
| 427 | + continue | ||
| 428 | + | ||
| 429 | + target_session_id = human_info['session_id'] | ||
| 430 | + success = await self.broadcast_to_session(target_session_id, message_type, content) | ||
| 431 | + if success: | ||
| 432 | + sent_count += 1 | ||
| 433 | + | ||
| 434 | + return sent_count | ||
| 435 | + | ||
| 436 | + | ||
| 437 | +# 创建数字人服务实例 | ||
| 438 | +digital_human_service = DigitalHumanWebSocketService() | ||
| 439 | + | ||
| 440 | + | ||
| 441 | +def get_digital_human_service() -> DigitalHumanWebSocketService: | ||
| 442 | + """获取数字人服务实例""" | ||
| 443 | + return digital_human_service |
core/unified_websocket_manager.py
0 → 100644
| 1 | +# -*- coding: utf-8 -*- | ||
| 2 | +""" | ||
| 3 | +AIfeng/2025-07-15 14:41:21 | ||
| 4 | +统一WebSocket管理模块 | ||
| 5 | +提供统一的WebSocket连接管理、消息推送和事件处理功能 | ||
| 6 | +""" | ||
| 7 | + | ||
| 8 | +import json | ||
| 9 | +import time | ||
| 10 | +import asyncio | ||
| 11 | +import weakref | ||
| 12 | +from typing import Dict, Any, Optional, Callable, List, Set | ||
| 13 | +from threading import Lock | ||
| 14 | +from aiohttp import web, WSMsgType | ||
| 15 | +from logger import logger | ||
| 16 | + | ||
| 17 | + | ||
| 18 | +class WebSocketSession: | ||
| 19 | + """WebSocket会话管理类""" | ||
| 20 | + | ||
| 21 | + def __init__(self, session_id: str, websocket: web.WebSocketResponse): | ||
| 22 | + self.session_id = session_id | ||
| 23 | + self.websocket = websocket | ||
| 24 | + self.created_at = time.time() | ||
| 25 | + self.last_ping = time.time() | ||
| 26 | + self.metadata = {} | ||
| 27 | + | ||
| 28 | + def __eq__(self, other): | ||
| 29 | + """基于websocket对象判断会话是否相等""" | ||
| 30 | + if not isinstance(other, WebSocketSession): | ||
| 31 | + return False | ||
| 32 | + return self.websocket is other.websocket | ||
| 33 | + | ||
| 34 | + def __hash__(self): | ||
| 35 | + """基于websocket对象的id生成哈希值""" | ||
| 36 | + return hash(id(self.websocket)) | ||
| 37 | + | ||
| 38 | + def is_alive(self) -> bool: | ||
| 39 | + """检查连接是否存活""" | ||
| 40 | + return not self.websocket.closed | ||
| 41 | + | ||
| 42 | + async def send_message(self, message: Dict[str, Any]) -> bool: | ||
| 43 | + """发送消息到WebSocket客户端""" | ||
| 44 | + try: | ||
| 45 | + if not self.is_alive(): | ||
| 46 | + return False | ||
| 47 | + await self.websocket.send_str(json.dumps(message)) | ||
| 48 | + return True | ||
| 49 | + except ConnectionResetError: | ||
| 50 | + logger.warning(f'[Session:{self.session_id}] 客户端连接已重置') | ||
| 51 | + return False | ||
| 52 | + except ConnectionAbortedError: | ||
| 53 | + logger.warning(f'[Session:{self.session_id}] 客户端连接已中止') | ||
| 54 | + return False | ||
| 55 | + except Exception as e: | ||
| 56 | + logger.error(f'[Session:{self.session_id}] 发送消息失败: {e}') | ||
| 57 | + return False | ||
| 58 | + | ||
| 59 | + def update_ping(self): | ||
| 60 | + """更新心跳时间""" | ||
| 61 | + self.last_ping = time.time() | ||
| 62 | + | ||
| 63 | + def set_metadata(self, key: str, value: Any): | ||
| 64 | + """设置会话元数据""" | ||
| 65 | + self.metadata[key] = value | ||
| 66 | + | ||
| 67 | + def get_metadata(self, key: str, default=None): | ||
| 68 | + """获取会话元数据""" | ||
| 69 | + return self.metadata.get(key, default) | ||
| 70 | + | ||
| 71 | + async def close(self): | ||
| 72 | + """关闭WebSocket连接""" | ||
| 73 | + try: | ||
| 74 | + if not self.websocket.closed: | ||
| 75 | + await self.websocket.close() | ||
| 76 | + logger.info(f'[Session:{self.session_id}] WebSocket连接已关闭') | ||
| 77 | + except ConnectionResetError: | ||
| 78 | + logger.warning(f'[Session:{self.session_id}] 连接已被远程主机重置,无需关闭') | ||
| 79 | + except ConnectionAbortedError: | ||
| 80 | + logger.warning(f'[Session:{self.session_id}] 连接已被中止,无需关闭') | ||
| 81 | + except Exception as e: | ||
| 82 | + logger.error(f'[Session:{self.session_id}] 关闭WebSocket连接失败: {e}') | ||
| 83 | + | ||
| 84 | + | ||
| 85 | +class UnifiedWebSocketManager: | ||
| 86 | + """统一WebSocket管理器""" | ||
| 87 | + | ||
| 88 | + def __init__(self): | ||
| 89 | + self._sessions: Dict[str, Set[WebSocketSession]] = {} # session_id -> WebSocketSession集合 | ||
| 90 | + self._websockets: Dict[web.WebSocketResponse, WebSocketSession] = {} # websocket -> session映射 | ||
| 91 | + self._message_handlers: Dict[str, Callable] = {} # 消息类型处理器 | ||
| 92 | + self._event_handlers: Dict[str, List[Callable]] = {} # 事件处理器 | ||
| 93 | + self._lock = Lock() | ||
| 94 | + | ||
| 95 | + # 注册默认消息处理器 | ||
| 96 | + self.register_message_handler('ping', self._handle_ping) | ||
| 97 | + self.register_message_handler('login', self._handle_login) | ||
| 98 | + | ||
| 99 | + def register_message_handler(self, message_type: str, handler: Callable): | ||
| 100 | + """注册消息处理器""" | ||
| 101 | + self._message_handlers[message_type] = handler | ||
| 102 | + logger.info(f'注册消息处理器: {message_type}') | ||
| 103 | + | ||
| 104 | + def register_event_handler(self, event_type: str, handler: Callable): | ||
| 105 | + """注册事件处理器""" | ||
| 106 | + if event_type not in self._event_handlers: | ||
| 107 | + self._event_handlers[event_type] = [] | ||
| 108 | + self._event_handlers[event_type].append(handler) | ||
| 109 | + logger.info(f'注册事件处理器: {event_type}') | ||
| 110 | + | ||
| 111 | + async def _emit_event(self, event_type: str, **kwargs): | ||
| 112 | + """触发事件""" | ||
| 113 | + if event_type in self._event_handlers: | ||
| 114 | + for handler in self._event_handlers[event_type]: | ||
| 115 | + try: | ||
| 116 | + if asyncio.iscoroutinefunction(handler): | ||
| 117 | + await handler(**kwargs) | ||
| 118 | + else: | ||
| 119 | + handler(**kwargs) | ||
| 120 | + except Exception as e: | ||
| 121 | + logger.error(f'事件处理器执行失败 {event_type}: {e}') | ||
| 122 | + | ||
| 123 | + def add_session(self, session_id: str, websocket: web.WebSocketResponse) -> WebSocketSession: | ||
| 124 | + """添加WebSocket会话""" | ||
| 125 | + with self._lock: | ||
| 126 | + # 检查是否已存在相同的websocket连接 | ||
| 127 | + if websocket in self._websockets: | ||
| 128 | + existing_session = self._websockets[websocket] | ||
| 129 | + logger.warning(f'[Session:{session_id}] WebSocket连接已存在 (WebSocket={id(websocket)}, 原Session={existing_session.session_id})') | ||
| 130 | + return existing_session | ||
| 131 | + | ||
| 132 | + session = WebSocketSession(session_id, websocket) | ||
| 133 | + | ||
| 134 | + # 初始化会话集合 | ||
| 135 | + if session_id not in self._sessions: | ||
| 136 | + self._sessions[session_id] = set() | ||
| 137 | + | ||
| 138 | + # 检查Set添加前后的大小变化 | ||
| 139 | + before_count = len(self._sessions[session_id]) | ||
| 140 | + self._sessions[session_id].add(session) | ||
| 141 | + after_count = len(self._sessions[session_id]) | ||
| 142 | + | ||
| 143 | + self._websockets[websocket] = session | ||
| 144 | + | ||
| 145 | + logger.info(f'[Session:{session_id}] 添加WebSocket会话 (WebSocket={id(websocket)}), 连接数变化: {before_count} -> {after_count}') | ||
| 146 | + | ||
| 147 | + # 如果Set大小没有变化,说明可能存在重复 | ||
| 148 | + if before_count == after_count: | ||
| 149 | + logger.warning(f'[Session:{session_id}] 检测到可能的重复会话添加!Set大小未变化') | ||
| 150 | + | ||
| 151 | + return session | ||
| 152 | + | ||
| 153 | + def remove_session(self, websocket: web.WebSocketResponse): | ||
| 154 | + """移除WebSocket会话""" | ||
| 155 | + with self._lock: | ||
| 156 | + if websocket in self._websockets: | ||
| 157 | + session = self._websockets[websocket] | ||
| 158 | + session_id = session.session_id | ||
| 159 | + | ||
| 160 | + # 从会话集合中移除 | ||
| 161 | + if session_id in self._sessions: | ||
| 162 | + self._sessions[session_id].discard(session) | ||
| 163 | + if not self._sessions[session_id]: # 如果集合为空,删除键 | ||
| 164 | + del self._sessions[session_id] | ||
| 165 | + | ||
| 166 | + # 从websocket映射中移除 | ||
| 167 | + del self._websockets[websocket] | ||
| 168 | + | ||
| 169 | + logger.info(f'[Session:{session_id}] 移除WebSocket会话') | ||
| 170 | + return session | ||
| 171 | + return None | ||
| 172 | + | ||
| 173 | + def get_session(self, websocket: web.WebSocketResponse) -> Optional[WebSocketSession]: | ||
| 174 | + """获取WebSocket会话""" | ||
| 175 | + return self._websockets.get(websocket) | ||
| 176 | + | ||
| 177 | + def get_sessions_by_id(self, session_id: str) -> Set[WebSocketSession]: | ||
| 178 | + """根据会话ID获取所有WebSocket会话""" | ||
| 179 | + with self._lock: | ||
| 180 | + # 尝试使用原始session_id查找 | ||
| 181 | + sessions = self._sessions.get(session_id, set()) | ||
| 182 | + if sessions: | ||
| 183 | + return sessions.copy() | ||
| 184 | + | ||
| 185 | + # 如果是字符串类型但存储的是整数类型,尝试转换 | ||
| 186 | + if isinstance(session_id, str) and session_id.isdigit(): | ||
| 187 | + int_session_id = int(session_id) | ||
| 188 | + sessions = self._sessions.get(int_session_id, set()) | ||
| 189 | + if sessions: | ||
| 190 | + return sessions.copy() | ||
| 191 | + | ||
| 192 | + # 如果是整数类型但存储的是字符串类型,尝试转换 | ||
| 193 | + elif isinstance(session_id, int): | ||
| 194 | + str_session_id = str(session_id) | ||
| 195 | + sessions = self._sessions.get(str_session_id, set()) | ||
| 196 | + if sessions: | ||
| 197 | + return sessions.copy() | ||
| 198 | + | ||
| 199 | + return set() | ||
| 200 | + | ||
| 201 | + def _update_session_id(self, websocket: web.WebSocketResponse, old_session_id: str, new_session_id: str): | ||
| 202 | + """更新WebSocket会话的session_id""" | ||
| 203 | + with self._lock: | ||
| 204 | + if websocket in self._websockets: | ||
| 205 | + session = self._websockets[websocket] | ||
| 206 | + | ||
| 207 | + # 从旧的session_id集合中移除 | ||
| 208 | + if old_session_id in self._sessions: | ||
| 209 | + self._sessions[old_session_id].discard(session) | ||
| 210 | + if not self._sessions[old_session_id]: # 如果集合为空,删除键 | ||
| 211 | + del self._sessions[old_session_id] | ||
| 212 | + | ||
| 213 | + # 更新session的session_id | ||
| 214 | + session.session_id = new_session_id | ||
| 215 | + | ||
| 216 | + # 添加到新的session_id集合 | ||
| 217 | + if new_session_id not in self._sessions: | ||
| 218 | + self._sessions[new_session_id] = set() | ||
| 219 | + self._sessions[new_session_id].add(session) | ||
| 220 | + | ||
| 221 | + logger.info(f'[Session] 更新会话ID: {old_session_id} -> {new_session_id}') | ||
| 222 | + return True | ||
| 223 | + return False | ||
| 224 | + | ||
| 225 | + async def broadcast_raw_message_to_session(self, session_id: str, message: Dict,source: str = "原数据") -> int: | ||
| 226 | + """直接广播原始消息到指定会话的所有WebSocket连接""" | ||
| 227 | + # 确保session_id为字符串类型,保持一致性 | ||
| 228 | + # 确保session_id为字符串类型,保持一致性 | ||
| 229 | + if isinstance(session_id, int): | ||
| 230 | + session_id = str(session_id) | ||
| 231 | + elif not isinstance(session_id, str): | ||
| 232 | + session_id = str(session_id) | ||
| 233 | + | ||
| 234 | + sessions = self.get_sessions_by_id(session_id) | ||
| 235 | + if not sessions: | ||
| 236 | + logger.warning(f'[Session:{session_id}] 没有找到WebSocket连接') | ||
| 237 | + return 0 | ||
| 238 | + | ||
| 239 | + # 详细调试日志:显示会话详情 | ||
| 240 | + logger.info(f'[Session:{session_id}] 开始广播消息,找到 {len(sessions)} 个连接') | ||
| 241 | + for i, session in enumerate(sessions): | ||
| 242 | + logger.info(f'[Session:{session_id}] 连接{i+1}: WebSocket={id(session.websocket)}, 创建时间={session.created_at}, 存活状态={session.is_alive()}') | ||
| 243 | + | ||
| 244 | + success_count = 0 | ||
| 245 | + failed_sessions = [] | ||
| 246 | + | ||
| 247 | + for i, session in enumerate(sessions): | ||
| 248 | + logger.info(f'[Session:{session_id}] 正在向连接{i+1}发送消息 (WebSocket={id(session.websocket)})') | ||
| 249 | + if await session.send_message(message): | ||
| 250 | + success_count += 1 | ||
| 251 | + logger.info(f'[Session:{session_id}] 连接{i+1}发送成功') | ||
| 252 | + else: | ||
| 253 | + failed_sessions.append(session) | ||
| 254 | + logger.warning(f'[Session:{session_id}] 连接{i+1}发送失败') | ||
| 255 | + | ||
| 256 | + # 清理失败的连接 | ||
| 257 | + for session in failed_sessions: | ||
| 258 | + self.remove_session(session.websocket) | ||
| 259 | + | ||
| 260 | + logger.info(f'[Session:{session_id}] 广播原始消息完成: 成功{success_count}/总计{len(sessions)}, 失败{len(failed_sessions)}') | ||
| 261 | + return success_count | ||
| 262 | + | ||
| 263 | + async def broadcast_to_session(self, session_id: str, message_type: str, content: Any, | ||
| 264 | + source: str = "系统", metadata: Dict = None) -> int: | ||
| 265 | + """向指定会话的所有WebSocket连接广播消息""" | ||
| 266 | + # 确保session_id为字符串类型,保持一致性 | ||
| 267 | + if isinstance(session_id, int): | ||
| 268 | + session_id = str(session_id) | ||
| 269 | + elif not isinstance(session_id, str): | ||
| 270 | + session_id = str(session_id) | ||
| 271 | + | ||
| 272 | + sessions = self.get_sessions_by_id(session_id) | ||
| 273 | + if not sessions: | ||
| 274 | + logger.warning(f'[Session:{session_id}] 没有找到WebSocket连接') | ||
| 275 | + return 0 | ||
| 276 | + | ||
| 277 | + message = { | ||
| 278 | + "type": message_type, | ||
| 279 | + "session_id": session_id, | ||
| 280 | + "content": content, | ||
| 281 | + "source": source, | ||
| 282 | + "timestamp": time.time(), | ||
| 283 | + **(metadata or {}) | ||
| 284 | + } | ||
| 285 | + | ||
| 286 | + success_count = 0 | ||
| 287 | + failed_sessions = [] | ||
| 288 | + | ||
| 289 | + for session in sessions: | ||
| 290 | + if await session.send_message(message): | ||
| 291 | + success_count += 1 | ||
| 292 | + else: | ||
| 293 | + failed_sessions.append(session) | ||
| 294 | + | ||
| 295 | + # 清理失败的连接 | ||
| 296 | + for session in failed_sessions: | ||
| 297 | + self.remove_session(session.websocket) | ||
| 298 | + | ||
| 299 | + logger.info(f'[Session:{session_id}] 广播消息成功: {success_count}/{len(sessions)}') | ||
| 300 | + return success_count | ||
| 301 | + | ||
| 302 | + async def broadcast_to_all(self, message_type: str, content: Any, | ||
| 303 | + source: str = "系统", metadata: Dict = None) -> int: | ||
| 304 | + """向所有WebSocket连接广播消息""" | ||
| 305 | + total_sent = 0 | ||
| 306 | + with self._lock: | ||
| 307 | + session_ids = list(self._sessions.keys()) | ||
| 308 | + | ||
| 309 | + for session_id in session_ids: | ||
| 310 | + sent = await self.broadcast_to_session(session_id, message_type, content, source, metadata) | ||
| 311 | + total_sent += sent | ||
| 312 | + | ||
| 313 | + logger.info(f'全局广播消息完成,总发送数: {total_sent}') | ||
| 314 | + return total_sent | ||
| 315 | + | ||
| 316 | + def get_session_count(self) -> int: | ||
| 317 | + """获取会话总数""" | ||
| 318 | + with self._lock: | ||
| 319 | + return len(self._sessions) | ||
| 320 | + | ||
| 321 | + def get_connection_count(self) -> int: | ||
| 322 | + """获取连接总数""" | ||
| 323 | + with self._lock: | ||
| 324 | + return len(self._websockets) | ||
| 325 | + | ||
| 326 | + def get_session_stats(self) -> Dict[str, Any]: | ||
| 327 | + """获取会话统计信息""" | ||
| 328 | + with self._lock: | ||
| 329 | + stats = { | ||
| 330 | + "total_sessions": len(self._sessions), | ||
| 331 | + "total_connections": len(self._websockets), | ||
| 332 | + "session_details": {} | ||
| 333 | + } | ||
| 334 | + | ||
| 335 | + for session_id, sessions in self._sessions.items(): | ||
| 336 | + stats["session_details"][session_id] = { | ||
| 337 | + "connection_count": len(sessions), | ||
| 338 | + "connections": [ | ||
| 339 | + { | ||
| 340 | + "created_at": session.created_at, | ||
| 341 | + "last_ping": session.last_ping, | ||
| 342 | + "is_alive": session.is_alive(), | ||
| 343 | + "metadata": session.metadata | ||
| 344 | + } for session in sessions | ||
| 345 | + ] | ||
| 346 | + } | ||
| 347 | + | ||
| 348 | + return stats | ||
| 349 | + | ||
| 350 | + async def _handle_ping(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 351 | + """处理心跳消息""" | ||
| 352 | + session = self.get_session(websocket) | ||
| 353 | + if session: | ||
| 354 | + session.update_ping() | ||
| 355 | + await session.send_message({"type": "pong", "timestamp": time.time()}) | ||
| 356 | + | ||
| 357 | + async def _handle_login(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 358 | + """处理登录消息""" | ||
| 359 | + session_id = data.get('session_id', data.get('sessionid', str(int(time.time())))) | ||
| 360 | + | ||
| 361 | + # 确保session_id为字符串类型,避免类型不一致问题 | ||
| 362 | + if isinstance(session_id, int): | ||
| 363 | + session_id = str(session_id) | ||
| 364 | + elif not isinstance(session_id, str): | ||
| 365 | + session_id = str(session_id) | ||
| 366 | + | ||
| 367 | + # 添加会话 | ||
| 368 | + session = self.add_session(session_id, websocket) | ||
| 369 | + | ||
| 370 | + # 触发连接事件 | ||
| 371 | + await self._emit_event('session_connected', session=session, data=data) | ||
| 372 | + | ||
| 373 | + # 发送登录确认 | ||
| 374 | + await session.send_message({ | ||
| 375 | + "type": "login_success", | ||
| 376 | + "data": { | ||
| 377 | + "session_id": session_id, | ||
| 378 | + "message": "WebSocket连接成功", | ||
| 379 | + "timestamp": time.time() | ||
| 380 | + } | ||
| 381 | + }) | ||
| 382 | + | ||
| 383 | + async def handle_websocket_message(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 384 | + """处理WebSocket消息""" | ||
| 385 | + message_type = data.get('type') | ||
| 386 | + | ||
| 387 | + if message_type in self._message_handlers: | ||
| 388 | + try: | ||
| 389 | + await self._message_handlers[message_type](websocket, data) | ||
| 390 | + except Exception as e: | ||
| 391 | + logger.error(f'消息处理器执行失败 {message_type}: {e}') | ||
| 392 | + session = self.get_session(websocket) | ||
| 393 | + if session: | ||
| 394 | + await session.send_message({ | ||
| 395 | + "type": "error", | ||
| 396 | + "data": { | ||
| 397 | + "message": f"消息处理失败: {str(e)}", | ||
| 398 | + "original_type": message_type | ||
| 399 | + } | ||
| 400 | + }) | ||
| 401 | + else: | ||
| 402 | + logger.warning(f'未知消息类型: {message_type}') | ||
| 403 | + | ||
| 404 | + async def websocket_handler(self, request) -> web.WebSocketResponse: | ||
| 405 | + """WebSocket连接处理器""" | ||
| 406 | + ws = web.WebSocketResponse() | ||
| 407 | + await ws.prepare(request) | ||
| 408 | + | ||
| 409 | + logger.info('新的WebSocket连接建立') | ||
| 410 | + | ||
| 411 | + try: | ||
| 412 | + async for msg in ws: | ||
| 413 | + if msg.type == WSMsgType.TEXT: | ||
| 414 | + try: | ||
| 415 | + data = json.loads(msg.data) | ||
| 416 | + await self.handle_websocket_message(ws, data) | ||
| 417 | + except json.JSONDecodeError: | ||
| 418 | + logger.error('收到无效的JSON数据') | ||
| 419 | + except Exception as e: | ||
| 420 | + logger.error(f'处理WebSocket消息时出错: {e}') | ||
| 421 | + | ||
| 422 | + elif msg.type == WSMsgType.ERROR: | ||
| 423 | + logger.error(f'WebSocket错误: {ws.exception()}') | ||
| 424 | + break | ||
| 425 | + elif msg.type == WSMsgType.CLOSE: | ||
| 426 | + logger.info('WebSocket连接正常关闭') | ||
| 427 | + break | ||
| 428 | + | ||
| 429 | + except ConnectionResetError: | ||
| 430 | + logger.warning('WebSocket连接被远程主机重置') | ||
| 431 | + except ConnectionAbortedError: | ||
| 432 | + logger.warning('WebSocket连接被中止') | ||
| 433 | + except Exception as e: | ||
| 434 | + logger.error(f'WebSocket连接错误: {e}') | ||
| 435 | + finally: | ||
| 436 | + # 清理会话 | ||
| 437 | + session = self.remove_session(ws) | ||
| 438 | + if session: | ||
| 439 | + await self._emit_event('session_disconnected', session=session) | ||
| 440 | + logger.info('WebSocket连接已关闭') | ||
| 441 | + | ||
| 442 | + return ws | ||
| 443 | + | ||
| 444 | + def get_expired_sessions(self, timeout: int = 60) -> List[WebSocketSession]: | ||
| 445 | + """获取过期的会话列表""" | ||
| 446 | + current_time = time.time() | ||
| 447 | + expired_sessions = [] | ||
| 448 | + | ||
| 449 | + with self._lock: | ||
| 450 | + for session in self._websockets.values(): | ||
| 451 | + if current_time - session.last_ping > timeout: | ||
| 452 | + expired_sessions.append(session) | ||
| 453 | + | ||
| 454 | + return expired_sessions | ||
| 455 | + | ||
| 456 | + async def cleanup_dead_connections(self): | ||
| 457 | + """清理死连接""" | ||
| 458 | + dead_websockets = [] | ||
| 459 | + | ||
| 460 | + with self._lock: | ||
| 461 | + for websocket, session in self._websockets.items(): | ||
| 462 | + if not session.is_alive(): | ||
| 463 | + dead_websockets.append(websocket) | ||
| 464 | + | ||
| 465 | + for websocket in dead_websockets: | ||
| 466 | + self.remove_session(websocket) | ||
| 467 | + | ||
| 468 | + if dead_websockets: | ||
| 469 | + logger.info(f'清理了 {len(dead_websockets)} 个死连接') | ||
| 470 | + | ||
| 471 | + return len(dead_websockets) | ||
| 472 | + | ||
| 473 | + | ||
| 474 | +# 全局统一WebSocket管理器实例 | ||
| 475 | +_unified_manager = UnifiedWebSocketManager() | ||
| 476 | + | ||
| 477 | + | ||
| 478 | +def get_unified_manager() -> UnifiedWebSocketManager: | ||
| 479 | + """获取统一WebSocket管理器实例""" | ||
| 480 | + return _unified_manager | ||
| 481 | + | ||
| 482 | + | ||
| 483 | +# 兼容性接口,保持与原有代码的兼容 | ||
| 484 | +async def broadcast_message_to_session(session_id: str, message_type: str, content: str, | ||
| 485 | + source: str = "数字人回复", model_info: str = None, | ||
| 486 | + request_source: str = "页面"): | ||
| 487 | + """兼容性接口:向指定会话广播消息""" | ||
| 488 | + metadata = {} | ||
| 489 | + if model_info: | ||
| 490 | + metadata['model_info'] = model_info | ||
| 491 | + if request_source: | ||
| 492 | + metadata['request_source'] = request_source | ||
| 493 | + | ||
| 494 | + return await _unified_manager.broadcast_to_session( | ||
| 495 | + session_id, message_type, content, source, metadata | ||
| 496 | + ) |
core/websocket_router.py
0 → 100644
| 1 | +# -*- coding: utf-8 -*- | ||
| 2 | +""" | ||
| 3 | +AIfeng/2025-07-15 14:41:21 | ||
| 4 | +WebSocket路由管理器 | ||
| 5 | +统一管理所有WebSocket服务的路由和初始化 | ||
| 6 | +""" | ||
| 7 | + | ||
| 8 | +import asyncio | ||
| 9 | +import time | ||
| 10 | +from typing import Dict, Any, Optional | ||
| 11 | +from aiohttp import web, WSMsgType | ||
| 12 | +import json | ||
| 13 | +from logger import logger | ||
| 14 | +from .unified_websocket_manager import get_unified_manager | ||
| 15 | +from .websocket_service_base import get_service_registry | ||
| 16 | +from .asr_websocket_service import get_asr_service | ||
| 17 | +from .digital_human_websocket_service import get_digital_human_service | ||
| 18 | + | ||
| 19 | + | ||
| 20 | +class WebSocketRouter: | ||
| 21 | + """WebSocket路由管理器""" | ||
| 22 | + | ||
| 23 | + def __init__(self): | ||
| 24 | + self.manager = get_unified_manager() | ||
| 25 | + self.service_registry = get_service_registry() | ||
| 26 | + self.is_initialized = False | ||
| 27 | + | ||
| 28 | + async def initialize(self): | ||
| 29 | + """初始化路由器和所有服务""" | ||
| 30 | + if self.is_initialized: | ||
| 31 | + return | ||
| 32 | + | ||
| 33 | + logger.info('初始化WebSocket路由器...') | ||
| 34 | + | ||
| 35 | + # 注册所有服务 | ||
| 36 | + await self._register_services() | ||
| 37 | + | ||
| 38 | + # 初始化所有服务 | ||
| 39 | + await self.service_registry.initialize_all() | ||
| 40 | + | ||
| 41 | + self.is_initialized = True | ||
| 42 | + logger.info('WebSocket路由器初始化完成') | ||
| 43 | + | ||
| 44 | + async def shutdown(self): | ||
| 45 | + """关闭路由器和所有服务""" | ||
| 46 | + if not self.is_initialized: | ||
| 47 | + return | ||
| 48 | + | ||
| 49 | + logger.info('关闭WebSocket路由器...') | ||
| 50 | + | ||
| 51 | + # 关闭所有服务 | ||
| 52 | + await self.service_registry.shutdown_all() | ||
| 53 | + | ||
| 54 | + # 关闭管理器 | ||
| 55 | + await self.manager.shutdown() | ||
| 56 | + | ||
| 57 | + self.is_initialized = False | ||
| 58 | + logger.info('WebSocket路由器已关闭') | ||
| 59 | + | ||
| 60 | + async def _register_services(self): | ||
| 61 | + """注册所有WebSocket服务""" | ||
| 62 | + logger.info('注册WebSocket服务...') | ||
| 63 | + | ||
| 64 | + # 注册ASR服务 | ||
| 65 | + asr_service = get_asr_service() | ||
| 66 | + self.service_registry.register_service(asr_service) | ||
| 67 | + | ||
| 68 | + # 注册数字人服务 | ||
| 69 | + digital_human_service = get_digital_human_service() | ||
| 70 | + self.service_registry.register_service(digital_human_service) | ||
| 71 | + | ||
| 72 | + # 注册WSA服务 | ||
| 73 | + from .wsa_websocket_service import WSAWebSocketService, initialize_wsa_service | ||
| 74 | + wsa_service = WSAWebSocketService(self.manager) | ||
| 75 | + self.service_registry.register_service(wsa_service) | ||
| 76 | + | ||
| 77 | + # 初始化WSA兼容性接口 | ||
| 78 | + initialize_wsa_service(wsa_service) | ||
| 79 | + | ||
| 80 | + logger.info(f'已注册 {len(self.service_registry.list_services())} 个WebSocket服务') | ||
| 81 | + | ||
| 82 | + async def websocket_handler(self, request: web.Request) -> web.WebSocketResponse: | ||
| 83 | + """统一的WebSocket处理器""" | ||
| 84 | + ws = web.WebSocketResponse() | ||
| 85 | + await ws.prepare(request) | ||
| 86 | + | ||
| 87 | + # 创建会话ID | ||
| 88 | + session_id = request.headers.get('X-Session-ID', str(int(time.time()))) | ||
| 89 | + session = self.manager.add_session(session_id, ws) | ||
| 90 | + logger.info(f'WebSocket连接建立: {session.session_id}') | ||
| 91 | + | ||
| 92 | + try: | ||
| 93 | + async for msg in ws: | ||
| 94 | + if msg.type == WSMsgType.TEXT: | ||
| 95 | + try: | ||
| 96 | + data = json.loads(msg.data) | ||
| 97 | + await self._handle_message(ws, data) | ||
| 98 | + except json.JSONDecodeError as e: | ||
| 99 | + logger.error(f'JSON解析失败: {e}') | ||
| 100 | + await session.send_message({ | ||
| 101 | + "type": "error", | ||
| 102 | + "data": {"message": "消息格式错误"} | ||
| 103 | + }) | ||
| 104 | + except Exception as e: | ||
| 105 | + logger.error(f'消息处理失败: {e}') | ||
| 106 | + await session.send_message({ | ||
| 107 | + "type": "error", | ||
| 108 | + "data": {"message": f"处理失败: {str(e)}"} | ||
| 109 | + }) | ||
| 110 | + elif msg.type == WSMsgType.ERROR: | ||
| 111 | + logger.error(f'WebSocket错误: {ws.exception()}') | ||
| 112 | + break | ||
| 113 | + elif msg.type == WSMsgType.CLOSE: | ||
| 114 | + logger.info(f'WebSocket连接关闭: {session.session_id}') | ||
| 115 | + break | ||
| 116 | + | ||
| 117 | + except ConnectionResetError: | ||
| 118 | + logger.warning(f'WebSocket连接被远程主机重置: {session.session_id}') | ||
| 119 | + except ConnectionAbortedError: | ||
| 120 | + logger.warning(f'WebSocket连接被中止: {session.session_id}') | ||
| 121 | + except Exception as e: | ||
| 122 | + logger.error(f'WebSocket处理异常: {e}') | ||
| 123 | + finally: | ||
| 124 | + # 清理会话 | ||
| 125 | + self.manager.remove_session(ws) | ||
| 126 | + | ||
| 127 | + return ws | ||
| 128 | + | ||
| 129 | + async def _handle_message(self, ws: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 130 | + """处理WebSocket消息""" | ||
| 131 | + message_type = data.get('type') | ||
| 132 | + if not message_type: | ||
| 133 | + session = self.manager.get_session(ws) | ||
| 134 | + if session: | ||
| 135 | + await session.send_message({ | ||
| 136 | + "type": "error", | ||
| 137 | + "data": {"message": "缺少消息类型"} | ||
| 138 | + }) | ||
| 139 | + return | ||
| 140 | + | ||
| 141 | + # 通过管理器处理消息 | ||
| 142 | + await self.manager.handle_websocket_message(ws, data) | ||
| 143 | + | ||
| 144 | + def get_router_stats(self) -> Dict[str, Any]: | ||
| 145 | + """获取路由器统计信息""" | ||
| 146 | + stats = { | ||
| 147 | + "initialized": self.is_initialized, | ||
| 148 | + "manager_stats": self.manager.get_session_stats(), | ||
| 149 | + "service_stats": self.service_registry.get_all_stats() | ||
| 150 | + } | ||
| 151 | + | ||
| 152 | + # 添加各服务的详细统计 | ||
| 153 | + asr_service = self.service_registry.get_service("asr_service") | ||
| 154 | + if asr_service: | ||
| 155 | + stats["asr_stats"] = asr_service.get_asr_stats() | ||
| 156 | + | ||
| 157 | + digital_human_service = self.service_registry.get_service("digital_human_service") | ||
| 158 | + if digital_human_service: | ||
| 159 | + stats["digital_human_stats"] = digital_human_service.get_digital_human_stats() | ||
| 160 | + | ||
| 161 | + return stats | ||
| 162 | + | ||
| 163 | + def setup_routes(self, app: web.Application, path: str = '/ws'): | ||
| 164 | + """设置WebSocket路由""" | ||
| 165 | + app.router.add_get(path, self.websocket_handler) | ||
| 166 | + logger.info(f'WebSocket路由已设置: {path}') | ||
| 167 | + | ||
| 168 | + async def broadcast_system_message(self, message: str, level: str = 'info'): | ||
| 169 | + """广播系统消息""" | ||
| 170 | + await self.manager.broadcast_to_all('system_message', { | ||
| 171 | + 'message': message, | ||
| 172 | + 'level': level, | ||
| 173 | + 'timestamp': asyncio.get_event_loop().time() | ||
| 174 | + }, source='system') | ||
| 175 | + | ||
| 176 | + async def send_to_session(self, session_id: str, message_type: str, content: Any): | ||
| 177 | + """向指定会话发送消息""" | ||
| 178 | + return await self.manager.broadcast_to_session(session_id, message_type, content, source='router') | ||
| 179 | + | ||
| 180 | + async def send_raw_to_session(self, session_id: str, message: Dict): | ||
| 181 | + """向指定会话发送消息""" | ||
| 182 | + return await self.manager.broadcast_raw_message_to_session(str(session_id), message) | ||
| 183 | + | ||
| 184 | + | ||
| 185 | + | ||
| 186 | + async def send_to_digital_human(self, human_id: str, message_type: str, content: Any): | ||
| 187 | + """向指定数字人发送消息""" | ||
| 188 | + digital_human_service = self.service_registry.get_service("digital_human_service") | ||
| 189 | + if digital_human_service: | ||
| 190 | + return await digital_human_service.send_to_digital_human(human_id, message_type, content) | ||
| 191 | + return False | ||
| 192 | + | ||
| 193 | + async def get_asr_stats(self) -> Optional[Dict[str, Any]]: | ||
| 194 | + """获取ASR统计信息""" | ||
| 195 | + asr_service = self.service_registry.get_service("asr_service") | ||
| 196 | + if asr_service: | ||
| 197 | + return asr_service.get_asr_stats() | ||
| 198 | + return None | ||
| 199 | + | ||
| 200 | + async def get_digital_human_stats(self) -> Optional[Dict[str, Any]]: | ||
| 201 | + """获取数字人统计信息""" | ||
| 202 | + digital_human_service = self.service_registry.get_service("digital_human_service") | ||
| 203 | + if digital_human_service: | ||
| 204 | + return digital_human_service.get_digital_human_stats() | ||
| 205 | + return None | ||
| 206 | + | ||
| 207 | + | ||
| 208 | +# 全局路由器实例 | ||
| 209 | +_websocket_router = None | ||
| 210 | + | ||
| 211 | + | ||
| 212 | +def get_websocket_router() -> WebSocketRouter: | ||
| 213 | + """获取WebSocket路由器实例""" | ||
| 214 | + global _websocket_router | ||
| 215 | + if _websocket_router is None: | ||
| 216 | + _websocket_router = WebSocketRouter() | ||
| 217 | + return _websocket_router | ||
| 218 | + | ||
| 219 | + | ||
| 220 | +async def initialize_websocket_router(): | ||
| 221 | + """初始化WebSocket路由器""" | ||
| 222 | + router = get_websocket_router() | ||
| 223 | + await router.initialize() | ||
| 224 | + return router | ||
| 225 | + | ||
| 226 | + | ||
| 227 | +async def shutdown_websocket_router(): | ||
| 228 | + """关闭WebSocket路由器""" | ||
| 229 | + global _websocket_router | ||
| 230 | + if _websocket_router: | ||
| 231 | + await _websocket_router.shutdown() | ||
| 232 | + _websocket_router = None | ||
| 233 | + | ||
| 234 | + | ||
| 235 | +def setup_websocket_routes(app: web.Application, path: str = '/ws'): | ||
| 236 | + """设置WebSocket路由(便捷函数)""" | ||
| 237 | + router = get_websocket_router() | ||
| 238 | + router.setup_routes(app, path) | ||
| 239 | + return router | ||
| 240 | + | ||
| 241 | + | ||
| 242 | +# 兼容性接口 | ||
| 243 | +class WebSocketCompatibilityAPI: | ||
| 244 | + """WebSocket兼容性API | ||
| 245 | + | ||
| 246 | + 为了保持与现有代码的兼容性,提供简化的接口 | ||
| 247 | + """ | ||
| 248 | + | ||
| 249 | + def __init__(self): | ||
| 250 | + self.router = get_websocket_router() | ||
| 251 | + | ||
| 252 | + async def broadcast_message_to_session(self, session_id: str, message: Dict[str, Any]): | ||
| 253 | + """向指定会话广播消息(兼容app.py接口)""" | ||
| 254 | + message_type = message.get('type', 'message') | ||
| 255 | + content = message.get('data', message) | ||
| 256 | + return await self.router.send_to_session(session_id, message_type, content) | ||
| 257 | + | ||
| 258 | + async def broadcast_to_all_sessions(self, message: Dict[str, Any]): | ||
| 259 | + """向所有会话广播消息""" | ||
| 260 | + message_type = message.get('type', 'message') | ||
| 261 | + content = message.get('data', message) | ||
| 262 | + return await self.router.manager.broadcast_to_all(message_type, content, source='compatibility') | ||
| 263 | + | ||
| 264 | + def get_active_sessions(self): | ||
| 265 | + """获取活跃会话列表""" | ||
| 266 | + return list(self.router.manager._sessions.keys()) | ||
| 267 | + | ||
| 268 | + def get_session_count(self): | ||
| 269 | + """获取会话数量""" | ||
| 270 | + return len(self.router.manager._sessions) | ||
| 271 | + | ||
| 272 | + async def send_asr_result(self, session_id: str, result: Dict[str, Any]): | ||
| 273 | + """发送ASR结果(兼容app.py接口)""" | ||
| 274 | + return await self.router.send_to_session(session_id, 'asr_result', result) | ||
| 275 | + | ||
| 276 | + | ||
| 277 | +# 全局兼容性API实例 | ||
| 278 | +_compatibility_api = None | ||
| 279 | + | ||
| 280 | + | ||
| 281 | +def get_websocket_compatibility_api() -> WebSocketCompatibilityAPI: | ||
| 282 | + """获取WebSocket兼容性API实例""" | ||
| 283 | + global _compatibility_api | ||
| 284 | + if _compatibility_api is None: | ||
| 285 | + _compatibility_api = WebSocketCompatibilityAPI() | ||
| 286 | + return _compatibility_api |
core/websocket_service_base.py
0 → 100644
| 1 | +# -*- coding: utf-8 -*- | ||
| 2 | +""" | ||
| 3 | +AIfeng/2025-07-15 14:41:21 | ||
| 4 | +WebSocket服务抽象基类 | ||
| 5 | +为不同类型的WebSocket服务提供统一接口和生命周期管理 | ||
| 6 | +""" | ||
| 7 | + | ||
| 8 | +import asyncio | ||
| 9 | +from abc import ABC, abstractmethod | ||
| 10 | +from typing import Dict, Any, Optional, List | ||
| 11 | +from aiohttp import web | ||
| 12 | +from logger import logger | ||
| 13 | +from .unified_websocket_manager import get_unified_manager, WebSocketSession | ||
| 14 | + | ||
| 15 | + | ||
| 16 | +class WebSocketServiceBase(ABC): | ||
| 17 | + """WebSocket服务抽象基类""" | ||
| 18 | + | ||
| 19 | + def __init__(self, service_name: str): | ||
| 20 | + self.service_name = service_name | ||
| 21 | + self.manager = get_unified_manager() | ||
| 22 | + self.is_initialized = False | ||
| 23 | + self._background_tasks: List[asyncio.Task] = [] | ||
| 24 | + | ||
| 25 | + async def initialize(self): | ||
| 26 | + """初始化服务""" | ||
| 27 | + if self.is_initialized: | ||
| 28 | + return | ||
| 29 | + | ||
| 30 | + logger.info(f'初始化WebSocket服务: {self.service_name}') | ||
| 31 | + | ||
| 32 | + # 注册消息处理器 | ||
| 33 | + await self._register_message_handlers() | ||
| 34 | + | ||
| 35 | + # 注册事件处理器 | ||
| 36 | + await self._register_event_handlers() | ||
| 37 | + | ||
| 38 | + # 启动后台任务 | ||
| 39 | + await self._start_background_tasks() | ||
| 40 | + | ||
| 41 | + self.is_initialized = True | ||
| 42 | + logger.info(f'WebSocket服务初始化完成: {self.service_name}') | ||
| 43 | + | ||
| 44 | + async def shutdown(self): | ||
| 45 | + """关闭服务""" | ||
| 46 | + if not self.is_initialized: | ||
| 47 | + return | ||
| 48 | + | ||
| 49 | + logger.info(f'关闭WebSocket服务: {self.service_name}') | ||
| 50 | + | ||
| 51 | + # 停止后台任务 | ||
| 52 | + for task in self._background_tasks: | ||
| 53 | + if not task.done(): | ||
| 54 | + task.cancel() | ||
| 55 | + try: | ||
| 56 | + await task | ||
| 57 | + except asyncio.CancelledError: | ||
| 58 | + pass | ||
| 59 | + | ||
| 60 | + self._background_tasks.clear() | ||
| 61 | + | ||
| 62 | + # 执行自定义清理 | ||
| 63 | + await self._cleanup() | ||
| 64 | + | ||
| 65 | + self.is_initialized = False | ||
| 66 | + logger.info(f'WebSocket服务已关闭: {self.service_name}') | ||
| 67 | + | ||
| 68 | + @abstractmethod | ||
| 69 | + async def _register_message_handlers(self): | ||
| 70 | + """注册消息处理器(子类实现)""" | ||
| 71 | + pass | ||
| 72 | + | ||
| 73 | + async def _register_event_handlers(self): | ||
| 74 | + """注册事件处理器(可选重写)""" | ||
| 75 | + # 注册通用事件处理器 | ||
| 76 | + self.manager.register_event_handler('session_connected', self._on_session_connected) | ||
| 77 | + self.manager.register_event_handler('session_disconnected', self._on_session_disconnected) | ||
| 78 | + | ||
| 79 | + async def _start_background_tasks(self): | ||
| 80 | + """启动后台任务(可选重写)""" | ||
| 81 | + pass | ||
| 82 | + | ||
| 83 | + async def _cleanup(self): | ||
| 84 | + """清理资源(可选重写)""" | ||
| 85 | + pass | ||
| 86 | + | ||
| 87 | + async def _on_session_connected(self, session: WebSocketSession, data: Dict[str, Any]): | ||
| 88 | + """会话连接事件处理(可选重写)""" | ||
| 89 | + logger.info(f'[{self.service_name}] 会话连接: {session.session_id}') | ||
| 90 | + | ||
| 91 | + async def _on_session_disconnected(self, session: WebSocketSession): | ||
| 92 | + """会话断开事件处理(可选重写)""" | ||
| 93 | + logger.info(f'[{self.service_name}] 会话断开: {session.session_id}') | ||
| 94 | + | ||
| 95 | + def add_background_task(self, coro): | ||
| 96 | + """添加后台任务""" | ||
| 97 | + task = asyncio.create_task(coro) | ||
| 98 | + self._background_tasks.append(task) | ||
| 99 | + return task | ||
| 100 | + | ||
| 101 | + async def broadcast_to_session(self, session_id: str, message_type: str, content: Any, | ||
| 102 | + source: str = None, metadata: Dict = None): | ||
| 103 | + """向指定会话广播消息""" | ||
| 104 | + if source is None: | ||
| 105 | + source = self.service_name | ||
| 106 | + return await self.manager.broadcast_to_session(session_id, message_type, content, source, metadata) | ||
| 107 | + | ||
| 108 | + async def broadcast_to_all(self, message_type: str, content: Any, | ||
| 109 | + source: str = None, metadata: Dict = None): | ||
| 110 | + """向所有会话广播消息""" | ||
| 111 | + if source is None: | ||
| 112 | + source = self.service_name | ||
| 113 | + return await self.manager.broadcast_to_all(message_type, content, source, metadata) | ||
| 114 | + | ||
| 115 | + def get_session_stats(self) -> Dict[str, Any]: | ||
| 116 | + """获取会话统计信息""" | ||
| 117 | + return self.manager.get_session_stats() | ||
| 118 | + | ||
| 119 | + def create_message_handler(self, message_type: str): | ||
| 120 | + """装饰器:创建消息处理器""" | ||
| 121 | + def decorator(func): | ||
| 122 | + async def wrapper(websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 123 | + session = self.manager.get_session(websocket) | ||
| 124 | + if session: | ||
| 125 | + try: | ||
| 126 | + await func(session, data) | ||
| 127 | + except Exception as e: | ||
| 128 | + logger.error(f'[{self.service_name}] 消息处理器 {message_type} 执行失败: {e}') | ||
| 129 | + await session.send_message({ | ||
| 130 | + "type": "error", | ||
| 131 | + "data": { | ||
| 132 | + "message": f"处理 {message_type} 消息失败: {str(e)}", | ||
| 133 | + "service": self.service_name | ||
| 134 | + } | ||
| 135 | + }) | ||
| 136 | + else: | ||
| 137 | + logger.warning(f'[{self.service_name}] 未找到会话,无法处理消息: {message_type}') | ||
| 138 | + | ||
| 139 | + self.manager.register_message_handler(message_type, wrapper) | ||
| 140 | + return func | ||
| 141 | + return decorator | ||
| 142 | + | ||
| 143 | + def create_event_handler(self, event_type: str): | ||
| 144 | + """装饰器:创建事件处理器""" | ||
| 145 | + def decorator(func): | ||
| 146 | + self.manager.register_event_handler(event_type, func) | ||
| 147 | + return func | ||
| 148 | + return decorator | ||
| 149 | + | ||
| 150 | + | ||
| 151 | +class WebSocketServiceRegistry: | ||
| 152 | + """WebSocket服务注册表""" | ||
| 153 | + | ||
| 154 | + def __init__(self): | ||
| 155 | + self._services: Dict[str, WebSocketServiceBase] = {} | ||
| 156 | + | ||
| 157 | + def register_service(self, service: WebSocketServiceBase): | ||
| 158 | + """注册服务""" | ||
| 159 | + if service.service_name in self._services: | ||
| 160 | + logger.warning(f'服务已存在,将被覆盖: {service.service_name}') | ||
| 161 | + | ||
| 162 | + self._services[service.service_name] = service | ||
| 163 | + logger.info(f'注册WebSocket服务: {service.service_name}') | ||
| 164 | + | ||
| 165 | + def get_service(self, service_name: str) -> Optional[WebSocketServiceBase]: | ||
| 166 | + """获取服务""" | ||
| 167 | + return self._services.get(service_name) | ||
| 168 | + | ||
| 169 | + def list_services(self) -> List[str]: | ||
| 170 | + """列出所有服务名称""" | ||
| 171 | + return list(self._services.keys()) | ||
| 172 | + | ||
| 173 | + async def initialize_all(self): | ||
| 174 | + """初始化所有服务""" | ||
| 175 | + logger.info('初始化所有WebSocket服务...') | ||
| 176 | + for service in self._services.values(): | ||
| 177 | + await service.initialize() | ||
| 178 | + logger.info('所有WebSocket服务初始化完成') | ||
| 179 | + | ||
| 180 | + async def shutdown_all(self): | ||
| 181 | + """关闭所有服务""" | ||
| 182 | + logger.info('关闭所有WebSocket服务...') | ||
| 183 | + for service in self._services.values(): | ||
| 184 | + await service.shutdown() | ||
| 185 | + logger.info('所有WebSocket服务已关闭') | ||
| 186 | + | ||
| 187 | + def get_all_stats(self) -> Dict[str, Any]: | ||
| 188 | + """获取所有服务的统计信息""" | ||
| 189 | + stats = { | ||
| 190 | + "services": {}, | ||
| 191 | + "total_services": len(self._services) | ||
| 192 | + } | ||
| 193 | + | ||
| 194 | + for name, service in self._services.items(): | ||
| 195 | + stats["services"][name] = { | ||
| 196 | + "initialized": service.is_initialized, | ||
| 197 | + "background_tasks": len(service._background_tasks) | ||
| 198 | + } | ||
| 199 | + | ||
| 200 | + return stats | ||
| 201 | + | ||
| 202 | + | ||
| 203 | +# 全局服务注册表 | ||
| 204 | +_service_registry = WebSocketServiceRegistry() | ||
| 205 | + | ||
| 206 | + | ||
| 207 | +def get_service_registry() -> WebSocketServiceRegistry: | ||
| 208 | + """获取服务注册表""" | ||
| 209 | + return _service_registry | ||
| 210 | + | ||
| 211 | + | ||
| 212 | +def register_websocket_service(service: WebSocketServiceBase): | ||
| 213 | + """注册WebSocket服务""" | ||
| 214 | + _service_registry.register_service(service) | ||
| 215 | + | ||
| 216 | + | ||
| 217 | +def get_websocket_service(service_name: str) -> Optional[WebSocketServiceBase]: | ||
| 218 | + """获取WebSocket服务""" | ||
| 219 | + return _service_registry.get_service(service_name) |
core/wsa_server.py
deleted
100644 → 0
| 1 | -# -*- coding: utf-8 -*- | ||
| 2 | -""" | ||
| 3 | -AIfeng/2025-01-27 | ||
| 4 | -WebSocket服务器管理模块 | ||
| 5 | -提供Web和Human连接的管理功能 | ||
| 6 | -""" | ||
| 7 | - | ||
| 8 | -import queue | ||
| 9 | -from typing import Dict, Any, Optional | ||
| 10 | -from threading import Lock | ||
| 11 | - | ||
| 12 | -class WebSocketManager: | ||
| 13 | - """WebSocket连接管理器""" | ||
| 14 | - | ||
| 15 | - def __init__(self): | ||
| 16 | - self._connections = {} | ||
| 17 | - self._command_queue = queue.Queue() | ||
| 18 | - self._lock = Lock() | ||
| 19 | - | ||
| 20 | - def is_connected(self, username: str) -> bool: | ||
| 21 | - """检查用户是否已连接 | ||
| 22 | - | ||
| 23 | - Args: | ||
| 24 | - username: 用户名 | ||
| 25 | - | ||
| 26 | - Returns: | ||
| 27 | - 是否已连接 | ||
| 28 | - """ | ||
| 29 | - with self._lock: | ||
| 30 | - return username in self._connections | ||
| 31 | - | ||
| 32 | - def is_connected_human(self, username: str) -> bool: | ||
| 33 | - """检查人类用户是否已连接 | ||
| 34 | - | ||
| 35 | - Args: | ||
| 36 | - username: 用户名 | ||
| 37 | - | ||
| 38 | - Returns: | ||
| 39 | - 是否已连接 | ||
| 40 | - """ | ||
| 41 | - # 简化实现,与is_connected相同 | ||
| 42 | - return self.is_connected(username) | ||
| 43 | - | ||
| 44 | - def add_connection(self, username: str, connection: Any): | ||
| 45 | - """添加连接 | ||
| 46 | - | ||
| 47 | - Args: | ||
| 48 | - username: 用户名 | ||
| 49 | - connection: 连接对象 | ||
| 50 | - """ | ||
| 51 | - with self._lock: | ||
| 52 | - self._connections[username] = connection | ||
| 53 | - | ||
| 54 | - def remove_connection(self, username: str): | ||
| 55 | - """移除连接 | ||
| 56 | - | ||
| 57 | - Args: | ||
| 58 | - username: 用户名 | ||
| 59 | - """ | ||
| 60 | - with self._lock: | ||
| 61 | - self._connections.pop(username, None) | ||
| 62 | - | ||
| 63 | - def add_cmd(self, command: Dict[str, Any]): | ||
| 64 | - """添加命令到队列 | ||
| 65 | - | ||
| 66 | - Args: | ||
| 67 | - command: 命令字典 | ||
| 68 | - """ | ||
| 69 | - try: | ||
| 70 | - self._command_queue.put(command, timeout=1.0) | ||
| 71 | - except queue.Full: | ||
| 72 | - print(f"警告: 命令队列已满,丢弃命令: {command}") | ||
| 73 | - | ||
| 74 | - def get_cmd(self, timeout: float = 1.0) -> Optional[Dict[str, Any]]: | ||
| 75 | - """从队列获取命令 | ||
| 76 | - | ||
| 77 | - Args: | ||
| 78 | - timeout: 超时时间 | ||
| 79 | - | ||
| 80 | - Returns: | ||
| 81 | - 命令字典或None | ||
| 82 | - """ | ||
| 83 | - try: | ||
| 84 | - return self._command_queue.get(timeout=timeout) | ||
| 85 | - except queue.Empty: | ||
| 86 | - return None | ||
| 87 | - | ||
| 88 | - def get_connection_count(self) -> int: | ||
| 89 | - """获取连接数量""" | ||
| 90 | - with self._lock: | ||
| 91 | - return len(self._connections) | ||
| 92 | - | ||
| 93 | - def get_usernames(self) -> list: | ||
| 94 | - """获取所有用户名列表""" | ||
| 95 | - with self._lock: | ||
| 96 | - return list(self._connections.keys()) | ||
| 97 | - | ||
| 98 | -# 全局实例 | ||
| 99 | -_web_instance = WebSocketManager() | ||
| 100 | -_human_instance = WebSocketManager() | ||
| 101 | - | ||
| 102 | -def get_web_instance() -> WebSocketManager: | ||
| 103 | - """获取Web WebSocket管理器实例""" | ||
| 104 | - return _web_instance | ||
| 105 | - | ||
| 106 | -def get_instance() -> WebSocketManager: | ||
| 107 | - """获取Human WebSocket管理器实例""" | ||
| 108 | - return _human_instance |
core/wsa_websocket_service.py
0 → 100644
| 1 | +# -*- coding: utf-8 -*- | ||
| 2 | +""" | ||
| 3 | +AIfeng/2025-01-27 16:30:00 | ||
| 4 | +WSA WebSocket服务 | ||
| 5 | +将原有wsa_server功能集成到统一WebSocket架构中 | ||
| 6 | +""" | ||
| 7 | + | ||
| 8 | +import asyncio | ||
| 9 | +import json | ||
| 10 | +import queue | ||
| 11 | +from typing import Dict, Any, Optional, Set | ||
| 12 | +from threading import Lock | ||
| 13 | +from aiohttp import web | ||
| 14 | + | ||
| 15 | +from .websocket_service_base import WebSocketServiceBase | ||
| 16 | +from .unified_websocket_manager import WebSocketSession | ||
| 17 | + | ||
| 18 | + | ||
| 19 | +class WSAWebSocketService(WebSocketServiceBase): | ||
| 20 | + """WSA WebSocket服务 | ||
| 21 | + | ||
| 22 | + 提供与原wsa_server兼容的功能: | ||
| 23 | + - Web连接管理 | ||
| 24 | + - Human连接管理 | ||
| 25 | + - 命令队列处理 | ||
| 26 | + - 消息转发 | ||
| 27 | + """ | ||
| 28 | + | ||
| 29 | + def __init__(self, manager): | ||
| 30 | + super().__init__("wsa") | ||
| 31 | + | ||
| 32 | + # 连接管理 | ||
| 33 | + self._web_connections: Dict[str, Set[WebSocketSession]] = {} | ||
| 34 | + self._human_connections: Dict[str, Set[WebSocketSession]] = {} | ||
| 35 | + self._connection_lock = Lock() | ||
| 36 | + | ||
| 37 | + # 命令队列 | ||
| 38 | + self._web_command_queue = queue.Queue() | ||
| 39 | + self._human_command_queue = queue.Queue() | ||
| 40 | + | ||
| 41 | + # 后台任务 | ||
| 42 | + self._queue_processor_task: Optional[asyncio.Task] = None | ||
| 43 | + | ||
| 44 | + async def _register_message_handlers(self): | ||
| 45 | + """注册消息处理器""" | ||
| 46 | + self.manager.register_message_handler("wsa_register_web", self._handle_register_web) | ||
| 47 | + self.manager.register_message_handler("wsa_register_human", self._handle_register_human) | ||
| 48 | + self.manager.register_message_handler("wsa_unregister", self._handle_unregister) | ||
| 49 | + self.manager.register_message_handler("wsa_get_status", self._handle_get_status) | ||
| 50 | + | ||
| 51 | + async def _start_background_tasks(self): | ||
| 52 | + """启动后台任务""" | ||
| 53 | + self._queue_processor_task = asyncio.create_task(self._process_command_queues()) | ||
| 54 | + | ||
| 55 | + async def _cleanup(self): | ||
| 56 | + """清理资源""" | ||
| 57 | + if self._queue_processor_task: | ||
| 58 | + self._queue_processor_task.cancel() | ||
| 59 | + try: | ||
| 60 | + await self._queue_processor_task | ||
| 61 | + except asyncio.CancelledError: | ||
| 62 | + pass | ||
| 63 | + | ||
| 64 | + async def _on_session_disconnected(self, session: WebSocketSession): | ||
| 65 | + """会话断开处理""" | ||
| 66 | + with self._connection_lock: | ||
| 67 | + # 从web连接中移除 | ||
| 68 | + for username, sessions in list(self._web_connections.items()): | ||
| 69 | + if session in sessions: | ||
| 70 | + sessions.discard(session) | ||
| 71 | + if not sessions: | ||
| 72 | + del self._web_connections[username] | ||
| 73 | + | ||
| 74 | + # 从human连接中移除 | ||
| 75 | + for username, sessions in list(self._human_connections.items()): | ||
| 76 | + if session in sessions: | ||
| 77 | + sessions.discard(session) | ||
| 78 | + if not sessions: | ||
| 79 | + del self._human_connections[username] | ||
| 80 | + | ||
| 81 | + async def _handle_register_web(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 82 | + """注册Web连接""" | ||
| 83 | + username = data.get('username') | ||
| 84 | + if not username: | ||
| 85 | + await websocket.send_str(json.dumps({ | ||
| 86 | + "type": "wsa_error", | ||
| 87 | + "message": "用户名不能为空" | ||
| 88 | + })) | ||
| 89 | + return | ||
| 90 | + | ||
| 91 | + session = self.manager.get_session(websocket) | ||
| 92 | + if not session: | ||
| 93 | + await websocket.send_str(json.dumps({ | ||
| 94 | + "type": "wsa_error", | ||
| 95 | + "message": "会话未找到" | ||
| 96 | + })) | ||
| 97 | + return | ||
| 98 | + | ||
| 99 | + with self._connection_lock: | ||
| 100 | + if username not in self._web_connections: | ||
| 101 | + self._web_connections[username] = set() | ||
| 102 | + self._web_connections[username].add(session) | ||
| 103 | + | ||
| 104 | + await websocket.send_str(json.dumps({ | ||
| 105 | + "type": "wsa_registered", | ||
| 106 | + "connection_type": "web", | ||
| 107 | + "username": username | ||
| 108 | + })) | ||
| 109 | + | ||
| 110 | + async def _handle_register_human(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 111 | + """注册Human连接""" | ||
| 112 | + username = data.get('username') | ||
| 113 | + if not username: | ||
| 114 | + await websocket.send_str(json.dumps({ | ||
| 115 | + "type": "wsa_error", | ||
| 116 | + "message": "用户名不能为空" | ||
| 117 | + })) | ||
| 118 | + return | ||
| 119 | + | ||
| 120 | + session = self.manager.get_session(websocket) | ||
| 121 | + if not session: | ||
| 122 | + await websocket.send_str(json.dumps({ | ||
| 123 | + "type": "wsa_error", | ||
| 124 | + "message": "会话未找到" | ||
| 125 | + })) | ||
| 126 | + return | ||
| 127 | + | ||
| 128 | + with self._connection_lock: | ||
| 129 | + if username not in self._human_connections: | ||
| 130 | + self._human_connections[username] = set() | ||
| 131 | + self._human_connections[username].add(session) | ||
| 132 | + | ||
| 133 | + await websocket.send_str(json.dumps({ | ||
| 134 | + "type": "wsa_registered", | ||
| 135 | + "connection_type": "human", | ||
| 136 | + "username": username | ||
| 137 | + })) | ||
| 138 | + | ||
| 139 | + async def _handle_unregister(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 140 | + """注销连接""" | ||
| 141 | + username = data.get('username') | ||
| 142 | + connection_type = data.get('connection_type', 'both') | ||
| 143 | + | ||
| 144 | + session = self.manager.get_session(websocket) | ||
| 145 | + if not session: | ||
| 146 | + return | ||
| 147 | + | ||
| 148 | + with self._connection_lock: | ||
| 149 | + if connection_type in ['web', 'both'] and username in self._web_connections: | ||
| 150 | + self._web_connections[username].discard(session) | ||
| 151 | + if not self._web_connections[username]: | ||
| 152 | + del self._web_connections[username] | ||
| 153 | + | ||
| 154 | + if connection_type in ['human', 'both'] and username in self._human_connections: | ||
| 155 | + self._human_connections[username].discard(session) | ||
| 156 | + if not self._human_connections[username]: | ||
| 157 | + del self._human_connections[username] | ||
| 158 | + | ||
| 159 | + await websocket.send_str(json.dumps({ | ||
| 160 | + "type": "wsa_unregistered", | ||
| 161 | + "username": username, | ||
| 162 | + "connection_type": connection_type | ||
| 163 | + })) | ||
| 164 | + | ||
| 165 | + async def _handle_get_status(self, websocket: web.WebSocketResponse, data: Dict[str, Any]): | ||
| 166 | + """获取连接状态""" | ||
| 167 | + with self._connection_lock: | ||
| 168 | + web_users = list(self._web_connections.keys()) | ||
| 169 | + human_users = list(self._human_connections.keys()) | ||
| 170 | + | ||
| 171 | + await websocket.send_str(json.dumps({ | ||
| 172 | + "type": "wsa_status", | ||
| 173 | + "data": { | ||
| 174 | + "web_connections": len(self._web_connections), | ||
| 175 | + "human_connections": len(self._human_connections), | ||
| 176 | + "web_users": web_users, | ||
| 177 | + "human_users": human_users, | ||
| 178 | + "web_queue_size": self._web_command_queue.qsize(), | ||
| 179 | + "human_queue_size": self._human_command_queue.qsize() | ||
| 180 | + } | ||
| 181 | + })) | ||
| 182 | + | ||
| 183 | + async def _process_command_queues(self): | ||
| 184 | + """处理命令队列""" | ||
| 185 | + while True: | ||
| 186 | + try: | ||
| 187 | + # 处理Web命令队列 | ||
| 188 | + await self._process_web_commands() | ||
| 189 | + | ||
| 190 | + # 处理Human命令队列 | ||
| 191 | + await self._process_human_commands() | ||
| 192 | + | ||
| 193 | + # 短暂休眠避免CPU占用过高 | ||
| 194 | + await asyncio.sleep(0.01) | ||
| 195 | + | ||
| 196 | + except asyncio.CancelledError: | ||
| 197 | + break | ||
| 198 | + except Exception as e: | ||
| 199 | + self.logger.error(f"命令队列处理错误: {e}") | ||
| 200 | + await asyncio.sleep(0.1) | ||
| 201 | + | ||
| 202 | + async def _process_web_commands(self): | ||
| 203 | + """处理Web命令队列""" | ||
| 204 | + try: | ||
| 205 | + while True: | ||
| 206 | + try: | ||
| 207 | + command = self._web_command_queue.get_nowait() | ||
| 208 | + await self._forward_web_command(command) | ||
| 209 | + except queue.Empty: | ||
| 210 | + break | ||
| 211 | + except Exception as e: | ||
| 212 | + self.logger.error(f"Web命令处理错误: {e}") | ||
| 213 | + | ||
| 214 | + async def _process_human_commands(self): | ||
| 215 | + """处理Human命令队列""" | ||
| 216 | + try: | ||
| 217 | + while True: | ||
| 218 | + try: | ||
| 219 | + command = self._human_command_queue.get_nowait() | ||
| 220 | + await self._forward_human_command(command) | ||
| 221 | + except queue.Empty: | ||
| 222 | + break | ||
| 223 | + except Exception as e: | ||
| 224 | + self.logger.error(f"Human命令处理错误: {e}") | ||
| 225 | + | ||
| 226 | + async def _forward_web_command(self, command: Dict[str, Any]): | ||
| 227 | + """转发Web命令""" | ||
| 228 | + username = command.get('Username') | ||
| 229 | + if not username: | ||
| 230 | + return | ||
| 231 | + | ||
| 232 | + with self._connection_lock: | ||
| 233 | + sessions = self._web_connections.get(username, set()) | ||
| 234 | + | ||
| 235 | + if sessions: | ||
| 236 | + message = { | ||
| 237 | + "type": "wsa_command", | ||
| 238 | + "source": "web", | ||
| 239 | + "data": command | ||
| 240 | + } | ||
| 241 | + | ||
| 242 | + for session in list(sessions): | ||
| 243 | + try: | ||
| 244 | + await session.send_message(message) | ||
| 245 | + except Exception as e: | ||
| 246 | + self.logger.error(f"发送Web命令失败 [{username}]: {e}") | ||
| 247 | + | ||
| 248 | + async def _forward_human_command(self, command: Dict[str, Any]): | ||
| 249 | + """转发Human命令""" | ||
| 250 | + username = command.get('Username') | ||
| 251 | + if not username: | ||
| 252 | + return | ||
| 253 | + | ||
| 254 | + with self._connection_lock: | ||
| 255 | + sessions = self._human_connections.get(username, set()) | ||
| 256 | + | ||
| 257 | + if sessions: | ||
| 258 | + message = { | ||
| 259 | + "type": "wsa_command", | ||
| 260 | + "source": "human", | ||
| 261 | + "data": command | ||
| 262 | + } | ||
| 263 | + | ||
| 264 | + for session in list(sessions): | ||
| 265 | + try: | ||
| 266 | + await session.send_message(message) | ||
| 267 | + except Exception as e: | ||
| 268 | + self.logger.error(f"发送Human命令失败 [{username}]: {e}") | ||
| 269 | + | ||
| 270 | + # 兼容性接口 | ||
| 271 | + def is_connected(self, username: str) -> bool: | ||
| 272 | + """检查Web用户是否已连接""" | ||
| 273 | + with self._connection_lock: | ||
| 274 | + return username in self._web_connections and bool(self._web_connections[username]) | ||
| 275 | + | ||
| 276 | + def is_connected_human(self, username: str) -> bool: | ||
| 277 | + """检查Human用户是否已连接""" | ||
| 278 | + with self._connection_lock: | ||
| 279 | + return username in self._human_connections and bool(self._human_connections[username]) | ||
| 280 | + | ||
| 281 | + def add_connection(self, username: str, connection: Any): | ||
| 282 | + """添加连接(兼容性接口,已废弃)""" | ||
| 283 | + self.logger.warning("add_connection方法已废弃,请使用消息注册机制") | ||
| 284 | + | ||
| 285 | + def remove_connection(self, username: str): | ||
| 286 | + """移除连接(兼容性接口,已废弃)""" | ||
| 287 | + self.logger.warning("remove_connection方法已废弃,连接会自动清理") | ||
| 288 | + | ||
| 289 | + def add_cmd(self, command: Dict[str, Any], target: str = "web"): | ||
| 290 | + """添加命令到队列""" | ||
| 291 | + try: | ||
| 292 | + if target == "web": | ||
| 293 | + self._web_command_queue.put(command, timeout=1.0) | ||
| 294 | + elif target == "human": | ||
| 295 | + self._human_command_queue.put(command, timeout=1.0) | ||
| 296 | + else: | ||
| 297 | + self.logger.warning(f"未知的目标类型: {target}") | ||
| 298 | + except queue.Full: | ||
| 299 | + self.logger.warning(f"命令队列已满,丢弃命令: {command}") | ||
| 300 | + | ||
| 301 | + async def send_direct_message(self, message: Dict[str, Any], target: str = "web"): | ||
| 302 | + """直接发送消息(不封装为wsa_command)""" | ||
| 303 | + username = message.get('Username') | ||
| 304 | + if not username: | ||
| 305 | + self.logger.warning("消息缺少Username字段") | ||
| 306 | + return | ||
| 307 | + | ||
| 308 | + with self._connection_lock: | ||
| 309 | + if target == "web": | ||
| 310 | + sessions = self._web_connections.get(username, set()) | ||
| 311 | + elif target == "human": | ||
| 312 | + sessions = self._human_connections.get(username, set()) | ||
| 313 | + else: | ||
| 314 | + self.logger.warning(f"未知的目标类型: {target}") | ||
| 315 | + return | ||
| 316 | + | ||
| 317 | + if sessions: | ||
| 318 | + for session in list(sessions): | ||
| 319 | + try: | ||
| 320 | + await session.send_message(message) | ||
| 321 | + except Exception as e: | ||
| 322 | + self.logger.error(f"直接发送消息失败 [{username}]: {e}") | ||
| 323 | + else: | ||
| 324 | + self.logger.debug(f"用户 {username} 未连接,无法发送直接消息") | ||
| 325 | + | ||
| 326 | + def get_cmd(self, timeout: float = 1.0, target: str = "web") -> Optional[Dict[str, Any]]: | ||
| 327 | + """从队列获取命令""" | ||
| 328 | + try: | ||
| 329 | + if target == "web": | ||
| 330 | + return self._web_command_queue.get(timeout=timeout) | ||
| 331 | + elif target == "human": | ||
| 332 | + return self._human_command_queue.get(timeout=timeout) | ||
| 333 | + else: | ||
| 334 | + self.logger.warning(f"未知的目标类型: {target}") | ||
| 335 | + return None | ||
| 336 | + except queue.Empty: | ||
| 337 | + return None | ||
| 338 | + | ||
| 339 | + def get_connection_count(self, target: str = "web") -> int: | ||
| 340 | + """获取连接数量""" | ||
| 341 | + with self._connection_lock: | ||
| 342 | + if target == "web": | ||
| 343 | + return len(self._web_connections) | ||
| 344 | + elif target == "human": | ||
| 345 | + return len(self._human_connections) | ||
| 346 | + else: | ||
| 347 | + return len(self._web_connections) + len(self._human_connections) | ||
| 348 | + | ||
| 349 | + def get_usernames(self, target: str = "web") -> list: | ||
| 350 | + """获取用户名列表""" | ||
| 351 | + with self._connection_lock: | ||
| 352 | + if target == "web": | ||
| 353 | + return list(self._web_connections.keys()) | ||
| 354 | + elif target == "human": | ||
| 355 | + return list(self._human_connections.keys()) | ||
| 356 | + else: | ||
| 357 | + return list(set(list(self._web_connections.keys()) + list(self._human_connections.keys()))) | ||
| 358 | + | ||
| 359 | + | ||
| 360 | +# 兼容性包装器 | ||
| 361 | +class WSAWebSocketManager: | ||
| 362 | + """WSA WebSocket管理器兼容性包装器""" | ||
| 363 | + | ||
| 364 | + def __init__(self, service: WSAWebSocketService): | ||
| 365 | + self.service = service | ||
| 366 | + | ||
| 367 | + def is_connected(self, username: str) -> bool: | ||
| 368 | + return self.service.is_connected(username) | ||
| 369 | + | ||
| 370 | + def is_connected_human(self, username: str) -> bool: | ||
| 371 | + return self.service.is_connected_human(username) | ||
| 372 | + | ||
| 373 | + def add_connection(self, username: str, connection: Any): | ||
| 374 | + self.service.add_connection(username, connection) | ||
| 375 | + | ||
| 376 | + def remove_connection(self, username: str): | ||
| 377 | + self.service.remove_connection(username) | ||
| 378 | + | ||
| 379 | + def add_cmd(self, command: Dict[str, Any]): | ||
| 380 | + self.service.add_cmd(command, "web") | ||
| 381 | + | ||
| 382 | + async def send_direct_message(self, message: Dict[str, Any]): | ||
| 383 | + """直接发送消息(不封装为wsa_command)""" | ||
| 384 | + await self.service.send_direct_message(message, "web") | ||
| 385 | + | ||
| 386 | + def get_cmd(self, timeout: float = 1.0) -> Optional[Dict[str, Any]]: | ||
| 387 | + return self.service.get_cmd(timeout, "web") | ||
| 388 | + | ||
| 389 | + def get_connection_count(self) -> int: | ||
| 390 | + return self.service.get_connection_count("web") | ||
| 391 | + | ||
| 392 | + def get_usernames(self) -> list: | ||
| 393 | + return self.service.get_usernames("web") | ||
| 394 | + | ||
| 395 | + | ||
| 396 | +# 全局实例(兼容性) | ||
| 397 | +_wsa_service: Optional[WSAWebSocketService] = None | ||
| 398 | +_web_instance: Optional[WSAWebSocketManager] = None | ||
| 399 | +_human_instance: Optional[WSAWebSocketManager] = None | ||
| 400 | + | ||
| 401 | + | ||
| 402 | +def initialize_wsa_service(service: WSAWebSocketService): | ||
| 403 | + """初始化WSA服务""" | ||
| 404 | + global _wsa_service, _web_instance, _human_instance | ||
| 405 | + _wsa_service = service | ||
| 406 | + _web_instance = WSAWebSocketManager(service) | ||
| 407 | + _human_instance = WSAWebSocketManager(service) | ||
| 408 | + | ||
| 409 | + | ||
| 410 | +def get_web_instance() -> WSAWebSocketManager: | ||
| 411 | + """获取Web WebSocket管理器实例""" | ||
| 412 | + if _web_instance is None: | ||
| 413 | + raise RuntimeError("WSA服务未初始化") | ||
| 414 | + return _web_instance | ||
| 415 | + | ||
| 416 | + | ||
| 417 | +def get_instance() -> WSAWebSocketManager: | ||
| 418 | + """获取Human WebSocket管理器实例""" | ||
| 419 | + if _human_instance is None: | ||
| 420 | + raise RuntimeError("WSA服务未初始化") | ||
| 421 | + return _human_instance |
| @@ -29,7 +29,7 @@ def _play_frame(stream, exit_event, queue, chunk): | @@ -29,7 +29,7 @@ def _play_frame(stream, exit_event, queue, chunk): | ||
| 29 | print(f'[INFO] play frame thread ends') | 29 | print(f'[INFO] play frame thread ends') |
| 30 | break | 30 | break |
| 31 | frame = queue.get() | 31 | frame = queue.get() |
| 32 | - frame = (frame * 32767).astype(np.int16).tobytes() | 32 | + frame = bytes((frame * 32767).astype(np.int16).tobytes()) # Fix BufferError: memoryview has 1 exported buffer |
| 33 | stream.write(frame, chunk) | 33 | stream.write(frame, chunk) |
| 34 | 34 | ||
| 35 | class ASR: | 35 | class ASR: |
| @@ -71,18 +71,45 @@ class FunASRClient(BaseASR): | @@ -71,18 +71,45 @@ class FunASRClient(BaseASR): | ||
| 71 | async def _connect_websocket(self): | 71 | async def _connect_websocket(self): |
| 72 | """连接WebSocket服务器""" | 72 | """连接WebSocket服务器""" |
| 73 | try: | 73 | try: |
| 74 | - self.websocket = await websockets.connect( | ||
| 75 | - self.server_url, | ||
| 76 | - timeout=getattr(cfg, 'asr_timeout', 30) | 74 | + # 修复: websockets新版本不支持timeout参数,使用asyncio.wait_for包装 |
| 75 | + timeout_seconds = getattr(cfg, 'asr_timeout', 30) | ||
| 76 | + self.websocket = await asyncio.wait_for( | ||
| 77 | + websockets.connect(self.server_url), | ||
| 78 | + timeout=timeout_seconds | ||
| 77 | ) | 79 | ) |
| 78 | self.connected = True | 80 | self.connected = True |
| 79 | util.log(1, f"FunASR WebSocket连接成功: {self.server_url}") | 81 | util.log(1, f"FunASR WebSocket连接成功: {self.server_url}") |
| 82 | + | ||
| 83 | + # 发送初始化配置消息(参考funasr_client_api.py) | ||
| 84 | + await self._send_init_message() | ||
| 85 | + | ||
| 80 | return True | 86 | return True |
| 81 | except Exception as e: | 87 | except Exception as e: |
| 82 | util.log(3, f"FunASR WebSocket连接失败: {e}") | 88 | util.log(3, f"FunASR WebSocket连接失败: {e}") |
| 83 | self.connected = False | 89 | self.connected = False |
| 84 | return False | 90 | return False |
| 85 | 91 | ||
| 92 | + async def _send_init_message(self): | ||
| 93 | + """发送FunASR初始化配置消息""" | ||
| 94 | + try: | ||
| 95 | + # 根据参考项目funasr_client_api.py的格式 | ||
| 96 | + init_message = { | ||
| 97 | + "mode": "2pass", | ||
| 98 | + "chunk_size": [0, 10, 5], # [vad_need, chunk_size, chunk_interval] | ||
| 99 | + "encoder_chunk_look_back": 4, | ||
| 100 | + "decoder_chunk_look_back": 1, | ||
| 101 | + "chunk_interval": 10, | ||
| 102 | + "wav_name": self.username, | ||
| 103 | + "is_speaking": True | ||
| 104 | + } | ||
| 105 | + | ||
| 106 | + await self.websocket.send(json.dumps(init_message)) | ||
| 107 | + util.log(1, f"发送FunASR初始化消息: {init_message}") | ||
| 108 | + | ||
| 109 | + except Exception as e: | ||
| 110 | + util.log(3, f"发送初始化消息失败: {e}") | ||
| 111 | + raise e | ||
| 112 | + | ||
| 86 | async def _disconnect_websocket(self): | 113 | async def _disconnect_websocket(self): |
| 87 | """断开WebSocket连接""" | 114 | """断开WebSocket连接""" |
| 88 | if self.websocket: | 115 | if self.websocket: |
| @@ -141,11 +168,12 @@ class FunASRClient(BaseASR): | @@ -141,11 +168,12 @@ class FunASRClient(BaseASR): | ||
| 141 | message = self.message_queue.get_nowait() | 168 | message = self.message_queue.get_nowait() |
| 142 | 169 | ||
| 143 | if isinstance(message, dict): | 170 | if isinstance(message, dict): |
| 144 | - # JSON消息 | 171 | + # JSON消息(配置消息或结束信号) |
| 145 | await self.websocket.send(json.dumps(message)) | 172 | await self.websocket.send(json.dumps(message)) |
| 146 | util.log(1, f"发送JSON消息: {message}") | 173 | util.log(1, f"发送JSON消息: {message}") |
| 147 | elif isinstance(message, bytes): | 174 | elif isinstance(message, bytes): |
| 148 | - # 二进制音频数据 | 175 | + # 二进制音频数据(参考funasr_client_api.py的feed_chunk方法) |
| 176 | + # 确保音频数据以二进制格式发送 | ||
| 149 | await self.websocket.send(message) | 177 | await self.websocket.send(message) |
| 150 | util.log(1, f"发送音频数据: {len(message)} bytes") | 178 | util.log(1, f"发送音频数据: {len(message)} bytes") |
| 151 | else: | 179 | else: |
| @@ -203,23 +231,24 @@ class FunASRClient(BaseASR): | @@ -203,23 +231,24 @@ class FunASRClient(BaseASR): | ||
| 203 | text: 识别文本 | 231 | text: 识别文本 |
| 204 | """ | 232 | """ |
| 205 | try: | 233 | try: |
| 206 | - from core import wsa_server | 234 | + from core import get_web_instance, get_instance |
| 207 | 235 | ||
| 208 | # 发送到Web客户端 | 236 | # 发送到Web客户端 |
| 209 | - if wsa_server.get_web_instance().is_connected(self.username): | ||
| 210 | - wsa_server.get_web_instance().add_cmd({ | ||
| 211 | - "panelMsg": text, | ||
| 212 | - "Username": self.username | ||
| 213 | - }) | ||
| 214 | - | ||
| 215 | - # 发送到Human客户端 | ||
| 216 | - if wsa_server.get_instance().is_connected_human(self.username): | ||
| 217 | - content = { | ||
| 218 | - 'Topic': 'human', | ||
| 219 | - 'Data': {'Key': 'log', 'Value': text}, | ||
| 220 | - 'Username': self.username | 237 | + if get_web_instance().is_connected(self.username): |
| 238 | + import asyncio | ||
| 239 | + # 创建chat_message直接推送 | ||
| 240 | + chat_message = { | ||
| 241 | + "type": "chat_message", | ||
| 242 | + "sender": "回音", | ||
| 243 | + "content": text, # 修复字段名:panelMsg -> content | ||
| 244 | + "Username": self.username, | ||
| 245 | + "model_info": "FunASR" | ||
| 221 | } | 246 | } |
| 222 | - wsa_server.get_instance().add_cmd(content) | 247 | + # 使用直接发送方法,避免wsa_command封装 |
| 248 | + asyncio.create_task(get_web_instance().send_direct_message(chat_message)) | ||
| 249 | + | ||
| 250 | + # Human客户端通知改为日志记录(避免重复通知当前服务) | ||
| 251 | + util.log(1, f"FunASR识别结果[{self.username}]: {text}") | ||
| 223 | 252 | ||
| 224 | except Exception as e: | 253 | except Exception as e: |
| 225 | util.log(2, f"发送到Web客户端失败: {e}") | 254 | util.log(2, f"发送到Web客户端失败: {e}") |
| @@ -333,6 +362,19 @@ class FunASRClient(BaseASR): | @@ -333,6 +362,19 @@ class FunASRClient(BaseASR): | ||
| 333 | self.message_queue.put(audio_data) | 362 | self.message_queue.put(audio_data) |
| 334 | return True | 363 | return True |
| 335 | 364 | ||
| 365 | + def send_end_signal(self): | ||
| 366 | + """发送结束信号""" | ||
| 367 | + if not self.connected: | ||
| 368 | + return | ||
| 369 | + | ||
| 370 | + try: | ||
| 371 | + # 发送结束消息(参考funasr_client_api.py的close方法) | ||
| 372 | + end_message = {"is_speaking": False} | ||
| 373 | + self.message_queue.put(end_message) | ||
| 374 | + util.log(1, "发送FunASR结束信号") | ||
| 375 | + except Exception as e: | ||
| 376 | + util.log(3, f"发送结束信号失败: {e}") | ||
| 377 | + | ||
| 336 | def start_recognition(self): | 378 | def start_recognition(self): |
| 337 | """开始语音识别""" | 379 | """开始语音识别""" |
| 338 | if not self.connected: | 380 | if not self.connected: |
| @@ -418,6 +460,99 @@ class FunASRClient(BaseASR): | @@ -418,6 +460,99 @@ class FunASRClient(BaseASR): | ||
| 418 | # 简化实现,返回空特征 | 460 | # 简化实现,返回空特征 |
| 419 | return np.zeros((1, 50), dtype=np.float32) | 461 | return np.zeros((1, 50), dtype=np.float32) |
| 420 | 462 | ||
| 463 | + async def connect(self): | ||
| 464 | + """异步连接到FunASR服务器""" | ||
| 465 | + if self.connected: | ||
| 466 | + util.log(1, "FunASR客户端已连接") | ||
| 467 | + return True | ||
| 468 | + | ||
| 469 | + try: | ||
| 470 | + success = await self._connect_websocket() | ||
| 471 | + if success: | ||
| 472 | + # 启动消息处理任务 | ||
| 473 | + self.receive_task = asyncio.create_task(self._receive_messages()) | ||
| 474 | + self.send_task = asyncio.create_task(self._send_message_loop()) | ||
| 475 | + util.log(1, "FunASR异步连接建立成功") | ||
| 476 | + return success | ||
| 477 | + except Exception as e: | ||
| 478 | + util.log(3, f"FunASR异步连接失败: {e}") | ||
| 479 | + return False | ||
| 480 | + | ||
| 481 | + async def disconnect(self): | ||
| 482 | + """异步断开连接""" | ||
| 483 | + try: | ||
| 484 | + # 取消任务 | ||
| 485 | + if hasattr(self, 'receive_task'): | ||
| 486 | + self.receive_task.cancel() | ||
| 487 | + if hasattr(self, 'send_task'): | ||
| 488 | + self.send_task.cancel() | ||
| 489 | + | ||
| 490 | + # 断开WebSocket连接 | ||
| 491 | + await self._disconnect_websocket() | ||
| 492 | + util.log(1, "FunASR异步连接已断开") | ||
| 493 | + except Exception as e: | ||
| 494 | + util.log(2, f"断开FunASR连接时出错: {e}") | ||
| 495 | + | ||
| 496 | + async def send_audio_data(self, audio_data): | ||
| 497 | + """异步发送音频数据""" | ||
| 498 | + try: | ||
| 499 | + if isinstance(audio_data, str): | ||
| 500 | + # Base64编码的音频数据,需要解码 | ||
| 501 | + import base64 | ||
| 502 | + audio_bytes = base64.b64decode(audio_data) | ||
| 503 | + util.log(1, f"解码Base64音频数据: {len(audio_bytes)} bytes") | ||
| 504 | + elif isinstance(audio_data, bytes): | ||
| 505 | + audio_bytes = audio_data | ||
| 506 | + util.log(1, f"接收字节音频数据: {len(audio_bytes)} bytes") | ||
| 507 | + elif isinstance(audio_data, np.ndarray): | ||
| 508 | + # NumPy数组转换为字节 | ||
| 509 | + if audio_data.dtype != np.int16: | ||
| 510 | + audio_data = audio_data.astype(np.int16) | ||
| 511 | + audio_bytes = bytes(audio_data.tobytes()) # Fix BufferError: memoryview has 1 exported buffer | ||
| 512 | + util.log(1, f"转换NumPy数组为字节: {len(audio_bytes)} bytes") | ||
| 513 | + else: | ||
| 514 | + util.log(3, f"不支持的音频数据类型: {type(audio_data)},尝试转换为字节") | ||
| 515 | + # 尝试强制转换 | ||
| 516 | + try: | ||
| 517 | + audio_bytes = bytes(audio_data) | ||
| 518 | + except Exception as convert_error: | ||
| 519 | + util.log(3, f"音频数据类型转换失败: {convert_error}") | ||
| 520 | + return False | ||
| 521 | + | ||
| 522 | + # 验证音频数据有效性 | ||
| 523 | + if len(audio_bytes) == 0: | ||
| 524 | + util.log(2, "音频数据为空,跳过发送") | ||
| 525 | + return False | ||
| 526 | + | ||
| 527 | + # 确保音频数据长度为偶数(16位采样) | ||
| 528 | + if len(audio_bytes) % 2 != 0: | ||
| 529 | + audio_bytes = audio_bytes[:-1] # 去掉最后一个字节 | ||
| 530 | + util.log(2, f"调整音频数据长度为偶数: {len(audio_bytes)} bytes") | ||
| 531 | + | ||
| 532 | + # 参考funasr_client_api.py,音频数据需要按chunk发送 | ||
| 533 | + # 计算stride(参考项目中的计算方式) | ||
| 534 | + chunk_interval = 10 # ms | ||
| 535 | + chunk_size = 10 # ms | ||
| 536 | + stride = int(60 * chunk_size / chunk_interval / 1000 * 16000 * 2) | ||
| 537 | + | ||
| 538 | + # 如果音频数据较大,分块发送 | ||
| 539 | + if len(audio_bytes) > stride: | ||
| 540 | + chunk_num = (len(audio_bytes) - 1) // stride + 1 | ||
| 541 | + for i in range(chunk_num): | ||
| 542 | + beg = i * stride | ||
| 543 | + chunk_data = audio_bytes[beg:beg + stride] | ||
| 544 | + self.message_queue.put(chunk_data) | ||
| 545 | + util.log(1, f"发送音频块 {i+1}/{chunk_num}: {len(chunk_data)} bytes") | ||
| 546 | + else: | ||
| 547 | + # 小数据直接发送 | ||
| 548 | + self.message_queue.put(audio_bytes) | ||
| 549 | + util.log(1, f"发送音频数据: {len(audio_bytes)} bytes") | ||
| 550 | + | ||
| 551 | + return True | ||
| 552 | + except Exception as e: | ||
| 553 | + util.log(3, f"发送音频数据失败: {e}") | ||
| 554 | + return False | ||
| 555 | + | ||
| 421 | def __del__(self): | 556 | def __del__(self): |
| 422 | """析构函数""" | 557 | """析构函数""" |
| 423 | self.stop() | 558 | self.stop() |
| @@ -293,7 +293,10 @@ class LightReal(BaseReal): | @@ -293,7 +293,10 @@ class LightReal(BaseReal): | ||
| 293 | frame,type_,eventpoint = audio_frame | 293 | frame,type_,eventpoint = audio_frame |
| 294 | frame = (frame * 32767).astype(np.int16) | 294 | frame = (frame * 32767).astype(np.int16) |
| 295 | new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) | 295 | new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) |
| 296 | - new_frame.planes[0].update(frame.tobytes()) | 296 | + # 修复 BufferError: memoryview has 1 exported buffer |
| 297 | + # 创建数据副本避免内存视图冲突 | ||
| 298 | + frame_bytes = bytes(frame.tobytes()) | ||
| 299 | + new_frame.planes[0].update(frame_bytes) | ||
| 297 | new_frame.sample_rate=16000 | 300 | new_frame.sample_rate=16000 |
| 298 | # if audio_track._queue.qsize()>10: | 301 | # if audio_track._queue.qsize()>10: |
| 299 | # time.sleep(0.1) | 302 | # time.sleep(0.1) |
| @@ -263,7 +263,41 @@ class LipReal(BaseReal): | @@ -263,7 +263,41 @@ class LipReal(BaseReal): | ||
| 263 | #print('blending time:',time.perf_counter()-t) | 263 | #print('blending time:',time.perf_counter()-t) |
| 264 | 264 | ||
| 265 | image = combine_frame #(outputs['image'] * 255).astype(np.uint8) | 265 | image = combine_frame #(outputs['image'] * 255).astype(np.uint8) |
| 266 | + | ||
| 267 | + # Fix MemoryError: 优化内存使用和错误处理 | ||
| 268 | + try: | ||
| 269 | + # 检查图像尺寸,如果过大则压缩 | ||
| 270 | + h, w = image.shape[:2] | ||
| 271 | + max_dimension = 1920 # 最大尺寸限制 | ||
| 272 | + if h > max_dimension or w > max_dimension: | ||
| 273 | + scale = max_dimension / max(h, w) | ||
| 274 | + new_h, new_w = int(h * scale), int(w * scale) | ||
| 275 | + image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA) | ||
| 276 | + logger.warning(f"Image resized from {w}x{h} to {new_w}x{new_h} to prevent MemoryError") | ||
| 277 | + | ||
| 278 | + # 确保数据类型正确并创建连续内存布局 | ||
| 279 | + if not image.flags['C_CONTIGUOUS']: | ||
| 280 | + image = np.ascontiguousarray(image) | ||
| 281 | + image = image.astype(np.uint8) | ||
| 282 | + | ||
| 266 | new_frame = VideoFrame.from_ndarray(image, format="bgr24") | 283 | new_frame = VideoFrame.from_ndarray(image, format="bgr24") |
| 284 | + except MemoryError as e: | ||
| 285 | + logger.error(f"MemoryError in VideoFrame creation: {e}, image shape: {image.shape}") | ||
| 286 | + # 进一步压缩图像作为备用方案 | ||
| 287 | + try: | ||
| 288 | + h, w = image.shape[:2] | ||
| 289 | + backup_scale = 0.5 | ||
| 290 | + backup_h, backup_w = int(h * backup_scale), int(w * backup_scale) | ||
| 291 | + image = cv2.resize(image, (backup_w, backup_h), interpolation=cv2.INTER_AREA) | ||
| 292 | + image = np.ascontiguousarray(image.astype(np.uint8)) | ||
| 293 | + new_frame = VideoFrame.from_ndarray(image, format="bgr24") | ||
| 294 | + logger.info(f"Backup resize successful: {backup_w}x{backup_h}") | ||
| 295 | + except Exception as backup_e: | ||
| 296 | + logger.error(f"Backup resize failed: {backup_e}") | ||
| 297 | + continue | ||
| 298 | + except Exception as e: | ||
| 299 | + logger.error(f"Unexpected error in VideoFrame creation: {e}") | ||
| 300 | + continue | ||
| 267 | asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop) | 301 | asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop) |
| 268 | self.record_video_data(image) | 302 | self.record_video_data(image) |
| 269 | 303 | ||
| @@ -271,11 +305,19 @@ class LipReal(BaseReal): | @@ -271,11 +305,19 @@ class LipReal(BaseReal): | ||
| 271 | frame,type,eventpoint = audio_frame | 305 | frame,type,eventpoint = audio_frame |
| 272 | frame = (frame * 32767).astype(np.int16) | 306 | frame = (frame * 32767).astype(np.int16) |
| 273 | new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) | 307 | new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) |
| 274 | - new_frame.planes[0].update(frame.tobytes()) | ||
| 275 | - new_frame.sample_rate=16000 | ||
| 276 | - # if audio_track._queue.qsize()>10: | ||
| 277 | - # time.sleep(0.1) | ||
| 278 | - asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) | 308 | + |
| 309 | + # 修复 BufferError: 强制复制数据避免内存视图冲突 | ||
| 310 | + frame_copy = frame.astype(np.int16).copy() | ||
| 311 | + frame_bytes = bytes(frame_copy.tobytes()) # Fix BufferError: memoryview has 1 exported buffer | ||
| 312 | + new_frame.planes[0].update(frame_bytes) | ||
| 313 | + | ||
| 314 | + new_frame.sample_rate = 16000 | ||
| 315 | + | ||
| 316 | + # 使用线程安全的方式提交到队列,避免闭包问题 | ||
| 317 | + def put_audio_frame(frame_obj, event_point): | ||
| 318 | + audio_track._queue.put_nowait((frame_obj, event_point)) | ||
| 319 | + | ||
| 320 | + loop.call_soon_threadsafe(put_audio_frame, new_frame, eventpoint) | ||
| 279 | self.record_audio_data(frame) | 321 | self.record_audio_data(frame) |
| 280 | #self.notify(eventpoint) | 322 | #self.notify(eventpoint) |
| 281 | logger.info('lipreal process_frames thread stop') | 323 | logger.info('lipreal process_frames thread stop') |
| 1 | import logging | 1 | import logging |
| 2 | +import os | ||
| 3 | + | ||
| 4 | +# 确保日志目录存在 | ||
| 5 | +log_dir = "logs" | ||
| 6 | +if not os.path.exists(log_dir): | ||
| 7 | + os.makedirs(log_dir) | ||
| 2 | 8 | ||
| 3 | # 配置日志器 | 9 | # 配置日志器 |
| 4 | logger = logging.getLogger(__name__) | 10 | logger = logging.getLogger(__name__) |
| @@ -14,3 +20,46 @@ logger.addHandler(fhandler) | @@ -14,3 +20,46 @@ logger.addHandler(fhandler) | ||
| 14 | # sformatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | 20 | # sformatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') |
| 15 | # handler.setFormatter(sformatter) | 21 | # handler.setFormatter(sformatter) |
| 16 | # logger.addHandler(handler) | 22 | # logger.addHandler(handler) |
| 23 | + | ||
| 24 | +def get_logger(name: str, level: str = "INFO") -> logging.Logger: | ||
| 25 | + """获取指定名称的日志器 | ||
| 26 | + | ||
| 27 | + Args: | ||
| 28 | + name: 日志器名称 | ||
| 29 | + level: 日志级别 (DEBUG, INFO, WARNING, ERROR, CRITICAL) | ||
| 30 | + | ||
| 31 | + Returns: | ||
| 32 | + 配置好的日志器实例 | ||
| 33 | + """ | ||
| 34 | + # 创建日志器 | ||
| 35 | + logger_instance = logging.getLogger(name) | ||
| 36 | + | ||
| 37 | + # 避免重复添加处理器 | ||
| 38 | + if logger_instance.handlers: | ||
| 39 | + return logger_instance | ||
| 40 | + | ||
| 41 | + # 设置日志级别 | ||
| 42 | + log_level = getattr(logging, level.upper(), logging.INFO) | ||
| 43 | + logger_instance.setLevel(log_level) | ||
| 44 | + | ||
| 45 | + # 创建格式器 | ||
| 46 | + formatter = logging.Formatter( | ||
| 47 | + '%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||
| 48 | + ) | ||
| 49 | + | ||
| 50 | + # 文件处理器 | ||
| 51 | + log_file = os.path.join(log_dir, f"{name.lower()}.log") | ||
| 52 | + file_handler = logging.FileHandler(log_file, encoding='utf-8') | ||
| 53 | + file_handler.setFormatter(formatter) | ||
| 54 | + file_handler.setLevel(log_level) | ||
| 55 | + logger_instance.addHandler(file_handler) | ||
| 56 | + | ||
| 57 | + # 控制台处理器(可选) | ||
| 58 | + console_handler = logging.StreamHandler() | ||
| 59 | + console_handler.setFormatter(logging.Formatter( | ||
| 60 | + '%(asctime)s - %(levelname)s - %(message)s' | ||
| 61 | + )) | ||
| 62 | + console_handler.setLevel(logging.WARNING) # 控制台只显示警告及以上级别 | ||
| 63 | + logger_instance.addHandler(console_handler) | ||
| 64 | + | ||
| 65 | + return logger_instance |
| @@ -346,7 +346,10 @@ class MuseReal(BaseReal): | @@ -346,7 +346,10 @@ class MuseReal(BaseReal): | ||
| 346 | frame,type,eventpoint = audio_frame | 346 | frame,type,eventpoint = audio_frame |
| 347 | frame = (frame * 32767).astype(np.int16) | 347 | frame = (frame * 32767).astype(np.int16) |
| 348 | new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) | 348 | new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) |
| 349 | - new_frame.planes[0].update(frame.tobytes()) | 349 | + # 修复 BufferError: memoryview has 1 exported buffer |
| 350 | + # 创建数据副本避免内存视图冲突 | ||
| 351 | + frame_bytes = bytes(frame.tobytes()) | ||
| 352 | + new_frame.planes[0].update(frame_bytes) | ||
| 350 | new_frame.sample_rate=16000 | 353 | new_frame.sample_rate=16000 |
| 351 | asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) | 354 | asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) |
| 352 | self.record_audio_data(frame) | 355 | self.record_audio_data(frame) |
| @@ -247,7 +247,10 @@ class NeRFReal(BaseReal): | @@ -247,7 +247,10 @@ class NeRFReal(BaseReal): | ||
| 247 | else: #webrtc | 247 | else: #webrtc |
| 248 | frame = (frame * 32767).astype(np.int16) | 248 | frame = (frame * 32767).astype(np.int16) |
| 249 | new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) | 249 | new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) |
| 250 | - new_frame.planes[0].update(frame.tobytes()) | 250 | + # 修复 BufferError: memoryview has 1 exported buffer |
| 251 | + # 创建数据副本避免内存视图冲突 | ||
| 252 | + frame_bytes = bytes(frame.tobytes()) | ||
| 253 | + new_frame.planes[0].update(frame_bytes) | ||
| 251 | new_frame.sample_rate=16000 | 254 | new_frame.sample_rate=16000 |
| 252 | asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) | 255 | asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) |
| 253 | 256 |
| @@ -18,7 +18,7 @@ face_alignment | @@ -18,7 +18,7 @@ face_alignment | ||
| 18 | python_speech_features | 18 | python_speech_features |
| 19 | numba | 19 | numba |
| 20 | resampy | 20 | resampy |
| 21 | -#pyaudio | 21 | +pyaudio |
| 22 | soundfile==0.12.1 | 22 | soundfile==0.12.1 |
| 23 | einops | 23 | einops |
| 24 | configargparse | 24 | configargparse |
| @@ -30,6 +30,9 @@ transformers | @@ -30,6 +30,9 @@ transformers | ||
| 30 | edge_tts | 30 | edge_tts |
| 31 | flask | 31 | flask |
| 32 | flask_sockets | 32 | flask_sockets |
| 33 | +flask-socketio | ||
| 34 | +websockets | ||
| 35 | +websocket-client | ||
| 33 | opencv-python-headless | 36 | opencv-python-headless |
| 34 | aiortc | 37 | aiortc |
| 35 | aiohttp_cors | 38 | aiohttp_cors |
| @@ -41,5 +44,6 @@ accelerate | @@ -41,5 +44,6 @@ accelerate | ||
| 41 | 44 | ||
| 42 | librosa | 45 | librosa |
| 43 | openai | 46 | openai |
| 47 | +aiofiles | ||
| 44 | #判断音频类型的支持 | 48 | #判断音频类型的支持 |
| 45 | AudioSegment | 49 | AudioSegment |
test_doubao_integration.py
deleted
100644 → 0
| 1 | -#!/usr/bin/env python3 | ||
| 2 | -# AIfeng/2024-12-19 | ||
| 3 | -# 豆包模型集成测试脚本 | ||
| 4 | - | ||
| 5 | -import os | ||
| 6 | -import sys | ||
| 7 | -import json | ||
| 8 | -from pathlib import Path | ||
| 9 | - | ||
| 10 | -# 添加项目根目录到Python路径 | ||
| 11 | -project_root = Path(__file__).parent | ||
| 12 | -sys.path.insert(0, str(project_root)) | ||
| 13 | - | ||
| 14 | -def test_config_files(): | ||
| 15 | - """测试配置文件是否存在和格式正确""" | ||
| 16 | - print("=== 配置文件测试 ===") | ||
| 17 | - | ||
| 18 | - # 测试LLM配置文件 | ||
| 19 | - llm_config_path = project_root / "config" / "llm_config.json" | ||
| 20 | - if llm_config_path.exists(): | ||
| 21 | - try: | ||
| 22 | - with open(llm_config_path, 'r', encoding='utf-8') as f: | ||
| 23 | - llm_config = json.load(f) | ||
| 24 | - print(f"✓ LLM配置文件加载成功: {llm_config_path}") | ||
| 25 | - print(f" 当前模型类型: {llm_config.get('model_type', 'unknown')}") | ||
| 26 | - except Exception as e: | ||
| 27 | - print(f"✗ LLM配置文件格式错误: {e}") | ||
| 28 | - else: | ||
| 29 | - print(f"✗ LLM配置文件不存在: {llm_config_path}") | ||
| 30 | - | ||
| 31 | - # 测试豆包配置文件 | ||
| 32 | - doubao_config_path = project_root / "config" / "doubao_config.json" | ||
| 33 | - if doubao_config_path.exists(): | ||
| 34 | - try: | ||
| 35 | - with open(doubao_config_path, 'r', encoding='utf-8') as f: | ||
| 36 | - doubao_config = json.load(f) | ||
| 37 | - print(f"✓ 豆包配置文件加载成功: {doubao_config_path}") | ||
| 38 | - print(f" 模型名称: {doubao_config.get('model', 'unknown')}") | ||
| 39 | - print(f" 人物设定: {doubao_config.get('character', {}).get('name', 'unknown')}") | ||
| 40 | - except Exception as e: | ||
| 41 | - print(f"✗ 豆包配置文件格式错误: {e}") | ||
| 42 | - else: | ||
| 43 | - print(f"✗ 豆包配置文件不存在: {doubao_config_path}") | ||
| 44 | - | ||
| 45 | -def test_module_import(): | ||
| 46 | - """测试模块导入""" | ||
| 47 | - print("\n=== 模块导入测试 ===") | ||
| 48 | - | ||
| 49 | - try: | ||
| 50 | - from llm.Doubao import Doubao | ||
| 51 | - print("✓ 豆包模块导入成功") | ||
| 52 | - except ImportError as e: | ||
| 53 | - print(f"✗ 豆包模块导入失败: {e}") | ||
| 54 | - return False | ||
| 55 | - | ||
| 56 | - try: | ||
| 57 | - import llm | ||
| 58 | - print(f"✓ LLM包导入成功,可用模型: {llm.AVAILABLE_MODELS}") | ||
| 59 | - except ImportError as e: | ||
| 60 | - print(f"✗ LLM包导入失败: {e}") | ||
| 61 | - | ||
| 62 | - return True | ||
| 63 | - | ||
| 64 | -def test_llm_config_loading(): | ||
| 65 | - """测试LLM配置加载函数""" | ||
| 66 | - print("\n=== LLM配置加载测试 ===") | ||
| 67 | - | ||
| 68 | - try: | ||
| 69 | - # 模拟llm.py中的配置加载函数 | ||
| 70 | - config_path = project_root / "config" / "llm_config.json" | ||
| 71 | - if config_path.exists(): | ||
| 72 | - with open(config_path, 'r', encoding='utf-8') as f: | ||
| 73 | - config = json.load(f) | ||
| 74 | - print(f"✓ 配置加载成功") | ||
| 75 | - print(f" 模型类型: {config.get('model_type')}") | ||
| 76 | - print(f" 配置项: {list(config.keys())}") | ||
| 77 | - return config | ||
| 78 | - else: | ||
| 79 | - print("✗ 配置文件不存在,使用默认配置") | ||
| 80 | - return {"model_type": "qwen"} | ||
| 81 | - except Exception as e: | ||
| 82 | - print(f"✗ 配置加载失败: {e}") | ||
| 83 | - return {"model_type": "qwen"} | ||
| 84 | - | ||
| 85 | -def test_doubao_instantiation(): | ||
| 86 | - """测试豆包模型实例化(不需要真实API密钥)""" | ||
| 87 | - print("\n=== 豆包实例化测试 ===") | ||
| 88 | - | ||
| 89 | - try: | ||
| 90 | - from llm.Doubao import Doubao | ||
| 91 | - | ||
| 92 | - # 设置测试API密钥 | ||
| 93 | - os.environ['DOUBAO_API_KEY'] = 'test_key_for_validation' | ||
| 94 | - | ||
| 95 | - doubao = Doubao() | ||
| 96 | - print("✓ 豆包实例化成功") | ||
| 97 | - print(f" 配置文件路径: {doubao.config_file}") | ||
| 98 | - print(f" API基础URL: {doubao.base_url}") | ||
| 99 | - print(f" 模型名称: {doubao.model}") | ||
| 100 | - | ||
| 101 | - # 清理测试环境变量 | ||
| 102 | - if 'DOUBAO_API_KEY' in os.environ: | ||
| 103 | - del os.environ['DOUBAO_API_KEY'] | ||
| 104 | - | ||
| 105 | - return True | ||
| 106 | - except Exception as e: | ||
| 107 | - print(f"✗ 豆包实例化失败: {e}") | ||
| 108 | - return False | ||
| 109 | - | ||
| 110 | -def test_integration_flow(): | ||
| 111 | - """测试完整集成流程""" | ||
| 112 | - print("\n=== 集成流程测试 ===") | ||
| 113 | - | ||
| 114 | - try: | ||
| 115 | - # 模拟llm.py中的流程 | ||
| 116 | - config = test_llm_config_loading() | ||
| 117 | - model_type = config.get("model_type", "qwen") | ||
| 118 | - | ||
| 119 | - print(f"根据配置选择模型: {model_type}") | ||
| 120 | - | ||
| 121 | - if model_type == "doubao": | ||
| 122 | - print("✓ 将使用豆包模型处理请求") | ||
| 123 | - elif model_type == "qwen": | ||
| 124 | - print("✓ 将使用通义千问模型处理请求") | ||
| 125 | - else: | ||
| 126 | - print(f"⚠ 未知模型类型: {model_type}") | ||
| 127 | - | ||
| 128 | - return True | ||
| 129 | - except Exception as e: | ||
| 130 | - print(f"✗ 集成流程测试失败: {e}") | ||
| 131 | - return False | ||
| 132 | - | ||
| 133 | -def main(): | ||
| 134 | - """主测试函数""" | ||
| 135 | - print("豆包模型集成测试") | ||
| 136 | - print("=" * 50) | ||
| 137 | - | ||
| 138 | - # 运行所有测试 | ||
| 139 | - test_config_files() | ||
| 140 | - | ||
| 141 | - if not test_module_import(): | ||
| 142 | - print("\n模块导入失败,停止测试") | ||
| 143 | - return | ||
| 144 | - | ||
| 145 | - test_llm_config_loading() | ||
| 146 | - test_doubao_instantiation() | ||
| 147 | - test_integration_flow() | ||
| 148 | - | ||
| 149 | - print("\n=== 测试总结 ===") | ||
| 150 | - print("✓ 豆包模型已成功集成到项目中") | ||
| 151 | - print("✓ 配置文件结构正确") | ||
| 152 | - print("✓ 模块导入正常") | ||
| 153 | - print("\n使用说明:") | ||
| 154 | - print("1. 设置环境变量 DOUBAO_API_KEY 为您的豆包API密钥") | ||
| 155 | - print("2. 在 config/llm_config.json 中设置 model_type 为 'doubao'") | ||
| 156 | - print("3. 根据需要修改 config/doubao_config.json 中的人物设定") | ||
| 157 | - print("4. 重启应用即可使用豆包模型") | ||
| 158 | - | ||
| 159 | -if __name__ == "__main__": | ||
| 160 | - main() |
test_funasr_connection.py
deleted
100644 → 0
| 1 | -# AIfeng/2025-01-27 | ||
| 2 | -""" | ||
| 3 | -FunASR服务连接测试脚本 | ||
| 4 | -用于验证本地FunASR WebSocket服务是否可以正常连接 | ||
| 5 | - | ||
| 6 | -使用方法: | ||
| 7 | -1. 先启动FunASR服务:python -u web/asr/funasr/ASR_server.py --host "127.0.0.1" --port 10197 --ngpu 0 | ||
| 8 | -2. 运行此测试脚本:python test_funasr_connection.py | ||
| 9 | -""" | ||
| 10 | - | ||
| 11 | -import asyncio | ||
| 12 | -import websockets | ||
| 13 | -import json | ||
| 14 | -import os | ||
| 15 | -import wave | ||
| 16 | -import numpy as np | ||
| 17 | -from pathlib import Path | ||
| 18 | - | ||
| 19 | -class FunASRConnectionTest: | ||
| 20 | - def __init__(self, host="127.0.0.1", port=10197): | ||
| 21 | - self.host = host | ||
| 22 | - self.port = port | ||
| 23 | - self.uri = f"ws://{host}:{port}" | ||
| 24 | - | ||
| 25 | - async def test_basic_connection(self): | ||
| 26 | - """测试基本WebSocket连接""" | ||
| 27 | - print(f"🔍 测试连接到 {self.uri}") | ||
| 28 | - try: | ||
| 29 | - async with websockets.connect(self.uri) as websocket: | ||
| 30 | - print("✅ FunASR WebSocket服务连接成功") | ||
| 31 | - return True | ||
| 32 | - except ConnectionRefusedError: | ||
| 33 | - print("❌ 连接被拒绝,请确认FunASR服务已启动") | ||
| 34 | - print(" 启动命令: python -u web/asr/funasr/ASR_server.py --host \"127.0.0.1\" --port 10197 --ngpu 0") | ||
| 35 | - return False | ||
| 36 | - except Exception as e: | ||
| 37 | - print(f"❌ 连接失败: {e}") | ||
| 38 | - return False | ||
| 39 | - | ||
| 40 | - def create_test_wav(self, filename="test_audio.wav", duration=2, sample_rate=16000): | ||
| 41 | - """创建测试用的WAV文件""" | ||
| 42 | - # 生成简单的正弦波音频 | ||
| 43 | - t = np.linspace(0, duration, int(sample_rate * duration), False) | ||
| 44 | - frequency = 440 # A4音符 | ||
| 45 | - audio_data = np.sin(2 * np.pi * frequency * t) * 0.3 | ||
| 46 | - | ||
| 47 | - # 转换为16位整数 | ||
| 48 | - audio_data = (audio_data * 32767).astype(np.int16) | ||
| 49 | - | ||
| 50 | - # 保存为WAV文件 | ||
| 51 | - with wave.open(filename, 'wb') as wav_file: | ||
| 52 | - wav_file.setnchannels(1) # 单声道 | ||
| 53 | - wav_file.setsampwidth(2) # 16位 | ||
| 54 | - wav_file.setframerate(sample_rate) | ||
| 55 | - wav_file.writeframes(audio_data.tobytes()) | ||
| 56 | - | ||
| 57 | - print(f"📁 创建测试音频文件: {filename}") | ||
| 58 | - return filename | ||
| 59 | - | ||
| 60 | - async def test_audio_recognition(self): | ||
| 61 | - """测试音频识别功能""" | ||
| 62 | - print("\n🎵 测试音频识别功能") | ||
| 63 | - | ||
| 64 | - # 创建测试音频文件 | ||
| 65 | - test_file = self.create_test_wav() | ||
| 66 | - test_file_path = os.path.abspath(test_file) | ||
| 67 | - | ||
| 68 | - try: | ||
| 69 | - async with websockets.connect(self.uri) as websocket: | ||
| 70 | - print("✅ 连接成功,发送音频文件路径") | ||
| 71 | - | ||
| 72 | - # 发送音频文件路径 | ||
| 73 | - message = {"url": test_file_path} | ||
| 74 | - await websocket.send(json.dumps(message)) | ||
| 75 | - print(f"📤 发送消息: {message}") | ||
| 76 | - | ||
| 77 | - # 等待识别结果 | ||
| 78 | - try: | ||
| 79 | - response = await asyncio.wait_for(websocket.recv(), timeout=10) | ||
| 80 | - print(f"📥 收到识别结果: {response}") | ||
| 81 | - return True | ||
| 82 | - except asyncio.TimeoutError: | ||
| 83 | - print("⏰ 等待响应超时(10秒)") | ||
| 84 | - print(" 这可能是正常的,因为测试音频是纯音调,无法识别为文字") | ||
| 85 | - return True # 超时也算连接成功 | ||
| 86 | - | ||
| 87 | - except Exception as e: | ||
| 88 | - print(f"❌ 音频识别测试失败: {e}") | ||
| 89 | - return False | ||
| 90 | - finally: | ||
| 91 | - # 清理测试文件 | ||
| 92 | - if os.path.exists(test_file): | ||
| 93 | - os.remove(test_file) | ||
| 94 | - print(f"🗑️ 清理测试文件: {test_file}") | ||
| 95 | - | ||
| 96 | - async def test_real_audio_files(self): | ||
| 97 | - """测试实际音频文件的识别效果""" | ||
| 98 | - print("\n🎤 测试实际音频文件识别") | ||
| 99 | - | ||
| 100 | - # 实际音频文件列表 | ||
| 101 | - audio_files = [ | ||
| 102 | - "yunxi.mp3", | ||
| 103 | - "yunxia.mp3", | ||
| 104 | - "yunyang.mp3" | ||
| 105 | - ] | ||
| 106 | - | ||
| 107 | - results = [] | ||
| 108 | - | ||
| 109 | - for audio_file in audio_files: | ||
| 110 | - file_path = os.path.abspath(audio_file) | ||
| 111 | - | ||
| 112 | - # 检查文件是否存在 | ||
| 113 | - if not os.path.exists(file_path): | ||
| 114 | - print(f"⚠️ 音频文件不存在: {file_path}") | ||
| 115 | - continue | ||
| 116 | - | ||
| 117 | - print(f"\n🎵 测试音频文件: {audio_file}") | ||
| 118 | - | ||
| 119 | - try: | ||
| 120 | - async with websockets.connect(self.uri) as websocket: | ||
| 121 | - print(f"✅ 连接成功,发送音频文件: {audio_file}") | ||
| 122 | - | ||
| 123 | - # 发送音频文件路径 | ||
| 124 | - message = {"url": file_path} | ||
| 125 | - await websocket.send(json.dumps(message)) | ||
| 126 | - print(f"📤 发送消息: {message}") | ||
| 127 | - | ||
| 128 | - # 等待识别结果 | ||
| 129 | - try: | ||
| 130 | - response = await asyncio.wait_for(websocket.recv(), timeout=30) | ||
| 131 | - print(f"📥 识别结果: {response}") | ||
| 132 | - | ||
| 133 | - # 解析响应 | ||
| 134 | - try: | ||
| 135 | - result_data = json.loads(response) | ||
| 136 | - if isinstance(result_data, dict) and 'text' in result_data: | ||
| 137 | - recognized_text = result_data['text'] | ||
| 138 | - print(f"🎯 识别文本: {recognized_text}") | ||
| 139 | - results.append({ | ||
| 140 | - 'file': audio_file, | ||
| 141 | - 'text': recognized_text, | ||
| 142 | - 'status': 'success' | ||
| 143 | - }) | ||
| 144 | - else: | ||
| 145 | - print(f"📄 原始响应: {response}") | ||
| 146 | - results.append({ | ||
| 147 | - 'file': audio_file, | ||
| 148 | - 'response': response, | ||
| 149 | - 'status': 'received' | ||
| 150 | - }) | ||
| 151 | - except json.JSONDecodeError: | ||
| 152 | - print(f"📄 非JSON响应: {response}") | ||
| 153 | - results.append({ | ||
| 154 | - 'file': audio_file, | ||
| 155 | - 'response': response, | ||
| 156 | - 'status': 'received' | ||
| 157 | - }) | ||
| 158 | - | ||
| 159 | - except asyncio.TimeoutError: | ||
| 160 | - print(f"⏰ 等待响应超时(30秒)- {audio_file}") | ||
| 161 | - results.append({ | ||
| 162 | - 'file': audio_file, | ||
| 163 | - 'status': 'timeout' | ||
| 164 | - }) | ||
| 165 | - | ||
| 166 | - except Exception as e: | ||
| 167 | - print(f"❌ 测试 {audio_file} 失败: {e}") | ||
| 168 | - results.append({ | ||
| 169 | - 'file': audio_file, | ||
| 170 | - 'error': str(e), | ||
| 171 | - 'status': 'error' | ||
| 172 | - }) | ||
| 173 | - | ||
| 174 | - # 文件间等待,避免服务器压力 | ||
| 175 | - await asyncio.sleep(1) | ||
| 176 | - | ||
| 177 | - # 输出测试总结 | ||
| 178 | - print("\n" + "="*50) | ||
| 179 | - print("📊 实际音频文件测试总结:") | ||
| 180 | - for i, result in enumerate(results, 1): | ||
| 181 | - print(f"\n{i}. 文件: {result['file']}") | ||
| 182 | - if result['status'] == 'success': | ||
| 183 | - print(f" ✅ 识别成功: {result['text']}") | ||
| 184 | - elif result['status'] == 'received': | ||
| 185 | - print(f" 📥 收到响应: {result.get('response', 'N/A')}") | ||
| 186 | - elif result['status'] == 'timeout': | ||
| 187 | - print(f" ⏰ 响应超时") | ||
| 188 | - elif result['status'] == 'error': | ||
| 189 | - print(f" ❌ 测试失败: {result.get('error', 'N/A')}") | ||
| 190 | - | ||
| 191 | - return len(results) > 0 | ||
| 192 | - | ||
| 193 | - async def test_message_format(self): | ||
| 194 | - """测试消息格式兼容性""" | ||
| 195 | - print("\n📋 测试消息格式兼容性") | ||
| 196 | - | ||
| 197 | - try: | ||
| 198 | - async with websockets.connect(self.uri) as websocket: | ||
| 199 | - # 测试不同的消息格式 | ||
| 200 | - test_messages = [ | ||
| 201 | - {"url": "nonexistent.wav"}, | ||
| 202 | - {"test": "message"}, | ||
| 203 | - "invalid_json" | ||
| 204 | - ] | ||
| 205 | - | ||
| 206 | - for i, msg in enumerate(test_messages, 1): | ||
| 207 | - try: | ||
| 208 | - if isinstance(msg, dict): | ||
| 209 | - await websocket.send(json.dumps(msg)) | ||
| 210 | - print(f"✅ 消息 {i} 发送成功: {msg}") | ||
| 211 | - else: | ||
| 212 | - await websocket.send(msg) | ||
| 213 | - print(f"✅ 消息 {i} 发送成功: {msg}") | ||
| 214 | - | ||
| 215 | - # 短暂等待,避免消息堆积 | ||
| 216 | - await asyncio.sleep(0.5) | ||
| 217 | - | ||
| 218 | - except Exception as e: | ||
| 219 | - print(f"⚠️ 消息 {i} 发送失败: {e}") | ||
| 220 | - | ||
| 221 | - return True | ||
| 222 | - | ||
| 223 | - except Exception as e: | ||
| 224 | - print(f"❌ 消息格式测试失败: {e}") | ||
| 225 | - return False | ||
| 226 | - | ||
| 227 | - def check_dependencies(self): | ||
| 228 | - """检查依赖项""" | ||
| 229 | - print("🔍 检查依赖项...") | ||
| 230 | - | ||
| 231 | - required_modules = [ | ||
| 232 | - 'websockets', | ||
| 233 | - 'asyncio', | ||
| 234 | - 'json', | ||
| 235 | - 'wave', | ||
| 236 | - 'numpy' | ||
| 237 | - ] | ||
| 238 | - | ||
| 239 | - missing_modules = [] | ||
| 240 | - for module in required_modules: | ||
| 241 | - try: | ||
| 242 | - __import__(module) | ||
| 243 | - print(f"✅ {module}") | ||
| 244 | - except ImportError: | ||
| 245 | - print(f"❌ {module} (缺失)") | ||
| 246 | - missing_modules.append(module) | ||
| 247 | - | ||
| 248 | - if missing_modules: | ||
| 249 | - print(f"\n⚠️ 缺失依赖项: {', '.join(missing_modules)}") | ||
| 250 | - print("安装命令: pip install " + ' '.join(missing_modules)) | ||
| 251 | - return False | ||
| 252 | - | ||
| 253 | - print("✅ 所有依赖项检查通过") | ||
| 254 | - return True | ||
| 255 | - | ||
| 256 | - def check_funasr_server_file(self): | ||
| 257 | - """检查FunASR服务器文件是否存在""" | ||
| 258 | - print("\n📁 检查FunASR服务器文件...") | ||
| 259 | - | ||
| 260 | - server_path = Path("web/asr/funasr/ASR_server.py") | ||
| 261 | - if server_path.exists(): | ||
| 262 | - print(f"✅ 找到服务器文件: {server_path.absolute()}") | ||
| 263 | - return True | ||
| 264 | - else: | ||
| 265 | - print(f"❌ 未找到服务器文件: {server_path.absolute()}") | ||
| 266 | - print(" 请确认文件路径是否正确") | ||
| 267 | - return False | ||
| 268 | - | ||
| 269 | - async def run_all_tests(self): | ||
| 270 | - """运行所有测试""" | ||
| 271 | - print("🚀 开始FunASR连接测试\n") | ||
| 272 | - | ||
| 273 | - # 检查依赖 | ||
| 274 | - if not self.check_dependencies(): | ||
| 275 | - return False | ||
| 276 | - | ||
| 277 | - # 检查服务器文件 | ||
| 278 | - if not self.check_funasr_server_file(): | ||
| 279 | - return False | ||
| 280 | - | ||
| 281 | - # 基本连接测试 | ||
| 282 | - print("\n" + "="*50) | ||
| 283 | - if not await self.test_basic_connection(): | ||
| 284 | - return False | ||
| 285 | - | ||
| 286 | - # 音频识别测试 | ||
| 287 | - print("\n" + "="*50) | ||
| 288 | - if not await self.test_audio_recognition(): | ||
| 289 | - return False | ||
| 290 | - | ||
| 291 | - # 实际音频文件测试 | ||
| 292 | - print("\n" + "="*50) | ||
| 293 | - await self.test_real_audio_files() | ||
| 294 | - | ||
| 295 | - # 消息格式测试 | ||
| 296 | - print("\n" + "="*50) | ||
| 297 | - if not await self.test_message_format(): | ||
| 298 | - return False | ||
| 299 | - | ||
| 300 | - print("\n" + "="*50) | ||
| 301 | - print("🎉 所有测试完成!FunASR服务连接正常") | ||
| 302 | - print("\n💡 集成建议:") | ||
| 303 | - print(" 1. 服务使用WebSocket协议,非gRPC") | ||
| 304 | - print(" 2. 默认监听端口: 10197") | ||
| 305 | - print(" 3. 消息格式: JSON字符串,包含'url'字段指向音频文件路径") | ||
| 306 | - print(" 4. 可以集成到现有项目的ASR模块中") | ||
| 307 | - | ||
| 308 | - return True | ||
| 309 | - | ||
| 310 | -async def main(): | ||
| 311 | - """主函数""" | ||
| 312 | - tester = FunASRConnectionTest() | ||
| 313 | - success = await tester.run_all_tests() | ||
| 314 | - | ||
| 315 | - if not success: | ||
| 316 | - print("\n❌ 测试失败,请检查FunASR服务状态") | ||
| 317 | - return 1 | ||
| 318 | - | ||
| 319 | - return 0 | ||
| 320 | - | ||
| 321 | -if __name__ == "__main__": | ||
| 322 | - try: | ||
| 323 | - exit_code = asyncio.run(main()) | ||
| 324 | - exit(exit_code) | ||
| 325 | - except KeyboardInterrupt: | ||
| 326 | - print("\n⏹️ 测试被用户中断") | ||
| 327 | - exit(1) | ||
| 328 | - except Exception as e: | ||
| 329 | - print(f"\n💥 测试过程中发生错误: {e}") | ||
| 330 | - exit(1) |
test_funasr_integration.py
deleted
100644 → 0
| 1 | -# -*- coding: utf-8 -*- | ||
| 2 | -""" | ||
| 3 | -AIfeng/2025-01-27 | ||
| 4 | -FunASR集成测试脚本 | ||
| 5 | -测试新的FunASRClient与项目的集成效果 | ||
| 6 | -""" | ||
| 7 | - | ||
| 8 | -import os | ||
| 9 | -import sys | ||
| 10 | -import time | ||
| 11 | -import threading | ||
| 12 | -from pathlib import Path | ||
| 13 | - | ||
| 14 | -# 添加项目路径 | ||
| 15 | -sys.path.append(os.path.dirname(__file__)) | ||
| 16 | - | ||
| 17 | -from funasr_asr import FunASRClient | ||
| 18 | -from web.asr.funasr import FunASR | ||
| 19 | -import util | ||
| 20 | - | ||
| 21 | -class TestFunASRIntegration: | ||
| 22 | - """FunASR集成测试类""" | ||
| 23 | - | ||
| 24 | - def __init__(self): | ||
| 25 | - self.test_results = [] | ||
| 26 | - self.test_audio_files = [ | ||
| 27 | - "yunxi.mp3", | ||
| 28 | - "yunxia.mp3", | ||
| 29 | - "yunyang.mp3" | ||
| 30 | - ] | ||
| 31 | - | ||
| 32 | - def log_test_result(self, test_name: str, success: bool, message: str = ""): | ||
| 33 | - """记录测试结果""" | ||
| 34 | - status = "✓ 通过" if success else "✗ 失败" | ||
| 35 | - result = f"[{status}] {test_name}" | ||
| 36 | - if message: | ||
| 37 | - result += f" - {message}" | ||
| 38 | - | ||
| 39 | - self.test_results.append((test_name, success, message)) | ||
| 40 | - print(result) | ||
| 41 | - | ||
| 42 | - def test_funasr_client_creation(self): | ||
| 43 | - """测试FunASRClient创建""" | ||
| 44 | - try: | ||
| 45 | - class SimpleOpt: | ||
| 46 | - def __init__(self): | ||
| 47 | - self.username = "test_user" | ||
| 48 | - | ||
| 49 | - opt = SimpleOpt() | ||
| 50 | - client = FunASRClient(opt) | ||
| 51 | - | ||
| 52 | - # 检查基本属性 | ||
| 53 | - assert hasattr(client, 'server_url') | ||
| 54 | - assert hasattr(client, 'connected') | ||
| 55 | - assert hasattr(client, 'running') | ||
| 56 | - | ||
| 57 | - self.log_test_result("FunASRClient创建", True, "客户端创建成功") | ||
| 58 | - return client | ||
| 59 | - | ||
| 60 | - except Exception as e: | ||
| 61 | - self.log_test_result("FunASRClient创建", False, f"错误: {e}") | ||
| 62 | - return None | ||
| 63 | - | ||
| 64 | - def test_compatibility_wrapper(self): | ||
| 65 | - """测试兼容性包装器""" | ||
| 66 | - try: | ||
| 67 | - funasr = FunASR("test_user") | ||
| 68 | - | ||
| 69 | - # 检查兼容性方法 | ||
| 70 | - assert hasattr(funasr, 'start') | ||
| 71 | - assert hasattr(funasr, 'end') | ||
| 72 | - assert hasattr(funasr, 'send') | ||
| 73 | - assert hasattr(funasr, 'add_frame') | ||
| 74 | - assert hasattr(funasr, 'set_message_callback') | ||
| 75 | - | ||
| 76 | - self.log_test_result("兼容性包装器", True, "所有兼容性方法存在") | ||
| 77 | - return funasr | ||
| 78 | - | ||
| 79 | - except Exception as e: | ||
| 80 | - self.log_test_result("兼容性包装器", False, f"错误: {e}") | ||
| 81 | - return None | ||
| 82 | - | ||
| 83 | - def test_callback_mechanism(self): | ||
| 84 | - """测试回调机制""" | ||
| 85 | - try: | ||
| 86 | - funasr = FunASR("test_user") | ||
| 87 | - callback_called = threading.Event() | ||
| 88 | - received_message = [] | ||
| 89 | - | ||
| 90 | - def test_callback(message): | ||
| 91 | - received_message.append(message) | ||
| 92 | - callback_called.set() | ||
| 93 | - | ||
| 94 | - funasr.set_message_callback(test_callback) | ||
| 95 | - | ||
| 96 | - # 模拟接收消息 | ||
| 97 | - test_message = "测试识别结果" | ||
| 98 | - funasr._handle_result(test_message) | ||
| 99 | - | ||
| 100 | - # 等待回调 | ||
| 101 | - if callback_called.wait(timeout=1.0): | ||
| 102 | - if received_message and received_message[0] == test_message: | ||
| 103 | - self.log_test_result("回调机制", True, "回调函数正常工作") | ||
| 104 | - else: | ||
| 105 | - self.log_test_result("回调机制", False, "回调消息不匹配") | ||
| 106 | - else: | ||
| 107 | - self.log_test_result("回调机制", False, "回调超时") | ||
| 108 | - | ||
| 109 | - except Exception as e: | ||
| 110 | - self.log_test_result("回调机制", False, f"错误: {e}") | ||
| 111 | - | ||
| 112 | - def test_audio_file_existence(self): | ||
| 113 | - """测试音频文件存在性""" | ||
| 114 | - existing_files = [] | ||
| 115 | - missing_files = [] | ||
| 116 | - | ||
| 117 | - for audio_file in self.test_audio_files: | ||
| 118 | - if os.path.exists(audio_file): | ||
| 119 | - existing_files.append(audio_file) | ||
| 120 | - else: | ||
| 121 | - missing_files.append(audio_file) | ||
| 122 | - | ||
| 123 | - if existing_files: | ||
| 124 | - self.log_test_result( | ||
| 125 | - "音频文件检查", | ||
| 126 | - True, | ||
| 127 | - f"找到 {len(existing_files)} 个文件: {', '.join(existing_files)}" | ||
| 128 | - ) | ||
| 129 | - | ||
| 130 | - if missing_files: | ||
| 131 | - self.log_test_result( | ||
| 132 | - "音频文件缺失", | ||
| 133 | - False, | ||
| 134 | - f"缺少 {len(missing_files)} 个文件: {', '.join(missing_files)}" | ||
| 135 | - ) | ||
| 136 | - | ||
| 137 | - return existing_files | ||
| 138 | - | ||
| 139 | - def test_connection_simulation(self): | ||
| 140 | - """测试连接模拟""" | ||
| 141 | - try: | ||
| 142 | - client = self.test_funasr_client_creation() | ||
| 143 | - if not client: | ||
| 144 | - return | ||
| 145 | - | ||
| 146 | - # 测试启动和停止 | ||
| 147 | - client.start() | ||
| 148 | - time.sleep(0.5) # 给连接一些时间 | ||
| 149 | - | ||
| 150 | - # 检查运行状态 | ||
| 151 | - if client.running: | ||
| 152 | - self.log_test_result("客户端启动", True, "客户端成功启动") | ||
| 153 | - else: | ||
| 154 | - self.log_test_result("客户端启动", False, "客户端启动失败") | ||
| 155 | - | ||
| 156 | - # 停止客户端 | ||
| 157 | - client.stop() | ||
| 158 | - time.sleep(0.5) | ||
| 159 | - | ||
| 160 | - if not client.running: | ||
| 161 | - self.log_test_result("客户端停止", True, "客户端成功停止") | ||
| 162 | - else: | ||
| 163 | - self.log_test_result("客户端停止", False, "客户端停止失败") | ||
| 164 | - | ||
| 165 | - except Exception as e: | ||
| 166 | - self.log_test_result("连接模拟", False, f"错误: {e}") | ||
| 167 | - | ||
| 168 | - def test_message_queue(self): | ||
| 169 | - """测试消息队列""" | ||
| 170 | - try: | ||
| 171 | - client = self.test_funasr_client_creation() | ||
| 172 | - if not client: | ||
| 173 | - return | ||
| 174 | - | ||
| 175 | - # 测试消息入队 | ||
| 176 | - test_message = {"test": "message"} | ||
| 177 | - client.message_queue.put(test_message) | ||
| 178 | - | ||
| 179 | - # 检查队列 | ||
| 180 | - if not client.message_queue.empty(): | ||
| 181 | - retrieved_message = client.message_queue.get_nowait() | ||
| 182 | - if retrieved_message == test_message: | ||
| 183 | - self.log_test_result("消息队列", True, "消息队列正常工作") | ||
| 184 | - else: | ||
| 185 | - self.log_test_result("消息队列", False, "消息内容不匹配") | ||
| 186 | - else: | ||
| 187 | - self.log_test_result("消息队列", False, "消息队列为空") | ||
| 188 | - | ||
| 189 | - except Exception as e: | ||
| 190 | - self.log_test_result("消息队列", False, f"错误: {e}") | ||
| 191 | - | ||
| 192 | - def test_config_loading(self): | ||
| 193 | - """测试配置加载""" | ||
| 194 | - try: | ||
| 195 | - import config_util as cfg | ||
| 196 | - | ||
| 197 | - # 检查关键配置项 | ||
| 198 | - required_configs = [ | ||
| 199 | - 'local_asr_ip', | ||
| 200 | - 'local_asr_port', | ||
| 201 | - 'asr_timeout', | ||
| 202 | - 'asr_reconnect_delay', | ||
| 203 | - 'asr_max_reconnect_attempts' | ||
| 204 | - ] | ||
| 205 | - | ||
| 206 | - missing_configs = [] | ||
| 207 | - for config_key in required_configs: | ||
| 208 | - try: | ||
| 209 | - if hasattr(cfg, 'config'): | ||
| 210 | - value = cfg.config.get(config_key) | ||
| 211 | - else: | ||
| 212 | - value = getattr(cfg, config_key, None) | ||
| 213 | - if value is None: | ||
| 214 | - missing_configs.append(config_key) | ||
| 215 | - except: | ||
| 216 | - missing_configs.append(config_key) | ||
| 217 | - | ||
| 218 | - if not missing_configs: | ||
| 219 | - self.log_test_result("配置加载", True, "所有必需配置项存在") | ||
| 220 | - else: | ||
| 221 | - self.log_test_result( | ||
| 222 | - "配置加载", | ||
| 223 | - False, | ||
| 224 | - f"缺少配置项: {', '.join(missing_configs)}" | ||
| 225 | - ) | ||
| 226 | - | ||
| 227 | - except Exception as e: | ||
| 228 | - self.log_test_result("配置加载", False, f"错误: {e}") | ||
| 229 | - | ||
| 230 | - def run_all_tests(self): | ||
| 231 | - """运行所有测试""" | ||
| 232 | - print("\n" + "="*60) | ||
| 233 | - print("FunASR集成测试开始") | ||
| 234 | - print("="*60) | ||
| 235 | - | ||
| 236 | - # 运行各项测试 | ||
| 237 | - self.test_config_loading() | ||
| 238 | - self.test_funasr_client_creation() | ||
| 239 | - self.test_compatibility_wrapper() | ||
| 240 | - self.test_callback_mechanism() | ||
| 241 | - self.test_message_queue() | ||
| 242 | - self.test_audio_file_existence() | ||
| 243 | - self.test_connection_simulation() | ||
| 244 | - | ||
| 245 | - # 输出测试总结 | ||
| 246 | - print("\n" + "="*60) | ||
| 247 | - print("测试总结") | ||
| 248 | - print("="*60) | ||
| 249 | - | ||
| 250 | - passed_tests = sum(1 for _, success, _ in self.test_results if success) | ||
| 251 | - total_tests = len(self.test_results) | ||
| 252 | - | ||
| 253 | - print(f"总测试数: {total_tests}") | ||
| 254 | - print(f"通过测试: {passed_tests}") | ||
| 255 | - print(f"失败测试: {total_tests - passed_tests}") | ||
| 256 | - print(f"成功率: {passed_tests/total_tests*100:.1f}%") | ||
| 257 | - | ||
| 258 | - # 显示失败的测试 | ||
| 259 | - failed_tests = [(name, msg) for name, success, msg in self.test_results if not success] | ||
| 260 | - if failed_tests: | ||
| 261 | - print("\n失败的测试:") | ||
| 262 | - for name, msg in failed_tests: | ||
| 263 | - print(f" - {name}: {msg}") | ||
| 264 | - | ||
| 265 | - print("\n" + "="*60) | ||
| 266 | - | ||
| 267 | - return passed_tests == total_tests | ||
| 268 | - | ||
| 269 | -def main(): | ||
| 270 | - """主函数""" | ||
| 271 | - tester = TestFunASRIntegration() | ||
| 272 | - success = tester.run_all_tests() | ||
| 273 | - | ||
| 274 | - if success: | ||
| 275 | - print("\n🎉 所有测试通过!FunASR集成准备就绪。") | ||
| 276 | - else: | ||
| 277 | - print("\n⚠️ 部分测试失败,请检查相关配置和依赖。") | ||
| 278 | - | ||
| 279 | - return 0 if success else 1 | ||
| 280 | - | ||
| 281 | -if __name__ == "__main__": | ||
| 282 | - exit(main()) |
test_websocket_server.py
deleted
100644 → 0
| 1 | -#!/usr/bin/env python3 | ||
| 2 | -# AIfeng/2024-12-19 | ||
| 3 | -# WebSocket通信测试服务器 | ||
| 4 | - | ||
| 5 | -import asyncio | ||
| 6 | -import json | ||
| 7 | -import time | ||
| 8 | -import weakref | ||
| 9 | -from aiohttp import web, WSMsgType | ||
| 10 | -import aiohttp_cors | ||
| 11 | -from typing import Dict | ||
| 12 | - | ||
| 13 | -# 全局变量 | ||
| 14 | -websocket_connections: Dict[int, weakref.WeakSet] = {} # sessionid:websocket_connections | ||
| 15 | - | ||
| 16 | -# WebSocket消息推送函数 | ||
| 17 | -async def broadcast_message_to_session(sessionid: int, message_type: str, content: str, source: str = "测试服务器"): | ||
| 18 | - """向指定会话的所有WebSocket连接推送消息""" | ||
| 19 | - if sessionid not in websocket_connections: | ||
| 20 | - print(f'[SessionID:{sessionid}] No WebSocket connections found') | ||
| 21 | - return | ||
| 22 | - | ||
| 23 | - message = { | ||
| 24 | - "type": "chat_message", | ||
| 25 | - "data": { | ||
| 26 | - "sessionid": sessionid, | ||
| 27 | - "message_type": message_type, | ||
| 28 | - "content": content, | ||
| 29 | - "source": source, | ||
| 30 | - "timestamp": time.time() | ||
| 31 | - } | ||
| 32 | - } | ||
| 33 | - | ||
| 34 | - # 获取该会话的所有WebSocket连接 | ||
| 35 | - connections = list(websocket_connections[sessionid]) | ||
| 36 | - print(f'[SessionID:{sessionid}] Broadcasting to {len(connections)} connections') | ||
| 37 | - | ||
| 38 | - # 向所有连接发送消息 | ||
| 39 | - for ws in connections: | ||
| 40 | - try: | ||
| 41 | - if not ws.closed: | ||
| 42 | - await ws.send_str(json.dumps(message)) | ||
| 43 | - print(f'[SessionID:{sessionid}] Message sent to WebSocket: {message_type}') | ||
| 44 | - except Exception as e: | ||
| 45 | - print(f'[SessionID:{sessionid}] Failed to send WebSocket message: {e}') | ||
| 46 | - | ||
| 47 | -# WebSocket处理器 | ||
| 48 | -async def websocket_handler(request): | ||
| 49 | - """处理WebSocket连接""" | ||
| 50 | - ws = web.WebSocketResponse() | ||
| 51 | - await ws.prepare(request) | ||
| 52 | - | ||
| 53 | - sessionid = None | ||
| 54 | - print('New WebSocket connection established') | ||
| 55 | - | ||
| 56 | - try: | ||
| 57 | - async for msg in ws: | ||
| 58 | - if msg.type == WSMsgType.TEXT: | ||
| 59 | - try: | ||
| 60 | - data = json.loads(msg.data) | ||
| 61 | - print(f'Received WebSocket message: {data}') | ||
| 62 | - | ||
| 63 | - if data.get('type') == 'login': | ||
| 64 | - sessionid = data.get('sessionid', 0) | ||
| 65 | - | ||
| 66 | - # 初始化该会话的WebSocket连接集合 | ||
| 67 | - if sessionid not in websocket_connections: | ||
| 68 | - websocket_connections[sessionid] = weakref.WeakSet() | ||
| 69 | - | ||
| 70 | - # 添加当前连接到会话 | ||
| 71 | - websocket_connections[sessionid].add(ws) | ||
| 72 | - | ||
| 73 | - print(f'[SessionID:{sessionid}] WebSocket client logged in') | ||
| 74 | - | ||
| 75 | - # 发送登录确认 | ||
| 76 | - await ws.send_str(json.dumps({ | ||
| 77 | - "type": "login_success", | ||
| 78 | - "sessionid": sessionid, | ||
| 79 | - "message": "WebSocket连接成功" | ||
| 80 | - })) | ||
| 81 | - | ||
| 82 | - elif data.get('type') == 'ping': | ||
| 83 | - # 心跳检测 | ||
| 84 | - await ws.send_str(json.dumps({"type": "pong"})) | ||
| 85 | - print('Sent pong response') | ||
| 86 | - | ||
| 87 | - except json.JSONDecodeError: | ||
| 88 | - print('Invalid JSON received from WebSocket') | ||
| 89 | - except Exception as e: | ||
| 90 | - print(f'Error processing WebSocket message: {e}') | ||
| 91 | - | ||
| 92 | - elif msg.type == WSMsgType.ERROR: | ||
| 93 | - print(f'WebSocket error: {ws.exception()}') | ||
| 94 | - break | ||
| 95 | - | ||
| 96 | - except Exception as e: | ||
| 97 | - print(f'WebSocket connection error: {e}') | ||
| 98 | - finally: | ||
| 99 | - if sessionid is not None: | ||
| 100 | - print(f'[SessionID:{sessionid}] WebSocket connection closed') | ||
| 101 | - else: | ||
| 102 | - print('WebSocket connection closed') | ||
| 103 | - | ||
| 104 | - return ws | ||
| 105 | - | ||
| 106 | -# 模拟human接口 | ||
| 107 | -async def human(request): | ||
| 108 | - try: | ||
| 109 | - params = await request.json() | ||
| 110 | - sessionid = params.get('sessionid', 0) | ||
| 111 | - user_message = params.get('text', '') | ||
| 112 | - message_type = params.get('type', 'echo') | ||
| 113 | - | ||
| 114 | - print(f'[SessionID:{sessionid}] Received {message_type} message: {user_message}') | ||
| 115 | - | ||
| 116 | - # 推送用户消息到WebSocket | ||
| 117 | - await broadcast_message_to_session(sessionid, message_type, user_message, "用户") | ||
| 118 | - | ||
| 119 | - if message_type == 'echo': | ||
| 120 | - # 推送回音消息到WebSocket | ||
| 121 | - await broadcast_message_to_session(sessionid, 'echo', user_message, "回音") | ||
| 122 | - | ||
| 123 | - elif message_type == 'chat': | ||
| 124 | - # 模拟AI回复 | ||
| 125 | - ai_response = f"这是对 '{user_message}' 的AI回复" | ||
| 126 | - await broadcast_message_to_session(sessionid, 'chat', ai_response, "AI助手") | ||
| 127 | - | ||
| 128 | - return web.Response( | ||
| 129 | - content_type="application/json", | ||
| 130 | - text=json.dumps( | ||
| 131 | - {"code": 0, "data": "ok", "message": "消息已处理并推送"} | ||
| 132 | - ), | ||
| 133 | - ) | ||
| 134 | - except Exception as e: | ||
| 135 | - print(f'Error in human endpoint: {e}') | ||
| 136 | - return web.Response( | ||
| 137 | - content_type="application/json", | ||
| 138 | - text=json.dumps( | ||
| 139 | - {"code": -1, "msg": str(e)} | ||
| 140 | - ), | ||
| 141 | - ) | ||
| 142 | - | ||
| 143 | -# 创建应用 | ||
| 144 | -def create_app(): | ||
| 145 | - app = web.Application() | ||
| 146 | - | ||
| 147 | - # 添加路由 | ||
| 148 | - app.router.add_post("/human", human) | ||
| 149 | - app.router.add_get("/ws", websocket_handler) | ||
| 150 | - app.router.add_static('/', path='web') | ||
| 151 | - | ||
| 152 | - # 配置CORS | ||
| 153 | - cors = aiohttp_cors.setup(app, defaults={ | ||
| 154 | - "*": aiohttp_cors.ResourceOptions( | ||
| 155 | - allow_credentials=True, | ||
| 156 | - expose_headers="*", | ||
| 157 | - allow_headers="*", | ||
| 158 | - ) | ||
| 159 | - }) | ||
| 160 | - | ||
| 161 | - # 为所有路由配置CORS | ||
| 162 | - for route in list(app.router.routes()): | ||
| 163 | - cors.add(route) | ||
| 164 | - | ||
| 165 | - return app | ||
| 166 | - | ||
| 167 | -if __name__ == '__main__': | ||
| 168 | - app = create_app() | ||
| 169 | - print('Starting WebSocket test server on http://localhost:8000') | ||
| 170 | - print('WebSocket endpoint: ws://localhost:8000/ws') | ||
| 171 | - print('HTTP endpoint: http://localhost:8000/human') | ||
| 172 | - print('Test page: http://localhost:8000/websocket_test.html') | ||
| 173 | - | ||
| 174 | - web.run_app(app, host='0.0.0.0', port=8000) |
| @@ -9,7 +9,7 @@ import _thread as thread | @@ -9,7 +9,7 @@ import _thread as thread | ||
| 9 | from aliyunsdkcore.client import AcsClient | 9 | from aliyunsdkcore.client import AcsClient |
| 10 | from aliyunsdkcore.request import CommonRequest | 10 | from aliyunsdkcore.request import CommonRequest |
| 11 | 11 | ||
| 12 | -from core import wsa_server | 12 | +from core import get_web_instance, get_instance |
| 13 | from scheduler.thread_manager import MyThread | 13 | from scheduler.thread_manager import MyThread |
| 14 | from utils import util | 14 | from utils import util |
| 15 | from utils import config_util as cfg | 15 | from utils import config_util as cfg |
| @@ -92,19 +92,37 @@ class ALiNls: | @@ -92,19 +92,37 @@ class ALiNls: | ||
| 92 | if name == 'SentenceEnd': | 92 | if name == 'SentenceEnd': |
| 93 | self.done = True | 93 | self.done = True |
| 94 | self.finalResults = data['payload']['result'] | 94 | self.finalResults = data['payload']['result'] |
| 95 | - if wsa_server.get_web_instance().is_connected(self.username): | ||
| 96 | - wsa_server.get_web_instance().add_cmd({"panelMsg": self.finalResults, "Username" : self.username}) | ||
| 97 | - if wsa_server.get_instance().is_connected_human(self.username): | ||
| 98 | - content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': self.finalResults}, 'Username' : self.username} | ||
| 99 | - wsa_server.get_instance().add_cmd(content) | 95 | + if get_web_instance().is_connected(self.username): |
| 96 | + import asyncio | ||
| 97 | + # 创建chat_message直接推送 | ||
| 98 | + chat_message = { | ||
| 99 | + "type": "chat_message", | ||
| 100 | + "sender": "回音", | ||
| 101 | + "content": self.finalResults, # 修复字段名:panelMsg -> content | ||
| 102 | + "Username": self.username, | ||
| 103 | + "model_info": "ALiNls" | ||
| 104 | + } | ||
| 105 | + # 使用直接发送方法,避免wsa_command封装 | ||
| 106 | + asyncio.create_task(get_web_instance().send_direct_message(chat_message)) | ||
| 107 | + # Human客户端通知改为日志记录(避免重复通知当前服务) | ||
| 108 | + util.log(1, f"ALiNls识别结果[{self.username}]: {self.finalResults}") | ||
| 100 | ws.close()#TODO | 109 | ws.close()#TODO |
| 101 | elif name == 'TranscriptionResultChanged': | 110 | elif name == 'TranscriptionResultChanged': |
| 102 | self.finalResults = data['payload']['result'] | 111 | self.finalResults = data['payload']['result'] |
| 103 | - if wsa_server.get_web_instance().is_connected(self.username): | ||
| 104 | - wsa_server.get_web_instance().add_cmd({"panelMsg": self.finalResults, "Username" : self.username}) | ||
| 105 | - if wsa_server.get_instance().is_connected_human(self.username): | ||
| 106 | - content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': self.finalResults}, 'Username' : self.username} | ||
| 107 | - wsa_server.get_instance().add_cmd(content) | 112 | + if get_web_instance().is_connected(self.username): |
| 113 | + import asyncio | ||
| 114 | + # 创建chat_message直接推送 | ||
| 115 | + chat_message = { | ||
| 116 | + "type": "chat_message", | ||
| 117 | + "sender": "回音", | ||
| 118 | + "content": self.finalResults, # 修复字段名:panelMsg -> content | ||
| 119 | + "Username": self.username, | ||
| 120 | + "model_info": "ALiNls" | ||
| 121 | + } | ||
| 122 | + # 使用直接发送方法,避免wsa_command封装 | ||
| 123 | + asyncio.create_task(get_web_instance().send_direct_message(chat_message)) | ||
| 124 | + # Human客户端通知改为日志记录(避免重复通知当前服务) | ||
| 125 | + util.log(1, f"ALiNls识别变化[{self.username}]: {self.finalResults}") | ||
| 108 | 126 | ||
| 109 | except Exception as e: | 127 | except Exception as e: |
| 110 | print(e) | 128 | print(e) |
| @@ -14,13 +14,13 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) | @@ -14,13 +14,13 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) | ||
| 14 | from funasr_asr import FunASRClient | 14 | from funasr_asr import FunASRClient |
| 15 | # 修复导入路径 | 15 | # 修复导入路径 |
| 16 | try: | 16 | try: |
| 17 | - from core import wsa_server | 17 | + from core import get_web_instance, get_instance |
| 18 | except ImportError: | 18 | except ImportError: |
| 19 | - # 如果core模块不存在,创建一个模拟的wsa_server | ||
| 20 | - class MockWSAServer: | ||
| 21 | - def get_web_instance(self): | 19 | + # 如果core模块不存在,创建一个模拟的函数 |
| 20 | + def get_web_instance(): | ||
| 22 | return MockWebInstance() | 21 | return MockWebInstance() |
| 23 | - def get_instance(self): | 22 | + |
| 23 | + def get_instance(): | ||
| 24 | return MockInstance() | 24 | return MockInstance() |
| 25 | 25 | ||
| 26 | class MockWebInstance: | 26 | class MockWebInstance: |
| @@ -35,8 +35,6 @@ except ImportError: | @@ -35,8 +35,6 @@ except ImportError: | ||
| 35 | def add_cmd(self, cmd): | 35 | def add_cmd(self, cmd): |
| 36 | print(f"Mock Human: {cmd}") | 36 | print(f"Mock Human: {cmd}") |
| 37 | 37 | ||
| 38 | - wsa_server = MockWSAServer() | ||
| 39 | - | ||
| 40 | try: | 38 | try: |
| 41 | from utils import config_util as cfg | 39 | from utils import config_util as cfg |
| 42 | except ImportError: | 40 | except ImportError: |
| @@ -92,11 +90,20 @@ class FunASR: | @@ -92,11 +90,20 @@ class FunASR: | ||
| 92 | if self.on_message_callback: | 90 | if self.on_message_callback: |
| 93 | self.on_message_callback(message) | 91 | self.on_message_callback(message) |
| 94 | 92 | ||
| 95 | - if wsa_server.get_web_instance().is_connected(self.username): | ||
| 96 | - wsa_server.get_web_instance().add_cmd({"panelMsg": self.finalResults, "Username" : self.username}) | ||
| 97 | - if wsa_server.get_instance().is_connected_human(self.username): | ||
| 98 | - content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': self.finalResults}, 'Username' : self.username} | ||
| 99 | - wsa_server.get_instance().add_cmd(content) | 93 | + if get_web_instance().is_connected(self.username): |
| 94 | + import asyncio | ||
| 95 | + # 创建chat_message直接推送 | ||
| 96 | + chat_message = { | ||
| 97 | + "type": "chat_message", | ||
| 98 | + "sender": "回音", | ||
| 99 | + "content": self.finalResults, # 修复字段名:panelMsg -> content | ||
| 100 | + "Username": self.username, | ||
| 101 | + "model_info": "FunASR" | ||
| 102 | + } | ||
| 103 | + # 使用直接发送方法,避免wsa_command封装 | ||
| 104 | + asyncio.create_task(get_web_instance().send_direct_message(chat_message)) | ||
| 105 | + # Human客户端通知改为日志记录(避免重复通知当前服务) | ||
| 106 | + util.log(1, f"FunASR识别结果[{self.username}]: {self.finalResults}") | ||
| 100 | except Exception as e: | 107 | except Exception as e: |
| 101 | print(e) | 108 | print(e) |
| 102 | 109 |
| @@ -20,25 +20,42 @@ args = parser.parse_args() | @@ -20,25 +20,42 @@ args = parser.parse_args() | ||
| 20 | 20 | ||
| 21 | # 初始化模型 | 21 | # 初始化模型 |
| 22 | print("model loading") | 22 | print("model loading") |
| 23 | -asr_model = AutoModel(model="paraformer-zh", model_revision="v2.0.4", | 23 | +try: |
| 24 | + asr_model = AutoModel(model="paraformer-zh", model_revision="v2.0.4", | ||
| 24 | vad_model="fsmn-vad", vad_model_revision="v2.0.4", | 25 | vad_model="fsmn-vad", vad_model_revision="v2.0.4", |
| 25 | punc_model="ct-punc-c", punc_model_revision="v2.0.4", | 26 | punc_model="ct-punc-c", punc_model_revision="v2.0.4", |
| 26 | device=f"cuda:{args.gpu_id}" if args.ngpu else "cpu", disable_update=True) | 27 | device=f"cuda:{args.gpu_id}" if args.ngpu else "cpu", disable_update=True) |
| 27 | # ,disable_update=True | 28 | # ,disable_update=True |
| 28 | -print("model loaded") | 29 | + print("model loaded") |
| 30 | +except Exception as e: | ||
| 31 | + print(f"模型加载失败: {e}") | ||
| 32 | + import traceback | ||
| 33 | + traceback.print_exc() | ||
| 34 | + exit(1) | ||
| 29 | websocket_users = {} | 35 | websocket_users = {} |
| 30 | task_queue = asyncio.Queue() | 36 | task_queue = asyncio.Queue() |
| 37 | +# 分块会话管理 | ||
| 38 | +chunk_sessions = {} # {user_id: {filename, chunks, total_chunks, received_chunks, temp_file}} | ||
| 31 | 39 | ||
| 32 | async def ws_serve(websocket, path): | 40 | async def ws_serve(websocket, path): |
| 33 | - global websocket_users | 41 | + global websocket_users, chunk_sessions |
| 34 | user_id = id(websocket) | 42 | user_id = id(websocket) |
| 35 | websocket_users[user_id] = websocket | 43 | websocket_users[user_id] = websocket |
| 36 | try: | 44 | try: |
| 37 | async for message in websocket: | 45 | async for message in websocket: |
| 38 | if isinstance(message, str): | 46 | if isinstance(message, str): |
| 39 | data = json.loads(message) | 47 | data = json.loads(message) |
| 40 | - if 'url' in data: | ||
| 41 | - await task_queue.put((websocket, data['url'])) | 48 | + |
| 49 | + # 处理分块协议 | ||
| 50 | + if 'type' in data: | ||
| 51 | + await handle_chunked_protocol(websocket, data, user_id) | ||
| 52 | + # 处理传统协议 | ||
| 53 | + elif 'url' in data: | ||
| 54 | + # 处理文件URL | ||
| 55 | + await task_queue.put((websocket, data['url'], 'url')) | ||
| 56 | + elif 'audio_data' in data: | ||
| 57 | + # 处理音频数据 | ||
| 58 | + await task_queue.put((websocket, data, 'audio_data')) | ||
| 42 | except websockets.exceptions.ConnectionClosed as e: | 59 | except websockets.exceptions.ConnectionClosed as e: |
| 43 | logger.info(f"Connection closed: {e.reason}") | 60 | logger.info(f"Connection closed: {e.reason}") |
| 44 | except Exception as e: | 61 | except Exception as e: |
| @@ -47,14 +64,28 @@ async def ws_serve(websocket, path): | @@ -47,14 +64,28 @@ async def ws_serve(websocket, path): | ||
| 47 | logger.info(f"Cleaning up connection for user {user_id}") | 64 | logger.info(f"Cleaning up connection for user {user_id}") |
| 48 | if user_id in websocket_users: | 65 | if user_id in websocket_users: |
| 49 | del websocket_users[user_id] | 66 | del websocket_users[user_id] |
| 67 | + # 清理分块会话 | ||
| 68 | + if user_id in chunk_sessions: | ||
| 69 | + await cleanup_chunk_session(user_id) | ||
| 50 | await websocket.close() | 70 | await websocket.close() |
| 51 | logger.info("WebSocket closed") | 71 | logger.info("WebSocket closed") |
| 52 | 72 | ||
| 53 | async def worker(): | 73 | async def worker(): |
| 54 | while True: | 74 | while True: |
| 55 | - websocket, url = await task_queue.get() | 75 | + task_data = await task_queue.get() |
| 76 | + websocket = task_data[0] | ||
| 77 | + | ||
| 56 | if websocket.open: | 78 | if websocket.open: |
| 57 | - await process_wav_file(websocket, url) | 79 | + if len(task_data) == 3: # 新格式: (websocket, data, type) |
| 80 | + data, data_type = task_data[1], task_data[2] | ||
| 81 | + if data_type == 'url': | ||
| 82 | + await process_wav_file(websocket, data) | ||
| 83 | + elif data_type == 'audio_data': | ||
| 84 | + await process_audio_data(websocket, data) | ||
| 85 | + elif data_type == 'chunked_audio': | ||
| 86 | + await process_chunked_audio(websocket, data) | ||
| 87 | + else: # 兼容旧格式: (websocket, url) | ||
| 88 | + await process_wav_file(websocket, task_data[1]) | ||
| 58 | else: | 89 | else: |
| 59 | logger.info("WebSocket connection is already closed when trying to process file") | 90 | logger.info("WebSocket connection is already closed when trying to process file") |
| 60 | task_queue.task_done() | 91 | task_queue.task_done() |
| @@ -77,8 +108,226 @@ async def process_wav_file(websocket, url): | @@ -77,8 +108,226 @@ async def process_wav_file(websocket, url): | ||
| 77 | except Exception as e: | 108 | except Exception as e: |
| 78 | print(f"Error during model.generate: {e}") | 109 | print(f"Error during model.generate: {e}") |
| 79 | finally: | 110 | finally: |
| 80 | - if os.path.exists(wav_path): | ||
| 81 | - os.remove(wav_path) | 111 | + # 注释掉文件删除操作,保留缓存文件用于测试 |
| 112 | + # if os.path.exists(wav_path): | ||
| 113 | + # os.remove(wav_path) | ||
| 114 | + print(f"保留音频文件用于测试: {wav_path}") | ||
| 115 | + | ||
| 116 | +async def handle_chunked_protocol(websocket, data, user_id): | ||
| 117 | + """处理分块协议消息""" | ||
| 118 | + global chunk_sessions | ||
| 119 | + | ||
| 120 | + try: | ||
| 121 | + msg_type = data.get('type') | ||
| 122 | + filename = data.get('filename', 'unknown.wav') | ||
| 123 | + | ||
| 124 | + if msg_type == 'audio_start': | ||
| 125 | + # 开始新的分块会话 | ||
| 126 | + total_chunks = data.get('total_chunks', 0) | ||
| 127 | + total_size = data.get('total_size', 0) | ||
| 128 | + | ||
| 129 | + print(f"开始接收分块音频: {filename}, 总分块数: {total_chunks}, 总大小: {total_size} bytes") | ||
| 130 | + | ||
| 131 | + # 创建临时文件 | ||
| 132 | + import tempfile | ||
| 133 | + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') | ||
| 134 | + | ||
| 135 | + chunk_sessions[user_id] = { | ||
| 136 | + 'filename': filename, | ||
| 137 | + 'total_chunks': total_chunks, | ||
| 138 | + 'total_size': total_size, | ||
| 139 | + 'received_chunks': 0, | ||
| 140 | + 'temp_file': temp_file, | ||
| 141 | + 'temp_path': temp_file.name, | ||
| 142 | + 'chunks_data': {} # {chunk_index: chunk_data} | ||
| 143 | + } | ||
| 144 | + | ||
| 145 | + await websocket.send(json.dumps({"status": "ready", "message": f"准备接收 {total_chunks} 个分块"})) | ||
| 146 | + | ||
| 147 | + elif msg_type == 'audio_chunk': | ||
| 148 | + # 接收音频分块 | ||
| 149 | + if user_id not in chunk_sessions: | ||
| 150 | + await websocket.send(json.dumps({"error": "未找到分块会话,请先发送audio_start"})) | ||
| 151 | + return | ||
| 152 | + | ||
| 153 | + session = chunk_sessions[user_id] | ||
| 154 | + chunk_index = data.get('chunk_index', -1) | ||
| 155 | + chunk_data = data.get('chunk_data', '') | ||
| 156 | + is_last = data.get('is_last', False) | ||
| 157 | + | ||
| 158 | + if chunk_index >= 0 and chunk_data: | ||
| 159 | + # 解码并存储分块数据 | ||
| 160 | + import base64 | ||
| 161 | + chunk_bytes = base64.b64decode(chunk_data) | ||
| 162 | + session['chunks_data'][chunk_index] = chunk_bytes | ||
| 163 | + session['received_chunks'] += 1 | ||
| 164 | + | ||
| 165 | + # 进度反馈 | ||
| 166 | + progress = (session['received_chunks'] / session['total_chunks']) * 100 | ||
| 167 | + if session['received_chunks'] % 10 == 0 or is_last: | ||
| 168 | + print(f"接收进度: {progress:.1f}% ({session['received_chunks']}/{session['total_chunks']})") | ||
| 169 | + | ||
| 170 | + elif msg_type == 'audio_end': | ||
| 171 | + # 完成分块接收,重组音频 | ||
| 172 | + if user_id not in chunk_sessions: | ||
| 173 | + await websocket.send(json.dumps({"error": "未找到分块会话"})) | ||
| 174 | + return | ||
| 175 | + | ||
| 176 | + session = chunk_sessions[user_id] | ||
| 177 | + | ||
| 178 | + # 检查是否接收完整 | ||
| 179 | + if session['received_chunks'] != session['total_chunks']: | ||
| 180 | + await websocket.send(json.dumps({ | ||
| 181 | + "error": f"分块不完整: 期望{session['total_chunks']}, 实际{session['received_chunks']}" | ||
| 182 | + })) | ||
| 183 | + await cleanup_chunk_session(user_id) | ||
| 184 | + return | ||
| 185 | + | ||
| 186 | + # 按顺序重组音频数据 | ||
| 187 | + print(f"重组音频文件: {session['filename']}") | ||
| 188 | + with open(session['temp_path'], 'wb') as f: | ||
| 189 | + for i in range(session['total_chunks']): | ||
| 190 | + if i in session['chunks_data']: | ||
| 191 | + f.write(session['chunks_data'][i]) | ||
| 192 | + else: | ||
| 193 | + print(f"警告: 分块 {i} 缺失") | ||
| 194 | + | ||
| 195 | + # 提交到处理队列 | ||
| 196 | + reconstructed_data = { | ||
| 197 | + 'audio_file_path': session['temp_path'], | ||
| 198 | + 'filename': session['filename'] | ||
| 199 | + } | ||
| 200 | + await task_queue.put((websocket, reconstructed_data, 'chunked_audio')) | ||
| 201 | + | ||
| 202 | + # 清理会话(保留临时文件给处理函数) | ||
| 203 | + del chunk_sessions[user_id] | ||
| 204 | + print(f"分块音频重组完成: {session['filename']}") | ||
| 205 | + | ||
| 206 | + except Exception as e: | ||
| 207 | + print(f"处理分块协议时出错: {e}") | ||
| 208 | + await websocket.send(json.dumps({"error": f"分块处理错误: {str(e)}"})) | ||
| 209 | + if user_id in chunk_sessions: | ||
| 210 | + await cleanup_chunk_session(user_id) | ||
| 211 | + | ||
| 212 | +async def cleanup_chunk_session(user_id): | ||
| 213 | + """清理分块会话""" | ||
| 214 | + global chunk_sessions | ||
| 215 | + | ||
| 216 | + if user_id in chunk_sessions: | ||
| 217 | + session = chunk_sessions[user_id] | ||
| 218 | + try: | ||
| 219 | + # 关闭并删除临时文件 | ||
| 220 | + if 'temp_file' in session: | ||
| 221 | + session['temp_file'].close() | ||
| 222 | + if 'temp_path' in session and os.path.exists(session['temp_path']): | ||
| 223 | + os.remove(session['temp_path']) | ||
| 224 | + print(f"清理临时文件: {session['temp_path']}") | ||
| 225 | + except Exception as e: | ||
| 226 | + print(f"清理分块会话时出错: {e}") | ||
| 227 | + finally: | ||
| 228 | + del chunk_sessions[user_id] | ||
| 229 | + | ||
| 230 | +async def process_chunked_audio(websocket, data): | ||
| 231 | + """处理分块重组后的音频文件""" | ||
| 232 | + try: | ||
| 233 | + audio_file_path = data.get('audio_file_path') | ||
| 234 | + filename = data.get('filename', 'chunked_audio.wav') | ||
| 235 | + | ||
| 236 | + if not audio_file_path or not os.path.exists(audio_file_path): | ||
| 237 | + await websocket.send(json.dumps({"error": "重组音频文件不存在"})) | ||
| 238 | + return | ||
| 239 | + | ||
| 240 | + print(f"处理分块重组音频: {filename}, 文件路径: {audio_file_path}") | ||
| 241 | + | ||
| 242 | + # 热词配置 | ||
| 243 | + param_dict = {"sentence_timestamp": False} | ||
| 244 | + try: | ||
| 245 | + with open("data/hotword.txt", "r", encoding="utf-8") as f: | ||
| 246 | + lines = f.readlines() | ||
| 247 | + lines = [line.strip() for line in lines] | ||
| 248 | + hotword = " ".join(lines) | ||
| 249 | + print(f"热词:{hotword}") | ||
| 250 | + param_dict["hotword"] = hotword | ||
| 251 | + except FileNotFoundError: | ||
| 252 | + print("热词文件不存在,跳过热词配置") | ||
| 253 | + | ||
| 254 | + # 进行语音识别 | ||
| 255 | + res = asr_model.generate(input=audio_file_path, is_final=True, **param_dict) | ||
| 256 | + if res and websocket.open: | ||
| 257 | + if 'text' in res[0]: | ||
| 258 | + result_text = res[0]['text'] | ||
| 259 | + print(f"分块音频识别结果: {result_text}") | ||
| 260 | + await websocket.send(result_text) | ||
| 261 | + else: | ||
| 262 | + await websocket.send("识别失败:无法获取文本结果") | ||
| 263 | + | ||
| 264 | + except Exception as e: | ||
| 265 | + print(f"处理分块音频时出错: {e}") | ||
| 266 | + if websocket.open: | ||
| 267 | + await websocket.send(f"分块音频识别错误: {str(e)}") | ||
| 268 | + finally: | ||
| 269 | + # 注释掉临时文件删除操作,保留用于测试 | ||
| 270 | + # if 'audio_file_path' in locals() and os.path.exists(audio_file_path): | ||
| 271 | + # os.remove(audio_file_path) | ||
| 272 | + if 'audio_file_path' in locals(): | ||
| 273 | + print(f"保留分块重组音频文件用于测试: {audio_file_path}") | ||
| 274 | + | ||
| 275 | +async def process_audio_data(websocket, data): | ||
| 276 | + """处理音频数据""" | ||
| 277 | + import base64 | ||
| 278 | + import tempfile | ||
| 279 | + | ||
| 280 | + try: | ||
| 281 | + # 获取音频数据 | ||
| 282 | + audio_data = data.get('audio_data') | ||
| 283 | + filename = data.get('filename', 'audio.wav') | ||
| 284 | + | ||
| 285 | + if not audio_data: | ||
| 286 | + await websocket.send(json.dumps({"error": "No audio data provided"})) | ||
| 287 | + return | ||
| 288 | + | ||
| 289 | + # 解码Base64音频数据 | ||
| 290 | + audio_bytes = base64.b64decode(audio_data) | ||
| 291 | + | ||
| 292 | + # 创建临时文件 | ||
| 293 | + with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file: | ||
| 294 | + temp_file.write(audio_bytes) | ||
| 295 | + temp_path = temp_file.name | ||
| 296 | + | ||
| 297 | + print(f"处理音频文件: {filename}, 临时路径: {temp_path}") | ||
| 298 | + | ||
| 299 | + # 热词配置 | ||
| 300 | + param_dict = {"sentence_timestamp": False} | ||
| 301 | + try: | ||
| 302 | + with open("data/hotword.txt", "r", encoding="utf-8") as f: | ||
| 303 | + lines = f.readlines() | ||
| 304 | + lines = [line.strip() for line in lines] | ||
| 305 | + hotword = " ".join(lines) | ||
| 306 | + print(f"热词:{hotword}") | ||
| 307 | + param_dict["hotword"] = hotword | ||
| 308 | + except FileNotFoundError: | ||
| 309 | + print("热词文件不存在,跳过热词配置") | ||
| 310 | + | ||
| 311 | + # 进行语音识别 | ||
| 312 | + res = asr_model.generate(input=temp_path, is_final=True, **param_dict) | ||
| 313 | + if res and websocket.open: | ||
| 314 | + if 'text' in res[0]: | ||
| 315 | + result_text = res[0]['text'] | ||
| 316 | + print(f"识别结果: {result_text}") | ||
| 317 | + await websocket.send(result_text) | ||
| 318 | + else: | ||
| 319 | + await websocket.send("识别失败:无法获取文本结果") | ||
| 320 | + | ||
| 321 | + except Exception as e: | ||
| 322 | + print(f"处理音频数据时出错: {e}") | ||
| 323 | + if websocket.open: | ||
| 324 | + await websocket.send(f"识别错误: {str(e)}") | ||
| 325 | + finally: | ||
| 326 | + # 注释掉临时文件删除操作,保留用于测试 | ||
| 327 | + # if 'temp_path' in locals() and os.path.exists(temp_path): | ||
| 328 | + # os.remove(temp_path) | ||
| 329 | + if 'temp_path' in locals(): | ||
| 330 | + print(f"保留临时音频文件用于测试: {temp_path}") | ||
| 82 | 331 | ||
| 83 | async def main(): | 332 | async def main(): |
| 84 | server = await websockets.serve(ws_serve, args.host, args.port, ping_interval=10) | 333 | server = await websockets.serve(ws_serve, args.host, args.port, ping_interval=10) |
| @@ -87,6 +336,7 @@ async def main(): | @@ -87,6 +336,7 @@ async def main(): | ||
| 87 | try: | 336 | try: |
| 88 | # 保持服务器运行,直到被手动中断 | 337 | # 保持服务器运行,直到被手动中断 |
| 89 | print(f"ASR服务器已启动,监听地址: {args.host}:{args.port}") | 338 | print(f"ASR服务器已启动,监听地址: {args.host}:{args.port}") |
| 339 | + print("注意:此版本已禁用文件自动删除功能,用于测试分析") | ||
| 90 | await asyncio.Future() # 永久等待,直到程序被中断 | 340 | await asyncio.Future() # 永久等待,直到程序被中断 |
| 91 | except asyncio.CancelledError: | 341 | except asyncio.CancelledError: |
| 92 | print("服务器正在关闭...") | 342 | print("服务器正在关闭...") |
| @@ -101,4 +351,11 @@ async def main(): | @@ -101,4 +351,11 @@ async def main(): | ||
| 101 | await server.wait_closed() | 351 | await server.wait_closed() |
| 102 | 352 | ||
| 103 | # 使用 asyncio 运行主函数 | 353 | # 使用 asyncio 运行主函数 |
| 104 | -asyncio.run(main()) | 354 | +try: |
| 355 | + asyncio.run(main()) | ||
| 356 | +except KeyboardInterrupt: | ||
| 357 | + logging.info("服务器已关闭") | ||
| 358 | +except Exception as e: | ||
| 359 | + logging.error(f"服务器启动失败: {e}") | ||
| 360 | + import traceback | ||
| 361 | + traceback.print_exc() |
-
Please register or login to post a comment