asr_websocket_service.py 11.5 KB
# -*- coding: utf-8 -*-
"""
AIfeng/2025-07-15 14:41:21
ASR WebSocket服务实现
从app.py中抽离的ASR相关WebSocket功能
"""

import asyncio
import json
import weakref
from typing import Dict, Any, Optional
from aiohttp import web
from logger import logger
from .websocket_service_base import WebSocketServiceBase
from .unified_websocket_manager import WebSocketSession


class ASRWebSocketService(WebSocketServiceBase):
    """ASR WebSocket服务"""
    
    def __init__(self):
        super().__init__("asr_service")
        # ASR连接管理
        self.asr_connections: Dict[str, Any] = {}  # sessionid -> asr_connection
        self._heartbeat_task = None
        
    async def _register_message_handlers(self):
        """注册ASR相关消息处理器"""
        # 注册消息处理器
        self.manager.register_message_handler('login', self._handle_login)
        self.manager.register_message_handler('heartbeat', self._handle_heartbeat)
        self.manager.register_message_handler('asr_audio_data', self._handle_asr_audio_data)
        self.manager.register_message_handler('start_asr_recognition', self._handle_start_asr_recognition)
        self.manager.register_message_handler('stop_asr_recognition', self._handle_stop_asr_recognition)
        
    async def _start_background_tasks(self):
        """启动心跳检测任务"""
        self._heartbeat_task = self.add_background_task(self._heartbeat_monitor())
        
    async def _cleanup(self):
        """清理ASR连接"""
        # 关闭所有ASR连接
        for session_id, asr_conn in list(self.asr_connections.items()):
            try:
                if hasattr(asr_conn, 'close'):
                    await asr_conn.close()
            except Exception as e:
                logger.error(f'关闭ASR连接失败 {session_id}: {e}')
                
        self.asr_connections.clear()
        
    async def _on_session_disconnected(self, session: WebSocketSession):
        """会话断开时清理ASR连接"""
        await super()._on_session_disconnected(session)
        
        # 清理对应的ASR连接
        if session.session_id in self.asr_connections:
            asr_conn = self.asr_connections.pop(session.session_id)
            try:
                if hasattr(asr_conn, 'close'):
                    await asr_conn.close()
                logger.info(f'已清理ASR连接: {session.session_id}')
            except Exception as e:
                logger.error(f'清理ASR连接失败 {session.session_id}: {e}')
                
    async def _handle_login(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
        """处理登录消息"""
        session = self.manager.get_session(websocket)
        if not session:
            return
            
        session_id = data.get('sessionid')
        if session_id:
            # 更新会话ID
            old_session_id = session.session_id
            session.session_id = session_id
            
            # 更新管理器中的会话映射
            self.manager._update_session_id(websocket, old_session_id, session_id)
            
            # 发送登录成功响应
            await session.send_message({
                "type": "login_response",
                "data": {
                    "status": "success",
                    "sessionid": session_id,
                    "message": "登录成功"
                }
            })
            
            logger.info(f'用户登录成功: {session_id}')
        else:
            await session.send_message({
                "type": "login_response",
                "data": {
                    "status": "error",
                    "message": "缺少sessionid"
                }
            })
            
    async def _handle_heartbeat(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
        """处理心跳消息"""
        session = self.manager.get_session(websocket)
        if session:
            session.update_last_heartbeat()
            await session.send_message({
                "type": "heartbeat_response",
                "data": {"status": "ok"}
            })
            
    async def _handle_asr_audio_data(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
        """处理ASR音频数据"""
        session = self.manager.get_session(websocket)
        if not session:
            return
            
        session_id = session.session_id
        audio_data = data.get('audio_data')
        
        if not audio_data:
            await session.send_message({
                "type": "error",
                "data": {"message": "缺少音频数据"}
            })
            return
            
        # 获取或创建ASR连接
        asr_conn = self.asr_connections.get(session_id)
        if not asr_conn:
            logger.warning(f'ASR连接不存在: {session_id}')
            await session.send_message({
                "type": "error",
                "data": {"message": "ASR连接未建立"}
            })
            return
            
        try:
            # 转发音频数据到ASR服务
            await self._forward_audio_to_asr(asr_conn, audio_data)
        except Exception as e:
            logger.error(f'转发音频数据失败 {session_id}: {e}')
            await session.send_message({
                "type": "error",
                "data": {"message": f"音频处理失败: {str(e)}"}
            })
            
    async def _handle_start_asr_recognition(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
        """处理开始ASR识别"""
        session = self.manager.get_session(websocket)
        if not session:
            return
            
        session_id = session.session_id
        
        try:
            # 创建ASR连接
            asr_conn = await self._create_asr_connection(session_id)
            if asr_conn:
                self.asr_connections[session_id] = asr_conn
                
                await session.send_message({
                    "type": "asr_recognition_started",
                    "data": {
                        "status": "success",
                        "message": "ASR识别已开始"
                    }
                })
                
                logger.info(f'ASR识别已开始: {session_id}')
            else:
                await session.send_message({
                    "type": "error",
                    "data": {"message": "创建ASR连接失败"}
                })
        except Exception as e:
            logger.error(f'开始ASR识别失败 {session_id}: {e}')
            await session.send_message({
                "type": "error",
                "data": {"message": f"开始识别失败: {str(e)}"}
            })
            
    async def _handle_stop_asr_recognition(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
        """处理停止ASR识别"""
        session = self.manager.get_session(websocket)
        if not session:
            return
            
        session_id = session.session_id
        
        if session_id in self.asr_connections:
            asr_conn = self.asr_connections.pop(session_id)
            try:
                if hasattr(asr_conn, 'close'):
                    await asr_conn.close()
                    
                await session.send_message({
                    "type": "asr_recognition_stopped",
                    "data": {
                        "status": "success",
                        "message": "ASR识别已停止"
                    }
                })
                
                logger.info(f'ASR识别已停止: {session_id}')
            except Exception as e:
                logger.error(f'停止ASR识别失败 {session_id}: {e}')
                await session.send_message({
                    "type": "error",
                    "data": {"message": f"停止识别失败: {str(e)}"}
                })
        else:
            await session.send_message({
                "type": "asr_recognition_stopped",
                "data": {
                    "status": "success",
                    "message": "ASR识别未在运行"
                }
            })
            
    async def _create_asr_connection(self, session_id: str):
        """创建ASR连接(需要根据实际ASR服务实现)"""
        # TODO: 这里需要根据实际的ASR服务(如FunASR)来实现连接逻辑
        # 暂时返回一个模拟连接对象
        logger.info(f'创建ASR连接: {session_id}')
        
        # 示例:创建到FunASR的WebSocket连接
        try:
            # 这里应该是实际的ASR连接逻辑
            # 例如:asr_conn = await create_funasr_connection(session_id, self._on_asr_result)
            asr_conn = MockASRConnection(session_id, self._on_asr_result)
            return asr_conn
        except Exception as e:
            logger.error(f'创建ASR连接失败 {session_id}: {e}')
            return None
            
    async def _forward_audio_to_asr(self, asr_conn, audio_data):
        """转发音频数据到ASR服务"""
        if hasattr(asr_conn, 'send_audio'):
            await asr_conn.send_audio(audio_data)
        else:
            logger.warning('ASR连接不支持发送音频数据')
            
    async def _on_asr_result(self, session_id: str, result: Dict[str, Any]):
        """ASR结果回调"""
        try:
            await self.broadcast_to_session(session_id, 'asr_result', result)
            logger.debug(f'ASR结果已发送: {session_id}')
        except Exception as e:
            logger.error(f'发送ASR结果失败 {session_id}: {e}')
            
    async def _heartbeat_monitor(self):
        """心跳监控任务"""
        while True:
            try:
                await asyncio.sleep(40)  # 每40秒检查一次
                
                # 检查会话心跳
                expired_sessions = self.manager.get_expired_sessions(timeout=60)
                for session in expired_sessions:
                    logger.info(f'会话心跳超时,断开连接: {session.session_id}')
                    await session.close()
                    
            except asyncio.CancelledError:
                break
            except Exception as e:
                logger.error(f'心跳监控异常: {e}')
                await asyncio.sleep(5)
                
    def get_asr_stats(self) -> Dict[str, Any]:
        """获取ASR统计信息"""
        return {
            "active_asr_connections": len(self.asr_connections),
            "asr_sessions": list(self.asr_connections.keys())
        }


class MockASRConnection:
    """模拟ASR连接(用于测试)"""
    
    def __init__(self, session_id: str, result_callback):
        self.session_id = session_id
        self.result_callback = result_callback
        self.is_closed = False
        
    async def send_audio(self, audio_data):
        """发送音频数据"""
        if self.is_closed:
            return
            
        # 模拟ASR处理
        await asyncio.sleep(0.1)
        
        # 模拟返回识别结果
        result = {
            "text": "模拟识别结果",
            "confidence": 0.95,
            "timestamp": asyncio.get_event_loop().time()
        }
        
        if self.result_callback:
            await self.result_callback(self.session_id, result)
            
    async def close(self):
        """关闭连接"""
        self.is_closed = True
        logger.info(f'模拟ASR连接已关闭: {self.session_id}')


# 创建ASR服务实例
asr_service = ASRWebSocketService()


def get_asr_service() -> ASRWebSocketService:
    """获取ASR服务实例"""
    return asr_service