websocket_service_base.py 7.83 KB
# -*- coding: utf-8 -*-
"""
AIfeng/2025-07-15 14:41:21
WebSocket服务抽象基类
为不同类型的WebSocket服务提供统一接口和生命周期管理
"""

import asyncio
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from aiohttp import web
from logger import logger
from .unified_websocket_manager import get_unified_manager, WebSocketSession


class WebSocketServiceBase(ABC):
    """WebSocket服务抽象基类"""
    
    def __init__(self, service_name: str):
        self.service_name = service_name
        self.manager = get_unified_manager()
        self.is_initialized = False
        self._background_tasks: List[asyncio.Task] = []
        
    async def initialize(self):
        """初始化服务"""
        if self.is_initialized:
            return
            
        logger.info(f'初始化WebSocket服务: {self.service_name}')
        
        # 注册消息处理器
        await self._register_message_handlers()
        
        # 注册事件处理器
        await self._register_event_handlers()
        
        # 启动后台任务
        await self._start_background_tasks()
        
        self.is_initialized = True
        logger.info(f'WebSocket服务初始化完成: {self.service_name}')
        
    async def shutdown(self):
        """关闭服务"""
        if not self.is_initialized:
            return
            
        logger.info(f'关闭WebSocket服务: {self.service_name}')
        
        # 停止后台任务
        for task in self._background_tasks:
            if not task.done():
                task.cancel()
                try:
                    await task
                except asyncio.CancelledError:
                    pass
                    
        self._background_tasks.clear()
        
        # 执行自定义清理
        await self._cleanup()
        
        self.is_initialized = False
        logger.info(f'WebSocket服务已关闭: {self.service_name}')
        
    @abstractmethod
    async def _register_message_handlers(self):
        """注册消息处理器(子类实现)"""
        pass
        
    async def _register_event_handlers(self):
        """注册事件处理器(可选重写)"""
        # 注册通用事件处理器
        self.manager.register_event_handler('session_connected', self._on_session_connected)
        self.manager.register_event_handler('session_disconnected', self._on_session_disconnected)
        
    async def _start_background_tasks(self):
        """启动后台任务(可选重写)"""
        pass
        
    async def _cleanup(self):
        """清理资源(可选重写)"""
        pass
        
    async def _on_session_connected(self, session: WebSocketSession, data: Dict[str, Any]):
        """会话连接事件处理(可选重写)"""
        logger.info(f'[{self.service_name}] 会话连接: {session.session_id}')
        
    async def _on_session_disconnected(self, session: WebSocketSession):
        """会话断开事件处理(可选重写)"""
        logger.info(f'[{self.service_name}] 会话断开: {session.session_id}')
        
    def add_background_task(self, coro):
        """添加后台任务"""
        task = asyncio.create_task(coro)
        self._background_tasks.append(task)
        return task
        
    async def broadcast_to_session(self, session_id: str, message_type: str, content: Any, 
                                 source: str = None, metadata: Dict = None):
        """向指定会话广播消息"""
        if source is None:
            source = self.service_name
        return await self.manager.broadcast_to_session(session_id, message_type, content, source, metadata)
        
    async def broadcast_to_all(self, message_type: str, content: Any, 
                             source: str = None, metadata: Dict = None):
        """向所有会话广播消息"""
        if source is None:
            source = self.service_name
        return await self.manager.broadcast_to_all(message_type, content, source, metadata)
        
    def get_session_stats(self) -> Dict[str, Any]:
        """获取会话统计信息"""
        return self.manager.get_session_stats()
        
    def create_message_handler(self, message_type: str):
        """装饰器:创建消息处理器"""
        def decorator(func):
            async def wrapper(websocket: web.WebSocketResponse, data: Dict[str, Any]):
                session = self.manager.get_session(websocket)
                if session:
                    try:
                        await func(session, data)
                    except Exception as e:
                        logger.error(f'[{self.service_name}] 消息处理器 {message_type} 执行失败: {e}')
                        await session.send_message({
                            "type": "error",
                            "data": {
                                "message": f"处理 {message_type} 消息失败: {str(e)}",
                                "service": self.service_name
                            }
                        })
                else:
                    logger.warning(f'[{self.service_name}] 未找到会话,无法处理消息: {message_type}')
                    
            self.manager.register_message_handler(message_type, wrapper)
            return func
        return decorator
        
    def create_event_handler(self, event_type: str):
        """装饰器:创建事件处理器"""
        def decorator(func):
            self.manager.register_event_handler(event_type, func)
            return func
        return decorator


class WebSocketServiceRegistry:
    """WebSocket服务注册表"""
    
    def __init__(self):
        self._services: Dict[str, WebSocketServiceBase] = {}
        
    def register_service(self, service: WebSocketServiceBase):
        """注册服务"""
        if service.service_name in self._services:
            logger.warning(f'服务已存在,将被覆盖: {service.service_name}')
            
        self._services[service.service_name] = service
        logger.info(f'注册WebSocket服务: {service.service_name}')
        
    def get_service(self, service_name: str) -> Optional[WebSocketServiceBase]:
        """获取服务"""
        return self._services.get(service_name)
        
    def list_services(self) -> List[str]:
        """列出所有服务名称"""
        return list(self._services.keys())
        
    async def initialize_all(self):
        """初始化所有服务"""
        logger.info('初始化所有WebSocket服务...')
        for service in self._services.values():
            await service.initialize()
        logger.info('所有WebSocket服务初始化完成')
        
    async def shutdown_all(self):
        """关闭所有服务"""
        logger.info('关闭所有WebSocket服务...')
        for service in self._services.values():
            await service.shutdown()
        logger.info('所有WebSocket服务已关闭')
        
    def get_all_stats(self) -> Dict[str, Any]:
        """获取所有服务的统计信息"""
        stats = {
            "services": {},
            "total_services": len(self._services)
        }
        
        for name, service in self._services.items():
            stats["services"][name] = {
                "initialized": service.is_initialized,
                "background_tasks": len(service._background_tasks)
            }
            
        return stats


# 全局服务注册表
_service_registry = WebSocketServiceRegistry()


def get_service_registry() -> WebSocketServiceRegistry:
    """获取服务注册表"""
    return _service_registry


def register_websocket_service(service: WebSocketServiceBase):
    """注册WebSocket服务"""
    _service_registry.register_service(service)


def get_websocket_service(service_name: str) -> Optional[WebSocketServiceBase]:
    """获取WebSocket服务"""
    return _service_registry.get_service(service_name)