test_funasr_protocol_fix.py
17.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
# AIfeng/2025-07-17 17:04:42
"""
FunASR协议兼容性修复测试脚本
测试ASR_server.py的分块协议支持
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import asyncio
import websockets
import json
import base64
import time
import numpy as np
from funasr_asr_sync import FunASRSync
from utils import util
class FunASRProtocolTest:
"""FunASR协议兼容性测试类"""
def __init__(self):
self.test_results = []
self.server_url = "ws://127.0.0.1:10197"
def log_test_result(self, test_name: str, success: bool, duration: float = 0, message: str = ""):
"""记录测试结果"""
status = "✓ 通过" if success else "✗ 失败"
result = f"[{status}] {test_name}"
if duration > 0:
result += f" - 耗时: {duration:.2f}s"
if message:
result += f" - {message}"
print(result)
self.test_results.append({
'test_name': test_name,
'success': success,
'duration': duration,
'message': message
})
def create_test_audio(self, size_mb: float) -> bytes:
"""创建指定大小的测试音频数据"""
try:
# 使用根目录下的真实音频文件
speech_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'speech.wav')
if not os.path.exists(speech_file):
print(f"警告: 未找到speech.wav文件,使用生成的音频数据")
return self._generate_synthetic_audio(size_mb)
# 读取真实音频文件
with open(speech_file, 'rb') as f:
real_audio_data = f.read()
print(f"使用真实音频文件: {speech_file}, 原始大小: {len(real_audio_data)} bytes")
# 如果需要更大的文件,重复音频数据
target_size = int(size_mb * 1024 * 1024)
if len(real_audio_data) < target_size:
# 计算需要重复的次数
repeat_count = (target_size // len(real_audio_data)) + 1
# 提取WAV头部(前44字节)和音频数据
if len(real_audio_data) > 44:
wav_header = real_audio_data[:44]
audio_data = real_audio_data[44:]
# 重复音频数据部分
repeated_audio = audio_data * repeat_count
# 截取到目标大小
final_audio_size = target_size - 44
if len(repeated_audio) > final_audio_size:
repeated_audio = repeated_audio[:final_audio_size]
# 更新WAV头部中的文件大小信息
total_size = 44 + len(repeated_audio)
updated_header = bytearray(wav_header)
# 更新文件大小(RIFF chunk size)
updated_header[4:8] = (total_size - 8).to_bytes(4, 'little')
# 更新数据块大小
updated_header[40:44] = len(repeated_audio).to_bytes(4, 'little')
final_audio = bytes(updated_header) + repeated_audio
print(f"扩展音频文件到: {len(final_audio)} bytes")
return final_audio
else:
return real_audio_data
else:
# 如果原文件已经足够大,直接截取
truncated_audio = real_audio_data[:target_size]
print(f"截取音频文件到: {len(truncated_audio)} bytes")
return truncated_audio
except Exception as e:
print(f"处理真实音频文件失败: {e},使用生成的音频数据")
return self._generate_synthetic_audio(size_mb)
def _generate_synthetic_audio(self, size_mb: float) -> bytes:
"""生成合成音频数据(备用方案)"""
try:
# 生成指定大小的随机音频数据
size_bytes = int(size_mb * 1024 * 1024)
# 模拟WAV文件头部(44字节)
wav_header = b'RIFF' + (size_bytes - 8).to_bytes(4, 'little') + b'WAVE'
wav_header += b'fmt ' + (16).to_bytes(4, 'little')
wav_header += (1).to_bytes(2, 'little') # PCM格式
wav_header += (1).to_bytes(2, 'little') # 单声道
wav_header += (16000).to_bytes(4, 'little') # 采样率16kHz
wav_header += (32000).to_bytes(4, 'little') # 字节率
wav_header += (2).to_bytes(2, 'little') # 块对齐
wav_header += (16).to_bytes(2, 'little') # 位深度
wav_header += b'data' + (size_bytes - 44).to_bytes(4, 'little')
# 生成音频数据(简单的正弦波)
audio_data_size = size_bytes - 44
samples = audio_data_size // 2 # 16位音频
# 生成正弦波音频数据
frequency = 440 # A4音符
sample_rate = 16000
t = np.linspace(0, samples / sample_rate, samples, False)
wave = np.sin(2 * np.pi * frequency * t)
# 转换为16位整数
audio_samples = (wave * 32767).astype(np.int16)
audio_data = audio_samples.tobytes()
# 如果生成的数据不够,用零填充
if len(audio_data) < audio_data_size:
audio_data += b'\x00' * (audio_data_size - len(audio_data))
elif len(audio_data) > audio_data_size:
audio_data = audio_data[:audio_data_size]
return wav_header + audio_data
except Exception as e:
print(f"创建合成音频失败: {e}")
return None
async def test_server_connection(self) -> bool:
"""测试服务器连接"""
print("\n=== 测试1: 服务器连接测试 ===")
try:
start_time = time.time()
async with websockets.connect(self.server_url) as websocket:
# 发送简单的ping消息
test_msg = {"test": "ping"}
await websocket.send(json.dumps(test_msg))
# 等待响应(可能没有响应,这是正常的)
try:
response = await asyncio.wait_for(websocket.recv(), timeout=2.0)
print(f"服务器响应: {response}")
except asyncio.TimeoutError:
print("服务器无响应(正常现象)")
duration = time.time() - start_time
self.log_test_result("服务器连接测试", True, duration, "连接成功")
return True
except Exception as e:
duration = time.time() - start_time
self.log_test_result("服务器连接测试", False, duration, f"连接失败: {e}")
return False
async def test_traditional_protocol(self) -> bool:
"""测试传统协议(小文件)"""
print("\n=== 测试2: 传统协议测试 ===")
try:
start_time = time.time()
# 使用真实音频文件进行测试
audio_data = self.create_test_audio(0.5) # 基于speech.wav的小文件
if audio_data is None:
self.log_test_result("传统协议测试", False, 0, "创建测试音频失败")
return False
async with websockets.connect(self.server_url) as websocket:
# 发送传统格式消息
audio_b64 = base64.b64encode(audio_data).decode('utf-8')
message = {
'audio_data': audio_b64,
'filename': 'test_traditional.wav'
}
await websocket.send(json.dumps(message))
print(f"发送传统协议消息: {len(audio_data)} bytes")
# 等待识别结果
try:
result = await asyncio.wait_for(websocket.recv(), timeout=30.0)
duration = time.time() - start_time
print(f"识别结果: {result}")
self.log_test_result("传统协议测试", True, duration, f"识别成功: {result[:50]}...")
return True
except asyncio.TimeoutError:
duration = time.time() - start_time
self.log_test_result("传统协议测试", False, duration, "等待结果超时")
return False
except Exception as e:
duration = time.time() - start_time
self.log_test_result("传统协议测试", False, duration, f"异常: {e}")
return False
async def test_chunked_protocol(self, size_mb: float) -> bool:
"""测试分块协议"""
print(f"\n=== 测试3: 分块协议测试 ({size_mb}MB) ===")
try:
start_time = time.time()
# 使用真实音频文件创建大文件测试
audio_data = self.create_test_audio(size_mb)
if audio_data is None:
self.log_test_result(f"分块协议测试({size_mb}MB)", False, 0, "创建测试音频失败")
return False
async with websockets.connect(self.server_url) as websocket:
filename = f'test_chunked_{size_mb}mb.wav'
chunk_size = 512 * 1024 # 512KB分块
total_size = len(audio_data)
total_chunks = (total_size + chunk_size - 1) // chunk_size
print(f"开始分块发送: 总大小 {total_size} bytes, 分块数 {total_chunks}")
# 1. 发送开始信号
start_msg = {
'type': 'audio_start',
'filename': filename,
'total_size': total_size,
'total_chunks': total_chunks,
'chunk_size': chunk_size
}
await websocket.send(json.dumps(start_msg))
# 等待服务器确认
try:
response = await asyncio.wait_for(websocket.recv(), timeout=5.0)
print(f"服务器确认: {response}")
except asyncio.TimeoutError:
print("服务器无确认响应")
# 2. 发送分块数据
for i in range(total_chunks):
start_pos = i * chunk_size
end_pos = min(start_pos + chunk_size, total_size)
chunk_data = audio_data[start_pos:end_pos]
chunk_b64 = base64.b64encode(chunk_data).decode('utf-8')
chunk_msg = {
'type': 'audio_chunk',
'filename': filename,
'chunk_index': i,
'chunk_data': chunk_b64,
'is_last': (i == total_chunks - 1)
}
await websocket.send(json.dumps(chunk_msg))
# 进度显示
if (i + 1) % 5 == 0 or i == total_chunks - 1:
progress = ((i + 1) / total_chunks) * 100
print(f"发送进度: {progress:.1f}% ({i+1}/{total_chunks})")
# 流控延迟
await asyncio.sleep(0.01)
# 3. 发送结束信号
end_msg = {
'type': 'audio_end',
'filename': filename
}
await websocket.send(json.dumps(end_msg))
print("分块发送完成,等待识别结果...")
# 4. 等待识别结果
try:
result = await asyncio.wait_for(websocket.recv(), timeout=60.0)
duration = time.time() - start_time
print(f"识别结果: {result}")
# 计算吞吐量
throughput = total_size / duration / 1024 / 1024
message = f"识别成功,吞吐量: {throughput:.2f}MB/s"
self.log_test_result(f"分块协议测试({size_mb}MB)", True, duration, message)
return True
except asyncio.TimeoutError:
duration = time.time() - start_time
self.log_test_result(f"分块协议测试({size_mb}MB)", False, duration, "等待结果超时")
return False
except Exception as e:
duration = time.time() - start_time
self.log_test_result(f"分块协议测试({size_mb}MB)", False, duration, f"异常: {e}")
return False
async def test_funasr_sync_client(self) -> bool:
"""测试FunASRSync客户端"""
print("\n=== 测试4: FunASRSync客户端测试 ===")
try:
start_time = time.time()
# 创建FunASRSync客户端
client = FunASRSync("test_user")
# 设置结果回调
received_result = None
def on_result(result):
nonlocal received_result
received_result = result
print(f"收到识别结果: {result}")
client.set_result_callback(on_result)
# 连接到服务器
if not client.connect():
self.log_test_result("FunASRSync客户端测试", False, 0, "连接失败")
return False
# 等待连接稳定
await asyncio.sleep(2)
# 使用真实音频文件进行测试
audio_data = self.create_test_audio(2.0) # 基于speech.wav的2MB文件
if audio_data is None:
self.log_test_result("FunASRSync客户端测试", False, 0, "创建测试音频失败")
return False
# 发送音频数据
success = client.send_audio_data(audio_data, "test_sync_client.wav")
if not success:
self.log_test_result("FunASRSync客户端测试", False, 0, "发送音频失败")
return False
# 等待识别结果
wait_time = 0
while received_result is None and wait_time < 60:
await asyncio.sleep(1)
wait_time += 1
duration = time.time() - start_time
if received_result:
self.log_test_result("FunASRSync客户端测试", True, duration, f"识别成功: {received_result[:50]}...")
return True
else:
self.log_test_result("FunASRSync客户端测试", False, duration, "未收到识别结果")
return False
except Exception as e:
duration = time.time() - start_time
self.log_test_result("FunASRSync客户端测试", False, duration, f"异常: {e}")
return False
finally:
try:
client.end()
except:
pass
async def run_all_tests(self):
"""运行所有测试"""
print("FunASR协议兼容性修复测试")
print("=" * 50)
# 测试1: 服务器连接
if not await self.test_server_connection():
print("\n❌ 服务器连接失败,请确保ASR_server.py正在运行")
return
# 测试2: 传统协议
await self.test_traditional_protocol()
await asyncio.sleep(2)
# 测试3: 分块协议
for size in [1.0, 3.0, 5.0]: # 1MB, 3MB, 5MB
await self.test_chunked_protocol(size)
await asyncio.sleep(2)
# 测试4: FunASRSync客户端
await self.test_funasr_sync_client()
# 输出测试总结
self.print_summary()
def print_summary(self):
"""打印测试总结"""
print("\n" + "=" * 50)
print("测试总结")
print("=" * 50)
total_tests = len(self.test_results)
passed_tests = sum(1 for result in self.test_results if result['success'])
print(f"总测试数: {total_tests}")
print(f"通过测试: {passed_tests}")
print(f"失败测试: {total_tests - passed_tests}")
print(f"成功率: {passed_tests/total_tests*100:.1f}%" if total_tests > 0 else "成功率: 0%")
print("\n详细结果:")
for result in self.test_results:
status = "✓" if result['success'] else "✗"
print(f" {status} {result['test_name']}")
if result['duration'] > 0:
print(f" 耗时: {result['duration']:.2f}s")
if result['message']:
print(f" 说明: {result['message']}")
async def main():
"""主函数"""
test = FunASRProtocolTest()
await test.run_all_tests()
if __name__ == "__main__":
asyncio.run(main())