test_funasr_large_file.py 11.9 KB
# AIfeng/2025-07-17 16:38:52
"""
FunASR大文件处理测试脚本
测试优化后的FunASRSync分块发送功能
"""

import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import time
import random
import numpy as np
from funasr_asr_sync import FunASRSync
from utils import util

class FunASRLargeFileTest:
    """FunASR大文件处理测试类"""
    
    def __init__(self):
        self.test_results = []
        self.client = None
    
    def log_test_result(self, test_name: str, success: bool, duration: float = 0, file_size: int = 0, message: str = ""):
        """记录测试结果"""
        status = "✓ 通过" if success else "✗ 失败"
        result = f"[{status}] {test_name}"
        if duration > 0:
            result += f" - 耗时: {duration:.2f}s"
        if file_size > 0:
            result += f" - 文件大小: {file_size/1024/1024:.2f}MB"
        if message:
            result += f" - {message}"
        
        print(result)
        self.test_results.append({
            'test_name': test_name,
            'success': success,
            'duration': duration,
            'file_size': file_size,
            'message': message
        })
    
    def create_test_audio(self, size_mb: float) -> bytes:
        """创建指定大小的测试音频数据"""
        try:
            # 生成指定大小的随机音频数据
            size_bytes = int(size_mb * 1024 * 1024)
            
            # 模拟WAV文件头部(44字节)
            wav_header = b'RIFF' + (size_bytes - 8).to_bytes(4, 'little') + b'WAVE'
            wav_header += b'fmt ' + (16).to_bytes(4, 'little')
            wav_header += (1).to_bytes(2, 'little')  # PCM格式
            wav_header += (1).to_bytes(2, 'little')  # 单声道
            wav_header += (16000).to_bytes(4, 'little')  # 采样率16kHz
            wav_header += (32000).to_bytes(4, 'little')  # 字节率
            wav_header += (2).to_bytes(2, 'little')  # 块对齐
            wav_header += (16).to_bytes(2, 'little')  # 位深度
            wav_header += b'data' + (size_bytes - 44).to_bytes(4, 'little')
            
            # 生成音频数据(简单的正弦波)
            audio_data_size = size_bytes - 44
            samples = audio_data_size // 2  # 16位音频
            
            # 生成正弦波音频数据
            frequency = 440  # A4音符
            sample_rate = 16000
            t = np.linspace(0, samples / sample_rate, samples, False)
            wave = np.sin(2 * np.pi * frequency * t)
            
            # 转换为16位整数
            audio_samples = (wave * 32767).astype(np.int16)
            audio_data = audio_samples.tobytes()
            
            # 如果生成的数据不够,用零填充
            if len(audio_data) < audio_data_size:
                audio_data += b'\x00' * (audio_data_size - len(audio_data))
            elif len(audio_data) > audio_data_size:
                audio_data = audio_data[:audio_data_size]
            
            return wav_header + audio_data
            
        except Exception as e:
            print(f"创建测试音频失败: {e}")
            return None
    
    def test_connection(self) -> bool:
        """测试FunASR连接"""
        print("\n=== 测试1: FunASR连接测试 ===")
        
        try:
            start_time = time.time()
            self.client = FunASRSync("test_user")
            
            # 设置结果回调
            def on_result(result):
                print(f"收到识别结果: {result}")
            
            self.client.set_result_callback(on_result)
            
            # 尝试连接
            success = self.client.connect()
            duration = time.time() - start_time
            
            self.log_test_result("FunASR连接测试", success, duration, 0, 
                               "连接成功" if success else "连接失败")
            
            return success
            
        except Exception as e:
            self.log_test_result("FunASR连接测试", False, 0, 0, f"异常: {e}")
            return False
    
    def test_small_file(self) -> bool:
        """测试小文件处理(<1MB)"""
        print("\n=== 测试2: 小文件处理测试 ===")
        
        if not self.client or not self.client.is_connected():
            self.log_test_result("小文件处理测试", False, 0, 0, "客户端未连接")
            return False
        
        try:
            # 创建500KB的测试文件
            file_size_mb = 0.5
            audio_data = self.create_test_audio(file_size_mb)
            
            if audio_data is None:
                self.log_test_result("小文件处理测试", False, 0, 0, "创建测试音频失败")
                return False
            
            start_time = time.time()
            success = self.client.send_audio_data(audio_data, "test_small.wav")
            duration = time.time() - start_time
            
            self.log_test_result("小文件处理测试", success, duration, len(audio_data),
                               "使用简单发送模式" if success else "发送失败")
            
            return success
            
        except Exception as e:
            self.log_test_result("小文件处理测试", False, 0, 0, f"异常: {e}")
            return False
    
    def test_large_file(self, size_mb: float) -> bool:
        """测试大文件处理"""
        print(f"\n=== 测试3: 大文件处理测试 ({size_mb}MB) ===")
        
        if not self.client or not self.client.is_connected():
            self.log_test_result(f"大文件处理测试({size_mb}MB)", False, 0, 0, "客户端未连接")
            return False
        
        try:
            # 创建指定大小的测试文件
            print(f"创建 {size_mb}MB 测试音频文件...")
            audio_data = self.create_test_audio(size_mb)
            
            if audio_data is None:
                self.log_test_result(f"大文件处理测试({size_mb}MB)", False, 0, 0, "创建测试音频失败")
                return False
            
            print(f"开始发送 {len(audio_data)} 字节音频数据...")
            start_time = time.time()
            success = self.client.send_audio_data(audio_data, f"test_large_{size_mb}mb.wav")
            duration = time.time() - start_time
            
            throughput = len(audio_data) / duration / 1024 / 1024 if duration > 0 else 0
            message = f"使用分块发送模式,吞吐量: {throughput:.2f}MB/s" if success else "发送失败"
            
            self.log_test_result(f"大文件处理测试({size_mb}MB)", success, duration, len(audio_data), message)
            
            return success
            
        except Exception as e:
            self.log_test_result(f"大文件处理测试({size_mb}MB)", False, 0, 0, f"异常: {e}")
            return False
    
    def test_multiple_large_files(self) -> bool:
        """测试多个大文件连续处理"""
        print("\n=== 测试4: 多个大文件连续处理测试 ===")
        
        if not self.client or not self.client.is_connected():
            self.log_test_result("多文件连续处理测试", False, 0, 0, "客户端未连接")
            return False
        
        try:
            file_sizes = [2.0, 3.0, 5.0]  # MB
            all_success = True
            total_duration = 0
            total_size = 0
            
            for i, size_mb in enumerate(file_sizes):
                print(f"\n处理第 {i+1}/{len(file_sizes)} 个文件 ({size_mb}MB)...")
                
                audio_data = self.create_test_audio(size_mb)
                if audio_data is None:
                    all_success = False
                    continue
                
                start_time = time.time()
                success = self.client.send_audio_data(audio_data, f"test_multi_{i+1}_{size_mb}mb.wav")
                duration = time.time() - start_time
                
                total_duration += duration
                total_size += len(audio_data)
                
                if not success:
                    all_success = False
                    print(f"文件 {i+1} 发送失败")
                else:
                    print(f"文件 {i+1} 发送成功,耗时: {duration:.2f}s")
                
                # 文件间间隔
                time.sleep(1)
            
            avg_throughput = total_size / total_duration / 1024 / 1024 if total_duration > 0 else 0
            message = f"平均吞吐量: {avg_throughput:.2f}MB/s" if all_success else "部分文件发送失败"
            
            self.log_test_result("多文件连续处理测试", all_success, total_duration, total_size, message)
            
            return all_success
            
        except Exception as e:
            self.log_test_result("多文件连续处理测试", False, 0, 0, f"异常: {e}")
            return False
    
    def cleanup(self):
        """清理资源"""
        if self.client:
            try:
                self.client.end()
                print("\n客户端已关闭")
            except Exception as e:
                print(f"关闭客户端时出错: {e}")
    
    def run_all_tests(self):
        """运行所有测试"""
        print("FunASR大文件处理优化测试")
        print("=" * 50)
        
        try:
            # 测试1: 连接测试
            if not self.test_connection():
                print("\n❌ 连接测试失败,跳过后续测试")
                return
            
            # 等待连接稳定
            time.sleep(2)
            
            # 测试2: 小文件测试
            self.test_small_file()
            time.sleep(1)
            
            # 测试3: 大文件测试
            for size in [2.0, 5.0, 10.0]:  # 2MB, 5MB, 10MB
                self.test_large_file(size)
                time.sleep(2)  # 测试间隔
            
            # 测试4: 多文件连续处理
            self.test_multiple_large_files()
            
        finally:
            self.cleanup()
        
        # 输出测试总结
        self.print_summary()
    
    def print_summary(self):
        """打印测试总结"""
        print("\n" + "=" * 50)
        print("测试总结")
        print("=" * 50)
        
        total_tests = len(self.test_results)
        passed_tests = sum(1 for result in self.test_results if result['success'])
        
        print(f"总测试数: {total_tests}")
        print(f"通过测试: {passed_tests}")
        print(f"失败测试: {total_tests - passed_tests}")
        print(f"成功率: {passed_tests/total_tests*100:.1f}%" if total_tests > 0 else "成功率: 0%")
        
        print("\n详细结果:")
        for result in self.test_results:
            status = "✓" if result['success'] else "✗"
            print(f"  {status} {result['test_name']}")
            if result['duration'] > 0:
                print(f"    耗时: {result['duration']:.2f}s")
            if result['file_size'] > 0:
                print(f"    文件大小: {result['file_size']/1024/1024:.2f}MB")
            if result['message']:
                print(f"    说明: {result['message']}")
        
        # 性能统计
        large_file_tests = [r for r in self.test_results if 'large' in r['test_name'].lower() and r['success']]
        if large_file_tests:
            avg_duration = sum(r['duration'] for r in large_file_tests) / len(large_file_tests)
            total_size = sum(r['file_size'] for r in large_file_tests)
            total_time = sum(r['duration'] for r in large_file_tests)
            avg_throughput = total_size / total_time / 1024 / 1024 if total_time > 0 else 0
            
            print(f"\n性能统计:")
            print(f"  大文件平均处理时间: {avg_duration:.2f}s")
            print(f"  平均吞吐量: {avg_throughput:.2f}MB/s")

def main():
    """主函数"""
    test = FunASRLargeFileTest()
    test.run_all_tests()

if __name__ == "__main__":
    main()