app_websocket_migration.py 8.99 KB
# -*- coding: utf-8 -*-
"""
AIfeng/2025-07-15 14:41:21
app.py WebSocket功能迁移脚本
将app.py中的WebSocket功能迁移到统一架构
"""

import asyncio
import json
import weakref
from typing import Dict, Any, Optional
from aiohttp import web
from logger import logger
from .websocket_router import get_websocket_router, get_websocket_compatibility_api
from .asr_websocket_service import get_asr_service
from .digital_human_websocket_service import get_digital_human_service


class AppWebSocketMigration:
    """app.py WebSocket功能迁移类"""
    
    def __init__(self):
        self.router = get_websocket_router()
        self.compatibility_api = get_websocket_compatibility_api()
        self.asr_service = get_asr_service()
        self.digital_human_service = get_digital_human_service()
        
        # 兼容性变量(保持与原app.py的接口一致)
        self.websocket_connections = {}
        self.asr_connections = {}
        
    async def initialize(self):
        """初始化迁移组件"""
        await self.router.initialize()
        logger.info('WebSocket迁移组件初始化完成')
        
    async def shutdown(self):
        """关闭迁移组件"""
        await self.router.shutdown()
        logger.info('WebSocket迁移组件已关闭')
        
    def setup_routes(self, app: web.Application):
        """设置路由(替换原app.py中的WebSocket路由)"""
        # 使用新的统一WebSocket处理器
        self.router.setup_routes(app, '/ws')
        
        # 添加兼容性路由(如果需要)
        app.router.add_get('/ws_legacy', self._legacy_websocket_handler)
        
    async def _legacy_websocket_handler(self, request: web.Request):
        """兼容性WebSocket处理器(保持原有接口)"""
        # 直接转发到新的统一处理器
        return await self.router.websocket_handler(request)
        
    # 兼容性接口方法
    async def broadcast_message_to_session(self, sessionid: int, message_type: str, 
                                         content: str, source: str = "数字人回复", 
                                         model_info: str = None, request_source: str = "页面"):
        """兼容原app.py的消息推送接口"""
        message_data = {
            "sessionid": sessionid,
            "message_type": message_type,
            "content": content,
            "source": source,
            "model_info": model_info,
            "request_source": request_source,
            "timestamp": asyncio.get_event_loop().time()
        }
        
        return await self.router.send_to_session(str(sessionid), 'chat_message', message_data)
        
    async def handle_asr_audio_data(self, data: Dict[str, Any], sessionid: int, ws):
        """兼容原app.py的ASR音频数据处理"""
        # 转换为新架构的消息格式
        message_data = {
            'audio_data': data.get('audio_data'),
            'sessionid': sessionid
        }
        
        # 通过新的ASR服务处理
        session = self.router.manager.get_session(ws)
        if session:
            await self.asr_service._handle_asr_audio_data(ws, message_data)
            
    async def handle_start_asr_recognition(self, sessionid: int, ws):
        """兼容原app.py的开始ASR识别"""
        session = self.router.manager.get_session(ws)
        if session:
            await self.asr_service._handle_start_asr_recognition(ws, {'sessionid': sessionid})
            
    async def handle_stop_asr_recognition(self, sessionid: int, ws):
        """兼容原app.py的停止ASR识别"""
        session = self.router.manager.get_session(ws)
        if session:
            await self.asr_service._handle_stop_asr_recognition(ws, {'sessionid': sessionid})
            
    async def send_asr_result(self, sessionid: int, result: Dict[str, Any]):
        """兼容原app.py的ASR结果发送"""
        return await self.router.send_to_session(str(sessionid), 'asr_result', {
            "text": result.get('text', ''),
            "is_final": result.get('is_final', False),
            "confidence": result.get('confidence', 0.0)
        })

    async def send_normal_asr_result(self, sessionid: int, result: Dict[str, Any]):
        """业务层决定传输内容以及结构"""
        return await self.router.send_raw_to_session(str(sessionid), result)
        

    def get_websocket_connections(self):
        """获取WebSocket连接(兼容性接口)"""
        # 返回兼容性字典格式,键为会话ID,值为WebSocket对象
        sessions_dict = self.router.manager._sessions
        result = {}
        for session_id, session_set in sessions_dict.items():
            # 取集合中的第一个WebSocket连接(通常每个session_id只有一个连接)
            if session_set:
                session = next(iter(session_set))
                result[session_id] = session.websocket
        return result
        
    def get_session_count(self):
        """获取会话数量(兼容性接口)"""
        return self.compatibility_api.get_session_count()
        
    async def cleanup_session(self, sessionid: int):
        """清理会话(兼容性接口)"""
        # 清理ASR连接
        if sessionid in self.asr_connections:
            del self.asr_connections[sessionid]
            
        # 通过新架构清理会话
        sessions = self.router.manager._sessions
        session_id_str = str(sessionid)
        
        for ws, session in list(sessions.items()):
            if session.session_id == session_id_str:
                await self.router.manager.remove_session(ws)
                break
                
    def get_migration_stats(self) -> Dict[str, Any]:
        """获取迁移统计信息"""
        return {
            "router_stats": self.router.get_router_stats(),
            "asr_stats": self.asr_service.get_asr_stats(),
            "digital_human_stats": self.digital_human_service.get_digital_human_stats(),
            "compatibility_sessions": len(self.websocket_connections),
            "compatibility_asr_connections": len(self.asr_connections)
        }


# 全局迁移实例
_migration_instance = None


def get_app_websocket_migration() -> AppWebSocketMigration:
    """获取app.py WebSocket迁移实例"""
    global _migration_instance
    if _migration_instance is None:
        _migration_instance = AppWebSocketMigration()
    return _migration_instance


async def initialize_app_websocket_migration():
    """初始化app.py WebSocket迁移"""
    migration = get_app_websocket_migration()
    await migration.initialize()
    return migration


async def shutdown_app_websocket_migration():
    """关闭app.py WebSocket迁移"""
    global _migration_instance
    if _migration_instance:
        await _migration_instance.shutdown()
        _migration_instance = None


def setup_app_websocket_routes(app: web.Application):
    """设置app.py WebSocket路由(便捷函数)"""
    migration = get_app_websocket_migration()
    migration.setup_routes(app)
    return migration


# 兼容性函数(保持与原app.py的接口一致)
async def broadcast_message_to_session(sessionid: int, message_type: str, content: str, 
                                     source: str = "数字人回复", model_info: str = None, 
                                     request_source: str = "页面"):
    """兼容原app.py的消息推送函数"""
    migration = get_app_websocket_migration()
    return await migration.broadcast_message_to_session(
        sessionid, message_type, content, source, model_info, request_source
    )


async def handle_asr_audio_data(data: Dict[str, Any], sessionid: int, ws):
    """兼容原app.py的ASR音频数据处理函数"""
    migration = get_app_websocket_migration()
    return await migration.handle_asr_audio_data(data, sessionid, ws)


async def handle_start_asr_recognition(sessionid: int, ws):
    """兼容原app.py的开始ASR识别函数"""
    migration = get_app_websocket_migration()
    return await migration.handle_start_asr_recognition(sessionid, ws)


async def handle_stop_asr_recognition(sessionid: int, ws):
    """兼容原app.py的停止ASR识别函数"""
    migration = get_app_websocket_migration()
    return await migration.handle_stop_asr_recognition(sessionid, ws)


async def send_asr_result(sessionid: int, result: Dict[str, Any]):
    """兼容原app.py的ASR结果发送函数"""
    migration = get_app_websocket_migration()
    return await migration.send_asr_result(sessionid, result)

async def send_normal_asr_result(sessionid: int, result: Dict[str, Any]):
    """兼容原app.py的ASR结果发送函数"""
    migration = get_app_websocket_migration()
    return await migration.send_normal_asr_result(sessionid, result)


# 全局变量兼容性接口
def get_websocket_connections():
    """获取WebSocket连接字典(兼容性接口)"""
    migration = get_app_websocket_migration()
    return migration.websocket_connections


def get_asr_connections():
    """获取ASR连接字典(兼容性接口)"""
    migration = get_app_websocket_migration()
    return migration.asr_connections