test_funasr_connection.py 12.1 KB
# AIfeng/2025-01-27
"""
FunASR服务连接测试脚本
用于验证本地FunASR WebSocket服务是否可以正常连接

使用方法:
1. 先启动FunASR服务:python -u web/asr/funasr/ASR_server.py --host "127.0.0.1" --port 10197 --ngpu 0
2. 运行此测试脚本:python test_funasr_connection.py
"""

import asyncio
import websockets
import json
import os
import wave
import numpy as np
from pathlib import Path

class FunASRConnectionTest:
    def __init__(self, host="127.0.0.1", port=10197):
        self.host = host
        self.port = port
        self.uri = f"ws://{host}:{port}"
        
    async def test_basic_connection(self):
        """测试基本WebSocket连接"""
        print(f"🔍 测试连接到 {self.uri}")
        try:
            async with websockets.connect(self.uri) as websocket:
                print("✅ FunASR WebSocket服务连接成功")
                return True
        except ConnectionRefusedError:
            print("❌ 连接被拒绝,请确认FunASR服务已启动")
            print("   启动命令: python -u web/asr/funasr/ASR_server.py --host \"127.0.0.1\" --port 10197 --ngpu 0")
            return False
        except Exception as e:
            print(f"❌ 连接失败: {e}")
            return False
    
    def create_test_wav(self, filename="test_audio.wav", duration=2, sample_rate=16000):
        """创建测试用的WAV文件"""
        # 生成简单的正弦波音频
        t = np.linspace(0, duration, int(sample_rate * duration), False)
        frequency = 440  # A4音符
        audio_data = np.sin(2 * np.pi * frequency * t) * 0.3
        
        # 转换为16位整数
        audio_data = (audio_data * 32767).astype(np.int16)
        
        # 保存为WAV文件
        with wave.open(filename, 'wb') as wav_file:
            wav_file.setnchannels(1)  # 单声道
            wav_file.setsampwidth(2)  # 16位
            wav_file.setframerate(sample_rate)
            wav_file.writeframes(audio_data.tobytes())
        
        print(f"📁 创建测试音频文件: {filename}")
        return filename
    
    async def test_audio_recognition(self):
        """测试音频识别功能"""
        print("\n🎵 测试音频识别功能")
        
        # 创建测试音频文件
        test_file = self.create_test_wav()
        test_file_path = os.path.abspath(test_file)
        
        try:
            async with websockets.connect(self.uri) as websocket:
                print("✅ 连接成功,发送音频文件路径")
                
                # 发送音频文件路径
                message = {"url": test_file_path}
                await websocket.send(json.dumps(message))
                print(f"📤 发送消息: {message}")
                
                # 等待识别结果
                try:
                    response = await asyncio.wait_for(websocket.recv(), timeout=10)
                    print(f"📥 收到识别结果: {response}")
                    return True
                except asyncio.TimeoutError:
                    print("⏰ 等待响应超时(10秒)")
                    print("   这可能是正常的,因为测试音频是纯音调,无法识别为文字")
                    return True  # 超时也算连接成功
                    
        except Exception as e:
            print(f"❌ 音频识别测试失败: {e}")
            return False
        finally:
            # 清理测试文件
            if os.path.exists(test_file):
                os.remove(test_file)
                print(f"🗑️ 清理测试文件: {test_file}")
    
    async def test_real_audio_files(self):
        """测试实际音频文件的识别效果"""
        print("\n🎤 测试实际音频文件识别")
        
        # 实际音频文件列表
        audio_files = [
            "yunxi.mp3",
            "yunxia.mp3", 
            "yunyang.mp3"
        ]
        
        results = []
        
        for audio_file in audio_files:
            file_path = os.path.abspath(audio_file)
            
            # 检查文件是否存在
            if not os.path.exists(file_path):
                print(f"⚠️ 音频文件不存在: {file_path}")
                continue
                
            print(f"\n🎵 测试音频文件: {audio_file}")
            
            try:
                async with websockets.connect(self.uri) as websocket:
                    print(f"✅ 连接成功,发送音频文件: {audio_file}")
                    
                    # 发送音频文件路径
                    message = {"url": file_path}
                    await websocket.send(json.dumps(message))
                    print(f"📤 发送消息: {message}")
                    
                    # 等待识别结果
                    try:
                        response = await asyncio.wait_for(websocket.recv(), timeout=30)
                        print(f"📥 识别结果: {response}")
                        
                        # 解析响应
                        try:
                            result_data = json.loads(response)
                            if isinstance(result_data, dict) and 'text' in result_data:
                                recognized_text = result_data['text']
                                print(f"🎯 识别文本: {recognized_text}")
                                results.append({
                                    'file': audio_file,
                                    'text': recognized_text,
                                    'status': 'success'
                                })
                            else:
                                print(f"📄 原始响应: {response}")
                                results.append({
                                    'file': audio_file,
                                    'response': response,
                                    'status': 'received'
                                })
                        except json.JSONDecodeError:
                            print(f"📄 非JSON响应: {response}")
                            results.append({
                                'file': audio_file,
                                'response': response,
                                'status': 'received'
                            })
                            
                    except asyncio.TimeoutError:
                        print(f"⏰ 等待响应超时(30秒)- {audio_file}")
                        results.append({
                            'file': audio_file,
                            'status': 'timeout'
                        })
                        
            except Exception as e:
                print(f"❌ 测试 {audio_file} 失败: {e}")
                results.append({
                    'file': audio_file,
                    'error': str(e),
                    'status': 'error'
                })
                
            # 文件间等待,避免服务器压力
            await asyncio.sleep(1)
        
        # 输出测试总结
        print("\n" + "="*50)
        print("📊 实际音频文件测试总结:")
        for i, result in enumerate(results, 1):
            print(f"\n{i}. 文件: {result['file']}")
            if result['status'] == 'success':
                print(f"   ✅ 识别成功: {result['text']}")
            elif result['status'] == 'received':
                print(f"   📥 收到响应: {result.get('response', 'N/A')}")
            elif result['status'] == 'timeout':
                print(f"   ⏰ 响应超时")
            elif result['status'] == 'error':
                print(f"   ❌ 测试失败: {result.get('error', 'N/A')}")
        
        return len(results) > 0
    
    async def test_message_format(self):
        """测试消息格式兼容性"""
        print("\n📋 测试消息格式兼容性")
        
        try:
            async with websockets.connect(self.uri) as websocket:
                # 测试不同的消息格式
                test_messages = [
                    {"url": "nonexistent.wav"},
                    {"test": "message"},
                    "invalid_json"
                ]
                
                for i, msg in enumerate(test_messages, 1):
                    try:
                        if isinstance(msg, dict):
                            await websocket.send(json.dumps(msg))
                            print(f"✅ 消息 {i} 发送成功: {msg}")
                        else:
                            await websocket.send(msg)
                            print(f"✅ 消息 {i} 发送成功: {msg}")
                        
                        # 短暂等待,避免消息堆积
                        await asyncio.sleep(0.5)
                        
                    except Exception as e:
                        print(f"⚠️ 消息 {i} 发送失败: {e}")
                
                return True
                
        except Exception as e:
            print(f"❌ 消息格式测试失败: {e}")
            return False
    
    def check_dependencies(self):
        """检查依赖项"""
        print("🔍 检查依赖项...")
        
        required_modules = [
            'websockets',
            'asyncio', 
            'json',
            'wave',
            'numpy'
        ]
        
        missing_modules = []
        for module in required_modules:
            try:
                __import__(module)
                print(f"✅ {module}")
            except ImportError:
                print(f"❌ {module} (缺失)")
                missing_modules.append(module)
        
        if missing_modules:
            print(f"\n⚠️ 缺失依赖项: {', '.join(missing_modules)}")
            print("安装命令: pip install " + ' '.join(missing_modules))
            return False
        
        print("✅ 所有依赖项检查通过")
        return True
    
    def check_funasr_server_file(self):
        """检查FunASR服务器文件是否存在"""
        print("\n📁 检查FunASR服务器文件...")
        
        server_path = Path("web/asr/funasr/ASR_server.py")
        if server_path.exists():
            print(f"✅ 找到服务器文件: {server_path.absolute()}")
            return True
        else:
            print(f"❌ 未找到服务器文件: {server_path.absolute()}")
            print("   请确认文件路径是否正确")
            return False
    
    async def run_all_tests(self):
        """运行所有测试"""
        print("🚀 开始FunASR连接测试\n")
        
        # 检查依赖
        if not self.check_dependencies():
            return False
        
        # 检查服务器文件
        if not self.check_funasr_server_file():
            return False
        
        # 基本连接测试
        print("\n" + "="*50)
        if not await self.test_basic_connection():
            return False
        
        # 音频识别测试
        print("\n" + "="*50)
        if not await self.test_audio_recognition():
            return False
        
        # 实际音频文件测试
        print("\n" + "="*50)
        await self.test_real_audio_files()
        
        # 消息格式测试
        print("\n" + "="*50)
        if not await self.test_message_format():
            return False
        
        print("\n" + "="*50)
        print("🎉 所有测试完成!FunASR服务连接正常")
        print("\n💡 集成建议:")
        print("   1. 服务使用WebSocket协议,非gRPC")
        print("   2. 默认监听端口: 10197")
        print("   3. 消息格式: JSON字符串,包含'url'字段指向音频文件路径")
        print("   4. 可以集成到现有项目的ASR模块中")
        
        return True

async def main():
    """主函数"""
    tester = FunASRConnectionTest()
    success = await tester.run_all_tests()
    
    if not success:
        print("\n❌ 测试失败,请检查FunASR服务状态")
        return 1
    
    return 0

if __name__ == "__main__":
    try:
        exit_code = asyncio.run(main())
        exit(exit_code)
    except KeyboardInterrupt:
        print("\n⏹️ 测试被用户中断")
        exit(1)
    except Exception as e:
        print(f"\n💥 测试过程中发生错误: {e}")
        exit(1)