websocket_service_base.py
7.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
# -*- coding: utf-8 -*-
"""
AIfeng/2025-07-15 14:41:21
WebSocket服务抽象基类
为不同类型的WebSocket服务提供统一接口和生命周期管理
"""
import asyncio
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from aiohttp import web
from logger import logger
from .unified_websocket_manager import get_unified_manager, WebSocketSession
class WebSocketServiceBase(ABC):
"""WebSocket服务抽象基类"""
def __init__(self, service_name: str):
self.service_name = service_name
self.manager = get_unified_manager()
self.is_initialized = False
self._background_tasks: List[asyncio.Task] = []
async def initialize(self):
"""初始化服务"""
if self.is_initialized:
return
logger.info(f'初始化WebSocket服务: {self.service_name}')
# 注册消息处理器
await self._register_message_handlers()
# 注册事件处理器
await self._register_event_handlers()
# 启动后台任务
await self._start_background_tasks()
self.is_initialized = True
logger.info(f'WebSocket服务初始化完成: {self.service_name}')
async def shutdown(self):
"""关闭服务"""
if not self.is_initialized:
return
logger.info(f'关闭WebSocket服务: {self.service_name}')
# 停止后台任务
for task in self._background_tasks:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self._background_tasks.clear()
# 执行自定义清理
await self._cleanup()
self.is_initialized = False
logger.info(f'WebSocket服务已关闭: {self.service_name}')
@abstractmethod
async def _register_message_handlers(self):
"""注册消息处理器(子类实现)"""
pass
async def _register_event_handlers(self):
"""注册事件处理器(可选重写)"""
# 注册通用事件处理器
self.manager.register_event_handler('session_connected', self._on_session_connected)
self.manager.register_event_handler('session_disconnected', self._on_session_disconnected)
async def _start_background_tasks(self):
"""启动后台任务(可选重写)"""
pass
async def _cleanup(self):
"""清理资源(可选重写)"""
pass
async def _on_session_connected(self, session: WebSocketSession, data: Dict[str, Any]):
"""会话连接事件处理(可选重写)"""
logger.info(f'[{self.service_name}] 会话连接: {session.session_id}')
async def _on_session_disconnected(self, session: WebSocketSession):
"""会话断开事件处理(可选重写)"""
logger.info(f'[{self.service_name}] 会话断开: {session.session_id}')
def add_background_task(self, coro):
"""添加后台任务"""
task = asyncio.create_task(coro)
self._background_tasks.append(task)
return task
async def broadcast_to_session(self, session_id: str, message_type: str, content: Any,
source: str = None, metadata: Dict = None):
"""向指定会话广播消息"""
if source is None:
source = self.service_name
return await self.manager.broadcast_to_session(session_id, message_type, content, source, metadata)
async def broadcast_to_all(self, message_type: str, content: Any,
source: str = None, metadata: Dict = None):
"""向所有会话广播消息"""
if source is None:
source = self.service_name
return await self.manager.broadcast_to_all(message_type, content, source, metadata)
def get_session_stats(self) -> Dict[str, Any]:
"""获取会话统计信息"""
return self.manager.get_session_stats()
def create_message_handler(self, message_type: str):
"""装饰器:创建消息处理器"""
def decorator(func):
async def wrapper(websocket: web.WebSocketResponse, data: Dict[str, Any]):
session = self.manager.get_session(websocket)
if session:
try:
await func(session, data)
except Exception as e:
logger.error(f'[{self.service_name}] 消息处理器 {message_type} 执行失败: {e}')
await session.send_message({
"type": "error",
"data": {
"message": f"处理 {message_type} 消息失败: {str(e)}",
"service": self.service_name
}
})
else:
logger.warning(f'[{self.service_name}] 未找到会话,无法处理消息: {message_type}')
self.manager.register_message_handler(message_type, wrapper)
return func
return decorator
def create_event_handler(self, event_type: str):
"""装饰器:创建事件处理器"""
def decorator(func):
self.manager.register_event_handler(event_type, func)
return func
return decorator
class WebSocketServiceRegistry:
"""WebSocket服务注册表"""
def __init__(self):
self._services: Dict[str, WebSocketServiceBase] = {}
def register_service(self, service: WebSocketServiceBase):
"""注册服务"""
if service.service_name in self._services:
logger.warning(f'服务已存在,将被覆盖: {service.service_name}')
self._services[service.service_name] = service
logger.info(f'注册WebSocket服务: {service.service_name}')
def get_service(self, service_name: str) -> Optional[WebSocketServiceBase]:
"""获取服务"""
return self._services.get(service_name)
def list_services(self) -> List[str]:
"""列出所有服务名称"""
return list(self._services.keys())
async def initialize_all(self):
"""初始化所有服务"""
logger.info('初始化所有WebSocket服务...')
for service in self._services.values():
await service.initialize()
logger.info('所有WebSocket服务初始化完成')
async def shutdown_all(self):
"""关闭所有服务"""
logger.info('关闭所有WebSocket服务...')
for service in self._services.values():
await service.shutdown()
logger.info('所有WebSocket服务已关闭')
def get_all_stats(self) -> Dict[str, Any]:
"""获取所有服务的统计信息"""
stats = {
"services": {},
"total_services": len(self._services)
}
for name, service in self._services.items():
stats["services"][name] = {
"initialized": service.is_initialized,
"background_tasks": len(service._background_tasks)
}
return stats
# 全局服务注册表
_service_registry = WebSocketServiceRegistry()
def get_service_registry() -> WebSocketServiceRegistry:
"""获取服务注册表"""
return _service_registry
def register_websocket_service(service: WebSocketServiceBase):
"""注册WebSocket服务"""
_service_registry.register_service(service)
def get_websocket_service(service_name: str) -> Optional[WebSocketServiceBase]:
"""获取WebSocket服务"""
return _service_registry.get_service(service_name)