websocket_router.py 10.4 KB
# -*- coding: utf-8 -*-
"""
AIfeng/2025-07-15 14:41:21
WebSocket路由管理器
统一管理所有WebSocket服务的路由和初始化
"""

import asyncio
import time
from typing import Dict, Any, Optional
from aiohttp import web, WSMsgType
import json
from logger import logger
from .unified_websocket_manager import get_unified_manager
from .websocket_service_base import get_service_registry
from .asr_websocket_service import get_asr_service
from .digital_human_websocket_service import get_digital_human_service


class WebSocketRouter:
    """WebSocket路由管理器"""
    
    def __init__(self):
        self.manager = get_unified_manager()
        self.service_registry = get_service_registry()
        self.is_initialized = False
        
    async def initialize(self):
        """初始化路由器和所有服务"""
        if self.is_initialized:
            return
            
        logger.info('初始化WebSocket路由器...')
        
        # 注册所有服务
        await self._register_services()
        
        # 初始化所有服务
        await self.service_registry.initialize_all()
        
        self.is_initialized = True
        logger.info('WebSocket路由器初始化完成')
        
    async def shutdown(self):
        """关闭路由器和所有服务"""
        if not self.is_initialized:
            return
            
        logger.info('关闭WebSocket路由器...')
        
        # 关闭所有服务
        await self.service_registry.shutdown_all()
        
        # 关闭管理器
        await self.manager.shutdown()
        
        self.is_initialized = False
        logger.info('WebSocket路由器已关闭')
        
    async def _register_services(self):
        """注册所有WebSocket服务"""
        logger.info('注册WebSocket服务...')
        
        # 注册ASR服务
        asr_service = get_asr_service()
        self.service_registry.register_service(asr_service)
        
        # 注册数字人服务
        digital_human_service = get_digital_human_service()
        self.service_registry.register_service(digital_human_service)
        
        # 注册WSA服务
        from .wsa_websocket_service import WSAWebSocketService, initialize_wsa_service
        wsa_service = WSAWebSocketService(self.manager)
        self.service_registry.register_service(wsa_service)
        
        # 初始化WSA兼容性接口
        initialize_wsa_service(wsa_service)
        
        logger.info(f'已注册 {len(self.service_registry.list_services())} 个WebSocket服务')
        
    async def websocket_handler(self, request: web.Request) -> web.WebSocketResponse:
        """统一的WebSocket处理器"""
        ws = web.WebSocketResponse()
        await ws.prepare(request)
        
        # 创建会话ID
        session_id = request.headers.get('X-Session-ID', str(int(time.time())))
        session = self.manager.add_session(session_id, ws)
        logger.info(f'WebSocket连接建立: {session.session_id}')
        
        try:
            async for msg in ws:
                if msg.type == WSMsgType.TEXT:
                    try:
                        data = json.loads(msg.data)
                        await self._handle_message(ws, data)
                    except json.JSONDecodeError as e:
                        logger.error(f'JSON解析失败: {e}')
                        await session.send_message({
                            "type": "error",
                            "data": {"message": "消息格式错误"}
                        })
                    except Exception as e:
                        logger.error(f'消息处理失败: {e}')
                        await session.send_message({
                            "type": "error",
                            "data": {"message": f"处理失败: {str(e)}"}
                        })
                elif msg.type == WSMsgType.ERROR:
                    logger.error(f'WebSocket错误: {ws.exception()}')
                    break
                elif msg.type == WSMsgType.CLOSE:
                    logger.info(f'WebSocket连接关闭: {session.session_id}')
                    break
                    
        except ConnectionResetError:
            logger.warning(f'WebSocket连接被远程主机重置: {session.session_id}')
        except ConnectionAbortedError:
            logger.warning(f'WebSocket连接被中止: {session.session_id}')
        except Exception as e:
            logger.error(f'WebSocket处理异常: {e}')
        finally:
            # 清理会话
            self.manager.remove_session(ws)
            
        return ws
        
    async def _handle_message(self, ws: web.WebSocketResponse, data: Dict[str, Any]):
        """处理WebSocket消息"""
        message_type = data.get('type')
        if not message_type:
            session = self.manager.get_session(ws)
            if session:
                await session.send_message({
                    "type": "error",
                    "data": {"message": "缺少消息类型"}
                })
            return
            
        # 通过管理器处理消息
        await self.manager.handle_websocket_message(ws, data)
        
    def get_router_stats(self) -> Dict[str, Any]:
        """获取路由器统计信息"""
        stats = {
            "initialized": self.is_initialized,
            "manager_stats": self.manager.get_session_stats(),
            "service_stats": self.service_registry.get_all_stats()
        }
        
        # 添加各服务的详细统计
        asr_service = self.service_registry.get_service("asr_service")
        if asr_service:
            stats["asr_stats"] = asr_service.get_asr_stats()
            
        digital_human_service = self.service_registry.get_service("digital_human_service")
        if digital_human_service:
            stats["digital_human_stats"] = digital_human_service.get_digital_human_stats()
            
        return stats
        
    def setup_routes(self, app: web.Application, path: str = '/ws'):
        """设置WebSocket路由"""
        app.router.add_get(path, self.websocket_handler)
        logger.info(f'WebSocket路由已设置: {path}')
        
    async def broadcast_system_message(self, message: str, level: str = 'info'):
        """广播系统消息"""
        await self.manager.broadcast_to_all('system_message', {
            'message': message,
            'level': level,
            'timestamp': asyncio.get_event_loop().time()
        }, source='system')
        
    async def send_to_session(self, session_id: str, message_type: str, content: Any):
        """向指定会话发送消息"""
        return await self.manager.broadcast_to_session(session_id, message_type, content, source='router')

    async def send_raw_to_session(self, session_id: str, message: Dict):
        """向指定会话发送消息"""
        return await self.manager.broadcast_raw_message_to_session(str(session_id), message)

        
    
    async def send_to_digital_human(self, human_id: str, message_type: str, content: Any):
        """向指定数字人发送消息"""
        digital_human_service = self.service_registry.get_service("digital_human_service")
        if digital_human_service:
            return await digital_human_service.send_to_digital_human(human_id, message_type, content)
        return False
        
    async def get_asr_stats(self) -> Optional[Dict[str, Any]]:
        """获取ASR统计信息"""
        asr_service = self.service_registry.get_service("asr_service")
        if asr_service:
            return asr_service.get_asr_stats()
        return None
        
    async def get_digital_human_stats(self) -> Optional[Dict[str, Any]]:
        """获取数字人统计信息"""
        digital_human_service = self.service_registry.get_service("digital_human_service")
        if digital_human_service:
            return digital_human_service.get_digital_human_stats()
        return None


# 全局路由器实例
_websocket_router = None


def get_websocket_router() -> WebSocketRouter:
    """获取WebSocket路由器实例"""
    global _websocket_router
    if _websocket_router is None:
        _websocket_router = WebSocketRouter()
    return _websocket_router


async def initialize_websocket_router():
    """初始化WebSocket路由器"""
    router = get_websocket_router()
    await router.initialize()
    return router


async def shutdown_websocket_router():
    """关闭WebSocket路由器"""
    global _websocket_router
    if _websocket_router:
        await _websocket_router.shutdown()
        _websocket_router = None


def setup_websocket_routes(app: web.Application, path: str = '/ws'):
    """设置WebSocket路由(便捷函数)"""
    router = get_websocket_router()
    router.setup_routes(app, path)
    return router


# 兼容性接口
class WebSocketCompatibilityAPI:
    """WebSocket兼容性API
    
    为了保持与现有代码的兼容性,提供简化的接口
    """
    
    def __init__(self):
        self.router = get_websocket_router()
        
    async def broadcast_message_to_session(self, session_id: str, message: Dict[str, Any]):
        """向指定会话广播消息(兼容app.py接口)"""
        message_type = message.get('type', 'message')
        content = message.get('data', message)
        return await self.router.send_to_session(session_id, message_type, content)
        
    async def broadcast_to_all_sessions(self, message: Dict[str, Any]):
        """向所有会话广播消息"""
        message_type = message.get('type', 'message')
        content = message.get('data', message)
        return await self.router.manager.broadcast_to_all(message_type, content, source='compatibility')
        
    def get_active_sessions(self):
        """获取活跃会话列表"""
        return list(self.router.manager._sessions.keys())
        
    def get_session_count(self):
        """获取会话数量"""
        return len(self.router.manager._sessions)
        
    async def send_asr_result(self, session_id: str, result: Dict[str, Any]):
        """发送ASR结果(兼容app.py接口)"""
        return await self.router.send_to_session(session_id, 'asr_result', result)


# 全局兼容性API实例
_compatibility_api = None


def get_websocket_compatibility_api() -> WebSocketCompatibilityAPI:
    """获取WebSocket兼容性API实例"""
    global _compatibility_api
    if _compatibility_api is None:
        _compatibility_api = WebSocketCompatibilityAPI()
    return _compatibility_api