test_unified_websocket_architecture.py 10.9 KB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AIfeng/2025-01-27 16:52:46
统一WebSocket架构测试
验证新架构的核心功能是否正常工作
"""

import asyncio
import json
import logging
from typing import Dict, Any
from aiohttp import web, WSMsgType, ClientSession

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class UnifiedWebSocketTester:
    """统一WebSocket架构测试器"""
    
    def __init__(self, server_url: str = "ws://localhost:8010/ws"):
        self.server_url = server_url
        self.session: ClientSession = None
        self.websocket = None
        self.received_messages = []
        
    async def connect(self):
        """连接到WebSocket服务器"""
        try:
            self.session = ClientSession()
            self.websocket = await self.session.ws_connect(self.server_url)
            logger.info(f"已连接到 {self.server_url}")
            return True
        except Exception as e:
            logger.error(f"连接失败: {e}")
            return False
            
    async def disconnect(self):
        """断开连接"""
        if self.websocket:
            await self.websocket.close()
        if self.session:
            await self.session.close()
        logger.info("已断开连接")
        
    async def send_message(self, message: Dict[str, Any]):
        """发送消息"""
        if not self.websocket:
            logger.error("WebSocket未连接")
            return False
            
        try:
            await self.websocket.send_str(json.dumps(message))
            logger.info(f"发送消息: {message['type']}")
            return True
        except Exception as e:
            logger.error(f"发送消息失败: {e}")
            return False
            
    async def receive_messages(self, timeout: float = 5.0):
        """接收消息"""
        if not self.websocket:
            return []
            
        messages = []
        try:
            async for msg in self.websocket:
                if msg.type == WSMsgType.TEXT:
                    data = json.loads(msg.data)
                    messages.append(data)
                    logger.info(f"收到消息: {data.get('type', 'unknown')}")
                elif msg.type == WSMsgType.ERROR:
                    logger.error(f"WebSocket错误: {self.websocket.exception()}")
                    break
                elif msg.type == WSMsgType.CLOSE:
                    logger.info("WebSocket连接已关闭")
                    break
                    
                # 简单超时控制
                if len(messages) >= 10:  # 限制接收消息数量
                    break
                    
        except Exception as e:
            logger.error(f"接收消息失败: {e}")
            
        return messages
        
    async def test_login(self, session_id: str = "test_session_001"):
        """测试登录功能"""
        logger.info("=== 测试登录功能 ===")
        
        login_message = {
            "type": "login",
            "sessionid": session_id,
            "data": {
                "user_id": "test_user",
                "client_type": "test_client"
            }
        }
        
        success = await self.send_message(login_message)
        if success:
            # 等待响应
            await asyncio.sleep(1)
            logger.info("✅ 登录消息发送成功")
        else:
            logger.error("❌ 登录消息发送失败")
            
        return success
        
    async def test_heartbeat(self):
        """测试心跳功能"""
        logger.info("=== 测试心跳功能 ===")
        
        heartbeat_message = {
            "type": "heartbeat",
            "timestamp": asyncio.get_event_loop().time()
        }
        
        success = await self.send_message(heartbeat_message)
        if success:
            await asyncio.sleep(1)
            logger.info("✅ 心跳消息发送成功")
        else:
            logger.error("❌ 心跳消息发送失败")
            
        return success
        
    async def test_asr_functionality(self):
        """测试ASR功能"""
        logger.info("=== 测试ASR功能 ===")
        
        # 测试开始ASR识别
        start_asr_message = {
            "type": "start_asr_recognition",
            "data": {
                "language": "zh-CN",
                "sample_rate": 16000
            }
        }
        
        success = await self.send_message(start_asr_message)
        if success:
            await asyncio.sleep(1)
            logger.info("✅ ASR开始识别消息发送成功")
        else:
            logger.error("❌ ASR开始识别消息发送失败")
            return False
            
        # 测试发送音频数据
        audio_data_message = {
            "type": "asr_audio_data",
            "data": {
                "audio_data": "fake_audio_data_base64",
                "format": "wav"
            }
        }
        
        success = await self.send_message(audio_data_message)
        if success:
            await asyncio.sleep(1)
            logger.info("✅ ASR音频数据发送成功")
        else:
            logger.error("❌ ASR音频数据发送失败")
            
        # 测试停止ASR识别
        stop_asr_message = {
            "type": "stop_asr_recognition"
        }
        
        success = await self.send_message(stop_asr_message)
        if success:
            await asyncio.sleep(1)
            logger.info("✅ ASR停止识别消息发送成功")
        else:
            logger.error("❌ ASR停止识别消息发送失败")
            
        return True
        
    async def test_digital_human_functionality(self):
        """测试数字人功能"""
        logger.info("=== 测试数字人功能 ===")
        
        # 测试注册数字人
        register_message = {
            "type": "register_digital_human",
            "data": {
                "human_id": "test_human_001",
                "name": "测试数字人",
                "capabilities": ["speak", "gesture", "emotion"]
            }
        }
        
        success = await self.send_message(register_message)
        if success:
            await asyncio.sleep(1)
            logger.info("✅ 数字人注册消息发送成功")
        else:
            logger.error("❌ 数字人注册消息发送失败")
            return False
            
        # 测试数字人说话
        speak_message = {
            "type": "digital_human_speak",
            "data": {
                "human_id": "test_human_001",
                "text": "你好,我是测试数字人",
                "voice_id": "default"
            }
        }
        
        success = await self.send_message(speak_message)
        if success:
            await asyncio.sleep(1)
            logger.info("✅ 数字人说话消息发送成功")
        else:
            logger.error("❌ 数字人说话消息发送失败")
            
        # 测试获取数字人列表
        list_message = {
            "type": "get_digital_humans"
        }
        
        success = await self.send_message(list_message)
        if success:
            await asyncio.sleep(1)
            logger.info("✅ 获取数字人列表消息发送成功")
        else:
            logger.error("❌ 获取数字人列表消息发送失败")
            
        return True
        
    async def test_wsa_functionality(self):
        """测试WSA功能"""
        logger.info("=== 测试WSA功能 ===")
        
        # 测试注册Web连接
        register_web_message = {
            "type": "wsa_register_web",
            "data": {
                "username": "test_web_user"
            }
        }
        
        success = await self.send_message(register_web_message)
        if success:
            await asyncio.sleep(1)
            logger.info("✅ WSA Web注册消息发送成功")
        else:
            logger.error("❌ WSA Web注册消息发送失败")
            return False
            
        # 测试获取WSA状态
        status_message = {
            "type": "wsa_get_status"
        }
        
        success = await self.send_message(status_message)
        if success:
            await asyncio.sleep(1)
            logger.info("✅ WSA状态查询消息发送成功")
        else:
            logger.error("❌ WSA状态查询消息发送失败")
            
        return True
        
    async def run_comprehensive_test(self):
        """运行综合测试"""
        logger.info("🚀 开始统一WebSocket架构综合测试")
        
        # 连接到服务器
        if not await self.connect():
            logger.error("❌ 无法连接到服务器,测试终止")
            return False
            
        try:
            # 启动消息接收任务
            receive_task = asyncio.create_task(self.receive_messages())
            
            # 执行各项功能测试
            test_results = []
            
            # 基础功能测试
            test_results.append(await self.test_login())
            test_results.append(await self.test_heartbeat())
            
            # 服务功能测试
            test_results.append(await self.test_asr_functionality())
            test_results.append(await self.test_digital_human_functionality())
            test_results.append(await self.test_wsa_functionality())
            
            # 等待接收响应
            await asyncio.sleep(2)
            
            # 取消接收任务
            receive_task.cancel()
            try:
                await receive_task
            except asyncio.CancelledError:
                pass
                
            # 统计测试结果
            passed_tests = sum(test_results)
            total_tests = len(test_results)
            
            logger.info(f"\n📊 测试结果统计:")
            logger.info(f"   总测试数: {total_tests}")
            logger.info(f"   通过测试: {passed_tests}")
            logger.info(f"   失败测试: {total_tests - passed_tests}")
            logger.info(f"   成功率: {passed_tests/total_tests*100:.1f}%")
            
            if passed_tests == total_tests:
                logger.info("🎉 所有测试通过!统一WebSocket架构工作正常")
                return True
            else:
                logger.warning("⚠️ 部分测试失败,请检查服务器状态")
                return False
                
        except Exception as e:
            logger.error(f"测试过程中发生错误: {e}")
            return False
        finally:
            await self.disconnect()
            

async def main():
    """主函数"""
    tester = UnifiedWebSocketTester()
    success = await tester.run_comprehensive_test()
    
    if success:
        logger.info("\n✅ 统一WebSocket架构测试完成 - 所有功能正常")
    else:
        logger.error("\n❌ 统一WebSocket架构测试完成 - 发现问题")
        
    return success


if __name__ == "__main__":
    # 运行测试
    result = asyncio.run(main())
    exit(0 if result else 1)