test_funasr_integration.py 9.3 KB
# -*- coding: utf-8 -*-
"""
AIfeng/2025-01-27
FunASR集成测试脚本
测试新的FunASRClient与项目的集成效果
"""

import os
import sys
import time
import threading
from pathlib import Path

# 添加项目路径
sys.path.append(os.path.dirname(__file__))

from funasr_asr import FunASRClient
from web.asr.funasr import FunASR
import util

class TestFunASRIntegration:
    """FunASR集成测试类"""
    
    def __init__(self):
        self.test_results = []
        self.test_audio_files = [
            "yunxi.mp3",
            "yunxia.mp3", 
            "yunyang.mp3"
        ]
    
    def log_test_result(self, test_name: str, success: bool, message: str = ""):
        """记录测试结果"""
        status = "✓ 通过" if success else "✗ 失败"
        result = f"[{status}] {test_name}"
        if message:
            result += f" - {message}"
        
        self.test_results.append((test_name, success, message))
        print(result)
    
    def test_funasr_client_creation(self):
        """测试FunASRClient创建"""
        try:
            class SimpleOpt:
                def __init__(self):
                    self.username = "test_user"
            
            opt = SimpleOpt()
            client = FunASRClient(opt)
            
            # 检查基本属性
            assert hasattr(client, 'server_url')
            assert hasattr(client, 'connected')
            assert hasattr(client, 'running')
            
            self.log_test_result("FunASRClient创建", True, "客户端创建成功")
            return client
            
        except Exception as e:
            self.log_test_result("FunASRClient创建", False, f"错误: {e}")
            return None
    
    def test_compatibility_wrapper(self):
        """测试兼容性包装器"""
        try:
            funasr = FunASR("test_user")
            
            # 检查兼容性方法
            assert hasattr(funasr, 'start')
            assert hasattr(funasr, 'end')
            assert hasattr(funasr, 'send')
            assert hasattr(funasr, 'add_frame')
            assert hasattr(funasr, 'set_message_callback')
            
            self.log_test_result("兼容性包装器", True, "所有兼容性方法存在")
            return funasr
            
        except Exception as e:
            self.log_test_result("兼容性包装器", False, f"错误: {e}")
            return None
    
    def test_callback_mechanism(self):
        """测试回调机制"""
        try:
            funasr = FunASR("test_user")
            callback_called = threading.Event()
            received_message = []
            
            def test_callback(message):
                received_message.append(message)
                callback_called.set()
            
            funasr.set_message_callback(test_callback)
            
            # 模拟接收消息
            test_message = "测试识别结果"
            funasr._handle_result(test_message)
            
            # 等待回调
            if callback_called.wait(timeout=1.0):
                if received_message and received_message[0] == test_message:
                    self.log_test_result("回调机制", True, "回调函数正常工作")
                else:
                    self.log_test_result("回调机制", False, "回调消息不匹配")
            else:
                self.log_test_result("回调机制", False, "回调超时")
                
        except Exception as e:
            self.log_test_result("回调机制", False, f"错误: {e}")
    
    def test_audio_file_existence(self):
        """测试音频文件存在性"""
        existing_files = []
        missing_files = []
        
        for audio_file in self.test_audio_files:
            if os.path.exists(audio_file):
                existing_files.append(audio_file)
            else:
                missing_files.append(audio_file)
        
        if existing_files:
            self.log_test_result(
                "音频文件检查", 
                True, 
                f"找到 {len(existing_files)} 个文件: {', '.join(existing_files)}"
            )
        
        if missing_files:
            self.log_test_result(
                "音频文件缺失", 
                False, 
                f"缺少 {len(missing_files)} 个文件: {', '.join(missing_files)}"
            )
        
        return existing_files
    
    def test_connection_simulation(self):
        """测试连接模拟"""
        try:
            client = self.test_funasr_client_creation()
            if not client:
                return
            
            # 测试启动和停止
            client.start()
            time.sleep(0.5)  # 给连接一些时间
            
            # 检查运行状态
            if client.running:
                self.log_test_result("客户端启动", True, "客户端成功启动")
            else:
                self.log_test_result("客户端启动", False, "客户端启动失败")
            
            # 停止客户端
            client.stop()
            time.sleep(0.5)
            
            if not client.running:
                self.log_test_result("客户端停止", True, "客户端成功停止")
            else:
                self.log_test_result("客户端停止", False, "客户端停止失败")
                
        except Exception as e:
            self.log_test_result("连接模拟", False, f"错误: {e}")
    
    def test_message_queue(self):
        """测试消息队列"""
        try:
            client = self.test_funasr_client_creation()
            if not client:
                return
            
            # 测试消息入队
            test_message = {"test": "message"}
            client.message_queue.put(test_message)
            
            # 检查队列
            if not client.message_queue.empty():
                retrieved_message = client.message_queue.get_nowait()
                if retrieved_message == test_message:
                    self.log_test_result("消息队列", True, "消息队列正常工作")
                else:
                    self.log_test_result("消息队列", False, "消息内容不匹配")
            else:
                self.log_test_result("消息队列", False, "消息队列为空")
                
        except Exception as e:
            self.log_test_result("消息队列", False, f"错误: {e}")
    
    def test_config_loading(self):
        """测试配置加载"""
        try:
            import config_util as cfg
            
            # 检查关键配置项
            required_configs = [
                'local_asr_ip',
                'local_asr_port',
                'asr_timeout',
                'asr_reconnect_delay',
                'asr_max_reconnect_attempts'
            ]
            
            missing_configs = []
            for config_key in required_configs:
                try:
                    if hasattr(cfg, 'config'):
                        value = cfg.config.get(config_key)
                    else:
                        value = getattr(cfg, config_key, None)
                    if value is None:
                        missing_configs.append(config_key)
                except:
                    missing_configs.append(config_key)
            
            if not missing_configs:
                self.log_test_result("配置加载", True, "所有必需配置项存在")
            else:
                self.log_test_result(
                    "配置加载", 
                    False, 
                    f"缺少配置项: {', '.join(missing_configs)}"
                )
                
        except Exception as e:
            self.log_test_result("配置加载", False, f"错误: {e}")
    
    def run_all_tests(self):
        """运行所有测试"""
        print("\n" + "="*60)
        print("FunASR集成测试开始")
        print("="*60)
        
        # 运行各项测试
        self.test_config_loading()
        self.test_funasr_client_creation()
        self.test_compatibility_wrapper()
        self.test_callback_mechanism()
        self.test_message_queue()
        self.test_audio_file_existence()
        self.test_connection_simulation()
        
        # 输出测试总结
        print("\n" + "="*60)
        print("测试总结")
        print("="*60)
        
        passed_tests = sum(1 for _, success, _ in self.test_results if success)
        total_tests = len(self.test_results)
        
        print(f"总测试数: {total_tests}")
        print(f"通过测试: {passed_tests}")
        print(f"失败测试: {total_tests - passed_tests}")
        print(f"成功率: {passed_tests/total_tests*100:.1f}%")
        
        # 显示失败的测试
        failed_tests = [(name, msg) for name, success, msg in self.test_results if not success]
        if failed_tests:
            print("\n失败的测试:")
            for name, msg in failed_tests:
                print(f"  - {name}: {msg}")
        
        print("\n" + "="*60)
        
        return passed_tests == total_tests

def main():
    """主函数"""
    tester = TestFunASRIntegration()
    success = tester.run_all_tests()
    
    if success:
        print("\n🎉 所有测试通过!FunASR集成准备就绪。")
    else:
        print("\n⚠️  部分测试失败,请检查相关配置和依赖。")
    
    return 0 if success else 1

if __name__ == "__main__":
    exit(main())