冯杨

ASR WebSocket服务实现:1.FunASR本地方案 2.豆包

1.已实现音频文件的处理,包括小文件直接转换以及大文件分割识别。
2.豆包接入流式识别,但封装仍需要修改

Too many changes to show.

To preserve performance only 28 of 28+ files are displayed.

... ... @@ -19,4 +19,5 @@ workspace/log_ngp.txt
models/
*.log
dist
\ No newline at end of file
dist
.vscode/launch.json
... ...
python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar7 --fullbody_height 1722 --fullbody_width 1080
\ No newline at end of file
python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar10 --fullbody_height 1920 --fullbody_width 1080
\ No newline at end of file
... ...
... ... @@ -49,120 +49,42 @@ import gc
import weakref
import time
# 注意:server_recording_api模块已移除,相关功能已迁移到其他模块
# 导入新的统一WebSocket管理架构
from core.app_websocket_migration import (
get_app_websocket_migration,
initialize_app_websocket_migration,
setup_app_websocket_routes,
broadcast_message_to_session,
handle_asr_audio_data,
handle_start_asr_recognition,
handle_stop_asr_recognition,
send_asr_result,
send_normal_asr_result
)
app = Flask(__name__)
#sockets = Sockets(app)
nerfreals:Dict[int, BaseReal] = {} #sessionid:BaseReal
websocket_connections:Dict[int, weakref.WeakSet] = {} #sessionid:websocket_connections
# WebSocket连接管理已迁移到统一架构
# websocket_connections和asr_connections现在通过迁移层管理
# 全局事件循环引用,用于跨线程异步调用
main_event_loop = None
opt = None
model = None
avatar = None
# WebSocket迁移实例
websocket_migration = None
#####webrtc###############################
pcs = set()
# WebSocket消息推送函数
async def broadcast_message_to_session(sessionid: int, message_type: str, content: str, source: str = "数字人回复", model_info: str = None, request_source: str = "页面"):
"""向指定会话的所有WebSocket连接推送消息"""
logger.info(f'[SessionID:{sessionid}] 开始推送消息: {message_type}, source: {source}, content: {content[:50]}...')
logger.info(f'[SessionID:{sessionid}] 当前websocket_connections keys: {list(websocket_connections.keys())}')
if sessionid not in websocket_connections:
logger.warning(f'[SessionID:{sessionid}] 会话不存在于websocket_connections中')
return
logger.info(f'[SessionID:{sessionid}] 找到会话,连接数量: {len(websocket_connections[sessionid])}')
message = {
"type": "chat_message",
"data": {
"sessionid": sessionid,
"message_type": message_type,
"content": content,
"source": source,
"model_info": model_info,
"request_source": request_source,
"timestamp": time.time()
}
}
# 获取该会话的所有WebSocket连接
connections = list(websocket_connections[sessionid])
# 向所有连接发送消息
logger.info(f'[SessionID:{sessionid}] 准备向{len(connections)}个连接发送消息')
for i, ws in enumerate(connections):
try:
logger.info(f'[SessionID:{sessionid}] 检查连接{i+1}: closed={ws.closed}')
if not ws.closed:
logger.info(f'[SessionID:{sessionid}] 向连接{i+1}发送消息: {json.dumps(message)}')
await ws.send_str(json.dumps(message))
logger.info(f'[SessionID:{sessionid}] 连接{i+1}消息发送成功: {message_type} from {request_source}')
else:
logger.warning(f'[SessionID:{sessionid}] 连接{i+1}已关闭,跳过发送')
except Exception as e:
logger.error(f'[SessionID:{sessionid}] 连接{i+1}发送失败: {e}')
# WebSocket消息推送函数已迁移到统一架构
# 通过 core.app_websocket_migration 模块提供兼容性接口
# WebSocket处理器
async def websocket_handler(request):
"""处理WebSocket连接"""
ws = web.WebSocketResponse()
await ws.prepare(request)
sessionid = None
logger.info('New WebSocket connection established')
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
if data.get('type') == 'login':
sessionid = data.get('sessionid', 0)
logger.info(f'[SessionID:{sessionid}] 收到登录请求,当前连接池: {list(websocket_connections.keys())}')
# 初始化该会话的WebSocket连接集合
if sessionid not in websocket_connections:
websocket_connections[sessionid] = weakref.WeakSet()
logger.info(f'[SessionID:{sessionid}] 创建新的连接集合')
# 添加当前连接到会话
websocket_connections[sessionid].add(ws)
logger.info(f'[SessionID:{sessionid}] 连接已添加,当前会话连接数: {len(websocket_connections[sessionid])}')
logger.info(f'[SessionID:{sessionid}] WebSocket client logged in')
# 发送登录确认
await ws.send_str(json.dumps({
"type": "login_success",
"sessionid": sessionid,
"message": "WebSocket连接成功"
}))
elif data.get('type') == 'ping':
# 心跳检测
await ws.send_str(json.dumps({"type": "pong"}))
except json.JSONDecodeError:
logger.error('Invalid JSON received from WebSocket')
except Exception as e:
logger.error(f'Error processing WebSocket message: {e}')
elif msg.type == WSMsgType.ERROR:
logger.error(f'WebSocket error: {ws.exception()}')
break
except Exception as e:
logger.error(f'WebSocket connection error: {e}')
finally:
if sessionid is not None:
logger.info(f'[SessionID:{sessionid}] WebSocket connection closed')
else:
logger.info('WebSocket connection closed')
return ws
# WebSocket处理器已迁移到统一架构
# 通过 core.app_websocket_migration 模块提供
def randN(N)->int:
'''生成长度为 N的随机数 '''
... ... @@ -456,41 +378,187 @@ async def interrupt_talk(request):
)
from pydub import AudioSegment
from io import BytesIO
async def humanaudio(request):
async def ensure_asr_connection(sessionid: int) -> bool:
"""确保ASR连接可用"""
# 通过迁移实例获取ASR连接
migration = get_app_websocket_migration()
if sessionid not in migration.asr_connections:
return await create_asr_connection(sessionid)
asr_client = migration.asr_connections[sessionid]
# 检查连接状态
if not asr_client.is_connected():
logger.warning(f"[SessionID:{sessionid}] ASR连接已断开,尝试重连")
try:
# 重新连接
success = await asyncio.get_event_loop().run_in_executor(
None, asr_client.connect
)
if success:
logger.info(f"[SessionID:{sessionid}] ASR重连成功")
return True
else:
logger.error(f"[SessionID:{sessionid}] ASR重连失败")
# 清理失效连接
migration = get_app_websocket_migration()
if sessionid in migration.asr_connections:
del migration.asr_connections[sessionid]
return False
except Exception as e:
logger.error(f"[SessionID:{sessionid}] ASR重连异常: {e}")
del asr_connections[sessionid]
return False
return True
async def create_asr_connection(sessionid: int) -> bool:
"""创建新的ASR连接"""
try:
params = await request.json()
sessionid = int(params.get('sessionid', 0))
fileobj = params.get('file_url')
# 获取音频文件数据
if isinstance(fileobj, str) and fileobj.startswith("http"):
async with aiohttp.ClientSession() as session:
async with session.get(fileobj) as response:
if response.status == 200:
filebytes = await response.read()
else:
return web.Response(
content_type="application/json",
text=json.dumps({"code": -1, "msg": "Error downloading file"})
)
# 根据 URL 后缀判断是否为 MP3 文件
is_mp3 = fileobj.lower().endswith('.mp3')
from funasr_asr_sync import FunASRSync
username = f'User_{sessionid}' # 修复大小写不一致:user_ -> User_
asr_client = FunASRSync(username)
# 设置结果回调
def on_asr_result(result):
if isinstance(result, str):
result_data = {
'text': result,
'is_final': True,
'confidence': 1.0
}
else:
result_data = result
# 线程安全地调度异步任务
try:
# 优先使用全局事件循环引用
if main_event_loop is not None and not main_event_loop.is_closed():
# 使用全局事件循环进行跨线程调用
asyncio.run_coroutine_threadsafe(
# send_asr_result(sessionid, result_data), main_event_loop
send_normal_asr_result(sessionid, result_data), main_event_loop
)
logger.debug(f"[SessionID:{sessionid}] 使用全局事件循环发送ASR结果")
else:
# 降级处理:尝试获取当前线程的事件循环
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.call_soon_threadsafe(
lambda: asyncio.create_task(send_normal_asr_result(sessionid, result_data))
)
else:
asyncio.create_task(send_normal_asr_result(sessionid, result_data))
except RuntimeError:
# 最终降级:仅记录日志
logger.info(f"[SessionID:{sessionid}] ASR识别结果: {result_data.get('text', 'N/A')}")
logger.warning(f"[SessionID:{sessionid}] 无法发送ASR结果到客户端,事件循环不可用")
except Exception as e:
logger.error(f"[SessionID:{sessionid}] ASR结果处理异常: {e}")
# 至少记录识别结果
logger.info(f"[SessionID:{sessionid}] ASR识别结果: {result_data.get('text', 'N/A')}")
asr_client.set_result_callback(on_asr_result)
# 异步连接
success = await asyncio.get_event_loop().run_in_executor(
None, asr_client.connect
)
if success:
# 通过迁移实例存储ASR连接
migration = get_app_websocket_migration()
migration.asr_connections[sessionid] = asr_client
logger.info(f"[SessionID:{sessionid}] ASR连接创建成功")
return True
else:
filename = fileobj.filename
filebytes = fileobj.file.read()
is_mp3 = filename.lower().endswith('.mp3')
logger.error(f"[SessionID:{sessionid}] ASR连接创建失败")
return False
except Exception as e:
logger.error(f"[SessionID:{sessionid}] 创建ASR连接异常: {e}")
return False
if is_mp3:
audio = AudioSegment.from_file(BytesIO(filebytes), format="mp3")
out_io = BytesIO()
audio.export(out_io, format="wav")
filebytes = out_io.getvalue()
async def humanaudio(request):
try:
# 检查请求内容类型,支持FormData和JSON两种格式
content_type = request.headers.get('content-type', '')
# 处理FormData格式(文件上传)
reader = await request.multipart()
sessionid = 0
fileobj = None
# 默认启用语音本地服务
asr_service = "funasr"
# 读取FormData字段
async for field in reader:
if field.name == 'sessionid':
sessionid = int(await field.text())
logger.info(f'Parsed sessionid: {sessionid}')
elif field.name == 'audio':
fileobj = field
filename = field.filename
filebytes = await field.read()
# 输出文件大小信息
logger.info(f'Audio file content size: {len(filebytes)} bytes')
if not fileobj:
return web.Response(
content_type="application/json",
text=json.dumps({"code": -1, "msg": "No audio file provided"})
)
elif field.name == 'asr_service':
asr_service = (await field.text()).strip().lower()
# 根据文件名判断是否为 MP3 文件
is_mp3 = filename.lower().endswith('.mp3') if filename else False
# 处理MP3转WAV
if is_mp3:
try:
with BytesIO(filebytes) as audio_buffer:
audio = AudioSegment.from_file(audio_buffer, format="mp3")
out_io = BytesIO()
audio.export(out_io, format="wav")
filebytes = out_io.getvalue()
except Exception as e:
logger.error(f"[SessionID:{sessionid}] 音频处理失败: {e}")
raise
# 获取WebSocket迁移实例来访问连接信息
migration = get_app_websocket_migration()
active_sessions = migration.get_websocket_connections()
logger.info(f'[SessionID:{sessionid}] 收到登录请求,当前连接池: {list(active_sessions.keys())}')
# 验证sessionid是否存在
if sessionid not in nerfreals:
return web.Response(
content_type="application/json",
text=json.dumps({"code": -1, "msg": f"Session {sessionid} not found. Please establish WebRTC connection first."})
)
# 发送音频数据进行处理 数字人播报
nerfreals[sessionid].put_audio_file(filebytes)
# ---------- ASR 分流 ----------
if asr_service == 'funasr':
await handle_funasr(sessionid, filebytes)
elif asr_service == 'doubao':
await handle_doubao(sessionid, filebytes)
else:
logger.warning(f'[SessionID:{sessionid}] 未指定或未知 asr_service,跳过 ASR')
# 通过迁移实例检查ASR连接状态
migration = get_app_websocket_migration()
asr_enabled = sessionid in migration.asr_connections
return web.Response(
content_type="application/json",
text=json.dumps({"code": 0, "msg": "ok"})
text=json.dumps({"code": 0, "msg": "ok", "asr_enabled": asr_enabled})
)
except Exception as e:
... ... @@ -500,6 +568,66 @@ async def humanaudio(request):
text=json.dumps( {"code": -1, "msg": str(e)})
)
async def handle_funasr(sessionid: int, audio_bytes: bytes):
# ASR识别处理 - 使用新的连接管理机制
try:
# 确保ASR连接可用
asr_available = await ensure_asr_connection(sessionid)
if asr_available:
# 发送音频数据到ASR服务进行识别
# 通过迁移实例获取ASR连接
migration = get_app_websocket_migration()
asr_client = migration.asr_connections[sessionid]
if hasattr(asr_client, 'send_audio_data'):
asr_client.send_audio_data(audio_bytes)
logger.info(f'[SessionID:{sessionid}] 音频数据已发送到ASR服务进行识别')
else:
logger.warning(f'[SessionID:{sessionid}] ASR客户端不支持send_audio_data方法')
else:
logger.warning(f'[SessionID:{sessionid}] ASR连接不可用,跳过语音识别')
except Exception as asr_error:
logger.error(f'[SessionID:{sessionid}] ASR处理错误: {asr_error}')
# ASR错误不影响主要功能,继续返回成功
# 导入 Doubao ASR 服务
from asr.doubao.service_factory import recognize_audio_data
import os
import json
async def handle_doubao(sessionid: int, audio_bytes: bytes):
"""云端 Doubao 调用"""
try:
logger.info(f"[SessionID:{sessionid}] 使用云端 Doubao 识别")
# 读取豆包ASR配置文件
config_path = os.path.join(os.path.dirname(__file__), 'asr', 'doubao', 'config.json')
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
# 获取认证配置
auth_config = config.get('auth_config', {})
app_key = auth_config.get('app_key')
access_key = auth_config.get('access_key')
if not app_key or not access_key:
raise ValueError("豆包ASR认证配置缺失:app_key 或 access_key 未配置")
result = await recognize_audio_data(
audio_data=audio_bytes,
app_key=app_key,
access_key=access_key,
streaming=True,
result_callback=lambda res: logger.info(f"[SessionID:{sessionid}] Doubao 识别结果: {res}")
)
return result
except Exception as e:
logger.error(f"[SessionID:{sessionid}] Doubao 错误: {e}")
raise
async def set_audiotype(request):
try:
params = await request.json()
... ... @@ -787,6 +915,10 @@ if __name__ == '__main__':
rendthrd.start()
#############################################################################
# ASR处理函数已迁移到统一架构
# 通过 core.app_websocket_migration 模块提供
#############################################################################
appasync = web.Application()
appasync.on_shutdown.append(on_shutdown)
appasync.router.add_post("/offer", offer)
... ... @@ -796,8 +928,26 @@ if __name__ == '__main__':
appasync.router.add_post("/record", record)
appasync.router.add_post("/interrupt_talk", interrupt_talk)
appasync.router.add_post("/is_speaking", is_speaking)
appasync.router.add_get("/ws", websocket_handler)
# 初始化统一WebSocket管理架构
websocket_migration = get_app_websocket_migration()
# 注册WebSocket接口 - 使用新的统一架构
setup_app_websocket_routes(appasync)
# 异步初始化将在服务器启动时进行
async def init_websocket_migration():
await initialize_app_websocket_migration()
logger.info("WebSocket迁移架构初始化完成")
# 添加启动时初始化
appasync.on_startup.append(lambda app: init_websocket_migration())
appasync.router.add_static('/',path='web')
# 服务端录音WebSocket接口已集成到统一架构中
# 通过 /ws 路由和消息类型区分访问:wsa_register_web, wsa_register_human 等
logger.info("主应用路由配置完成,WebSocket接口已统一到 /ws 路由")
# Configure default CORS settings.
cors = aiohttp_cors.setup(appasync, defaults={
... ... @@ -819,8 +969,13 @@ if __name__ == '__main__':
logger.info('start http server; http://<serverip>:'+str(opt.listenport)+'/'+pagename)
logger.info('如果使用webrtc,推荐访问webrtc集成前端: http://<serverip>:'+str(opt.listenport)+'/dashboard.html')
def run_server(runner):
global main_event_loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 设置全局事件循环引用,用于跨线程异步调用
main_event_loop = loop
logger.info("全局事件循环引用已设置")
loop.run_until_complete(runner.setup())
site = web.TCPSite(runner, '0.0.0.0', opt.listenport)
loop.run_until_complete(site.start())
... ...
# -*- coding: utf-8 -*-
"""
AIfeng/2025-01-27
AIfeng/2025-06-30
配置管理工具模块
统一管理项目配置参数
"""
... ... @@ -24,6 +24,9 @@ class ConfigManager:
'local_asr_ip': '127.0.0.1',
'local_asr_port': 10197,
# 音频设备配置
'local_audio_ip': '127.0.0.1',
# 阿里云NLS配置
'key_ali_nls_key_id': '',
'key_ali_nls_key_secret': '',
... ... @@ -82,6 +85,7 @@ _config_manager = ConfigManager()
# 兼容原有的属性访问方式
local_asr_ip = _config_manager.local_asr_ip
local_asr_port = _config_manager.local_asr_port
local_audio_ip = _config_manager.local_audio_ip
key_ali_nls_key_id = _config_manager.key_ali_nls_key_id
key_ali_nls_key_secret = _config_manager.key_ali_nls_key_secret
key_ali_nls_app_key = _config_manager.key_ali_nls_app_key
... ...
# -*- coding: utf-8 -*-
"""
AIfeng/2025-01-27
Core模块初始化文件
AIfeng/2025-07-17 14:15:27
Core模块初始化
已迁移到统一WebSocket架构,移除旧的wsa_server依赖
"""
from .wsa_server import get_web_instance, get_instance
# 统一WebSocket架构导入
from .wsa_websocket_service import get_web_instance, get_instance
__all__ = ['get_web_instance', 'get_instance']
\ No newline at end of file
... ...
# -*- coding: utf-8 -*-
"""
AIfeng/2025-07-15 14:41:21
app.py WebSocket功能迁移脚本
将app.py中的WebSocket功能迁移到统一架构
"""
import asyncio
import json
import weakref
from typing import Dict, Any, Optional
from aiohttp import web
from logger import logger
from .websocket_router import get_websocket_router, get_websocket_compatibility_api
from .asr_websocket_service import get_asr_service
from .digital_human_websocket_service import get_digital_human_service
class AppWebSocketMigration:
"""app.py WebSocket功能迁移类"""
def __init__(self):
self.router = get_websocket_router()
self.compatibility_api = get_websocket_compatibility_api()
self.asr_service = get_asr_service()
self.digital_human_service = get_digital_human_service()
# 兼容性变量(保持与原app.py的接口一致)
self.websocket_connections = {}
self.asr_connections = {}
async def initialize(self):
"""初始化迁移组件"""
await self.router.initialize()
logger.info('WebSocket迁移组件初始化完成')
async def shutdown(self):
"""关闭迁移组件"""
await self.router.shutdown()
logger.info('WebSocket迁移组件已关闭')
def setup_routes(self, app: web.Application):
"""设置路由(替换原app.py中的WebSocket路由)"""
# 使用新的统一WebSocket处理器
self.router.setup_routes(app, '/ws')
# 添加兼容性路由(如果需要)
app.router.add_get('/ws_legacy', self._legacy_websocket_handler)
async def _legacy_websocket_handler(self, request: web.Request):
"""兼容性WebSocket处理器(保持原有接口)"""
# 直接转发到新的统一处理器
return await self.router.websocket_handler(request)
# 兼容性接口方法
async def broadcast_message_to_session(self, sessionid: int, message_type: str,
content: str, source: str = "数字人回复",
model_info: str = None, request_source: str = "页面"):
"""兼容原app.py的消息推送接口"""
message_data = {
"sessionid": sessionid,
"message_type": message_type,
"content": content,
"source": source,
"model_info": model_info,
"request_source": request_source,
"timestamp": asyncio.get_event_loop().time()
}
return await self.router.send_to_session(str(sessionid), 'chat_message', message_data)
async def handle_asr_audio_data(self, data: Dict[str, Any], sessionid: int, ws):
"""兼容原app.py的ASR音频数据处理"""
# 转换为新架构的消息格式
message_data = {
'audio_data': data.get('audio_data'),
'sessionid': sessionid
}
# 通过新的ASR服务处理
session = self.router.manager.get_session(ws)
if session:
await self.asr_service._handle_asr_audio_data(ws, message_data)
async def handle_start_asr_recognition(self, sessionid: int, ws):
"""兼容原app.py的开始ASR识别"""
session = self.router.manager.get_session(ws)
if session:
await self.asr_service._handle_start_asr_recognition(ws, {'sessionid': sessionid})
async def handle_stop_asr_recognition(self, sessionid: int, ws):
"""兼容原app.py的停止ASR识别"""
session = self.router.manager.get_session(ws)
if session:
await self.asr_service._handle_stop_asr_recognition(ws, {'sessionid': sessionid})
async def send_asr_result(self, sessionid: int, result: Dict[str, Any]):
"""兼容原app.py的ASR结果发送"""
return await self.router.send_to_session(str(sessionid), 'asr_result', {
"text": result.get('text', ''),
"is_final": result.get('is_final', False),
"confidence": result.get('confidence', 0.0)
})
async def send_normal_asr_result(self, sessionid: int, result: Dict[str, Any]):
"""业务层决定传输内容以及结构"""
return await self.router.send_raw_to_session(str(sessionid), result)
def get_websocket_connections(self):
"""获取WebSocket连接(兼容性接口)"""
# 返回兼容性字典格式,键为会话ID,值为WebSocket对象
sessions_dict = self.router.manager._sessions
result = {}
for session_id, session_set in sessions_dict.items():
# 取集合中的第一个WebSocket连接(通常每个session_id只有一个连接)
if session_set:
session = next(iter(session_set))
result[session_id] = session.websocket
return result
def get_session_count(self):
"""获取会话数量(兼容性接口)"""
return self.compatibility_api.get_session_count()
async def cleanup_session(self, sessionid: int):
"""清理会话(兼容性接口)"""
# 清理ASR连接
if sessionid in self.asr_connections:
del self.asr_connections[sessionid]
# 通过新架构清理会话
sessions = self.router.manager._sessions
session_id_str = str(sessionid)
for ws, session in list(sessions.items()):
if session.session_id == session_id_str:
await self.router.manager.remove_session(ws)
break
def get_migration_stats(self) -> Dict[str, Any]:
"""获取迁移统计信息"""
return {
"router_stats": self.router.get_router_stats(),
"asr_stats": self.asr_service.get_asr_stats(),
"digital_human_stats": self.digital_human_service.get_digital_human_stats(),
"compatibility_sessions": len(self.websocket_connections),
"compatibility_asr_connections": len(self.asr_connections)
}
# 全局迁移实例
_migration_instance = None
def get_app_websocket_migration() -> AppWebSocketMigration:
"""获取app.py WebSocket迁移实例"""
global _migration_instance
if _migration_instance is None:
_migration_instance = AppWebSocketMigration()
return _migration_instance
async def initialize_app_websocket_migration():
"""初始化app.py WebSocket迁移"""
migration = get_app_websocket_migration()
await migration.initialize()
return migration
async def shutdown_app_websocket_migration():
"""关闭app.py WebSocket迁移"""
global _migration_instance
if _migration_instance:
await _migration_instance.shutdown()
_migration_instance = None
def setup_app_websocket_routes(app: web.Application):
"""设置app.py WebSocket路由(便捷函数)"""
migration = get_app_websocket_migration()
migration.setup_routes(app)
return migration
# 兼容性函数(保持与原app.py的接口一致)
async def broadcast_message_to_session(sessionid: int, message_type: str, content: str,
source: str = "数字人回复", model_info: str = None,
request_source: str = "页面"):
"""兼容原app.py的消息推送函数"""
migration = get_app_websocket_migration()
return await migration.broadcast_message_to_session(
sessionid, message_type, content, source, model_info, request_source
)
async def handle_asr_audio_data(data: Dict[str, Any], sessionid: int, ws):
"""兼容原app.py的ASR音频数据处理函数"""
migration = get_app_websocket_migration()
return await migration.handle_asr_audio_data(data, sessionid, ws)
async def handle_start_asr_recognition(sessionid: int, ws):
"""兼容原app.py的开始ASR识别函数"""
migration = get_app_websocket_migration()
return await migration.handle_start_asr_recognition(sessionid, ws)
async def handle_stop_asr_recognition(sessionid: int, ws):
"""兼容原app.py的停止ASR识别函数"""
migration = get_app_websocket_migration()
return await migration.handle_stop_asr_recognition(sessionid, ws)
async def send_asr_result(sessionid: int, result: Dict[str, Any]):
"""兼容原app.py的ASR结果发送函数"""
migration = get_app_websocket_migration()
return await migration.send_asr_result(sessionid, result)
async def send_normal_asr_result(sessionid: int, result: Dict[str, Any]):
"""兼容原app.py的ASR结果发送函数"""
migration = get_app_websocket_migration()
return await migration.send_normal_asr_result(sessionid, result)
# 全局变量兼容性接口
def get_websocket_connections():
"""获取WebSocket连接字典(兼容性接口)"""
migration = get_app_websocket_migration()
return migration.websocket_connections
def get_asr_connections():
"""获取ASR连接字典(兼容性接口)"""
migration = get_app_websocket_migration()
return migration.asr_connections
\ No newline at end of file
... ...
# -*- coding: utf-8 -*-
"""
AIfeng/2025-07-15 14:41:21
ASR WebSocket服务实现
从app.py中抽离的ASR相关WebSocket功能
"""
import asyncio
import json
import weakref
from typing import Dict, Any, Optional
from aiohttp import web
from logger import logger
from .websocket_service_base import WebSocketServiceBase
from .unified_websocket_manager import WebSocketSession
class ASRWebSocketService(WebSocketServiceBase):
"""ASR WebSocket服务"""
def __init__(self):
super().__init__("asr_service")
# ASR连接管理
self.asr_connections: Dict[str, Any] = {} # sessionid -> asr_connection
self._heartbeat_task = None
async def _register_message_handlers(self):
"""注册ASR相关消息处理器"""
# 注册消息处理器
self.manager.register_message_handler('login', self._handle_login)
self.manager.register_message_handler('heartbeat', self._handle_heartbeat)
self.manager.register_message_handler('asr_audio_data', self._handle_asr_audio_data)
self.manager.register_message_handler('start_asr_recognition', self._handle_start_asr_recognition)
self.manager.register_message_handler('stop_asr_recognition', self._handle_stop_asr_recognition)
async def _start_background_tasks(self):
"""启动心跳检测任务"""
self._heartbeat_task = self.add_background_task(self._heartbeat_monitor())
async def _cleanup(self):
"""清理ASR连接"""
# 关闭所有ASR连接
for session_id, asr_conn in list(self.asr_connections.items()):
try:
if hasattr(asr_conn, 'close'):
await asr_conn.close()
except Exception as e:
logger.error(f'关闭ASR连接失败 {session_id}: {e}')
self.asr_connections.clear()
async def _on_session_disconnected(self, session: WebSocketSession):
"""会话断开时清理ASR连接"""
await super()._on_session_disconnected(session)
# 清理对应的ASR连接
if session.session_id in self.asr_connections:
asr_conn = self.asr_connections.pop(session.session_id)
try:
if hasattr(asr_conn, 'close'):
await asr_conn.close()
logger.info(f'已清理ASR连接: {session.session_id}')
except Exception as e:
logger.error(f'清理ASR连接失败 {session.session_id}: {e}')
async def _handle_login(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理登录消息"""
session = self.manager.get_session(websocket)
if not session:
return
session_id = data.get('sessionid')
if session_id:
# 更新会话ID
old_session_id = session.session_id
session.session_id = session_id
# 更新管理器中的会话映射
self.manager._update_session_id(websocket, old_session_id, session_id)
# 发送登录成功响应
await session.send_message({
"type": "login_response",
"data": {
"status": "success",
"sessionid": session_id,
"message": "登录成功"
}
})
logger.info(f'用户登录成功: {session_id}')
else:
await session.send_message({
"type": "login_response",
"data": {
"status": "error",
"message": "缺少sessionid"
}
})
async def _handle_heartbeat(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理心跳消息"""
session = self.manager.get_session(websocket)
if session:
session.update_last_heartbeat()
await session.send_message({
"type": "heartbeat_response",
"data": {"status": "ok"}
})
async def _handle_asr_audio_data(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理ASR音频数据"""
session = self.manager.get_session(websocket)
if not session:
return
session_id = session.session_id
audio_data = data.get('audio_data')
if not audio_data:
await session.send_message({
"type": "error",
"data": {"message": "缺少音频数据"}
})
return
# 获取或创建ASR连接
asr_conn = self.asr_connections.get(session_id)
if not asr_conn:
logger.warning(f'ASR连接不存在: {session_id}')
await session.send_message({
"type": "error",
"data": {"message": "ASR连接未建立"}
})
return
try:
# 转发音频数据到ASR服务
await self._forward_audio_to_asr(asr_conn, audio_data)
except Exception as e:
logger.error(f'转发音频数据失败 {session_id}: {e}')
await session.send_message({
"type": "error",
"data": {"message": f"音频处理失败: {str(e)}"}
})
async def _handle_start_asr_recognition(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理开始ASR识别"""
session = self.manager.get_session(websocket)
if not session:
return
session_id = session.session_id
try:
# 创建ASR连接
asr_conn = await self._create_asr_connection(session_id)
if asr_conn:
self.asr_connections[session_id] = asr_conn
await session.send_message({
"type": "asr_recognition_started",
"data": {
"status": "success",
"message": "ASR识别已开始"
}
})
logger.info(f'ASR识别已开始: {session_id}')
else:
await session.send_message({
"type": "error",
"data": {"message": "创建ASR连接失败"}
})
except Exception as e:
logger.error(f'开始ASR识别失败 {session_id}: {e}')
await session.send_message({
"type": "error",
"data": {"message": f"开始识别失败: {str(e)}"}
})
async def _handle_stop_asr_recognition(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理停止ASR识别"""
session = self.manager.get_session(websocket)
if not session:
return
session_id = session.session_id
if session_id in self.asr_connections:
asr_conn = self.asr_connections.pop(session_id)
try:
if hasattr(asr_conn, 'close'):
await asr_conn.close()
await session.send_message({
"type": "asr_recognition_stopped",
"data": {
"status": "success",
"message": "ASR识别已停止"
}
})
logger.info(f'ASR识别已停止: {session_id}')
except Exception as e:
logger.error(f'停止ASR识别失败 {session_id}: {e}')
await session.send_message({
"type": "error",
"data": {"message": f"停止识别失败: {str(e)}"}
})
else:
await session.send_message({
"type": "asr_recognition_stopped",
"data": {
"status": "success",
"message": "ASR识别未在运行"
}
})
async def _create_asr_connection(self, session_id: str):
"""创建ASR连接(需要根据实际ASR服务实现)"""
# TODO: 这里需要根据实际的ASR服务(如FunASR)来实现连接逻辑
# 暂时返回一个模拟连接对象
logger.info(f'创建ASR连接: {session_id}')
# 示例:创建到FunASR的WebSocket连接
try:
# 这里应该是实际的ASR连接逻辑
# 例如:asr_conn = await create_funasr_connection(session_id, self._on_asr_result)
asr_conn = MockASRConnection(session_id, self._on_asr_result)
return asr_conn
except Exception as e:
logger.error(f'创建ASR连接失败 {session_id}: {e}')
return None
async def _forward_audio_to_asr(self, asr_conn, audio_data):
"""转发音频数据到ASR服务"""
if hasattr(asr_conn, 'send_audio'):
await asr_conn.send_audio(audio_data)
else:
logger.warning('ASR连接不支持发送音频数据')
async def _on_asr_result(self, session_id: str, result: Dict[str, Any]):
"""ASR结果回调"""
try:
await self.broadcast_to_session(session_id, 'asr_result', result)
logger.debug(f'ASR结果已发送: {session_id}')
except Exception as e:
logger.error(f'发送ASR结果失败 {session_id}: {e}')
async def _heartbeat_monitor(self):
"""心跳监控任务"""
while True:
try:
await asyncio.sleep(40) # 每40秒检查一次
# 检查会话心跳
expired_sessions = self.manager.get_expired_sessions(timeout=60)
for session in expired_sessions:
logger.info(f'会话心跳超时,断开连接: {session.session_id}')
await session.close()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f'心跳监控异常: {e}')
await asyncio.sleep(5)
def get_asr_stats(self) -> Dict[str, Any]:
"""获取ASR统计信息"""
return {
"active_asr_connections": len(self.asr_connections),
"asr_sessions": list(self.asr_connections.keys())
}
class MockASRConnection:
"""模拟ASR连接(用于测试)"""
def __init__(self, session_id: str, result_callback):
self.session_id = session_id
self.result_callback = result_callback
self.is_closed = False
async def send_audio(self, audio_data):
"""发送音频数据"""
if self.is_closed:
return
# 模拟ASR处理
await asyncio.sleep(0.1)
# 模拟返回识别结果
result = {
"text": "模拟识别结果",
"confidence": 0.95,
"timestamp": asyncio.get_event_loop().time()
}
if self.result_callback:
await self.result_callback(self.session_id, result)
async def close(self):
"""关闭连接"""
self.is_closed = True
logger.info(f'模拟ASR连接已关闭: {self.session_id}')
# 创建ASR服务实例
asr_service = ASRWebSocketService()
def get_asr_service() -> ASRWebSocketService:
"""获取ASR服务实例"""
return asr_service
\ No newline at end of file
... ...
# -*- coding: utf-8 -*-
"""
AIfeng/2025-07-15 14:41:21
数字人WebSocket服务实现
处理数字人相关的WebSocket通信和状态管理
"""
import asyncio
import json
from typing import Dict, Any, Optional, List
from aiohttp import web
from logger import logger
from .websocket_service_base import WebSocketServiceBase
from .unified_websocket_manager import WebSocketSession
class DigitalHumanWebSocketService(WebSocketServiceBase):
"""数字人WebSocket服务"""
def __init__(self):
super().__init__("digital_human_service")
# 数字人状态管理
self.digital_humans: Dict[str, Dict[str, Any]] = {} # human_id -> human_info
self.human_sessions: Dict[str, str] = {} # session_id -> human_id
self.session_humans: Dict[str, List[str]] = {} # session_id -> [human_ids]
async def _register_message_handlers(self):
"""注册数字人相关消息处理器"""
self.manager.register_message_handler('register_digital_human', self._handle_register_digital_human)
self.manager.register_message_handler('unregister_digital_human', self._handle_unregister_digital_human)
self.manager.register_message_handler('digital_human_status', self._handle_digital_human_status)
self.manager.register_message_handler('digital_human_action', self._handle_digital_human_action)
self.manager.register_message_handler('digital_human_speak', self._handle_digital_human_speak)
self.manager.register_message_handler('digital_human_emotion', self._handle_digital_human_emotion)
self.manager.register_message_handler('get_digital_humans', self._handle_get_digital_humans)
async def _on_session_disconnected(self, session: WebSocketSession):
"""会话断开时清理数字人注册"""
await super()._on_session_disconnected(session)
session_id = session.session_id
# 清理该会话注册的数字人
if session_id in self.session_humans:
human_ids = self.session_humans.pop(session_id)
for human_id in human_ids:
if human_id in self.digital_humans:
human_info = self.digital_humans.pop(human_id)
logger.info(f'数字人已注销: {human_id} (会话断开)')
# 通知其他会话数字人已离线
await self.broadcast_to_all('digital_human_offline', {
'human_id': human_id,
'name': human_info.get('name', ''),
'reason': 'session_disconnected'
})
# 清理会话到数字人的映射
if session_id in self.human_sessions:
del self.human_sessions[session_id]
async def _handle_register_digital_human(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理数字人注册"""
session = self.manager.get_session(websocket)
if not session:
return
human_id = data.get('human_id')
human_name = data.get('name', '')
human_type = data.get('type', 'default')
capabilities = data.get('capabilities', [])
if not human_id:
await session.send_message({
"type": "error",
"data": {"message": "缺少human_id"}
})
return
# 检查数字人是否已存在
if human_id in self.digital_humans:
await session.send_message({
"type": "register_digital_human_response",
"data": {
"status": "error",
"message": f"数字人已存在: {human_id}"
}
})
return
# 注册数字人
human_info = {
'human_id': human_id,
'name': human_name,
'type': human_type,
'capabilities': capabilities,
'session_id': session.session_id,
'status': 'online',
'registered_at': asyncio.get_event_loop().time(),
'last_activity': asyncio.get_event_loop().time()
}
self.digital_humans[human_id] = human_info
self.human_sessions[session.session_id] = human_id
# 记录会话注册的数字人
if session.session_id not in self.session_humans:
self.session_humans[session.session_id] = []
self.session_humans[session.session_id].append(human_id)
# 发送注册成功响应
await session.send_message({
"type": "register_digital_human_response",
"data": {
"status": "success",
"human_id": human_id,
"message": "数字人注册成功"
}
})
# 通知其他会话新数字人上线
await self.broadcast_to_all('digital_human_online', {
'human_id': human_id,
'name': human_name,
'type': human_type,
'capabilities': capabilities
}, metadata={'exclude_session': session.session_id})
logger.info(f'数字人已注册: {human_id} ({human_name})')
async def _handle_unregister_digital_human(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理数字人注销"""
session = self.manager.get_session(websocket)
if not session:
return
human_id = data.get('human_id')
if not human_id:
await session.send_message({
"type": "error",
"data": {"message": "缺少human_id"}
})
return
# 检查数字人是否存在且属于当前会话
if human_id not in self.digital_humans:
await session.send_message({
"type": "unregister_digital_human_response",
"data": {
"status": "error",
"message": f"数字人不存在: {human_id}"
}
})
return
human_info = self.digital_humans[human_id]
if human_info['session_id'] != session.session_id:
await session.send_message({
"type": "unregister_digital_human_response",
"data": {
"status": "error",
"message": "无权注销该数字人"
}
})
return
# 注销数字人
del self.digital_humans[human_id]
if session.session_id in self.human_sessions:
del self.human_sessions[session.session_id]
if session.session_id in self.session_humans:
self.session_humans[session.session_id].remove(human_id)
if not self.session_humans[session.session_id]:
del self.session_humans[session.session_id]
# 发送注销成功响应
await session.send_message({
"type": "unregister_digital_human_response",
"data": {
"status": "success",
"human_id": human_id,
"message": "数字人注销成功"
}
})
# 通知其他会话数字人已离线
await self.broadcast_to_all('digital_human_offline', {
'human_id': human_id,
'name': human_info.get('name', ''),
'reason': 'manual_unregister'
}, metadata={'exclude_session': session.session_id})
logger.info(f'数字人已注销: {human_id}')
async def _handle_digital_human_status(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理数字人状态更新"""
session = self.manager.get_session(websocket)
if not session:
return
human_id = data.get('human_id')
status = data.get('status')
if not human_id or not status:
await session.send_message({
"type": "error",
"data": {"message": "缺少human_id或status"}
})
return
if human_id not in self.digital_humans:
await session.send_message({
"type": "error",
"data": {"message": f"数字人不存在: {human_id}"}
})
return
human_info = self.digital_humans[human_id]
if human_info['session_id'] != session.session_id:
await session.send_message({
"type": "error",
"data": {"message": "无权更新该数字人状态"}
})
return
# 更新状态
old_status = human_info['status']
human_info['status'] = status
human_info['last_activity'] = asyncio.get_event_loop().time()
# 广播状态变化
await self.broadcast_to_all('digital_human_status_changed', {
'human_id': human_id,
'old_status': old_status,
'new_status': status,
'name': human_info.get('name', '')
})
logger.info(f'数字人状态更新: {human_id} {old_status} -> {status}')
async def _handle_digital_human_action(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理数字人动作指令"""
session = self.manager.get_session(websocket)
if not session:
return
human_id = data.get('human_id')
action = data.get('action')
params = data.get('params', {})
if not human_id or not action:
await session.send_message({
"type": "error",
"data": {"message": "缺少human_id或action"}
})
return
if human_id not in self.digital_humans:
await session.send_message({
"type": "error",
"data": {"message": f"数字人不存在: {human_id}"}
})
return
human_info = self.digital_humans[human_id]
# 更新活动时间
human_info['last_activity'] = asyncio.get_event_loop().time()
# 转发动作指令到数字人会话
target_session_id = human_info['session_id']
await self.broadcast_to_session(target_session_id, 'digital_human_action_command', {
'human_id': human_id,
'action': action,
'params': params,
'from_session': session.session_id
})
logger.info(f'数字人动作指令: {human_id} -> {action}')
async def _handle_digital_human_speak(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理数字人说话指令"""
session = self.manager.get_session(websocket)
if not session:
return
human_id = data.get('human_id')
text = data.get('text')
voice_config = data.get('voice_config', {})
if not human_id or not text:
await session.send_message({
"type": "error",
"data": {"message": "缺少human_id或text"}
})
return
if human_id not in self.digital_humans:
await session.send_message({
"type": "error",
"data": {"message": f"数字人不存在: {human_id}"}
})
return
human_info = self.digital_humans[human_id]
# 更新活动时间
human_info['last_activity'] = asyncio.get_event_loop().time()
# 转发说话指令到数字人会话
target_session_id = human_info['session_id']
await self.broadcast_to_session(target_session_id, 'digital_human_speak_command', {
'human_id': human_id,
'text': text,
'voice_config': voice_config,
'from_session': session.session_id
})
# 广播数字人说话事件
await self.broadcast_to_all('digital_human_speaking', {
'human_id': human_id,
'name': human_info.get('name', ''),
'text': text
}, metadata={'exclude_session': target_session_id})
logger.info(f'数字人说话指令: {human_id} -> "{text}"')
async def _handle_digital_human_emotion(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理数字人情感状态"""
session = self.manager.get_session(websocket)
if not session:
return
human_id = data.get('human_id')
emotion = data.get('emotion')
intensity = data.get('intensity', 1.0)
if not human_id or not emotion:
await session.send_message({
"type": "error",
"data": {"message": "缺少human_id或emotion"}
})
return
if human_id not in self.digital_humans:
await session.send_message({
"type": "error",
"data": {"message": f"数字人不存在: {human_id}"}
})
return
human_info = self.digital_humans[human_id]
# 更新情感状态
human_info['emotion'] = emotion
human_info['emotion_intensity'] = intensity
human_info['last_activity'] = asyncio.get_event_loop().time()
# 广播情感变化
await self.broadcast_to_all('digital_human_emotion_changed', {
'human_id': human_id,
'name': human_info.get('name', ''),
'emotion': emotion,
'intensity': intensity
})
logger.info(f'数字人情感更新: {human_id} -> {emotion} ({intensity})')
async def _handle_get_digital_humans(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理获取数字人列表请求"""
session = self.manager.get_session(websocket)
if not session:
return
# 返回所有在线数字人信息
humans_list = []
for human_id, human_info in self.digital_humans.items():
humans_list.append({
'human_id': human_id,
'name': human_info.get('name', ''),
'type': human_info.get('type', 'default'),
'status': human_info.get('status', 'unknown'),
'capabilities': human_info.get('capabilities', []),
'emotion': human_info.get('emotion', 'neutral'),
'emotion_intensity': human_info.get('emotion_intensity', 1.0)
})
await session.send_message({
"type": "digital_humans_list",
"data": {
"humans": humans_list,
"total": len(humans_list)
}
})
def get_digital_human_stats(self) -> Dict[str, Any]:
"""获取数字人统计信息"""
online_count = len([h for h in self.digital_humans.values() if h.get('status') == 'online'])
return {
"total_digital_humans": len(self.digital_humans),
"online_digital_humans": online_count,
"active_sessions": len(self.session_humans),
"human_types": list(set(h.get('type', 'default') for h in self.digital_humans.values()))
}
async def send_to_digital_human(self, human_id: str, message_type: str, content: Any):
"""向指定数字人发送消息"""
if human_id not in self.digital_humans:
logger.warning(f'数字人不存在: {human_id}')
return False
human_info = self.digital_humans[human_id]
target_session_id = human_info['session_id']
return await self.broadcast_to_session(target_session_id, message_type, content)
async def broadcast_to_digital_humans(self, message_type: str, content: Any,
human_filter: Optional[callable] = None):
"""向数字人广播消息"""
sent_count = 0
for human_id, human_info in self.digital_humans.items():
if human_filter and not human_filter(human_info):
continue
target_session_id = human_info['session_id']
success = await self.broadcast_to_session(target_session_id, message_type, content)
if success:
sent_count += 1
return sent_count
# 创建数字人服务实例
digital_human_service = DigitalHumanWebSocketService()
def get_digital_human_service() -> DigitalHumanWebSocketService:
"""获取数字人服务实例"""
return digital_human_service
\ No newline at end of file
... ...
# -*- coding: utf-8 -*-
"""
AIfeng/2025-07-15 14:41:21
统一WebSocket管理模块
提供统一的WebSocket连接管理、消息推送和事件处理功能
"""
import json
import time
import asyncio
import weakref
from typing import Dict, Any, Optional, Callable, List, Set
from threading import Lock
from aiohttp import web, WSMsgType
from logger import logger
class WebSocketSession:
"""WebSocket会话管理类"""
def __init__(self, session_id: str, websocket: web.WebSocketResponse):
self.session_id = session_id
self.websocket = websocket
self.created_at = time.time()
self.last_ping = time.time()
self.metadata = {}
def __eq__(self, other):
"""基于websocket对象判断会话是否相等"""
if not isinstance(other, WebSocketSession):
return False
return self.websocket is other.websocket
def __hash__(self):
"""基于websocket对象的id生成哈希值"""
return hash(id(self.websocket))
def is_alive(self) -> bool:
"""检查连接是否存活"""
return not self.websocket.closed
async def send_message(self, message: Dict[str, Any]) -> bool:
"""发送消息到WebSocket客户端"""
try:
if not self.is_alive():
return False
await self.websocket.send_str(json.dumps(message))
return True
except ConnectionResetError:
logger.warning(f'[Session:{self.session_id}] 客户端连接已重置')
return False
except ConnectionAbortedError:
logger.warning(f'[Session:{self.session_id}] 客户端连接已中止')
return False
except Exception as e:
logger.error(f'[Session:{self.session_id}] 发送消息失败: {e}')
return False
def update_ping(self):
"""更新心跳时间"""
self.last_ping = time.time()
def set_metadata(self, key: str, value: Any):
"""设置会话元数据"""
self.metadata[key] = value
def get_metadata(self, key: str, default=None):
"""获取会话元数据"""
return self.metadata.get(key, default)
async def close(self):
"""关闭WebSocket连接"""
try:
if not self.websocket.closed:
await self.websocket.close()
logger.info(f'[Session:{self.session_id}] WebSocket连接已关闭')
except ConnectionResetError:
logger.warning(f'[Session:{self.session_id}] 连接已被远程主机重置,无需关闭')
except ConnectionAbortedError:
logger.warning(f'[Session:{self.session_id}] 连接已被中止,无需关闭')
except Exception as e:
logger.error(f'[Session:{self.session_id}] 关闭WebSocket连接失败: {e}')
class UnifiedWebSocketManager:
"""统一WebSocket管理器"""
def __init__(self):
self._sessions: Dict[str, Set[WebSocketSession]] = {} # session_id -> WebSocketSession集合
self._websockets: Dict[web.WebSocketResponse, WebSocketSession] = {} # websocket -> session映射
self._message_handlers: Dict[str, Callable] = {} # 消息类型处理器
self._event_handlers: Dict[str, List[Callable]] = {} # 事件处理器
self._lock = Lock()
# 注册默认消息处理器
self.register_message_handler('ping', self._handle_ping)
self.register_message_handler('login', self._handle_login)
def register_message_handler(self, message_type: str, handler: Callable):
"""注册消息处理器"""
self._message_handlers[message_type] = handler
logger.info(f'注册消息处理器: {message_type}')
def register_event_handler(self, event_type: str, handler: Callable):
"""注册事件处理器"""
if event_type not in self._event_handlers:
self._event_handlers[event_type] = []
self._event_handlers[event_type].append(handler)
logger.info(f'注册事件处理器: {event_type}')
async def _emit_event(self, event_type: str, **kwargs):
"""触发事件"""
if event_type in self._event_handlers:
for handler in self._event_handlers[event_type]:
try:
if asyncio.iscoroutinefunction(handler):
await handler(**kwargs)
else:
handler(**kwargs)
except Exception as e:
logger.error(f'事件处理器执行失败 {event_type}: {e}')
def add_session(self, session_id: str, websocket: web.WebSocketResponse) -> WebSocketSession:
"""添加WebSocket会话"""
with self._lock:
# 检查是否已存在相同的websocket连接
if websocket in self._websockets:
existing_session = self._websockets[websocket]
logger.warning(f'[Session:{session_id}] WebSocket连接已存在 (WebSocket={id(websocket)}, 原Session={existing_session.session_id})')
return existing_session
session = WebSocketSession(session_id, websocket)
# 初始化会话集合
if session_id not in self._sessions:
self._sessions[session_id] = set()
# 检查Set添加前后的大小变化
before_count = len(self._sessions[session_id])
self._sessions[session_id].add(session)
after_count = len(self._sessions[session_id])
self._websockets[websocket] = session
logger.info(f'[Session:{session_id}] 添加WebSocket会话 (WebSocket={id(websocket)}), 连接数变化: {before_count} -> {after_count}')
# 如果Set大小没有变化,说明可能存在重复
if before_count == after_count:
logger.warning(f'[Session:{session_id}] 检测到可能的重复会话添加!Set大小未变化')
return session
def remove_session(self, websocket: web.WebSocketResponse):
"""移除WebSocket会话"""
with self._lock:
if websocket in self._websockets:
session = self._websockets[websocket]
session_id = session.session_id
# 从会话集合中移除
if session_id in self._sessions:
self._sessions[session_id].discard(session)
if not self._sessions[session_id]: # 如果集合为空,删除键
del self._sessions[session_id]
# 从websocket映射中移除
del self._websockets[websocket]
logger.info(f'[Session:{session_id}] 移除WebSocket会话')
return session
return None
def get_session(self, websocket: web.WebSocketResponse) -> Optional[WebSocketSession]:
"""获取WebSocket会话"""
return self._websockets.get(websocket)
def get_sessions_by_id(self, session_id: str) -> Set[WebSocketSession]:
"""根据会话ID获取所有WebSocket会话"""
with self._lock:
# 尝试使用原始session_id查找
sessions = self._sessions.get(session_id, set())
if sessions:
return sessions.copy()
# 如果是字符串类型但存储的是整数类型,尝试转换
if isinstance(session_id, str) and session_id.isdigit():
int_session_id = int(session_id)
sessions = self._sessions.get(int_session_id, set())
if sessions:
return sessions.copy()
# 如果是整数类型但存储的是字符串类型,尝试转换
elif isinstance(session_id, int):
str_session_id = str(session_id)
sessions = self._sessions.get(str_session_id, set())
if sessions:
return sessions.copy()
return set()
def _update_session_id(self, websocket: web.WebSocketResponse, old_session_id: str, new_session_id: str):
"""更新WebSocket会话的session_id"""
with self._lock:
if websocket in self._websockets:
session = self._websockets[websocket]
# 从旧的session_id集合中移除
if old_session_id in self._sessions:
self._sessions[old_session_id].discard(session)
if not self._sessions[old_session_id]: # 如果集合为空,删除键
del self._sessions[old_session_id]
# 更新session的session_id
session.session_id = new_session_id
# 添加到新的session_id集合
if new_session_id not in self._sessions:
self._sessions[new_session_id] = set()
self._sessions[new_session_id].add(session)
logger.info(f'[Session] 更新会话ID: {old_session_id} -> {new_session_id}')
return True
return False
async def broadcast_raw_message_to_session(self, session_id: str, message: Dict,source: str = "原数据") -> int:
"""直接广播原始消息到指定会话的所有WebSocket连接"""
# 确保session_id为字符串类型,保持一致性
# 确保session_id为字符串类型,保持一致性
if isinstance(session_id, int):
session_id = str(session_id)
elif not isinstance(session_id, str):
session_id = str(session_id)
sessions = self.get_sessions_by_id(session_id)
if not sessions:
logger.warning(f'[Session:{session_id}] 没有找到WebSocket连接')
return 0
# 详细调试日志:显示会话详情
logger.info(f'[Session:{session_id}] 开始广播消息,找到 {len(sessions)} 个连接')
for i, session in enumerate(sessions):
logger.info(f'[Session:{session_id}] 连接{i+1}: WebSocket={id(session.websocket)}, 创建时间={session.created_at}, 存活状态={session.is_alive()}')
success_count = 0
failed_sessions = []
for i, session in enumerate(sessions):
logger.info(f'[Session:{session_id}] 正在向连接{i+1}发送消息 (WebSocket={id(session.websocket)})')
if await session.send_message(message):
success_count += 1
logger.info(f'[Session:{session_id}] 连接{i+1}发送成功')
else:
failed_sessions.append(session)
logger.warning(f'[Session:{session_id}] 连接{i+1}发送失败')
# 清理失败的连接
for session in failed_sessions:
self.remove_session(session.websocket)
logger.info(f'[Session:{session_id}] 广播原始消息完成: 成功{success_count}/总计{len(sessions)}, 失败{len(failed_sessions)}')
return success_count
async def broadcast_to_session(self, session_id: str, message_type: str, content: Any,
source: str = "系统", metadata: Dict = None) -> int:
"""向指定会话的所有WebSocket连接广播消息"""
# 确保session_id为字符串类型,保持一致性
if isinstance(session_id, int):
session_id = str(session_id)
elif not isinstance(session_id, str):
session_id = str(session_id)
sessions = self.get_sessions_by_id(session_id)
if not sessions:
logger.warning(f'[Session:{session_id}] 没有找到WebSocket连接')
return 0
message = {
"type": message_type,
"session_id": session_id,
"content": content,
"source": source,
"timestamp": time.time(),
**(metadata or {})
}
success_count = 0
failed_sessions = []
for session in sessions:
if await session.send_message(message):
success_count += 1
else:
failed_sessions.append(session)
# 清理失败的连接
for session in failed_sessions:
self.remove_session(session.websocket)
logger.info(f'[Session:{session_id}] 广播消息成功: {success_count}/{len(sessions)}')
return success_count
async def broadcast_to_all(self, message_type: str, content: Any,
source: str = "系统", metadata: Dict = None) -> int:
"""向所有WebSocket连接广播消息"""
total_sent = 0
with self._lock:
session_ids = list(self._sessions.keys())
for session_id in session_ids:
sent = await self.broadcast_to_session(session_id, message_type, content, source, metadata)
total_sent += sent
logger.info(f'全局广播消息完成,总发送数: {total_sent}')
return total_sent
def get_session_count(self) -> int:
"""获取会话总数"""
with self._lock:
return len(self._sessions)
def get_connection_count(self) -> int:
"""获取连接总数"""
with self._lock:
return len(self._websockets)
def get_session_stats(self) -> Dict[str, Any]:
"""获取会话统计信息"""
with self._lock:
stats = {
"total_sessions": len(self._sessions),
"total_connections": len(self._websockets),
"session_details": {}
}
for session_id, sessions in self._sessions.items():
stats["session_details"][session_id] = {
"connection_count": len(sessions),
"connections": [
{
"created_at": session.created_at,
"last_ping": session.last_ping,
"is_alive": session.is_alive(),
"metadata": session.metadata
} for session in sessions
]
}
return stats
async def _handle_ping(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理心跳消息"""
session = self.get_session(websocket)
if session:
session.update_ping()
await session.send_message({"type": "pong", "timestamp": time.time()})
async def _handle_login(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理登录消息"""
session_id = data.get('session_id', data.get('sessionid', str(int(time.time()))))
# 确保session_id为字符串类型,避免类型不一致问题
if isinstance(session_id, int):
session_id = str(session_id)
elif not isinstance(session_id, str):
session_id = str(session_id)
# 添加会话
session = self.add_session(session_id, websocket)
# 触发连接事件
await self._emit_event('session_connected', session=session, data=data)
# 发送登录确认
await session.send_message({
"type": "login_success",
"data": {
"session_id": session_id,
"message": "WebSocket连接成功",
"timestamp": time.time()
}
})
async def handle_websocket_message(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理WebSocket消息"""
message_type = data.get('type')
if message_type in self._message_handlers:
try:
await self._message_handlers[message_type](websocket, data)
except Exception as e:
logger.error(f'消息处理器执行失败 {message_type}: {e}')
session = self.get_session(websocket)
if session:
await session.send_message({
"type": "error",
"data": {
"message": f"消息处理失败: {str(e)}",
"original_type": message_type
}
})
else:
logger.warning(f'未知消息类型: {message_type}')
async def websocket_handler(self, request) -> web.WebSocketResponse:
"""WebSocket连接处理器"""
ws = web.WebSocketResponse()
await ws.prepare(request)
logger.info('新的WebSocket连接建立')
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
await self.handle_websocket_message(ws, data)
except json.JSONDecodeError:
logger.error('收到无效的JSON数据')
except Exception as e:
logger.error(f'处理WebSocket消息时出错: {e}')
elif msg.type == WSMsgType.ERROR:
logger.error(f'WebSocket错误: {ws.exception()}')
break
elif msg.type == WSMsgType.CLOSE:
logger.info('WebSocket连接正常关闭')
break
except ConnectionResetError:
logger.warning('WebSocket连接被远程主机重置')
except ConnectionAbortedError:
logger.warning('WebSocket连接被中止')
except Exception as e:
logger.error(f'WebSocket连接错误: {e}')
finally:
# 清理会话
session = self.remove_session(ws)
if session:
await self._emit_event('session_disconnected', session=session)
logger.info('WebSocket连接已关闭')
return ws
def get_expired_sessions(self, timeout: int = 60) -> List[WebSocketSession]:
"""获取过期的会话列表"""
current_time = time.time()
expired_sessions = []
with self._lock:
for session in self._websockets.values():
if current_time - session.last_ping > timeout:
expired_sessions.append(session)
return expired_sessions
async def cleanup_dead_connections(self):
"""清理死连接"""
dead_websockets = []
with self._lock:
for websocket, session in self._websockets.items():
if not session.is_alive():
dead_websockets.append(websocket)
for websocket in dead_websockets:
self.remove_session(websocket)
if dead_websockets:
logger.info(f'清理了 {len(dead_websockets)} 个死连接')
return len(dead_websockets)
# 全局统一WebSocket管理器实例
_unified_manager = UnifiedWebSocketManager()
def get_unified_manager() -> UnifiedWebSocketManager:
"""获取统一WebSocket管理器实例"""
return _unified_manager
# 兼容性接口,保持与原有代码的兼容
async def broadcast_message_to_session(session_id: str, message_type: str, content: str,
source: str = "数字人回复", model_info: str = None,
request_source: str = "页面"):
"""兼容性接口:向指定会话广播消息"""
metadata = {}
if model_info:
metadata['model_info'] = model_info
if request_source:
metadata['request_source'] = request_source
return await _unified_manager.broadcast_to_session(
session_id, message_type, content, source, metadata
)
\ No newline at end of file
... ...
# -*- coding: utf-8 -*-
"""
AIfeng/2025-07-15 14:41:21
WebSocket路由管理器
统一管理所有WebSocket服务的路由和初始化
"""
import asyncio
import time
from typing import Dict, Any, Optional
from aiohttp import web, WSMsgType
import json
from logger import logger
from .unified_websocket_manager import get_unified_manager
from .websocket_service_base import get_service_registry
from .asr_websocket_service import get_asr_service
from .digital_human_websocket_service import get_digital_human_service
class WebSocketRouter:
"""WebSocket路由管理器"""
def __init__(self):
self.manager = get_unified_manager()
self.service_registry = get_service_registry()
self.is_initialized = False
async def initialize(self):
"""初始化路由器和所有服务"""
if self.is_initialized:
return
logger.info('初始化WebSocket路由器...')
# 注册所有服务
await self._register_services()
# 初始化所有服务
await self.service_registry.initialize_all()
self.is_initialized = True
logger.info('WebSocket路由器初始化完成')
async def shutdown(self):
"""关闭路由器和所有服务"""
if not self.is_initialized:
return
logger.info('关闭WebSocket路由器...')
# 关闭所有服务
await self.service_registry.shutdown_all()
# 关闭管理器
await self.manager.shutdown()
self.is_initialized = False
logger.info('WebSocket路由器已关闭')
async def _register_services(self):
"""注册所有WebSocket服务"""
logger.info('注册WebSocket服务...')
# 注册ASR服务
asr_service = get_asr_service()
self.service_registry.register_service(asr_service)
# 注册数字人服务
digital_human_service = get_digital_human_service()
self.service_registry.register_service(digital_human_service)
# 注册WSA服务
from .wsa_websocket_service import WSAWebSocketService, initialize_wsa_service
wsa_service = WSAWebSocketService(self.manager)
self.service_registry.register_service(wsa_service)
# 初始化WSA兼容性接口
initialize_wsa_service(wsa_service)
logger.info(f'已注册 {len(self.service_registry.list_services())} 个WebSocket服务')
async def websocket_handler(self, request: web.Request) -> web.WebSocketResponse:
"""统一的WebSocket处理器"""
ws = web.WebSocketResponse()
await ws.prepare(request)
# 创建会话ID
session_id = request.headers.get('X-Session-ID', str(int(time.time())))
session = self.manager.add_session(session_id, ws)
logger.info(f'WebSocket连接建立: {session.session_id}')
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
await self._handle_message(ws, data)
except json.JSONDecodeError as e:
logger.error(f'JSON解析失败: {e}')
await session.send_message({
"type": "error",
"data": {"message": "消息格式错误"}
})
except Exception as e:
logger.error(f'消息处理失败: {e}')
await session.send_message({
"type": "error",
"data": {"message": f"处理失败: {str(e)}"}
})
elif msg.type == WSMsgType.ERROR:
logger.error(f'WebSocket错误: {ws.exception()}')
break
elif msg.type == WSMsgType.CLOSE:
logger.info(f'WebSocket连接关闭: {session.session_id}')
break
except ConnectionResetError:
logger.warning(f'WebSocket连接被远程主机重置: {session.session_id}')
except ConnectionAbortedError:
logger.warning(f'WebSocket连接被中止: {session.session_id}')
except Exception as e:
logger.error(f'WebSocket处理异常: {e}')
finally:
# 清理会话
self.manager.remove_session(ws)
return ws
async def _handle_message(self, ws: web.WebSocketResponse, data: Dict[str, Any]):
"""处理WebSocket消息"""
message_type = data.get('type')
if not message_type:
session = self.manager.get_session(ws)
if session:
await session.send_message({
"type": "error",
"data": {"message": "缺少消息类型"}
})
return
# 通过管理器处理消息
await self.manager.handle_websocket_message(ws, data)
def get_router_stats(self) -> Dict[str, Any]:
"""获取路由器统计信息"""
stats = {
"initialized": self.is_initialized,
"manager_stats": self.manager.get_session_stats(),
"service_stats": self.service_registry.get_all_stats()
}
# 添加各服务的详细统计
asr_service = self.service_registry.get_service("asr_service")
if asr_service:
stats["asr_stats"] = asr_service.get_asr_stats()
digital_human_service = self.service_registry.get_service("digital_human_service")
if digital_human_service:
stats["digital_human_stats"] = digital_human_service.get_digital_human_stats()
return stats
def setup_routes(self, app: web.Application, path: str = '/ws'):
"""设置WebSocket路由"""
app.router.add_get(path, self.websocket_handler)
logger.info(f'WebSocket路由已设置: {path}')
async def broadcast_system_message(self, message: str, level: str = 'info'):
"""广播系统消息"""
await self.manager.broadcast_to_all('system_message', {
'message': message,
'level': level,
'timestamp': asyncio.get_event_loop().time()
}, source='system')
async def send_to_session(self, session_id: str, message_type: str, content: Any):
"""向指定会话发送消息"""
return await self.manager.broadcast_to_session(session_id, message_type, content, source='router')
async def send_raw_to_session(self, session_id: str, message: Dict):
"""向指定会话发送消息"""
return await self.manager.broadcast_raw_message_to_session(str(session_id), message)
async def send_to_digital_human(self, human_id: str, message_type: str, content: Any):
"""向指定数字人发送消息"""
digital_human_service = self.service_registry.get_service("digital_human_service")
if digital_human_service:
return await digital_human_service.send_to_digital_human(human_id, message_type, content)
return False
async def get_asr_stats(self) -> Optional[Dict[str, Any]]:
"""获取ASR统计信息"""
asr_service = self.service_registry.get_service("asr_service")
if asr_service:
return asr_service.get_asr_stats()
return None
async def get_digital_human_stats(self) -> Optional[Dict[str, Any]]:
"""获取数字人统计信息"""
digital_human_service = self.service_registry.get_service("digital_human_service")
if digital_human_service:
return digital_human_service.get_digital_human_stats()
return None
# 全局路由器实例
_websocket_router = None
def get_websocket_router() -> WebSocketRouter:
"""获取WebSocket路由器实例"""
global _websocket_router
if _websocket_router is None:
_websocket_router = WebSocketRouter()
return _websocket_router
async def initialize_websocket_router():
"""初始化WebSocket路由器"""
router = get_websocket_router()
await router.initialize()
return router
async def shutdown_websocket_router():
"""关闭WebSocket路由器"""
global _websocket_router
if _websocket_router:
await _websocket_router.shutdown()
_websocket_router = None
def setup_websocket_routes(app: web.Application, path: str = '/ws'):
"""设置WebSocket路由(便捷函数)"""
router = get_websocket_router()
router.setup_routes(app, path)
return router
# 兼容性接口
class WebSocketCompatibilityAPI:
"""WebSocket兼容性API
为了保持与现有代码的兼容性,提供简化的接口
"""
def __init__(self):
self.router = get_websocket_router()
async def broadcast_message_to_session(self, session_id: str, message: Dict[str, Any]):
"""向指定会话广播消息(兼容app.py接口)"""
message_type = message.get('type', 'message')
content = message.get('data', message)
return await self.router.send_to_session(session_id, message_type, content)
async def broadcast_to_all_sessions(self, message: Dict[str, Any]):
"""向所有会话广播消息"""
message_type = message.get('type', 'message')
content = message.get('data', message)
return await self.router.manager.broadcast_to_all(message_type, content, source='compatibility')
def get_active_sessions(self):
"""获取活跃会话列表"""
return list(self.router.manager._sessions.keys())
def get_session_count(self):
"""获取会话数量"""
return len(self.router.manager._sessions)
async def send_asr_result(self, session_id: str, result: Dict[str, Any]):
"""发送ASR结果(兼容app.py接口)"""
return await self.router.send_to_session(session_id, 'asr_result', result)
# 全局兼容性API实例
_compatibility_api = None
def get_websocket_compatibility_api() -> WebSocketCompatibilityAPI:
"""获取WebSocket兼容性API实例"""
global _compatibility_api
if _compatibility_api is None:
_compatibility_api = WebSocketCompatibilityAPI()
return _compatibility_api
\ No newline at end of file
... ...
# -*- coding: utf-8 -*-
"""
AIfeng/2025-07-15 14:41:21
WebSocket服务抽象基类
为不同类型的WebSocket服务提供统一接口和生命周期管理
"""
import asyncio
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from aiohttp import web
from logger import logger
from .unified_websocket_manager import get_unified_manager, WebSocketSession
class WebSocketServiceBase(ABC):
"""WebSocket服务抽象基类"""
def __init__(self, service_name: str):
self.service_name = service_name
self.manager = get_unified_manager()
self.is_initialized = False
self._background_tasks: List[asyncio.Task] = []
async def initialize(self):
"""初始化服务"""
if self.is_initialized:
return
logger.info(f'初始化WebSocket服务: {self.service_name}')
# 注册消息处理器
await self._register_message_handlers()
# 注册事件处理器
await self._register_event_handlers()
# 启动后台任务
await self._start_background_tasks()
self.is_initialized = True
logger.info(f'WebSocket服务初始化完成: {self.service_name}')
async def shutdown(self):
"""关闭服务"""
if not self.is_initialized:
return
logger.info(f'关闭WebSocket服务: {self.service_name}')
# 停止后台任务
for task in self._background_tasks:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self._background_tasks.clear()
# 执行自定义清理
await self._cleanup()
self.is_initialized = False
logger.info(f'WebSocket服务已关闭: {self.service_name}')
@abstractmethod
async def _register_message_handlers(self):
"""注册消息处理器(子类实现)"""
pass
async def _register_event_handlers(self):
"""注册事件处理器(可选重写)"""
# 注册通用事件处理器
self.manager.register_event_handler('session_connected', self._on_session_connected)
self.manager.register_event_handler('session_disconnected', self._on_session_disconnected)
async def _start_background_tasks(self):
"""启动后台任务(可选重写)"""
pass
async def _cleanup(self):
"""清理资源(可选重写)"""
pass
async def _on_session_connected(self, session: WebSocketSession, data: Dict[str, Any]):
"""会话连接事件处理(可选重写)"""
logger.info(f'[{self.service_name}] 会话连接: {session.session_id}')
async def _on_session_disconnected(self, session: WebSocketSession):
"""会话断开事件处理(可选重写)"""
logger.info(f'[{self.service_name}] 会话断开: {session.session_id}')
def add_background_task(self, coro):
"""添加后台任务"""
task = asyncio.create_task(coro)
self._background_tasks.append(task)
return task
async def broadcast_to_session(self, session_id: str, message_type: str, content: Any,
source: str = None, metadata: Dict = None):
"""向指定会话广播消息"""
if source is None:
source = self.service_name
return await self.manager.broadcast_to_session(session_id, message_type, content, source, metadata)
async def broadcast_to_all(self, message_type: str, content: Any,
source: str = None, metadata: Dict = None):
"""向所有会话广播消息"""
if source is None:
source = self.service_name
return await self.manager.broadcast_to_all(message_type, content, source, metadata)
def get_session_stats(self) -> Dict[str, Any]:
"""获取会话统计信息"""
return self.manager.get_session_stats()
def create_message_handler(self, message_type: str):
"""装饰器:创建消息处理器"""
def decorator(func):
async def wrapper(websocket: web.WebSocketResponse, data: Dict[str, Any]):
session = self.manager.get_session(websocket)
if session:
try:
await func(session, data)
except Exception as e:
logger.error(f'[{self.service_name}] 消息处理器 {message_type} 执行失败: {e}')
await session.send_message({
"type": "error",
"data": {
"message": f"处理 {message_type} 消息失败: {str(e)}",
"service": self.service_name
}
})
else:
logger.warning(f'[{self.service_name}] 未找到会话,无法处理消息: {message_type}')
self.manager.register_message_handler(message_type, wrapper)
return func
return decorator
def create_event_handler(self, event_type: str):
"""装饰器:创建事件处理器"""
def decorator(func):
self.manager.register_event_handler(event_type, func)
return func
return decorator
class WebSocketServiceRegistry:
"""WebSocket服务注册表"""
def __init__(self):
self._services: Dict[str, WebSocketServiceBase] = {}
def register_service(self, service: WebSocketServiceBase):
"""注册服务"""
if service.service_name in self._services:
logger.warning(f'服务已存在,将被覆盖: {service.service_name}')
self._services[service.service_name] = service
logger.info(f'注册WebSocket服务: {service.service_name}')
def get_service(self, service_name: str) -> Optional[WebSocketServiceBase]:
"""获取服务"""
return self._services.get(service_name)
def list_services(self) -> List[str]:
"""列出所有服务名称"""
return list(self._services.keys())
async def initialize_all(self):
"""初始化所有服务"""
logger.info('初始化所有WebSocket服务...')
for service in self._services.values():
await service.initialize()
logger.info('所有WebSocket服务初始化完成')
async def shutdown_all(self):
"""关闭所有服务"""
logger.info('关闭所有WebSocket服务...')
for service in self._services.values():
await service.shutdown()
logger.info('所有WebSocket服务已关闭')
def get_all_stats(self) -> Dict[str, Any]:
"""获取所有服务的统计信息"""
stats = {
"services": {},
"total_services": len(self._services)
}
for name, service in self._services.items():
stats["services"][name] = {
"initialized": service.is_initialized,
"background_tasks": len(service._background_tasks)
}
return stats
# 全局服务注册表
_service_registry = WebSocketServiceRegistry()
def get_service_registry() -> WebSocketServiceRegistry:
"""获取服务注册表"""
return _service_registry
def register_websocket_service(service: WebSocketServiceBase):
"""注册WebSocket服务"""
_service_registry.register_service(service)
def get_websocket_service(service_name: str) -> Optional[WebSocketServiceBase]:
"""获取WebSocket服务"""
return _service_registry.get_service(service_name)
\ No newline at end of file
... ...
# -*- coding: utf-8 -*-
"""
AIfeng/2025-01-27
WebSocket服务器管理模块
提供Web和Human连接的管理功能
"""
import queue
from typing import Dict, Any, Optional
from threading import Lock
class WebSocketManager:
"""WebSocket连接管理器"""
def __init__(self):
self._connections = {}
self._command_queue = queue.Queue()
self._lock = Lock()
def is_connected(self, username: str) -> bool:
"""检查用户是否已连接
Args:
username: 用户名
Returns:
是否已连接
"""
with self._lock:
return username in self._connections
def is_connected_human(self, username: str) -> bool:
"""检查人类用户是否已连接
Args:
username: 用户名
Returns:
是否已连接
"""
# 简化实现,与is_connected相同
return self.is_connected(username)
def add_connection(self, username: str, connection: Any):
"""添加连接
Args:
username: 用户名
connection: 连接对象
"""
with self._lock:
self._connections[username] = connection
def remove_connection(self, username: str):
"""移除连接
Args:
username: 用户名
"""
with self._lock:
self._connections.pop(username, None)
def add_cmd(self, command: Dict[str, Any]):
"""添加命令到队列
Args:
command: 命令字典
"""
try:
self._command_queue.put(command, timeout=1.0)
except queue.Full:
print(f"警告: 命令队列已满,丢弃命令: {command}")
def get_cmd(self, timeout: float = 1.0) -> Optional[Dict[str, Any]]:
"""从队列获取命令
Args:
timeout: 超时时间
Returns:
命令字典或None
"""
try:
return self._command_queue.get(timeout=timeout)
except queue.Empty:
return None
def get_connection_count(self) -> int:
"""获取连接数量"""
with self._lock:
return len(self._connections)
def get_usernames(self) -> list:
"""获取所有用户名列表"""
with self._lock:
return list(self._connections.keys())
# 全局实例
_web_instance = WebSocketManager()
_human_instance = WebSocketManager()
def get_web_instance() -> WebSocketManager:
"""获取Web WebSocket管理器实例"""
return _web_instance
def get_instance() -> WebSocketManager:
"""获取Human WebSocket管理器实例"""
return _human_instance
\ No newline at end of file
# -*- coding: utf-8 -*-
"""
AIfeng/2025-01-27 16:30:00
WSA WebSocket服务
将原有wsa_server功能集成到统一WebSocket架构中
"""
import asyncio
import json
import queue
from typing import Dict, Any, Optional, Set
from threading import Lock
from aiohttp import web
from .websocket_service_base import WebSocketServiceBase
from .unified_websocket_manager import WebSocketSession
class WSAWebSocketService(WebSocketServiceBase):
"""WSA WebSocket服务
提供与原wsa_server兼容的功能:
- Web连接管理
- Human连接管理
- 命令队列处理
- 消息转发
"""
def __init__(self, manager):
super().__init__("wsa")
# 连接管理
self._web_connections: Dict[str, Set[WebSocketSession]] = {}
self._human_connections: Dict[str, Set[WebSocketSession]] = {}
self._connection_lock = Lock()
# 命令队列
self._web_command_queue = queue.Queue()
self._human_command_queue = queue.Queue()
# 后台任务
self._queue_processor_task: Optional[asyncio.Task] = None
async def _register_message_handlers(self):
"""注册消息处理器"""
self.manager.register_message_handler("wsa_register_web", self._handle_register_web)
self.manager.register_message_handler("wsa_register_human", self._handle_register_human)
self.manager.register_message_handler("wsa_unregister", self._handle_unregister)
self.manager.register_message_handler("wsa_get_status", self._handle_get_status)
async def _start_background_tasks(self):
"""启动后台任务"""
self._queue_processor_task = asyncio.create_task(self._process_command_queues())
async def _cleanup(self):
"""清理资源"""
if self._queue_processor_task:
self._queue_processor_task.cancel()
try:
await self._queue_processor_task
except asyncio.CancelledError:
pass
async def _on_session_disconnected(self, session: WebSocketSession):
"""会话断开处理"""
with self._connection_lock:
# 从web连接中移除
for username, sessions in list(self._web_connections.items()):
if session in sessions:
sessions.discard(session)
if not sessions:
del self._web_connections[username]
# 从human连接中移除
for username, sessions in list(self._human_connections.items()):
if session in sessions:
sessions.discard(session)
if not sessions:
del self._human_connections[username]
async def _handle_register_web(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""注册Web连接"""
username = data.get('username')
if not username:
await websocket.send_str(json.dumps({
"type": "wsa_error",
"message": "用户名不能为空"
}))
return
session = self.manager.get_session(websocket)
if not session:
await websocket.send_str(json.dumps({
"type": "wsa_error",
"message": "会话未找到"
}))
return
with self._connection_lock:
if username not in self._web_connections:
self._web_connections[username] = set()
self._web_connections[username].add(session)
await websocket.send_str(json.dumps({
"type": "wsa_registered",
"connection_type": "web",
"username": username
}))
async def _handle_register_human(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""注册Human连接"""
username = data.get('username')
if not username:
await websocket.send_str(json.dumps({
"type": "wsa_error",
"message": "用户名不能为空"
}))
return
session = self.manager.get_session(websocket)
if not session:
await websocket.send_str(json.dumps({
"type": "wsa_error",
"message": "会话未找到"
}))
return
with self._connection_lock:
if username not in self._human_connections:
self._human_connections[username] = set()
self._human_connections[username].add(session)
await websocket.send_str(json.dumps({
"type": "wsa_registered",
"connection_type": "human",
"username": username
}))
async def _handle_unregister(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""注销连接"""
username = data.get('username')
connection_type = data.get('connection_type', 'both')
session = self.manager.get_session(websocket)
if not session:
return
with self._connection_lock:
if connection_type in ['web', 'both'] and username in self._web_connections:
self._web_connections[username].discard(session)
if not self._web_connections[username]:
del self._web_connections[username]
if connection_type in ['human', 'both'] and username in self._human_connections:
self._human_connections[username].discard(session)
if not self._human_connections[username]:
del self._human_connections[username]
await websocket.send_str(json.dumps({
"type": "wsa_unregistered",
"username": username,
"connection_type": connection_type
}))
async def _handle_get_status(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""获取连接状态"""
with self._connection_lock:
web_users = list(self._web_connections.keys())
human_users = list(self._human_connections.keys())
await websocket.send_str(json.dumps({
"type": "wsa_status",
"data": {
"web_connections": len(self._web_connections),
"human_connections": len(self._human_connections),
"web_users": web_users,
"human_users": human_users,
"web_queue_size": self._web_command_queue.qsize(),
"human_queue_size": self._human_command_queue.qsize()
}
}))
async def _process_command_queues(self):
"""处理命令队列"""
while True:
try:
# 处理Web命令队列
await self._process_web_commands()
# 处理Human命令队列
await self._process_human_commands()
# 短暂休眠避免CPU占用过高
await asyncio.sleep(0.01)
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"命令队列处理错误: {e}")
await asyncio.sleep(0.1)
async def _process_web_commands(self):
"""处理Web命令队列"""
try:
while True:
try:
command = self._web_command_queue.get_nowait()
await self._forward_web_command(command)
except queue.Empty:
break
except Exception as e:
self.logger.error(f"Web命令处理错误: {e}")
async def _process_human_commands(self):
"""处理Human命令队列"""
try:
while True:
try:
command = self._human_command_queue.get_nowait()
await self._forward_human_command(command)
except queue.Empty:
break
except Exception as e:
self.logger.error(f"Human命令处理错误: {e}")
async def _forward_web_command(self, command: Dict[str, Any]):
"""转发Web命令"""
username = command.get('Username')
if not username:
return
with self._connection_lock:
sessions = self._web_connections.get(username, set())
if sessions:
message = {
"type": "wsa_command",
"source": "web",
"data": command
}
for session in list(sessions):
try:
await session.send_message(message)
except Exception as e:
self.logger.error(f"发送Web命令失败 [{username}]: {e}")
async def _forward_human_command(self, command: Dict[str, Any]):
"""转发Human命令"""
username = command.get('Username')
if not username:
return
with self._connection_lock:
sessions = self._human_connections.get(username, set())
if sessions:
message = {
"type": "wsa_command",
"source": "human",
"data": command
}
for session in list(sessions):
try:
await session.send_message(message)
except Exception as e:
self.logger.error(f"发送Human命令失败 [{username}]: {e}")
# 兼容性接口
def is_connected(self, username: str) -> bool:
"""检查Web用户是否已连接"""
with self._connection_lock:
return username in self._web_connections and bool(self._web_connections[username])
def is_connected_human(self, username: str) -> bool:
"""检查Human用户是否已连接"""
with self._connection_lock:
return username in self._human_connections and bool(self._human_connections[username])
def add_connection(self, username: str, connection: Any):
"""添加连接(兼容性接口,已废弃)"""
self.logger.warning("add_connection方法已废弃,请使用消息注册机制")
def remove_connection(self, username: str):
"""移除连接(兼容性接口,已废弃)"""
self.logger.warning("remove_connection方法已废弃,连接会自动清理")
def add_cmd(self, command: Dict[str, Any], target: str = "web"):
"""添加命令到队列"""
try:
if target == "web":
self._web_command_queue.put(command, timeout=1.0)
elif target == "human":
self._human_command_queue.put(command, timeout=1.0)
else:
self.logger.warning(f"未知的目标类型: {target}")
except queue.Full:
self.logger.warning(f"命令队列已满,丢弃命令: {command}")
async def send_direct_message(self, message: Dict[str, Any], target: str = "web"):
"""直接发送消息(不封装为wsa_command)"""
username = message.get('Username')
if not username:
self.logger.warning("消息缺少Username字段")
return
with self._connection_lock:
if target == "web":
sessions = self._web_connections.get(username, set())
elif target == "human":
sessions = self._human_connections.get(username, set())
else:
self.logger.warning(f"未知的目标类型: {target}")
return
if sessions:
for session in list(sessions):
try:
await session.send_message(message)
except Exception as e:
self.logger.error(f"直接发送消息失败 [{username}]: {e}")
else:
self.logger.debug(f"用户 {username} 未连接,无法发送直接消息")
def get_cmd(self, timeout: float = 1.0, target: str = "web") -> Optional[Dict[str, Any]]:
"""从队列获取命令"""
try:
if target == "web":
return self._web_command_queue.get(timeout=timeout)
elif target == "human":
return self._human_command_queue.get(timeout=timeout)
else:
self.logger.warning(f"未知的目标类型: {target}")
return None
except queue.Empty:
return None
def get_connection_count(self, target: str = "web") -> int:
"""获取连接数量"""
with self._connection_lock:
if target == "web":
return len(self._web_connections)
elif target == "human":
return len(self._human_connections)
else:
return len(self._web_connections) + len(self._human_connections)
def get_usernames(self, target: str = "web") -> list:
"""获取用户名列表"""
with self._connection_lock:
if target == "web":
return list(self._web_connections.keys())
elif target == "human":
return list(self._human_connections.keys())
else:
return list(set(list(self._web_connections.keys()) + list(self._human_connections.keys())))
# 兼容性包装器
class WSAWebSocketManager:
"""WSA WebSocket管理器兼容性包装器"""
def __init__(self, service: WSAWebSocketService):
self.service = service
def is_connected(self, username: str) -> bool:
return self.service.is_connected(username)
def is_connected_human(self, username: str) -> bool:
return self.service.is_connected_human(username)
def add_connection(self, username: str, connection: Any):
self.service.add_connection(username, connection)
def remove_connection(self, username: str):
self.service.remove_connection(username)
def add_cmd(self, command: Dict[str, Any]):
self.service.add_cmd(command, "web")
async def send_direct_message(self, message: Dict[str, Any]):
"""直接发送消息(不封装为wsa_command)"""
await self.service.send_direct_message(message, "web")
def get_cmd(self, timeout: float = 1.0) -> Optional[Dict[str, Any]]:
return self.service.get_cmd(timeout, "web")
def get_connection_count(self) -> int:
return self.service.get_connection_count("web")
def get_usernames(self) -> list:
return self.service.get_usernames("web")
# 全局实例(兼容性)
_wsa_service: Optional[WSAWebSocketService] = None
_web_instance: Optional[WSAWebSocketManager] = None
_human_instance: Optional[WSAWebSocketManager] = None
def initialize_wsa_service(service: WSAWebSocketService):
"""初始化WSA服务"""
global _wsa_service, _web_instance, _human_instance
_wsa_service = service
_web_instance = WSAWebSocketManager(service)
_human_instance = WSAWebSocketManager(service)
def get_web_instance() -> WSAWebSocketManager:
"""获取Web WebSocket管理器实例"""
if _web_instance is None:
raise RuntimeError("WSA服务未初始化")
return _web_instance
def get_instance() -> WSAWebSocketManager:
"""获取Human WebSocket管理器实例"""
if _human_instance is None:
raise RuntimeError("WSA服务未初始化")
return _human_instance
\ No newline at end of file
... ...
... ... @@ -29,7 +29,7 @@ def _play_frame(stream, exit_event, queue, chunk):
print(f'[INFO] play frame thread ends')
break
frame = queue.get()
frame = (frame * 32767).astype(np.int16).tobytes()
frame = bytes((frame * 32767).astype(np.int16).tobytes()) # Fix BufferError: memoryview has 1 exported buffer
stream.write(frame, chunk)
class ASR:
... ...
... ... @@ -71,18 +71,45 @@ class FunASRClient(BaseASR):
async def _connect_websocket(self):
"""连接WebSocket服务器"""
try:
self.websocket = await websockets.connect(
self.server_url,
timeout=getattr(cfg, 'asr_timeout', 30)
# 修复: websockets新版本不支持timeout参数,使用asyncio.wait_for包装
timeout_seconds = getattr(cfg, 'asr_timeout', 30)
self.websocket = await asyncio.wait_for(
websockets.connect(self.server_url),
timeout=timeout_seconds
)
self.connected = True
util.log(1, f"FunASR WebSocket连接成功: {self.server_url}")
# 发送初始化配置消息(参考funasr_client_api.py)
await self._send_init_message()
return True
except Exception as e:
util.log(3, f"FunASR WebSocket连接失败: {e}")
self.connected = False
return False
async def _send_init_message(self):
"""发送FunASR初始化配置消息"""
try:
# 根据参考项目funasr_client_api.py的格式
init_message = {
"mode": "2pass",
"chunk_size": [0, 10, 5], # [vad_need, chunk_size, chunk_interval]
"encoder_chunk_look_back": 4,
"decoder_chunk_look_back": 1,
"chunk_interval": 10,
"wav_name": self.username,
"is_speaking": True
}
await self.websocket.send(json.dumps(init_message))
util.log(1, f"发送FunASR初始化消息: {init_message}")
except Exception as e:
util.log(3, f"发送初始化消息失败: {e}")
raise e
async def _disconnect_websocket(self):
"""断开WebSocket连接"""
if self.websocket:
... ... @@ -141,11 +168,12 @@ class FunASRClient(BaseASR):
message = self.message_queue.get_nowait()
if isinstance(message, dict):
# JSON消息
# JSON消息(配置消息或结束信号)
await self.websocket.send(json.dumps(message))
util.log(1, f"发送JSON消息: {message}")
elif isinstance(message, bytes):
# 二进制音频数据
# 二进制音频数据(参考funasr_client_api.py的feed_chunk方法)
# 确保音频数据以二进制格式发送
await self.websocket.send(message)
util.log(1, f"发送音频数据: {len(message)} bytes")
else:
... ... @@ -203,23 +231,24 @@ class FunASRClient(BaseASR):
text: 识别文本
"""
try:
from core import wsa_server
from core import get_web_instance, get_instance
# 发送到Web客户端
if wsa_server.get_web_instance().is_connected(self.username):
wsa_server.get_web_instance().add_cmd({
"panelMsg": text,
"Username": self.username
})
# 发送到Human客户端
if wsa_server.get_instance().is_connected_human(self.username):
content = {
'Topic': 'human',
'Data': {'Key': 'log', 'Value': text},
'Username': self.username
if get_web_instance().is_connected(self.username):
import asyncio
# 创建chat_message直接推送
chat_message = {
"type": "chat_message",
"sender": "回音",
"content": text, # 修复字段名:panelMsg -> content
"Username": self.username,
"model_info": "FunASR"
}
wsa_server.get_instance().add_cmd(content)
# 使用直接发送方法,避免wsa_command封装
asyncio.create_task(get_web_instance().send_direct_message(chat_message))
# Human客户端通知改为日志记录(避免重复通知当前服务)
util.log(1, f"FunASR识别结果[{self.username}]: {text}")
except Exception as e:
util.log(2, f"发送到Web客户端失败: {e}")
... ... @@ -333,6 +362,19 @@ class FunASRClient(BaseASR):
self.message_queue.put(audio_data)
return True
def send_end_signal(self):
"""发送结束信号"""
if not self.connected:
return
try:
# 发送结束消息(参考funasr_client_api.py的close方法)
end_message = {"is_speaking": False}
self.message_queue.put(end_message)
util.log(1, "发送FunASR结束信号")
except Exception as e:
util.log(3, f"发送结束信号失败: {e}")
def start_recognition(self):
"""开始语音识别"""
if not self.connected:
... ... @@ -418,6 +460,99 @@ class FunASRClient(BaseASR):
# 简化实现,返回空特征
return np.zeros((1, 50), dtype=np.float32)
async def connect(self):
"""异步连接到FunASR服务器"""
if self.connected:
util.log(1, "FunASR客户端已连接")
return True
try:
success = await self._connect_websocket()
if success:
# 启动消息处理任务
self.receive_task = asyncio.create_task(self._receive_messages())
self.send_task = asyncio.create_task(self._send_message_loop())
util.log(1, "FunASR异步连接建立成功")
return success
except Exception as e:
util.log(3, f"FunASR异步连接失败: {e}")
return False
async def disconnect(self):
"""异步断开连接"""
try:
# 取消任务
if hasattr(self, 'receive_task'):
self.receive_task.cancel()
if hasattr(self, 'send_task'):
self.send_task.cancel()
# 断开WebSocket连接
await self._disconnect_websocket()
util.log(1, "FunASR异步连接已断开")
except Exception as e:
util.log(2, f"断开FunASR连接时出错: {e}")
async def send_audio_data(self, audio_data):
"""异步发送音频数据"""
try:
if isinstance(audio_data, str):
# Base64编码的音频数据,需要解码
import base64
audio_bytes = base64.b64decode(audio_data)
util.log(1, f"解码Base64音频数据: {len(audio_bytes)} bytes")
elif isinstance(audio_data, bytes):
audio_bytes = audio_data
util.log(1, f"接收字节音频数据: {len(audio_bytes)} bytes")
elif isinstance(audio_data, np.ndarray):
# NumPy数组转换为字节
if audio_data.dtype != np.int16:
audio_data = audio_data.astype(np.int16)
audio_bytes = bytes(audio_data.tobytes()) # Fix BufferError: memoryview has 1 exported buffer
util.log(1, f"转换NumPy数组为字节: {len(audio_bytes)} bytes")
else:
util.log(3, f"不支持的音频数据类型: {type(audio_data)},尝试转换为字节")
# 尝试强制转换
try:
audio_bytes = bytes(audio_data)
except Exception as convert_error:
util.log(3, f"音频数据类型转换失败: {convert_error}")
return False
# 验证音频数据有效性
if len(audio_bytes) == 0:
util.log(2, "音频数据为空,跳过发送")
return False
# 确保音频数据长度为偶数(16位采样)
if len(audio_bytes) % 2 != 0:
audio_bytes = audio_bytes[:-1] # 去掉最后一个字节
util.log(2, f"调整音频数据长度为偶数: {len(audio_bytes)} bytes")
# 参考funasr_client_api.py,音频数据需要按chunk发送
# 计算stride(参考项目中的计算方式)
chunk_interval = 10 # ms
chunk_size = 10 # ms
stride = int(60 * chunk_size / chunk_interval / 1000 * 16000 * 2)
# 如果音频数据较大,分块发送
if len(audio_bytes) > stride:
chunk_num = (len(audio_bytes) - 1) // stride + 1
for i in range(chunk_num):
beg = i * stride
chunk_data = audio_bytes[beg:beg + stride]
self.message_queue.put(chunk_data)
util.log(1, f"发送音频块 {i+1}/{chunk_num}: {len(chunk_data)} bytes")
else:
# 小数据直接发送
self.message_queue.put(audio_bytes)
util.log(1, f"发送音频数据: {len(audio_bytes)} bytes")
return True
except Exception as e:
util.log(3, f"发送音频数据失败: {e}")
return False
def __del__(self):
"""析构函数"""
self.stop()
... ...
... ... @@ -293,7 +293,10 @@ class LightReal(BaseReal):
frame,type_,eventpoint = audio_frame
frame = (frame * 32767).astype(np.int16)
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
new_frame.planes[0].update(frame.tobytes())
# 修复 BufferError: memoryview has 1 exported buffer
# 创建数据副本避免内存视图冲突
frame_bytes = bytes(frame.tobytes())
new_frame.planes[0].update(frame_bytes)
new_frame.sample_rate=16000
# if audio_track._queue.qsize()>10:
# time.sleep(0.1)
... ...
... ... @@ -263,7 +263,41 @@ class LipReal(BaseReal):
#print('blending time:',time.perf_counter()-t)
image = combine_frame #(outputs['image'] * 255).astype(np.uint8)
new_frame = VideoFrame.from_ndarray(image, format="bgr24")
# Fix MemoryError: 优化内存使用和错误处理
try:
# 检查图像尺寸,如果过大则压缩
h, w = image.shape[:2]
max_dimension = 1920 # 最大尺寸限制
if h > max_dimension or w > max_dimension:
scale = max_dimension / max(h, w)
new_h, new_w = int(h * scale), int(w * scale)
image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
logger.warning(f"Image resized from {w}x{h} to {new_w}x{new_h} to prevent MemoryError")
# 确保数据类型正确并创建连续内存布局
if not image.flags['C_CONTIGUOUS']:
image = np.ascontiguousarray(image)
image = image.astype(np.uint8)
new_frame = VideoFrame.from_ndarray(image, format="bgr24")
except MemoryError as e:
logger.error(f"MemoryError in VideoFrame creation: {e}, image shape: {image.shape}")
# 进一步压缩图像作为备用方案
try:
h, w = image.shape[:2]
backup_scale = 0.5
backup_h, backup_w = int(h * backup_scale), int(w * backup_scale)
image = cv2.resize(image, (backup_w, backup_h), interpolation=cv2.INTER_AREA)
image = np.ascontiguousarray(image.astype(np.uint8))
new_frame = VideoFrame.from_ndarray(image, format="bgr24")
logger.info(f"Backup resize successful: {backup_w}x{backup_h}")
except Exception as backup_e:
logger.error(f"Backup resize failed: {backup_e}")
continue
except Exception as e:
logger.error(f"Unexpected error in VideoFrame creation: {e}")
continue
asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop)
self.record_video_data(image)
... ... @@ -271,11 +305,19 @@ class LipReal(BaseReal):
frame,type,eventpoint = audio_frame
frame = (frame * 32767).astype(np.int16)
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
new_frame.planes[0].update(frame.tobytes())
new_frame.sample_rate=16000
# if audio_track._queue.qsize()>10:
# time.sleep(0.1)
asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop)
# 修复 BufferError: 强制复制数据避免内存视图冲突
frame_copy = frame.astype(np.int16).copy()
frame_bytes = bytes(frame_copy.tobytes()) # Fix BufferError: memoryview has 1 exported buffer
new_frame.planes[0].update(frame_bytes)
new_frame.sample_rate = 16000
# 使用线程安全的方式提交到队列,避免闭包问题
def put_audio_frame(frame_obj, event_point):
audio_track._queue.put_nowait((frame_obj, event_point))
loop.call_soon_threadsafe(put_audio_frame, new_frame, eventpoint)
self.record_audio_data(frame)
#self.notify(eventpoint)
logger.info('lipreal process_frames thread stop')
... ...
import logging
import os
# 确保日志目录存在
log_dir = "logs"
if not os.path.exists(log_dir):
os.makedirs(log_dir)
# 配置日志器
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
... ... @@ -13,4 +19,47 @@ logger.addHandler(fhandler)
# handler.setLevel(logging.DEBUG)
# sformatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
# handler.setFormatter(sformatter)
# logger.addHandler(handler)
\ No newline at end of file
# logger.addHandler(handler)
def get_logger(name: str, level: str = "INFO") -> logging.Logger:
"""获取指定名称的日志器
Args:
name: 日志器名称
level: 日志级别 (DEBUG, INFO, WARNING, ERROR, CRITICAL)
Returns:
配置好的日志器实例
"""
# 创建日志器
logger_instance = logging.getLogger(name)
# 避免重复添加处理器
if logger_instance.handlers:
return logger_instance
# 设置日志级别
log_level = getattr(logging, level.upper(), logging.INFO)
logger_instance.setLevel(log_level)
# 创建格式器
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 文件处理器
log_file = os.path.join(log_dir, f"{name.lower()}.log")
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler.setFormatter(formatter)
file_handler.setLevel(log_level)
logger_instance.addHandler(file_handler)
# 控制台处理器(可选)
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
'%(asctime)s - %(levelname)s - %(message)s'
))
console_handler.setLevel(logging.WARNING) # 控制台只显示警告及以上级别
logger_instance.addHandler(console_handler)
return logger_instance
\ No newline at end of file
... ...
... ... @@ -346,7 +346,10 @@ class MuseReal(BaseReal):
frame,type,eventpoint = audio_frame
frame = (frame * 32767).astype(np.int16)
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
new_frame.planes[0].update(frame.tobytes())
# 修复 BufferError: memoryview has 1 exported buffer
# 创建数据副本避免内存视图冲突
frame_bytes = bytes(frame.tobytes())
new_frame.planes[0].update(frame_bytes)
new_frame.sample_rate=16000
asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop)
self.record_audio_data(frame)
... ...
... ... @@ -247,7 +247,10 @@ class NeRFReal(BaseReal):
else: #webrtc
frame = (frame * 32767).astype(np.int16)
new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
new_frame.planes[0].update(frame.tobytes())
# 修复 BufferError: memoryview has 1 exported buffer
# 创建数据副本避免内存视图冲突
frame_bytes = bytes(frame.tobytes())
new_frame.planes[0].update(frame_bytes)
new_frame.sample_rate=16000
asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop)
... ...
... ... @@ -18,7 +18,7 @@ face_alignment
python_speech_features
numba
resampy
#pyaudio
pyaudio
soundfile==0.12.1
einops
configargparse
... ... @@ -30,6 +30,9 @@ transformers
edge_tts
flask
flask_sockets
flask-socketio
websockets
websocket-client
opencv-python-headless
aiortc
aiohttp_cors
... ... @@ -41,5 +44,6 @@ accelerate
librosa
openai
aiofiles
#判断音频类型的支持
AudioSegment
... ...
#!/usr/bin/env python3
# AIfeng/2024-12-19
# 豆包模型集成测试脚本
import os
import sys
import json
from pathlib import Path
# 添加项目根目录到Python路径
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
def test_config_files():
"""测试配置文件是否存在和格式正确"""
print("=== 配置文件测试 ===")
# 测试LLM配置文件
llm_config_path = project_root / "config" / "llm_config.json"
if llm_config_path.exists():
try:
with open(llm_config_path, 'r', encoding='utf-8') as f:
llm_config = json.load(f)
print(f"✓ LLM配置文件加载成功: {llm_config_path}")
print(f" 当前模型类型: {llm_config.get('model_type', 'unknown')}")
except Exception as e:
print(f"✗ LLM配置文件格式错误: {e}")
else:
print(f"✗ LLM配置文件不存在: {llm_config_path}")
# 测试豆包配置文件
doubao_config_path = project_root / "config" / "doubao_config.json"
if doubao_config_path.exists():
try:
with open(doubao_config_path, 'r', encoding='utf-8') as f:
doubao_config = json.load(f)
print(f"✓ 豆包配置文件加载成功: {doubao_config_path}")
print(f" 模型名称: {doubao_config.get('model', 'unknown')}")
print(f" 人物设定: {doubao_config.get('character', {}).get('name', 'unknown')}")
except Exception as e:
print(f"✗ 豆包配置文件格式错误: {e}")
else:
print(f"✗ 豆包配置文件不存在: {doubao_config_path}")
def test_module_import():
"""测试模块导入"""
print("\n=== 模块导入测试 ===")
try:
from llm.Doubao import Doubao
print("✓ 豆包模块导入成功")
except ImportError as e:
print(f"✗ 豆包模块导入失败: {e}")
return False
try:
import llm
print(f"✓ LLM包导入成功,可用模型: {llm.AVAILABLE_MODELS}")
except ImportError as e:
print(f"✗ LLM包导入失败: {e}")
return True
def test_llm_config_loading():
"""测试LLM配置加载函数"""
print("\n=== LLM配置加载测试 ===")
try:
# 模拟llm.py中的配置加载函数
config_path = project_root / "config" / "llm_config.json"
if config_path.exists():
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
print(f"✓ 配置加载成功")
print(f" 模型类型: {config.get('model_type')}")
print(f" 配置项: {list(config.keys())}")
return config
else:
print("✗ 配置文件不存在,使用默认配置")
return {"model_type": "qwen"}
except Exception as e:
print(f"✗ 配置加载失败: {e}")
return {"model_type": "qwen"}
def test_doubao_instantiation():
"""测试豆包模型实例化(不需要真实API密钥)"""
print("\n=== 豆包实例化测试 ===")
try:
from llm.Doubao import Doubao
# 设置测试API密钥
os.environ['DOUBAO_API_KEY'] = 'test_key_for_validation'
doubao = Doubao()
print("✓ 豆包实例化成功")
print(f" 配置文件路径: {doubao.config_file}")
print(f" API基础URL: {doubao.base_url}")
print(f" 模型名称: {doubao.model}")
# 清理测试环境变量
if 'DOUBAO_API_KEY' in os.environ:
del os.environ['DOUBAO_API_KEY']
return True
except Exception as e:
print(f"✗ 豆包实例化失败: {e}")
return False
def test_integration_flow():
"""测试完整集成流程"""
print("\n=== 集成流程测试 ===")
try:
# 模拟llm.py中的流程
config = test_llm_config_loading()
model_type = config.get("model_type", "qwen")
print(f"根据配置选择模型: {model_type}")
if model_type == "doubao":
print("✓ 将使用豆包模型处理请求")
elif model_type == "qwen":
print("✓ 将使用通义千问模型处理请求")
else:
print(f"⚠ 未知模型类型: {model_type}")
return True
except Exception as e:
print(f"✗ 集成流程测试失败: {e}")
return False
def main():
"""主测试函数"""
print("豆包模型集成测试")
print("=" * 50)
# 运行所有测试
test_config_files()
if not test_module_import():
print("\n模块导入失败,停止测试")
return
test_llm_config_loading()
test_doubao_instantiation()
test_integration_flow()
print("\n=== 测试总结 ===")
print("✓ 豆包模型已成功集成到项目中")
print("✓ 配置文件结构正确")
print("✓ 模块导入正常")
print("\n使用说明:")
print("1. 设置环境变量 DOUBAO_API_KEY 为您的豆包API密钥")
print("2. 在 config/llm_config.json 中设置 model_type 为 'doubao'")
print("3. 根据需要修改 config/doubao_config.json 中的人物设定")
print("4. 重启应用即可使用豆包模型")
if __name__ == "__main__":
main()
\ No newline at end of file
# AIfeng/2025-01-27
"""
FunASR服务连接测试脚本
用于验证本地FunASR WebSocket服务是否可以正常连接
使用方法:
1. 先启动FunASR服务:python -u web/asr/funasr/ASR_server.py --host "127.0.0.1" --port 10197 --ngpu 0
2. 运行此测试脚本:python test_funasr_connection.py
"""
import asyncio
import websockets
import json
import os
import wave
import numpy as np
from pathlib import Path
class FunASRConnectionTest:
def __init__(self, host="127.0.0.1", port=10197):
self.host = host
self.port = port
self.uri = f"ws://{host}:{port}"
async def test_basic_connection(self):
"""测试基本WebSocket连接"""
print(f"🔍 测试连接到 {self.uri}")
try:
async with websockets.connect(self.uri) as websocket:
print("✅ FunASR WebSocket服务连接成功")
return True
except ConnectionRefusedError:
print("❌ 连接被拒绝,请确认FunASR服务已启动")
print(" 启动命令: python -u web/asr/funasr/ASR_server.py --host \"127.0.0.1\" --port 10197 --ngpu 0")
return False
except Exception as e:
print(f"❌ 连接失败: {e}")
return False
def create_test_wav(self, filename="test_audio.wav", duration=2, sample_rate=16000):
"""创建测试用的WAV文件"""
# 生成简单的正弦波音频
t = np.linspace(0, duration, int(sample_rate * duration), False)
frequency = 440 # A4音符
audio_data = np.sin(2 * np.pi * frequency * t) * 0.3
# 转换为16位整数
audio_data = (audio_data * 32767).astype(np.int16)
# 保存为WAV文件
with wave.open(filename, 'wb') as wav_file:
wav_file.setnchannels(1) # 单声道
wav_file.setsampwidth(2) # 16位
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_data.tobytes())
print(f"📁 创建测试音频文件: {filename}")
return filename
async def test_audio_recognition(self):
"""测试音频识别功能"""
print("\n🎵 测试音频识别功能")
# 创建测试音频文件
test_file = self.create_test_wav()
test_file_path = os.path.abspath(test_file)
try:
async with websockets.connect(self.uri) as websocket:
print("✅ 连接成功,发送音频文件路径")
# 发送音频文件路径
message = {"url": test_file_path}
await websocket.send(json.dumps(message))
print(f"📤 发送消息: {message}")
# 等待识别结果
try:
response = await asyncio.wait_for(websocket.recv(), timeout=10)
print(f"📥 收到识别结果: {response}")
return True
except asyncio.TimeoutError:
print("⏰ 等待响应超时(10秒)")
print(" 这可能是正常的,因为测试音频是纯音调,无法识别为文字")
return True # 超时也算连接成功
except Exception as e:
print(f"❌ 音频识别测试失败: {e}")
return False
finally:
# 清理测试文件
if os.path.exists(test_file):
os.remove(test_file)
print(f"🗑️ 清理测试文件: {test_file}")
async def test_real_audio_files(self):
"""测试实际音频文件的识别效果"""
print("\n🎤 测试实际音频文件识别")
# 实际音频文件列表
audio_files = [
"yunxi.mp3",
"yunxia.mp3",
"yunyang.mp3"
]
results = []
for audio_file in audio_files:
file_path = os.path.abspath(audio_file)
# 检查文件是否存在
if not os.path.exists(file_path):
print(f"⚠️ 音频文件不存在: {file_path}")
continue
print(f"\n🎵 测试音频文件: {audio_file}")
try:
async with websockets.connect(self.uri) as websocket:
print(f"✅ 连接成功,发送音频文件: {audio_file}")
# 发送音频文件路径
message = {"url": file_path}
await websocket.send(json.dumps(message))
print(f"📤 发送消息: {message}")
# 等待识别结果
try:
response = await asyncio.wait_for(websocket.recv(), timeout=30)
print(f"📥 识别结果: {response}")
# 解析响应
try:
result_data = json.loads(response)
if isinstance(result_data, dict) and 'text' in result_data:
recognized_text = result_data['text']
print(f"🎯 识别文本: {recognized_text}")
results.append({
'file': audio_file,
'text': recognized_text,
'status': 'success'
})
else:
print(f"📄 原始响应: {response}")
results.append({
'file': audio_file,
'response': response,
'status': 'received'
})
except json.JSONDecodeError:
print(f"📄 非JSON响应: {response}")
results.append({
'file': audio_file,
'response': response,
'status': 'received'
})
except asyncio.TimeoutError:
print(f"⏰ 等待响应超时(30秒)- {audio_file}")
results.append({
'file': audio_file,
'status': 'timeout'
})
except Exception as e:
print(f"❌ 测试 {audio_file} 失败: {e}")
results.append({
'file': audio_file,
'error': str(e),
'status': 'error'
})
# 文件间等待,避免服务器压力
await asyncio.sleep(1)
# 输出测试总结
print("\n" + "="*50)
print("📊 实际音频文件测试总结:")
for i, result in enumerate(results, 1):
print(f"\n{i}. 文件: {result['file']}")
if result['status'] == 'success':
print(f" ✅ 识别成功: {result['text']}")
elif result['status'] == 'received':
print(f" 📥 收到响应: {result.get('response', 'N/A')}")
elif result['status'] == 'timeout':
print(f" ⏰ 响应超时")
elif result['status'] == 'error':
print(f" ❌ 测试失败: {result.get('error', 'N/A')}")
return len(results) > 0
async def test_message_format(self):
"""测试消息格式兼容性"""
print("\n📋 测试消息格式兼容性")
try:
async with websockets.connect(self.uri) as websocket:
# 测试不同的消息格式
test_messages = [
{"url": "nonexistent.wav"},
{"test": "message"},
"invalid_json"
]
for i, msg in enumerate(test_messages, 1):
try:
if isinstance(msg, dict):
await websocket.send(json.dumps(msg))
print(f"✅ 消息 {i} 发送成功: {msg}")
else:
await websocket.send(msg)
print(f"✅ 消息 {i} 发送成功: {msg}")
# 短暂等待,避免消息堆积
await asyncio.sleep(0.5)
except Exception as e:
print(f"⚠️ 消息 {i} 发送失败: {e}")
return True
except Exception as e:
print(f"❌ 消息格式测试失败: {e}")
return False
def check_dependencies(self):
"""检查依赖项"""
print("🔍 检查依赖项...")
required_modules = [
'websockets',
'asyncio',
'json',
'wave',
'numpy'
]
missing_modules = []
for module in required_modules:
try:
__import__(module)
print(f"✅ {module}")
except ImportError:
print(f"❌ {module} (缺失)")
missing_modules.append(module)
if missing_modules:
print(f"\n⚠️ 缺失依赖项: {', '.join(missing_modules)}")
print("安装命令: pip install " + ' '.join(missing_modules))
return False
print("✅ 所有依赖项检查通过")
return True
def check_funasr_server_file(self):
"""检查FunASR服务器文件是否存在"""
print("\n📁 检查FunASR服务器文件...")
server_path = Path("web/asr/funasr/ASR_server.py")
if server_path.exists():
print(f"✅ 找到服务器文件: {server_path.absolute()}")
return True
else:
print(f"❌ 未找到服务器文件: {server_path.absolute()}")
print(" 请确认文件路径是否正确")
return False
async def run_all_tests(self):
"""运行所有测试"""
print("🚀 开始FunASR连接测试\n")
# 检查依赖
if not self.check_dependencies():
return False
# 检查服务器文件
if not self.check_funasr_server_file():
return False
# 基本连接测试
print("\n" + "="*50)
if not await self.test_basic_connection():
return False
# 音频识别测试
print("\n" + "="*50)
if not await self.test_audio_recognition():
return False
# 实际音频文件测试
print("\n" + "="*50)
await self.test_real_audio_files()
# 消息格式测试
print("\n" + "="*50)
if not await self.test_message_format():
return False
print("\n" + "="*50)
print("🎉 所有测试完成!FunASR服务连接正常")
print("\n💡 集成建议:")
print(" 1. 服务使用WebSocket协议,非gRPC")
print(" 2. 默认监听端口: 10197")
print(" 3. 消息格式: JSON字符串,包含'url'字段指向音频文件路径")
print(" 4. 可以集成到现有项目的ASR模块中")
return True
async def main():
"""主函数"""
tester = FunASRConnectionTest()
success = await tester.run_all_tests()
if not success:
print("\n❌ 测试失败,请检查FunASR服务状态")
return 1
return 0
if __name__ == "__main__":
try:
exit_code = asyncio.run(main())
exit(exit_code)
except KeyboardInterrupt:
print("\n⏹️ 测试被用户中断")
exit(1)
except Exception as e:
print(f"\n💥 测试过程中发生错误: {e}")
exit(1)
\ No newline at end of file
# -*- coding: utf-8 -*-
"""
AIfeng/2025-01-27
FunASR集成测试脚本
测试新的FunASRClient与项目的集成效果
"""
import os
import sys
import time
import threading
from pathlib import Path
# 添加项目路径
sys.path.append(os.path.dirname(__file__))
from funasr_asr import FunASRClient
from web.asr.funasr import FunASR
import util
class TestFunASRIntegration:
"""FunASR集成测试类"""
def __init__(self):
self.test_results = []
self.test_audio_files = [
"yunxi.mp3",
"yunxia.mp3",
"yunyang.mp3"
]
def log_test_result(self, test_name: str, success: bool, message: str = ""):
"""记录测试结果"""
status = "✓ 通过" if success else "✗ 失败"
result = f"[{status}] {test_name}"
if message:
result += f" - {message}"
self.test_results.append((test_name, success, message))
print(result)
def test_funasr_client_creation(self):
"""测试FunASRClient创建"""
try:
class SimpleOpt:
def __init__(self):
self.username = "test_user"
opt = SimpleOpt()
client = FunASRClient(opt)
# 检查基本属性
assert hasattr(client, 'server_url')
assert hasattr(client, 'connected')
assert hasattr(client, 'running')
self.log_test_result("FunASRClient创建", True, "客户端创建成功")
return client
except Exception as e:
self.log_test_result("FunASRClient创建", False, f"错误: {e}")
return None
def test_compatibility_wrapper(self):
"""测试兼容性包装器"""
try:
funasr = FunASR("test_user")
# 检查兼容性方法
assert hasattr(funasr, 'start')
assert hasattr(funasr, 'end')
assert hasattr(funasr, 'send')
assert hasattr(funasr, 'add_frame')
assert hasattr(funasr, 'set_message_callback')
self.log_test_result("兼容性包装器", True, "所有兼容性方法存在")
return funasr
except Exception as e:
self.log_test_result("兼容性包装器", False, f"错误: {e}")
return None
def test_callback_mechanism(self):
"""测试回调机制"""
try:
funasr = FunASR("test_user")
callback_called = threading.Event()
received_message = []
def test_callback(message):
received_message.append(message)
callback_called.set()
funasr.set_message_callback(test_callback)
# 模拟接收消息
test_message = "测试识别结果"
funasr._handle_result(test_message)
# 等待回调
if callback_called.wait(timeout=1.0):
if received_message and received_message[0] == test_message:
self.log_test_result("回调机制", True, "回调函数正常工作")
else:
self.log_test_result("回调机制", False, "回调消息不匹配")
else:
self.log_test_result("回调机制", False, "回调超时")
except Exception as e:
self.log_test_result("回调机制", False, f"错误: {e}")
def test_audio_file_existence(self):
"""测试音频文件存在性"""
existing_files = []
missing_files = []
for audio_file in self.test_audio_files:
if os.path.exists(audio_file):
existing_files.append(audio_file)
else:
missing_files.append(audio_file)
if existing_files:
self.log_test_result(
"音频文件检查",
True,
f"找到 {len(existing_files)} 个文件: {', '.join(existing_files)}"
)
if missing_files:
self.log_test_result(
"音频文件缺失",
False,
f"缺少 {len(missing_files)} 个文件: {', '.join(missing_files)}"
)
return existing_files
def test_connection_simulation(self):
"""测试连接模拟"""
try:
client = self.test_funasr_client_creation()
if not client:
return
# 测试启动和停止
client.start()
time.sleep(0.5) # 给连接一些时间
# 检查运行状态
if client.running:
self.log_test_result("客户端启动", True, "客户端成功启动")
else:
self.log_test_result("客户端启动", False, "客户端启动失败")
# 停止客户端
client.stop()
time.sleep(0.5)
if not client.running:
self.log_test_result("客户端停止", True, "客户端成功停止")
else:
self.log_test_result("客户端停止", False, "客户端停止失败")
except Exception as e:
self.log_test_result("连接模拟", False, f"错误: {e}")
def test_message_queue(self):
"""测试消息队列"""
try:
client = self.test_funasr_client_creation()
if not client:
return
# 测试消息入队
test_message = {"test": "message"}
client.message_queue.put(test_message)
# 检查队列
if not client.message_queue.empty():
retrieved_message = client.message_queue.get_nowait()
if retrieved_message == test_message:
self.log_test_result("消息队列", True, "消息队列正常工作")
else:
self.log_test_result("消息队列", False, "消息内容不匹配")
else:
self.log_test_result("消息队列", False, "消息队列为空")
except Exception as e:
self.log_test_result("消息队列", False, f"错误: {e}")
def test_config_loading(self):
"""测试配置加载"""
try:
import config_util as cfg
# 检查关键配置项
required_configs = [
'local_asr_ip',
'local_asr_port',
'asr_timeout',
'asr_reconnect_delay',
'asr_max_reconnect_attempts'
]
missing_configs = []
for config_key in required_configs:
try:
if hasattr(cfg, 'config'):
value = cfg.config.get(config_key)
else:
value = getattr(cfg, config_key, None)
if value is None:
missing_configs.append(config_key)
except:
missing_configs.append(config_key)
if not missing_configs:
self.log_test_result("配置加载", True, "所有必需配置项存在")
else:
self.log_test_result(
"配置加载",
False,
f"缺少配置项: {', '.join(missing_configs)}"
)
except Exception as e:
self.log_test_result("配置加载", False, f"错误: {e}")
def run_all_tests(self):
"""运行所有测试"""
print("\n" + "="*60)
print("FunASR集成测试开始")
print("="*60)
# 运行各项测试
self.test_config_loading()
self.test_funasr_client_creation()
self.test_compatibility_wrapper()
self.test_callback_mechanism()
self.test_message_queue()
self.test_audio_file_existence()
self.test_connection_simulation()
# 输出测试总结
print("\n" + "="*60)
print("测试总结")
print("="*60)
passed_tests = sum(1 for _, success, _ in self.test_results if success)
total_tests = len(self.test_results)
print(f"总测试数: {total_tests}")
print(f"通过测试: {passed_tests}")
print(f"失败测试: {total_tests - passed_tests}")
print(f"成功率: {passed_tests/total_tests*100:.1f}%")
# 显示失败的测试
failed_tests = [(name, msg) for name, success, msg in self.test_results if not success]
if failed_tests:
print("\n失败的测试:")
for name, msg in failed_tests:
print(f" - {name}: {msg}")
print("\n" + "="*60)
return passed_tests == total_tests
def main():
"""主函数"""
tester = TestFunASRIntegration()
success = tester.run_all_tests()
if success:
print("\n🎉 所有测试通过!FunASR集成准备就绪。")
else:
print("\n⚠️ 部分测试失败,请检查相关配置和依赖。")
return 0 if success else 1
if __name__ == "__main__":
exit(main())
\ No newline at end of file
#!/usr/bin/env python3
# AIfeng/2024-12-19
# WebSocket通信测试服务器
import asyncio
import json
import time
import weakref
from aiohttp import web, WSMsgType
import aiohttp_cors
from typing import Dict
# 全局变量
websocket_connections: Dict[int, weakref.WeakSet] = {} # sessionid:websocket_connections
# WebSocket消息推送函数
async def broadcast_message_to_session(sessionid: int, message_type: str, content: str, source: str = "测试服务器"):
"""向指定会话的所有WebSocket连接推送消息"""
if sessionid not in websocket_connections:
print(f'[SessionID:{sessionid}] No WebSocket connections found')
return
message = {
"type": "chat_message",
"data": {
"sessionid": sessionid,
"message_type": message_type,
"content": content,
"source": source,
"timestamp": time.time()
}
}
# 获取该会话的所有WebSocket连接
connections = list(websocket_connections[sessionid])
print(f'[SessionID:{sessionid}] Broadcasting to {len(connections)} connections')
# 向所有连接发送消息
for ws in connections:
try:
if not ws.closed:
await ws.send_str(json.dumps(message))
print(f'[SessionID:{sessionid}] Message sent to WebSocket: {message_type}')
except Exception as e:
print(f'[SessionID:{sessionid}] Failed to send WebSocket message: {e}')
# WebSocket处理器
async def websocket_handler(request):
"""处理WebSocket连接"""
ws = web.WebSocketResponse()
await ws.prepare(request)
sessionid = None
print('New WebSocket connection established')
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
print(f'Received WebSocket message: {data}')
if data.get('type') == 'login':
sessionid = data.get('sessionid', 0)
# 初始化该会话的WebSocket连接集合
if sessionid not in websocket_connections:
websocket_connections[sessionid] = weakref.WeakSet()
# 添加当前连接到会话
websocket_connections[sessionid].add(ws)
print(f'[SessionID:{sessionid}] WebSocket client logged in')
# 发送登录确认
await ws.send_str(json.dumps({
"type": "login_success",
"sessionid": sessionid,
"message": "WebSocket连接成功"
}))
elif data.get('type') == 'ping':
# 心跳检测
await ws.send_str(json.dumps({"type": "pong"}))
print('Sent pong response')
except json.JSONDecodeError:
print('Invalid JSON received from WebSocket')
except Exception as e:
print(f'Error processing WebSocket message: {e}')
elif msg.type == WSMsgType.ERROR:
print(f'WebSocket error: {ws.exception()}')
break
except Exception as e:
print(f'WebSocket connection error: {e}')
finally:
if sessionid is not None:
print(f'[SessionID:{sessionid}] WebSocket connection closed')
else:
print('WebSocket connection closed')
return ws
# 模拟human接口
async def human(request):
try:
params = await request.json()
sessionid = params.get('sessionid', 0)
user_message = params.get('text', '')
message_type = params.get('type', 'echo')
print(f'[SessionID:{sessionid}] Received {message_type} message: {user_message}')
# 推送用户消息到WebSocket
await broadcast_message_to_session(sessionid, message_type, user_message, "用户")
if message_type == 'echo':
# 推送回音消息到WebSocket
await broadcast_message_to_session(sessionid, 'echo', user_message, "回音")
elif message_type == 'chat':
# 模拟AI回复
ai_response = f"这是对 '{user_message}' 的AI回复"
await broadcast_message_to_session(sessionid, 'chat', ai_response, "AI助手")
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": 0, "data": "ok", "message": "消息已处理并推送"}
),
)
except Exception as e:
print(f'Error in human endpoint: {e}')
return web.Response(
content_type="application/json",
text=json.dumps(
{"code": -1, "msg": str(e)}
),
)
# 创建应用
def create_app():
app = web.Application()
# 添加路由
app.router.add_post("/human", human)
app.router.add_get("/ws", websocket_handler)
app.router.add_static('/', path='web')
# 配置CORS
cors = aiohttp_cors.setup(app, defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True,
expose_headers="*",
allow_headers="*",
)
})
# 为所有路由配置CORS
for route in list(app.router.routes()):
cors.add(route)
return app
if __name__ == '__main__':
app = create_app()
print('Starting WebSocket test server on http://localhost:8000')
print('WebSocket endpoint: ws://localhost:8000/ws')
print('HTTP endpoint: http://localhost:8000/human')
print('Test page: http://localhost:8000/websocket_test.html')
web.run_app(app, host='0.0.0.0', port=8000)
\ No newline at end of file
... ... @@ -9,7 +9,7 @@ import _thread as thread
from aliyunsdkcore.client import AcsClient
from aliyunsdkcore.request import CommonRequest
from core import wsa_server
from core import get_web_instance, get_instance
from scheduler.thread_manager import MyThread
from utils import util
from utils import config_util as cfg
... ... @@ -92,19 +92,37 @@ class ALiNls:
if name == 'SentenceEnd':
self.done = True
self.finalResults = data['payload']['result']
if wsa_server.get_web_instance().is_connected(self.username):
wsa_server.get_web_instance().add_cmd({"panelMsg": self.finalResults, "Username" : self.username})
if wsa_server.get_instance().is_connected_human(self.username):
content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': self.finalResults}, 'Username' : self.username}
wsa_server.get_instance().add_cmd(content)
if get_web_instance().is_connected(self.username):
import asyncio
# 创建chat_message直接推送
chat_message = {
"type": "chat_message",
"sender": "回音",
"content": self.finalResults, # 修复字段名:panelMsg -> content
"Username": self.username,
"model_info": "ALiNls"
}
# 使用直接发送方法,避免wsa_command封装
asyncio.create_task(get_web_instance().send_direct_message(chat_message))
# Human客户端通知改为日志记录(避免重复通知当前服务)
util.log(1, f"ALiNls识别结果[{self.username}]: {self.finalResults}")
ws.close()#TODO
elif name == 'TranscriptionResultChanged':
self.finalResults = data['payload']['result']
if wsa_server.get_web_instance().is_connected(self.username):
wsa_server.get_web_instance().add_cmd({"panelMsg": self.finalResults, "Username" : self.username})
if wsa_server.get_instance().is_connected_human(self.username):
content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': self.finalResults}, 'Username' : self.username}
wsa_server.get_instance().add_cmd(content)
if get_web_instance().is_connected(self.username):
import asyncio
# 创建chat_message直接推送
chat_message = {
"type": "chat_message",
"sender": "回音",
"content": self.finalResults, # 修复字段名:panelMsg -> content
"Username": self.username,
"model_info": "ALiNls"
}
# 使用直接发送方法,避免wsa_command封装
asyncio.create_task(get_web_instance().send_direct_message(chat_message))
# Human客户端通知改为日志记录(避免重复通知当前服务)
util.log(1, f"ALiNls识别变化[{self.username}]: {self.finalResults}")
except Exception as e:
print(e)
... ...
... ... @@ -14,14 +14,14 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from funasr_asr import FunASRClient
# 修复导入路径
try:
from core import wsa_server
from core import get_web_instance, get_instance
except ImportError:
# 如果core模块不存在,创建一个模拟的wsa_server
class MockWSAServer:
def get_web_instance(self):
return MockWebInstance()
def get_instance(self):
return MockInstance()
# 如果core模块不存在,创建一个模拟的函数
def get_web_instance():
return MockWebInstance()
def get_instance():
return MockInstance()
class MockWebInstance:
def is_connected(self, username):
... ... @@ -34,8 +34,6 @@ except ImportError:
return False
def add_cmd(self, cmd):
print(f"Mock Human: {cmd}")
wsa_server = MockWSAServer()
try:
from utils import config_util as cfg
... ... @@ -92,11 +90,20 @@ class FunASR:
if self.on_message_callback:
self.on_message_callback(message)
if wsa_server.get_web_instance().is_connected(self.username):
wsa_server.get_web_instance().add_cmd({"panelMsg": self.finalResults, "Username" : self.username})
if wsa_server.get_instance().is_connected_human(self.username):
content = {'Topic': 'human', 'Data': {'Key': 'log', 'Value': self.finalResults}, 'Username' : self.username}
wsa_server.get_instance().add_cmd(content)
if get_web_instance().is_connected(self.username):
import asyncio
# 创建chat_message直接推送
chat_message = {
"type": "chat_message",
"sender": "回音",
"content": self.finalResults, # 修复字段名:panelMsg -> content
"Username": self.username,
"model_info": "FunASR"
}
# 使用直接发送方法,避免wsa_command封装
asyncio.create_task(get_web_instance().send_direct_message(chat_message))
# Human客户端通知改为日志记录(避免重复通知当前服务)
util.log(1, f"FunASR识别结果[{self.username}]: {self.finalResults}")
except Exception as e:
print(e)
... ...
... ... @@ -20,25 +20,42 @@ args = parser.parse_args()
# 初始化模型
print("model loading")
asr_model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
vad_model="fsmn-vad", vad_model_revision="v2.0.4",
punc_model="ct-punc-c", punc_model_revision="v2.0.4",
device=f"cuda:{args.gpu_id}" if args.ngpu else "cpu", disable_update=True)
# ,disable_update=True
print("model loaded")
try:
asr_model = AutoModel(model="paraformer-zh", model_revision="v2.0.4",
vad_model="fsmn-vad", vad_model_revision="v2.0.4",
punc_model="ct-punc-c", punc_model_revision="v2.0.4",
device=f"cuda:{args.gpu_id}" if args.ngpu else "cpu", disable_update=True)
# ,disable_update=True
print("model loaded")
except Exception as e:
print(f"模型加载失败: {e}")
import traceback
traceback.print_exc()
exit(1)
websocket_users = {}
task_queue = asyncio.Queue()
# 分块会话管理
chunk_sessions = {} # {user_id: {filename, chunks, total_chunks, received_chunks, temp_file}}
async def ws_serve(websocket, path):
global websocket_users
global websocket_users, chunk_sessions
user_id = id(websocket)
websocket_users[user_id] = websocket
try:
async for message in websocket:
if isinstance(message, str):
data = json.loads(message)
if 'url' in data:
await task_queue.put((websocket, data['url']))
# 处理分块协议
if 'type' in data:
await handle_chunked_protocol(websocket, data, user_id)
# 处理传统协议
elif 'url' in data:
# 处理文件URL
await task_queue.put((websocket, data['url'], 'url'))
elif 'audio_data' in data:
# 处理音频数据
await task_queue.put((websocket, data, 'audio_data'))
except websockets.exceptions.ConnectionClosed as e:
logger.info(f"Connection closed: {e.reason}")
except Exception as e:
... ... @@ -47,14 +64,28 @@ async def ws_serve(websocket, path):
logger.info(f"Cleaning up connection for user {user_id}")
if user_id in websocket_users:
del websocket_users[user_id]
# 清理分块会话
if user_id in chunk_sessions:
await cleanup_chunk_session(user_id)
await websocket.close()
logger.info("WebSocket closed")
async def worker():
while True:
websocket, url = await task_queue.get()
task_data = await task_queue.get()
websocket = task_data[0]
if websocket.open:
await process_wav_file(websocket, url)
if len(task_data) == 3: # 新格式: (websocket, data, type)
data, data_type = task_data[1], task_data[2]
if data_type == 'url':
await process_wav_file(websocket, data)
elif data_type == 'audio_data':
await process_audio_data(websocket, data)
elif data_type == 'chunked_audio':
await process_chunked_audio(websocket, data)
else: # 兼容旧格式: (websocket, url)
await process_wav_file(websocket, task_data[1])
else:
logger.info("WebSocket connection is already closed when trying to process file")
task_queue.task_done()
... ... @@ -77,8 +108,226 @@ async def process_wav_file(websocket, url):
except Exception as e:
print(f"Error during model.generate: {e}")
finally:
if os.path.exists(wav_path):
os.remove(wav_path)
# 注释掉文件删除操作,保留缓存文件用于测试
# if os.path.exists(wav_path):
# os.remove(wav_path)
print(f"保留音频文件用于测试: {wav_path}")
async def handle_chunked_protocol(websocket, data, user_id):
"""处理分块协议消息"""
global chunk_sessions
try:
msg_type = data.get('type')
filename = data.get('filename', 'unknown.wav')
if msg_type == 'audio_start':
# 开始新的分块会话
total_chunks = data.get('total_chunks', 0)
total_size = data.get('total_size', 0)
print(f"开始接收分块音频: {filename}, 总分块数: {total_chunks}, 总大小: {total_size} bytes")
# 创建临时文件
import tempfile
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
chunk_sessions[user_id] = {
'filename': filename,
'total_chunks': total_chunks,
'total_size': total_size,
'received_chunks': 0,
'temp_file': temp_file,
'temp_path': temp_file.name,
'chunks_data': {} # {chunk_index: chunk_data}
}
await websocket.send(json.dumps({"status": "ready", "message": f"准备接收 {total_chunks} 个分块"}))
elif msg_type == 'audio_chunk':
# 接收音频分块
if user_id not in chunk_sessions:
await websocket.send(json.dumps({"error": "未找到分块会话,请先发送audio_start"}))
return
session = chunk_sessions[user_id]
chunk_index = data.get('chunk_index', -1)
chunk_data = data.get('chunk_data', '')
is_last = data.get('is_last', False)
if chunk_index >= 0 and chunk_data:
# 解码并存储分块数据
import base64
chunk_bytes = base64.b64decode(chunk_data)
session['chunks_data'][chunk_index] = chunk_bytes
session['received_chunks'] += 1
# 进度反馈
progress = (session['received_chunks'] / session['total_chunks']) * 100
if session['received_chunks'] % 10 == 0 or is_last:
print(f"接收进度: {progress:.1f}% ({session['received_chunks']}/{session['total_chunks']})")
elif msg_type == 'audio_end':
# 完成分块接收,重组音频
if user_id not in chunk_sessions:
await websocket.send(json.dumps({"error": "未找到分块会话"}))
return
session = chunk_sessions[user_id]
# 检查是否接收完整
if session['received_chunks'] != session['total_chunks']:
await websocket.send(json.dumps({
"error": f"分块不完整: 期望{session['total_chunks']}, 实际{session['received_chunks']}"
}))
await cleanup_chunk_session(user_id)
return
# 按顺序重组音频数据
print(f"重组音频文件: {session['filename']}")
with open(session['temp_path'], 'wb') as f:
for i in range(session['total_chunks']):
if i in session['chunks_data']:
f.write(session['chunks_data'][i])
else:
print(f"警告: 分块 {i} 缺失")
# 提交到处理队列
reconstructed_data = {
'audio_file_path': session['temp_path'],
'filename': session['filename']
}
await task_queue.put((websocket, reconstructed_data, 'chunked_audio'))
# 清理会话(保留临时文件给处理函数)
del chunk_sessions[user_id]
print(f"分块音频重组完成: {session['filename']}")
except Exception as e:
print(f"处理分块协议时出错: {e}")
await websocket.send(json.dumps({"error": f"分块处理错误: {str(e)}"}))
if user_id in chunk_sessions:
await cleanup_chunk_session(user_id)
async def cleanup_chunk_session(user_id):
"""清理分块会话"""
global chunk_sessions
if user_id in chunk_sessions:
session = chunk_sessions[user_id]
try:
# 关闭并删除临时文件
if 'temp_file' in session:
session['temp_file'].close()
if 'temp_path' in session and os.path.exists(session['temp_path']):
os.remove(session['temp_path'])
print(f"清理临时文件: {session['temp_path']}")
except Exception as e:
print(f"清理分块会话时出错: {e}")
finally:
del chunk_sessions[user_id]
async def process_chunked_audio(websocket, data):
"""处理分块重组后的音频文件"""
try:
audio_file_path = data.get('audio_file_path')
filename = data.get('filename', 'chunked_audio.wav')
if not audio_file_path or not os.path.exists(audio_file_path):
await websocket.send(json.dumps({"error": "重组音频文件不存在"}))
return
print(f"处理分块重组音频: {filename}, 文件路径: {audio_file_path}")
# 热词配置
param_dict = {"sentence_timestamp": False}
try:
with open("data/hotword.txt", "r", encoding="utf-8") as f:
lines = f.readlines()
lines = [line.strip() for line in lines]
hotword = " ".join(lines)
print(f"热词:{hotword}")
param_dict["hotword"] = hotword
except FileNotFoundError:
print("热词文件不存在,跳过热词配置")
# 进行语音识别
res = asr_model.generate(input=audio_file_path, is_final=True, **param_dict)
if res and websocket.open:
if 'text' in res[0]:
result_text = res[0]['text']
print(f"分块音频识别结果: {result_text}")
await websocket.send(result_text)
else:
await websocket.send("识别失败:无法获取文本结果")
except Exception as e:
print(f"处理分块音频时出错: {e}")
if websocket.open:
await websocket.send(f"分块音频识别错误: {str(e)}")
finally:
# 注释掉临时文件删除操作,保留用于测试
# if 'audio_file_path' in locals() and os.path.exists(audio_file_path):
# os.remove(audio_file_path)
if 'audio_file_path' in locals():
print(f"保留分块重组音频文件用于测试: {audio_file_path}")
async def process_audio_data(websocket, data):
"""处理音频数据"""
import base64
import tempfile
try:
# 获取音频数据
audio_data = data.get('audio_data')
filename = data.get('filename', 'audio.wav')
if not audio_data:
await websocket.send(json.dumps({"error": "No audio data provided"}))
return
# 解码Base64音频数据
audio_bytes = base64.b64decode(audio_data)
# 创建临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
temp_file.write(audio_bytes)
temp_path = temp_file.name
print(f"处理音频文件: {filename}, 临时路径: {temp_path}")
# 热词配置
param_dict = {"sentence_timestamp": False}
try:
with open("data/hotword.txt", "r", encoding="utf-8") as f:
lines = f.readlines()
lines = [line.strip() for line in lines]
hotword = " ".join(lines)
print(f"热词:{hotword}")
param_dict["hotword"] = hotword
except FileNotFoundError:
print("热词文件不存在,跳过热词配置")
# 进行语音识别
res = asr_model.generate(input=temp_path, is_final=True, **param_dict)
if res and websocket.open:
if 'text' in res[0]:
result_text = res[0]['text']
print(f"识别结果: {result_text}")
await websocket.send(result_text)
else:
await websocket.send("识别失败:无法获取文本结果")
except Exception as e:
print(f"处理音频数据时出错: {e}")
if websocket.open:
await websocket.send(f"识别错误: {str(e)}")
finally:
# 注释掉临时文件删除操作,保留用于测试
# if 'temp_path' in locals() and os.path.exists(temp_path):
# os.remove(temp_path)
if 'temp_path' in locals():
print(f"保留临时音频文件用于测试: {temp_path}")
async def main():
server = await websockets.serve(ws_serve, args.host, args.port, ping_interval=10)
... ... @@ -87,6 +336,7 @@ async def main():
try:
# 保持服务器运行,直到被手动中断
print(f"ASR服务器已启动,监听地址: {args.host}:{args.port}")
print("注意:此版本已禁用文件自动删除功能,用于测试分析")
await asyncio.Future() # 永久等待,直到程序被中断
except asyncio.CancelledError:
print("服务器正在关闭...")
... ... @@ -101,4 +351,11 @@ async def main():
await server.wait_closed()
# 使用 asyncio 运行主函数
asyncio.run(main())
try:
asyncio.run(main())
except KeyboardInterrupt:
logging.info("服务器已关闭")
except Exception as e:
logging.error(f"服务器启动失败: {e}")
import traceback
traceback.print_exc()
\ No newline at end of file
... ...