unified_websocket_manager.py
20 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
# -*- coding: utf-8 -*-
"""
AIfeng/2025-07-15 14:41:21
统一WebSocket管理模块
提供统一的WebSocket连接管理、消息推送和事件处理功能
"""
import json
import time
import asyncio
import weakref
from typing import Dict, Any, Optional, Callable, List, Set
from threading import Lock
from aiohttp import web, WSMsgType
from logger import logger
class WebSocketSession:
"""WebSocket会话管理类"""
def __init__(self, session_id: str, websocket: web.WebSocketResponse):
self.session_id = session_id
self.websocket = websocket
self.created_at = time.time()
self.last_ping = time.time()
self.metadata = {}
def __eq__(self, other):
"""基于websocket对象判断会话是否相等"""
if not isinstance(other, WebSocketSession):
return False
return self.websocket is other.websocket
def __hash__(self):
"""基于websocket对象的id生成哈希值"""
return hash(id(self.websocket))
def is_alive(self) -> bool:
"""检查连接是否存活"""
return not self.websocket.closed
async def send_message(self, message: Dict[str, Any]) -> bool:
"""发送消息到WebSocket客户端"""
try:
if not self.is_alive():
return False
await self.websocket.send_str(json.dumps(message))
return True
except ConnectionResetError:
logger.warning(f'[Session:{self.session_id}] 客户端连接已重置')
return False
except ConnectionAbortedError:
logger.warning(f'[Session:{self.session_id}] 客户端连接已中止')
return False
except Exception as e:
logger.error(f'[Session:{self.session_id}] 发送消息失败: {e}')
return False
def update_ping(self):
"""更新心跳时间"""
self.last_ping = time.time()
def set_metadata(self, key: str, value: Any):
"""设置会话元数据"""
self.metadata[key] = value
def get_metadata(self, key: str, default=None):
"""获取会话元数据"""
return self.metadata.get(key, default)
async def close(self):
"""关闭WebSocket连接"""
try:
if not self.websocket.closed:
await self.websocket.close()
logger.info(f'[Session:{self.session_id}] WebSocket连接已关闭')
except ConnectionResetError:
logger.warning(f'[Session:{self.session_id}] 连接已被远程主机重置,无需关闭')
except ConnectionAbortedError:
logger.warning(f'[Session:{self.session_id}] 连接已被中止,无需关闭')
except Exception as e:
logger.error(f'[Session:{self.session_id}] 关闭WebSocket连接失败: {e}')
class UnifiedWebSocketManager:
"""统一WebSocket管理器"""
def __init__(self):
self._sessions: Dict[str, Set[WebSocketSession]] = {} # session_id -> WebSocketSession集合
self._websockets: Dict[web.WebSocketResponse, WebSocketSession] = {} # websocket -> session映射
self._message_handlers: Dict[str, Callable] = {} # 消息类型处理器
self._event_handlers: Dict[str, List[Callable]] = {} # 事件处理器
self._lock = Lock()
# 注册默认消息处理器
self.register_message_handler('ping', self._handle_ping)
self.register_message_handler('login', self._handle_login)
def register_message_handler(self, message_type: str, handler: Callable):
"""注册消息处理器"""
self._message_handlers[message_type] = handler
logger.info(f'注册消息处理器: {message_type}')
def register_event_handler(self, event_type: str, handler: Callable):
"""注册事件处理器"""
if event_type not in self._event_handlers:
self._event_handlers[event_type] = []
self._event_handlers[event_type].append(handler)
logger.info(f'注册事件处理器: {event_type}')
async def _emit_event(self, event_type: str, **kwargs):
"""触发事件"""
if event_type in self._event_handlers:
for handler in self._event_handlers[event_type]:
try:
if asyncio.iscoroutinefunction(handler):
await handler(**kwargs)
else:
handler(**kwargs)
except Exception as e:
logger.error(f'事件处理器执行失败 {event_type}: {e}')
def add_session(self, session_id: str, websocket: web.WebSocketResponse) -> WebSocketSession:
"""添加WebSocket会话"""
with self._lock:
# 检查是否已存在相同的websocket连接
if websocket in self._websockets:
existing_session = self._websockets[websocket]
logger.warning(f'[Session:{session_id}] WebSocket连接已存在 (WebSocket={id(websocket)}, 原Session={existing_session.session_id})')
return existing_session
session = WebSocketSession(session_id, websocket)
# 初始化会话集合
if session_id not in self._sessions:
self._sessions[session_id] = set()
# 检查Set添加前后的大小变化
before_count = len(self._sessions[session_id])
self._sessions[session_id].add(session)
after_count = len(self._sessions[session_id])
self._websockets[websocket] = session
logger.info(f'[Session:{session_id}] 添加WebSocket会话 (WebSocket={id(websocket)}), 连接数变化: {before_count} -> {after_count}')
# 如果Set大小没有变化,说明可能存在重复
if before_count == after_count:
logger.warning(f'[Session:{session_id}] 检测到可能的重复会话添加!Set大小未变化')
return session
def remove_session(self, websocket: web.WebSocketResponse):
"""移除WebSocket会话"""
with self._lock:
if websocket in self._websockets:
session = self._websockets[websocket]
session_id = session.session_id
# 从会话集合中移除
if session_id in self._sessions:
self._sessions[session_id].discard(session)
if not self._sessions[session_id]: # 如果集合为空,删除键
del self._sessions[session_id]
# 从websocket映射中移除
del self._websockets[websocket]
logger.info(f'[Session:{session_id}] 移除WebSocket会话')
return session
return None
def get_session(self, websocket: web.WebSocketResponse) -> Optional[WebSocketSession]:
"""获取WebSocket会话"""
return self._websockets.get(websocket)
def get_sessions_by_id(self, session_id: str) -> Set[WebSocketSession]:
"""根据会话ID获取所有WebSocket会话"""
with self._lock:
# 尝试使用原始session_id查找
sessions = self._sessions.get(session_id, set())
if sessions:
return sessions.copy()
# 如果是字符串类型但存储的是整数类型,尝试转换
if isinstance(session_id, str) and session_id.isdigit():
int_session_id = int(session_id)
sessions = self._sessions.get(int_session_id, set())
if sessions:
return sessions.copy()
# 如果是整数类型但存储的是字符串类型,尝试转换
elif isinstance(session_id, int):
str_session_id = str(session_id)
sessions = self._sessions.get(str_session_id, set())
if sessions:
return sessions.copy()
return set()
def _update_session_id(self, websocket: web.WebSocketResponse, old_session_id: str, new_session_id: str):
"""更新WebSocket会话的session_id"""
with self._lock:
if websocket in self._websockets:
session = self._websockets[websocket]
# 从旧的session_id集合中移除
if old_session_id in self._sessions:
self._sessions[old_session_id].discard(session)
if not self._sessions[old_session_id]: # 如果集合为空,删除键
del self._sessions[old_session_id]
# 更新session的session_id
session.session_id = new_session_id
# 添加到新的session_id集合
if new_session_id not in self._sessions:
self._sessions[new_session_id] = set()
self._sessions[new_session_id].add(session)
logger.info(f'[Session] 更新会话ID: {old_session_id} -> {new_session_id}')
return True
return False
async def broadcast_raw_message_to_session(self, session_id: str, message: Dict,source: str = "原数据") -> int:
"""直接广播原始消息到指定会话的所有WebSocket连接"""
# 确保session_id为字符串类型,保持一致性
# 确保session_id为字符串类型,保持一致性
if isinstance(session_id, int):
session_id = str(session_id)
elif not isinstance(session_id, str):
session_id = str(session_id)
sessions = self.get_sessions_by_id(session_id)
if not sessions:
logger.warning(f'[Session:{session_id}] 没有找到WebSocket连接')
return 0
# 详细调试日志:显示会话详情
logger.info(f'[Session:{session_id}] 开始广播消息,找到 {len(sessions)} 个连接')
for i, session in enumerate(sessions):
logger.info(f'[Session:{session_id}] 连接{i+1}: WebSocket={id(session.websocket)}, 创建时间={session.created_at}, 存活状态={session.is_alive()}')
success_count = 0
failed_sessions = []
for i, session in enumerate(sessions):
logger.info(f'[Session:{session_id}] 正在向连接{i+1}发送消息 (WebSocket={id(session.websocket)})')
if await session.send_message(message):
success_count += 1
logger.info(f'[Session:{session_id}] 连接{i+1}发送成功')
else:
failed_sessions.append(session)
logger.warning(f'[Session:{session_id}] 连接{i+1}发送失败')
# 清理失败的连接
for session in failed_sessions:
self.remove_session(session.websocket)
logger.info(f'[Session:{session_id}] 广播原始消息完成: 成功{success_count}/总计{len(sessions)}, 失败{len(failed_sessions)}')
return success_count
async def broadcast_to_session(self, session_id: str, message_type: str, content: Any,
source: str = "系统", metadata: Dict = None) -> int:
"""向指定会话的所有WebSocket连接广播消息"""
# 确保session_id为字符串类型,保持一致性
if isinstance(session_id, int):
session_id = str(session_id)
elif not isinstance(session_id, str):
session_id = str(session_id)
sessions = self.get_sessions_by_id(session_id)
if not sessions:
logger.warning(f'[Session:{session_id}] 没有找到WebSocket连接')
return 0
message = {
"type": message_type,
"session_id": session_id,
"content": content,
"source": source,
"timestamp": time.time(),
**(metadata or {})
}
success_count = 0
failed_sessions = []
for session in sessions:
if await session.send_message(message):
success_count += 1
else:
failed_sessions.append(session)
# 清理失败的连接
for session in failed_sessions:
self.remove_session(session.websocket)
logger.info(f'[Session:{session_id}] 广播消息成功: {success_count}/{len(sessions)}')
return success_count
async def broadcast_to_all(self, message_type: str, content: Any,
source: str = "系统", metadata: Dict = None) -> int:
"""向所有WebSocket连接广播消息"""
total_sent = 0
with self._lock:
session_ids = list(self._sessions.keys())
for session_id in session_ids:
sent = await self.broadcast_to_session(session_id, message_type, content, source, metadata)
total_sent += sent
logger.info(f'全局广播消息完成,总发送数: {total_sent}')
return total_sent
def get_session_count(self) -> int:
"""获取会话总数"""
with self._lock:
return len(self._sessions)
def get_connection_count(self) -> int:
"""获取连接总数"""
with self._lock:
return len(self._websockets)
def get_session_stats(self) -> Dict[str, Any]:
"""获取会话统计信息"""
with self._lock:
stats = {
"total_sessions": len(self._sessions),
"total_connections": len(self._websockets),
"session_details": {}
}
for session_id, sessions in self._sessions.items():
stats["session_details"][session_id] = {
"connection_count": len(sessions),
"connections": [
{
"created_at": session.created_at,
"last_ping": session.last_ping,
"is_alive": session.is_alive(),
"metadata": session.metadata
} for session in sessions
]
}
return stats
async def _handle_ping(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理心跳消息"""
session = self.get_session(websocket)
if session:
session.update_ping()
await session.send_message({"type": "pong", "timestamp": time.time()})
async def _handle_login(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理登录消息"""
session_id = data.get('session_id', data.get('sessionid', str(int(time.time()))))
# 确保session_id为字符串类型,避免类型不一致问题
if isinstance(session_id, int):
session_id = str(session_id)
elif not isinstance(session_id, str):
session_id = str(session_id)
# 添加会话
session = self.add_session(session_id, websocket)
# 触发连接事件
await self._emit_event('session_connected', session=session, data=data)
# 发送登录确认
await session.send_message({
"type": "login_success",
"data": {
"session_id": session_id,
"message": "WebSocket连接成功",
"timestamp": time.time()
}
})
async def handle_websocket_message(self, websocket: web.WebSocketResponse, data: Dict[str, Any]):
"""处理WebSocket消息"""
message_type = data.get('type')
if message_type in self._message_handlers:
try:
await self._message_handlers[message_type](websocket, data)
except Exception as e:
logger.error(f'消息处理器执行失败 {message_type}: {e}')
session = self.get_session(websocket)
if session:
await session.send_message({
"type": "error",
"data": {
"message": f"消息处理失败: {str(e)}",
"original_type": message_type
}
})
else:
logger.warning(f'未知消息类型: {message_type}')
async def websocket_handler(self, request) -> web.WebSocketResponse:
"""WebSocket连接处理器"""
ws = web.WebSocketResponse()
await ws.prepare(request)
logger.info('新的WebSocket连接建立')
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
await self.handle_websocket_message(ws, data)
except json.JSONDecodeError:
logger.error('收到无效的JSON数据')
except Exception as e:
logger.error(f'处理WebSocket消息时出错: {e}')
elif msg.type == WSMsgType.ERROR:
logger.error(f'WebSocket错误: {ws.exception()}')
break
elif msg.type == WSMsgType.CLOSE:
logger.info('WebSocket连接正常关闭')
break
except ConnectionResetError:
logger.warning('WebSocket连接被远程主机重置')
except ConnectionAbortedError:
logger.warning('WebSocket连接被中止')
except Exception as e:
logger.error(f'WebSocket连接错误: {e}')
finally:
# 清理会话
session = self.remove_session(ws)
if session:
await self._emit_event('session_disconnected', session=session)
logger.info('WebSocket连接已关闭')
return ws
def get_expired_sessions(self, timeout: int = 60) -> List[WebSocketSession]:
"""获取过期的会话列表"""
current_time = time.time()
expired_sessions = []
with self._lock:
for session in self._websockets.values():
if current_time - session.last_ping > timeout:
expired_sessions.append(session)
return expired_sessions
async def cleanup_dead_connections(self):
"""清理死连接"""
dead_websockets = []
with self._lock:
for websocket, session in self._websockets.items():
if not session.is_alive():
dead_websockets.append(websocket)
for websocket in dead_websockets:
self.remove_session(websocket)
if dead_websockets:
logger.info(f'清理了 {len(dead_websockets)} 个死连接')
return len(dead_websockets)
# 全局统一WebSocket管理器实例
_unified_manager = UnifiedWebSocketManager()
def get_unified_manager() -> UnifiedWebSocketManager:
"""获取统一WebSocket管理器实例"""
return _unified_manager
# 兼容性接口,保持与原有代码的兼容
async def broadcast_message_to_session(session_id: str, message_type: str, content: str,
source: str = "数字人回复", model_info: str = None,
request_source: str = "页面"):
"""兼容性接口:向指定会话广播消息"""
metadata = {}
if model_info:
metadata['model_info'] = model_info
if request_source:
metadata['request_source'] = request_source
return await _unified_manager.broadcast_to_session(
session_id, message_type, content, source, metadata
)