wsa_server.py 2.75 KB
# -*- coding: utf-8 -*-
"""
AIfeng/2025-01-27
WebSocket服务器管理模块
提供Web和Human连接的管理功能
"""

import queue
from typing import Dict, Any, Optional
from threading import Lock

class WebSocketManager:
    """WebSocket连接管理器"""
    
    def __init__(self):
        self._connections = {}
        self._command_queue = queue.Queue()
        self._lock = Lock()
    
    def is_connected(self, username: str) -> bool:
        """检查用户是否已连接
        
        Args:
            username: 用户名
            
        Returns:
            是否已连接
        """
        with self._lock:
            return username in self._connections
    
    def is_connected_human(self, username: str) -> bool:
        """检查人类用户是否已连接
        
        Args:
            username: 用户名
            
        Returns:
            是否已连接
        """
        # 简化实现,与is_connected相同
        return self.is_connected(username)
    
    def add_connection(self, username: str, connection: Any):
        """添加连接
        
        Args:
            username: 用户名
            connection: 连接对象
        """
        with self._lock:
            self._connections[username] = connection
    
    def remove_connection(self, username: str):
        """移除连接
        
        Args:
            username: 用户名
        """
        with self._lock:
            self._connections.pop(username, None)
    
    def add_cmd(self, command: Dict[str, Any]):
        """添加命令到队列
        
        Args:
            command: 命令字典
        """
        try:
            self._command_queue.put(command, timeout=1.0)
        except queue.Full:
            print(f"警告: 命令队列已满,丢弃命令: {command}")
    
    def get_cmd(self, timeout: float = 1.0) -> Optional[Dict[str, Any]]:
        """从队列获取命令
        
        Args:
            timeout: 超时时间
            
        Returns:
            命令字典或None
        """
        try:
            return self._command_queue.get(timeout=timeout)
        except queue.Empty:
            return None
    
    def get_connection_count(self) -> int:
        """获取连接数量"""
        with self._lock:
            return len(self._connections)
    
    def get_usernames(self) -> list:
        """获取所有用户名列表"""
        with self._lock:
            return list(self._connections.keys())

# 全局实例
_web_instance = WebSocketManager()
_human_instance = WebSocketManager()

def get_web_instance() -> WebSocketManager:
    """获取Web WebSocket管理器实例"""
    return _web_instance

def get_instance() -> WebSocketManager:
    """获取Human WebSocket管理器实例"""
    return _human_instance