test_funasr_protocol_fix.py 17.3 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
# AIfeng/2025-07-17 17:04:42
"""
FunASR协议兼容性修复测试脚本
测试ASR_server.py的分块协议支持
"""

import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import asyncio
import websockets
import json
import base64
import time
import numpy as np
from funasr_asr_sync import FunASRSync
from utils import util

class FunASRProtocolTest:
    """FunASR协议兼容性测试类"""
    
    def __init__(self):
        self.test_results = []
        self.server_url = "ws://127.0.0.1:10197"
    
    def log_test_result(self, test_name: str, success: bool, duration: float = 0, message: str = ""):
        """记录测试结果"""
        status = "✓ 通过" if success else "✗ 失败"
        result = f"[{status}] {test_name}"
        if duration > 0:
            result += f" - 耗时: {duration:.2f}s"
        if message:
            result += f" - {message}"
        
        print(result)
        self.test_results.append({
            'test_name': test_name,
            'success': success,
            'duration': duration,
            'message': message
        })
    
    def create_test_audio(self, size_mb: float) -> bytes:
        """创建指定大小的测试音频数据"""
        try:
            # 使用根目录下的真实音频文件
            speech_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'speech.wav')
            
            if not os.path.exists(speech_file):
                print(f"警告: 未找到speech.wav文件,使用生成的音频数据")
                return self._generate_synthetic_audio(size_mb)
            
            # 读取真实音频文件
            with open(speech_file, 'rb') as f:
                real_audio_data = f.read()
            
            print(f"使用真实音频文件: {speech_file}, 原始大小: {len(real_audio_data)} bytes")
            
            # 如果需要更大的文件,重复音频数据
            target_size = int(size_mb * 1024 * 1024)
            if len(real_audio_data) < target_size:
                # 计算需要重复的次数
                repeat_count = (target_size // len(real_audio_data)) + 1
                
                # 提取WAV头部(前44字节)和音频数据
                if len(real_audio_data) > 44:
                    wav_header = real_audio_data[:44]
                    audio_data = real_audio_data[44:]
                    
                    # 重复音频数据部分
                    repeated_audio = audio_data * repeat_count
                    
                    # 截取到目标大小
                    final_audio_size = target_size - 44
                    if len(repeated_audio) > final_audio_size:
                        repeated_audio = repeated_audio[:final_audio_size]
                    
                    # 更新WAV头部中的文件大小信息
                    total_size = 44 + len(repeated_audio)
                    updated_header = bytearray(wav_header)
                    # 更新文件大小(RIFF chunk size)
                    updated_header[4:8] = (total_size - 8).to_bytes(4, 'little')
                    # 更新数据块大小
                    updated_header[40:44] = len(repeated_audio).to_bytes(4, 'little')
                    
                    final_audio = bytes(updated_header) + repeated_audio
                    print(f"扩展音频文件到: {len(final_audio)} bytes")
                    return final_audio
                else:
                    return real_audio_data
            else:
                # 如果原文件已经足够大,直接截取
                truncated_audio = real_audio_data[:target_size]
                print(f"截取音频文件到: {len(truncated_audio)} bytes")
                return truncated_audio
            
        except Exception as e:
            print(f"处理真实音频文件失败: {e},使用生成的音频数据")
            return self._generate_synthetic_audio(size_mb)
    
    def _generate_synthetic_audio(self, size_mb: float) -> bytes:
        """生成合成音频数据(备用方案)"""
        try:
            # 生成指定大小的随机音频数据
            size_bytes = int(size_mb * 1024 * 1024)
            
            # 模拟WAV文件头部(44字节)
            wav_header = b'RIFF' + (size_bytes - 8).to_bytes(4, 'little') + b'WAVE'
            wav_header += b'fmt ' + (16).to_bytes(4, 'little')
            wav_header += (1).to_bytes(2, 'little')  # PCM格式
            wav_header += (1).to_bytes(2, 'little')  # 单声道
            wav_header += (16000).to_bytes(4, 'little')  # 采样率16kHz
            wav_header += (32000).to_bytes(4, 'little')  # 字节率
            wav_header += (2).to_bytes(2, 'little')  # 块对齐
            wav_header += (16).to_bytes(2, 'little')  # 位深度
            wav_header += b'data' + (size_bytes - 44).to_bytes(4, 'little')
            
            # 生成音频数据(简单的正弦波)
            audio_data_size = size_bytes - 44
            samples = audio_data_size // 2  # 16位音频
            
            # 生成正弦波音频数据
            frequency = 440  # A4音符
            sample_rate = 16000
            t = np.linspace(0, samples / sample_rate, samples, False)
            wave = np.sin(2 * np.pi * frequency * t)
            
            # 转换为16位整数
            audio_samples = (wave * 32767).astype(np.int16)
            audio_data = audio_samples.tobytes()
            
            # 如果生成的数据不够,用零填充
            if len(audio_data) < audio_data_size:
                audio_data += b'\x00' * (audio_data_size - len(audio_data))
            elif len(audio_data) > audio_data_size:
                audio_data = audio_data[:audio_data_size]
            
            return wav_header + audio_data
            
        except Exception as e:
            print(f"创建合成音频失败: {e}")
            return None
    
    async def test_server_connection(self) -> bool:
        """测试服务器连接"""
        print("\n=== 测试1: 服务器连接测试 ===")
        
        try:
            start_time = time.time()
            async with websockets.connect(self.server_url) as websocket:
                # 发送简单的ping消息
                test_msg = {"test": "ping"}
                await websocket.send(json.dumps(test_msg))
                
                # 等待响应(可能没有响应,这是正常的)
                try:
                    response = await asyncio.wait_for(websocket.recv(), timeout=2.0)
                    print(f"服务器响应: {response}")
                except asyncio.TimeoutError:
                    print("服务器无响应(正常现象)")
                
                duration = time.time() - start_time
                self.log_test_result("服务器连接测试", True, duration, "连接成功")
                return True
                
        except Exception as e:
            duration = time.time() - start_time
            self.log_test_result("服务器连接测试", False, duration, f"连接失败: {e}")
            return False
    
    async def test_traditional_protocol(self) -> bool:
        """测试传统协议(小文件)"""
        print("\n=== 测试2: 传统协议测试 ===")
        
        try:
            start_time = time.time()
            
            # 使用真实音频文件进行测试
            audio_data = self.create_test_audio(0.5)  # 基于speech.wav的小文件
            if audio_data is None:
                self.log_test_result("传统协议测试", False, 0, "创建测试音频失败")
                return False
            
            async with websockets.connect(self.server_url) as websocket:
                # 发送传统格式消息
                audio_b64 = base64.b64encode(audio_data).decode('utf-8')
                message = {
                    'audio_data': audio_b64,
                    'filename': 'test_traditional.wav'
                }
                
                await websocket.send(json.dumps(message))
                print(f"发送传统协议消息: {len(audio_data)} bytes")
                
                # 等待识别结果
                try:
                    result = await asyncio.wait_for(websocket.recv(), timeout=30.0)
                    duration = time.time() - start_time
                    print(f"识别结果: {result}")
                    self.log_test_result("传统协议测试", True, duration, f"识别成功: {result[:50]}...")
                    return True
                except asyncio.TimeoutError:
                    duration = time.time() - start_time
                    self.log_test_result("传统协议测试", False, duration, "等待结果超时")
                    return False
                
        except Exception as e:
            duration = time.time() - start_time
            self.log_test_result("传统协议测试", False, duration, f"异常: {e}")
            return False
    
    async def test_chunked_protocol(self, size_mb: float) -> bool:
        """测试分块协议"""
        print(f"\n=== 测试3: 分块协议测试 ({size_mb}MB) ===")
        
        try:
            start_time = time.time()
            
            # 使用真实音频文件创建大文件测试
            audio_data = self.create_test_audio(size_mb)
            if audio_data is None:
                self.log_test_result(f"分块协议测试({size_mb}MB)", False, 0, "创建测试音频失败")
                return False
            
            async with websockets.connect(self.server_url) as websocket:
                filename = f'test_chunked_{size_mb}mb.wav'
                chunk_size = 512 * 1024  # 512KB分块
                total_size = len(audio_data)
                total_chunks = (total_size + chunk_size - 1) // chunk_size
                
                print(f"开始分块发送: 总大小 {total_size} bytes, 分块数 {total_chunks}")
                
                # 1. 发送开始信号
                start_msg = {
                    'type': 'audio_start',
                    'filename': filename,
                    'total_size': total_size,
                    'total_chunks': total_chunks,
                    'chunk_size': chunk_size
                }
                await websocket.send(json.dumps(start_msg))
                
                # 等待服务器确认
                try:
                    response = await asyncio.wait_for(websocket.recv(), timeout=5.0)
                    print(f"服务器确认: {response}")
                except asyncio.TimeoutError:
                    print("服务器无确认响应")
                
                # 2. 发送分块数据
                for i in range(total_chunks):
                    start_pos = i * chunk_size
                    end_pos = min(start_pos + chunk_size, total_size)
                    chunk_data = audio_data[start_pos:end_pos]
                    
                    chunk_b64 = base64.b64encode(chunk_data).decode('utf-8')
                    chunk_msg = {
                        'type': 'audio_chunk',
                        'filename': filename,
                        'chunk_index': i,
                        'chunk_data': chunk_b64,
                        'is_last': (i == total_chunks - 1)
                    }
                    
                    await websocket.send(json.dumps(chunk_msg))
                    
                    # 进度显示
                    if (i + 1) % 5 == 0 or i == total_chunks - 1:
                        progress = ((i + 1) / total_chunks) * 100
                        print(f"发送进度: {progress:.1f}% ({i+1}/{total_chunks})")
                    
                    # 流控延迟
                    await asyncio.sleep(0.01)
                
                # 3. 发送结束信号
                end_msg = {
                    'type': 'audio_end',
                    'filename': filename
                }
                await websocket.send(json.dumps(end_msg))
                print("分块发送完成,等待识别结果...")
                
                # 4. 等待识别结果
                try:
                    result = await asyncio.wait_for(websocket.recv(), timeout=60.0)
                    duration = time.time() - start_time
                    print(f"识别结果: {result}")
                    
                    # 计算吞吐量
                    throughput = total_size / duration / 1024 / 1024
                    message = f"识别成功,吞吐量: {throughput:.2f}MB/s"
                    
                    self.log_test_result(f"分块协议测试({size_mb}MB)", True, duration, message)
                    return True
                    
                except asyncio.TimeoutError:
                    duration = time.time() - start_time
                    self.log_test_result(f"分块协议测试({size_mb}MB)", False, duration, "等待结果超时")
                    return False
                
        except Exception as e:
            duration = time.time() - start_time
            self.log_test_result(f"分块协议测试({size_mb}MB)", False, duration, f"异常: {e}")
            return False
    
    async def test_funasr_sync_client(self) -> bool:
        """测试FunASRSync客户端"""
        print("\n=== 测试4: FunASRSync客户端测试 ===")
        
        try:
            start_time = time.time()
            
            # 创建FunASRSync客户端
            client = FunASRSync("test_user")
            
            # 设置结果回调
            received_result = None
            def on_result(result):
                nonlocal received_result
                received_result = result
                print(f"收到识别结果: {result}")
            
            client.set_result_callback(on_result)
            
            # 连接到服务器
            if not client.connect():
                self.log_test_result("FunASRSync客户端测试", False, 0, "连接失败")
                return False
            
            # 等待连接稳定
            await asyncio.sleep(2)
            
            # 使用真实音频文件进行测试
            audio_data = self.create_test_audio(2.0)  # 基于speech.wav的2MB文件
            if audio_data is None:
                self.log_test_result("FunASRSync客户端测试", False, 0, "创建测试音频失败")
                return False
            
            # 发送音频数据
            success = client.send_audio_data(audio_data, "test_sync_client.wav")
            if not success:
                self.log_test_result("FunASRSync客户端测试", False, 0, "发送音频失败")
                return False
            
            # 等待识别结果
            wait_time = 0
            while received_result is None and wait_time < 60:
                await asyncio.sleep(1)
                wait_time += 1
            
            duration = time.time() - start_time
            
            if received_result:
                self.log_test_result("FunASRSync客户端测试", True, duration, f"识别成功: {received_result[:50]}...")
                return True
            else:
                self.log_test_result("FunASRSync客户端测试", False, duration, "未收到识别结果")
                return False
            
        except Exception as e:
            duration = time.time() - start_time
            self.log_test_result("FunASRSync客户端测试", False, duration, f"异常: {e}")
            return False
        finally:
            try:
                client.end()
            except:
                pass
    
    async def run_all_tests(self):
        """运行所有测试"""
        print("FunASR协议兼容性修复测试")
        print("=" * 50)
        
        # 测试1: 服务器连接
        if not await self.test_server_connection():
            print("\n❌ 服务器连接失败,请确保ASR_server.py正在运行")
            return
        
        # 测试2: 传统协议
        await self.test_traditional_protocol()
        await asyncio.sleep(2)
        
        # 测试3: 分块协议
        for size in [1.0, 3.0, 5.0]:  # 1MB, 3MB, 5MB
            await self.test_chunked_protocol(size)
            await asyncio.sleep(2)
        
        # 测试4: FunASRSync客户端
        await self.test_funasr_sync_client()
        
        # 输出测试总结
        self.print_summary()
    
    def print_summary(self):
        """打印测试总结"""
        print("\n" + "=" * 50)
        print("测试总结")
        print("=" * 50)
        
        total_tests = len(self.test_results)
        passed_tests = sum(1 for result in self.test_results if result['success'])
        
        print(f"总测试数: {total_tests}")
        print(f"通过测试: {passed_tests}")
        print(f"失败测试: {total_tests - passed_tests}")
        print(f"成功率: {passed_tests/total_tests*100:.1f}%" if total_tests > 0 else "成功率: 0%")
        
        print("\n详细结果:")
        for result in self.test_results:
            status = "✓" if result['success'] else "✗"
            print(f"  {status} {result['test_name']}")
            if result['duration'] > 0:
                print(f"    耗时: {result['duration']:.2f}s")
            if result['message']:
                print(f"    说明: {result['message']}")

async def main():
    """主函数"""
    test = FunASRProtocolTest()
    await test.run_all_tests()

if __name__ == "__main__":
    asyncio.run(main())