ASR WebSocket服务实现:1.FunASR本地方案 2.豆包
1.已实现音频文件的处理,包括小文件直接转换以及大文件分割识别。 2.豆包接入流式识别,但封装仍需要修改 3.线程管理 4.websocket集中管理
Showing
17 changed files
with
4150 additions
and
1 deletions
| @@ -3,6 +3,8 @@ build/ | @@ -3,6 +3,8 @@ build/ | ||
| 3 | *.egg-info/ | 3 | *.egg-info/ |
| 4 | *.so | 4 | *.so |
| 5 | *.mp4 | 5 | *.mp4 |
| 6 | +*.mp3 | ||
| 7 | +*.wav | ||
| 6 | 8 | ||
| 7 | tmp* | 9 | tmp* |
| 8 | trial*/ | 10 | trial*/ |
| @@ -12,7 +14,6 @@ data_utils/face_tracking/3DMM/* | @@ -12,7 +14,6 @@ data_utils/face_tracking/3DMM/* | ||
| 12 | data_utils/face_parsing/79999_iter.pth | 14 | data_utils/face_parsing/79999_iter.pth |
| 13 | 15 | ||
| 14 | pretrained | 16 | pretrained |
| 15 | -*.mp4 | ||
| 16 | .DS_Store | 17 | .DS_Store |
| 17 | workspace/log_ngp.txt | 18 | workspace/log_ngp.txt |
| 18 | .idea | 19 | .idea |
| @@ -21,3 +22,4 @@ models/ | @@ -21,3 +22,4 @@ models/ | ||
| 21 | *.log | 22 | *.log |
| 22 | dist | 23 | dist |
| 23 | .vscode/launch.json | 24 | .vscode/launch.json |
| 25 | +/speech.wav |
asr/doubao/README.md
0 → 100644
| 1 | +# AIfeng/2025-07-11 13:36:00 | ||
| 2 | + | ||
| 3 | +# 豆包ASR语音识别服务 | ||
| 4 | + | ||
| 5 | +基于豆包(Doubao)语音识别API的通用ASR服务,支持流式和非流式语音识别,提供简洁易用的Python接口。 | ||
| 6 | + | ||
| 7 | +## 🚀 特性 | ||
| 8 | + | ||
| 9 | +- **多种识别模式**: 支持流式和非流式语音识别 | ||
| 10 | +- **多格式支持**: 支持WAV、MP3、PCM等音频格式 | ||
| 11 | +- **灵活配置**: 支持配置文件、环境变量、代码配置等多种方式 | ||
| 12 | +- **异步支持**: 基于asyncio的异步API,支持高并发 | ||
| 13 | +- **实时回调**: 流式识别支持实时结果回调 | ||
| 14 | +- **错误处理**: 完善的错误处理和重试机制 | ||
| 15 | +- **易于集成**: 简洁的API设计,易于集成到现有项目 | ||
| 16 | + | ||
| 17 | +## 📦 安装 | ||
| 18 | + | ||
| 19 | +### 依赖要求 | ||
| 20 | + | ||
| 21 | +```bash | ||
| 22 | +pip install websockets aiofiles | ||
| 23 | +``` | ||
| 24 | + | ||
| 25 | +### 项目结构 | ||
| 26 | + | ||
| 27 | +``` | ||
| 28 | +asr/doubao/ | ||
| 29 | +├── __init__.py # 模块初始化和公共API | ||
| 30 | +├── config.json # 默认配置文件 | ||
| 31 | +├── config_manager.py # 配置管理器 | ||
| 32 | +├── protocol.py # 豆包协议处理 | ||
| 33 | +├── audio_utils.py # 音频处理工具 | ||
| 34 | +├── asr_client.py # ASR客户端核心 | ||
| 35 | +├── service_factory.py # 服务工厂和便捷接口 | ||
| 36 | +├── example.py # 使用示例 | ||
| 37 | +└── README.md # 项目文档 | ||
| 38 | +``` | ||
| 39 | + | ||
| 40 | +## 🔧 配置 | ||
| 41 | + | ||
| 42 | +### 1. 获取API密钥 | ||
| 43 | + | ||
| 44 | +访问[豆包开放平台](https://www.volcengine.com/docs/6561/1354869)获取: | ||
| 45 | +- `app_key`: 应用密钥 | ||
| 46 | +- `access_key`: 访问密钥 | ||
| 47 | + | ||
| 48 | +### 2. 配置方式 | ||
| 49 | + | ||
| 50 | +#### 方式1: 环境变量(推荐) | ||
| 51 | + | ||
| 52 | +```bash | ||
| 53 | +export DOUBAO_APP_KEY="your_app_key" | ||
| 54 | +export DOUBAO_ACCESS_KEY="your_access_key" | ||
| 55 | +``` | ||
| 56 | + | ||
| 57 | +#### 方式2: 配置文件 | ||
| 58 | + | ||
| 59 | +创建 `config.json`: | ||
| 60 | + | ||
| 61 | +```json | ||
| 62 | +{ | ||
| 63 | + "auth_config": { | ||
| 64 | + "app_key": "your_app_key", | ||
| 65 | + "access_key": "your_access_key" | ||
| 66 | + }, | ||
| 67 | + "asr_config": { | ||
| 68 | + "streaming_mode": true, | ||
| 69 | + "enable_punc": true, | ||
| 70 | + "seg_duration": 200 | ||
| 71 | + } | ||
| 72 | +} | ||
| 73 | +``` | ||
| 74 | + | ||
| 75 | +#### 方式3: 代码配置 | ||
| 76 | + | ||
| 77 | +```python | ||
| 78 | +service = create_asr_service( | ||
| 79 | + app_key="your_app_key", | ||
| 80 | + access_key="your_access_key" | ||
| 81 | +) | ||
| 82 | +``` | ||
| 83 | + | ||
| 84 | +## 🎯 快速开始 | ||
| 85 | + | ||
| 86 | +### 1. 简单文件识别 | ||
| 87 | + | ||
| 88 | +```python | ||
| 89 | +import asyncio | ||
| 90 | +from asr.doubao import recognize_file | ||
| 91 | + | ||
| 92 | +async def simple_recognition(): | ||
| 93 | + result = await recognize_file( | ||
| 94 | + audio_path="path/to/your/audio.wav", | ||
| 95 | + app_key="your_app_key", | ||
| 96 | + access_key="your_access_key", | ||
| 97 | + streaming=True | ||
| 98 | + ) | ||
| 99 | + print(f"识别结果: {result}") | ||
| 100 | + | ||
| 101 | +# 运行 | ||
| 102 | +asyncio.run(simple_recognition()) | ||
| 103 | +``` | ||
| 104 | + | ||
| 105 | +### 2. 流式识别with实时回调 | ||
| 106 | + | ||
| 107 | +```python | ||
| 108 | +import asyncio | ||
| 109 | +from asr.doubao import create_asr_service | ||
| 110 | + | ||
| 111 | +async def streaming_recognition(): | ||
| 112 | + # 定义结果回调函数 | ||
| 113 | + def on_result(result): | ||
| 114 | + if result.get('payload_msg'): | ||
| 115 | + print(f"实时结果: {result['payload_msg']}") | ||
| 116 | + | ||
| 117 | + # 创建服务实例 | ||
| 118 | + service = create_asr_service( | ||
| 119 | + app_key="your_app_key", | ||
| 120 | + access_key="your_access_key", | ||
| 121 | + streaming=True | ||
| 122 | + ) | ||
| 123 | + | ||
| 124 | + try: | ||
| 125 | + result = await service.recognize_file( | ||
| 126 | + "path/to/your/audio.wav", | ||
| 127 | + result_callback=on_result | ||
| 128 | + ) | ||
| 129 | + print(f"最终结果: {result}") | ||
| 130 | + finally: | ||
| 131 | + await service.close() | ||
| 132 | + | ||
| 133 | +# 运行 | ||
| 134 | +asyncio.run(streaming_recognition()) | ||
| 135 | +``` | ||
| 136 | + | ||
| 137 | +### 2.1 优化的文本输出(推荐) | ||
| 138 | + | ||
| 139 | +针对豆包ASR输出完整报文但只需要文本的问题,提供了专门的结果处理器: | ||
| 140 | + | ||
| 141 | +```python | ||
| 142 | +import asyncio | ||
| 143 | +from asr.doubao import create_asr_service | ||
| 144 | +from asr.doubao.result_processor import create_text_only_callback | ||
| 145 | + | ||
| 146 | +async def optimized_streaming(): | ||
| 147 | + # 只处理文本内容的回调函数 | ||
| 148 | + def on_text(text: str): | ||
| 149 | + print(f"识别文本: {text}") | ||
| 150 | + # 流式数据特点:后一次覆盖前一次,最终结果会不停刷新 | ||
| 151 | + | ||
| 152 | + # 创建优化的回调(自动提取text字段) | ||
| 153 | + optimized_callback = create_text_only_callback( | ||
| 154 | + user_callback=on_text, | ||
| 155 | + enable_streaming_log=False # 关闭中间结果日志 | ||
| 156 | + ) | ||
| 157 | + | ||
| 158 | + service = create_asr_service( | ||
| 159 | + app_key="your_app_key", | ||
| 160 | + access_key="your_access_key", | ||
| 161 | + streaming=True | ||
| 162 | + ) | ||
| 163 | + | ||
| 164 | + try: | ||
| 165 | + await service.recognize_file( | ||
| 166 | + "path/to/your/audio.wav", | ||
| 167 | + result_callback=optimized_callback | ||
| 168 | + ) | ||
| 169 | + finally: | ||
| 170 | + await service.close() | ||
| 171 | + | ||
| 172 | +# 运行 | ||
| 173 | +asyncio.run(optimized_streaming()) | ||
| 174 | +``` | ||
| 175 | + | ||
| 176 | +### 3. 音频数据识别 | ||
| 177 | + | ||
| 178 | +```python | ||
| 179 | +import asyncio | ||
| 180 | +from asr.doubao import recognize_audio_data | ||
| 181 | + | ||
| 182 | +async def data_recognition(): | ||
| 183 | + # 读取音频数据 | ||
| 184 | + with open("path/to/your/audio.wav", "rb") as f: | ||
| 185 | + audio_data = f.read() | ||
| 186 | + | ||
| 187 | + result = await recognize_audio_data( | ||
| 188 | + audio_data=audio_data, | ||
| 189 | + audio_format="wav", | ||
| 190 | + app_key="your_app_key", | ||
| 191 | + access_key="your_access_key" | ||
| 192 | + ) | ||
| 193 | + print(f"识别结果: {result}") | ||
| 194 | + | ||
| 195 | +# 运行 | ||
| 196 | +asyncio.run(data_recognition()) | ||
| 197 | +``` | ||
| 198 | + | ||
| 199 | +### 4. 同步方式(简单场景) | ||
| 200 | + | ||
| 201 | +```python | ||
| 202 | +from asr.doubao import run_recognition | ||
| 203 | + | ||
| 204 | +# 同步识别 | ||
| 205 | +result = run_recognition( | ||
| 206 | + audio_path="path/to/your/audio.wav", | ||
| 207 | + app_key="your_app_key", | ||
| 208 | + access_key="your_access_key" | ||
| 209 | +) | ||
| 210 | +print(f"识别结果: {result}") | ||
| 211 | +``` | ||
| 212 | + | ||
| 213 | +## 📚 详细用法 | ||
| 214 | + | ||
| 215 | +### 服务实例管理 | ||
| 216 | + | ||
| 217 | +```python | ||
| 218 | +from asr.doubao import DoubaoASRService, create_asr_service | ||
| 219 | + | ||
| 220 | +# 创建服务实例 | ||
| 221 | +service = create_asr_service( | ||
| 222 | + app_key="your_app_key", | ||
| 223 | + access_key="your_access_key", | ||
| 224 | + streaming=True, | ||
| 225 | + debug=True | ||
| 226 | +) | ||
| 227 | + | ||
| 228 | +# 执行多次识别(复用连接) | ||
| 229 | +result1 = await service.recognize_file("audio1.wav") | ||
| 230 | +result2 = await service.recognize_file("audio2.wav") | ||
| 231 | + | ||
| 232 | +# 关闭服务 | ||
| 233 | +await service.close() | ||
| 234 | +``` | ||
| 235 | + | ||
| 236 | +### 批量识别 | ||
| 237 | + | ||
| 238 | +```python | ||
| 239 | +async def batch_recognition(audio_files): | ||
| 240 | + service = create_asr_service( | ||
| 241 | + app_key="your_app_key", | ||
| 242 | + access_key="your_access_key" | ||
| 243 | + ) | ||
| 244 | + | ||
| 245 | + results = [] | ||
| 246 | + try: | ||
| 247 | + for audio_file in audio_files: | ||
| 248 | + result = await service.recognize_file(audio_file) | ||
| 249 | + results.append({ | ||
| 250 | + 'file': audio_file, | ||
| 251 | + 'result': result | ||
| 252 | + }) | ||
| 253 | + finally: | ||
| 254 | + await service.close() | ||
| 255 | + | ||
| 256 | + return results | ||
| 257 | + | ||
| 258 | +# 使用 | ||
| 259 | +audio_files = ["audio1.wav", "audio2.wav", "audio3.wav"] | ||
| 260 | +results = await batch_recognition(audio_files) | ||
| 261 | +``` | ||
| 262 | + | ||
| 263 | +### 自定义配置 | ||
| 264 | + | ||
| 265 | +```python | ||
| 266 | +custom_config = { | ||
| 267 | + 'asr_config': { | ||
| 268 | + 'enable_punc': True, | ||
| 269 | + 'seg_duration': 300, # 自定义分段时长 | ||
| 270 | + 'streaming_mode': True | ||
| 271 | + }, | ||
| 272 | + 'connection_config': { | ||
| 273 | + 'timeout': 60, # 自定义超时时间 | ||
| 274 | + 'retry_times': 5 | ||
| 275 | + }, | ||
| 276 | + 'logging_config': { | ||
| 277 | + 'enable_debug': True | ||
| 278 | + } | ||
| 279 | +} | ||
| 280 | + | ||
| 281 | +service = create_asr_service( | ||
| 282 | + app_key="your_app_key", | ||
| 283 | + access_key="your_access_key", | ||
| 284 | + custom_config=custom_config | ||
| 285 | +) | ||
| 286 | +``` | ||
| 287 | + | ||
| 288 | +### 配置文件使用 | ||
| 289 | + | ||
| 290 | +```python | ||
| 291 | +# 使用配置文件 | ||
| 292 | +result = await recognize_file( | ||
| 293 | + audio_path="audio.wav", | ||
| 294 | + config_path="asr/doubao/config.json" | ||
| 295 | +) | ||
| 296 | + | ||
| 297 | +# 或者 | ||
| 298 | +service = create_asr_service(config_path="asr/doubao/config.json") | ||
| 299 | +``` | ||
| 300 | + | ||
| 301 | +## 🔧 配置参数 | ||
| 302 | + | ||
| 303 | +### ASR配置 (asr_config) | ||
| 304 | + | ||
| 305 | +| 参数 | 类型 | 默认值 | 说明 | | ||
| 306 | +|------|------|--------|------| | ||
| 307 | +| `ws_url` | str | wss://openspeech.bytedance.com/api/v3/sauc/bigmodel | 流式识别WebSocket URL | | ||
| 308 | +| `ws_url_nostream` | str | wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream | 非流式识别WebSocket URL | | ||
| 309 | +| `resource_id` | str | volc.bigasr.sauc.duration | 资源ID | | ||
| 310 | +| `model_name` | str | bigmodel | 模型名称 | | ||
| 311 | +| `enable_punc` | bool | true | 是否启用标点符号 | | ||
| 312 | +| `streaming_mode` | bool | true | 是否启用流式模式 | | ||
| 313 | +| `seg_duration` | int | 200 | 音频分段时长(ms) | | ||
| 314 | +| `mp3_seg_size` | int | 1000 | MP3分段大小(bytes) | | ||
| 315 | + | ||
| 316 | +### 认证配置 (auth_config) | ||
| 317 | + | ||
| 318 | +| 参数 | 类型 | 说明 | | ||
| 319 | +|------|------|------| | ||
| 320 | +| `app_key` | str | 应用密钥 | | ||
| 321 | +| `access_key` | str | 访问密钥 | | ||
| 322 | + | ||
| 323 | +### 音频配置 (audio_config) | ||
| 324 | + | ||
| 325 | +| 参数 | 类型 | 默认值 | 说明 | | ||
| 326 | +|------|------|--------|------| | ||
| 327 | +| `default_format` | str | wav | 默认音频格式 | | ||
| 328 | +| `default_rate` | int | 16000 | 默认采样率 | | ||
| 329 | +| `default_bits` | int | 16 | 默认位深度 | | ||
| 330 | +| `default_channel` | int | 1 | 默认声道数 | | ||
| 331 | +| `supported_formats` | list | ["wav", "mp3", "pcm"] | 支持的音频格式 | | ||
| 332 | + | ||
| 333 | +### 连接配置 (connection_config) | ||
| 334 | + | ||
| 335 | +| 参数 | 类型 | 默认值 | 说明 | | ||
| 336 | +|------|------|--------|------| | ||
| 337 | +| `max_size` | int | 1000000000 | 最大消息大小 | | ||
| 338 | +| `timeout` | int | 30 | 连接超时时间(秒) | | ||
| 339 | +| `retry_times` | int | 3 | 重试次数 | | ||
| 340 | +| `retry_delay` | int | 1 | 重试延迟(秒) | | ||
| 341 | + | ||
| 342 | +## 🎵 支持的音频格式 | ||
| 343 | + | ||
| 344 | +| 格式 | 扩展名 | 说明 | | ||
| 345 | +|------|--------|------| | ||
| 346 | +| WAV | .wav | 无损音频格式,推荐使用 | | ||
| 347 | +| MP3 | .mp3 | 压缩音频格式 | | ||
| 348 | +| PCM | .pcm | 原始音频数据 | | ||
| 349 | + | ||
| 350 | +### 音频要求 | ||
| 351 | + | ||
| 352 | +- **采样率**: 推荐16kHz,支持8kHz、16kHz、24kHz、48kHz | ||
| 353 | +- **位深度**: 推荐16bit | ||
| 354 | +- **声道**: 推荐单声道(mono) | ||
| 355 | +- **编码**: PCM编码 | ||
| 356 | + | ||
| 357 | +## 🔍 API参考 | ||
| 358 | + | ||
| 359 | +### 便捷函数 | ||
| 360 | + | ||
| 361 | +#### `recognize_file()` | ||
| 362 | + | ||
| 363 | +```python | ||
| 364 | +async def recognize_file( | ||
| 365 | + audio_path: str, | ||
| 366 | + app_key: str = None, | ||
| 367 | + access_key: str = None, | ||
| 368 | + config_path: str = None, | ||
| 369 | + streaming: bool = True, | ||
| 370 | + result_callback: callable = None, | ||
| 371 | + **kwargs | ||
| 372 | +) -> dict: | ||
| 373 | +``` | ||
| 374 | + | ||
| 375 | +识别音频文件。 | ||
| 376 | + | ||
| 377 | +**参数:** | ||
| 378 | +- `audio_path`: 音频文件路径 | ||
| 379 | +- `app_key`: 应用密钥 | ||
| 380 | +- `access_key`: 访问密钥 | ||
| 381 | +- `config_path`: 配置文件路径 | ||
| 382 | +- `streaming`: 是否使用流式识别 | ||
| 383 | +- `result_callback`: 结果回调函数 | ||
| 384 | +- `**kwargs`: 其他配置参数 | ||
| 385 | + | ||
| 386 | +**返回:** 识别结果字典 | ||
| 387 | + | ||
| 388 | +#### `recognize_audio_data()` | ||
| 389 | + | ||
| 390 | +```python | ||
| 391 | +async def recognize_audio_data( | ||
| 392 | + audio_data: bytes, | ||
| 393 | + audio_format: str, | ||
| 394 | + app_key: str = None, | ||
| 395 | + access_key: str = None, | ||
| 396 | + config_path: str = None, | ||
| 397 | + streaming: bool = True, | ||
| 398 | + result_callback: callable = None, | ||
| 399 | + **kwargs | ||
| 400 | +) -> dict: | ||
| 401 | +``` | ||
| 402 | + | ||
| 403 | +识别音频数据。 | ||
| 404 | + | ||
| 405 | +**参数:** | ||
| 406 | +- `audio_data`: 音频数据(bytes) | ||
| 407 | +- `audio_format`: 音频格式("wav", "mp3", "pcm") | ||
| 408 | +- 其他参数同`recognize_file()` | ||
| 409 | + | ||
| 410 | +#### `run_recognition()` | ||
| 411 | + | ||
| 412 | +```python | ||
| 413 | +def run_recognition( | ||
| 414 | + audio_path: str = None, | ||
| 415 | + audio_data: bytes = None, | ||
| 416 | + audio_format: str = None, | ||
| 417 | + **kwargs | ||
| 418 | +) -> dict: | ||
| 419 | +``` | ||
| 420 | + | ||
| 421 | +同步方式执行识别。 | ||
| 422 | + | ||
| 423 | +#### `create_asr_service()` | ||
| 424 | + | ||
| 425 | +```python | ||
| 426 | +def create_asr_service( | ||
| 427 | + app_key: str = None, | ||
| 428 | + access_key: str = None, | ||
| 429 | + config_path: str = None, | ||
| 430 | + custom_config: dict = None, | ||
| 431 | + **kwargs | ||
| 432 | +) -> DoubaoASRService: | ||
| 433 | +``` | ||
| 434 | + | ||
| 435 | +创建ASR服务实例。 | ||
| 436 | + | ||
| 437 | +### 核心类 | ||
| 438 | + | ||
| 439 | +#### `DoubaoASRService` | ||
| 440 | + | ||
| 441 | +主要服务类,提供高级API。 | ||
| 442 | + | ||
| 443 | +**方法:** | ||
| 444 | +- `recognize_file(audio_path, result_callback=None)`: 识别文件 | ||
| 445 | +- `recognize_audio_data(audio_data, audio_format, result_callback=None)`: 识别音频数据 | ||
| 446 | +- `get_status()`: 获取服务状态 | ||
| 447 | +- `close()`: 关闭服务 | ||
| 448 | + | ||
| 449 | +#### `DoubaoASRClient` | ||
| 450 | + | ||
| 451 | +底层客户端类,处理WebSocket通信。 | ||
| 452 | + | ||
| 453 | +#### `ConfigManager` | ||
| 454 | + | ||
| 455 | +配置管理器,处理配置加载、验证、合并等。 | ||
| 456 | + | ||
| 457 | +**方法:** | ||
| 458 | +- `load_config(config_path)`: 加载配置文件 | ||
| 459 | +- `save_config(config, config_path)`: 保存配置文件 | ||
| 460 | +- `validate_config(config)`: 验证配置 | ||
| 461 | +- `merge_configs(base_config, override_config)`: 合并配置 | ||
| 462 | +- `create_default_config()`: 创建默认配置 | ||
| 463 | + | ||
| 464 | +#### `DoubaoResultProcessor` | ||
| 465 | + | ||
| 466 | +结果处理器,专门处理豆包ASR流式识别结果,解决输出完整报文但只需要文本的问题。 | ||
| 467 | + | ||
| 468 | +**特性:** | ||
| 469 | +- 自动提取`payload_msg.result.text`字段 | ||
| 470 | +- 处理流式数据覆盖更新特性 | ||
| 471 | +- 可配置日志输出级别 | ||
| 472 | +- 支持自定义文本回调函数 | ||
| 473 | + | ||
| 474 | +**方法:** | ||
| 475 | +- `extract_text_from_result(result)`: 从完整结果中提取文本 | ||
| 476 | +- `process_streaming_result(result)`: 处理流式结果 | ||
| 477 | +- `create_optimized_callback(user_callback)`: 创建优化的回调函数 | ||
| 478 | +- `get_current_result()`: 获取当前识别状态 | ||
| 479 | +- `reset()`: 重置处理器状态 | ||
| 480 | + | ||
| 481 | +**便捷函数:** | ||
| 482 | +- `create_text_only_callback(user_callback, enable_streaming_log)`: 创建只处理文本的回调 | ||
| 483 | +- `extract_text_only(result)`: 快速提取文本内容 | ||
| 484 | + | ||
| 485 | +**使用示例:** | ||
| 486 | +```python | ||
| 487 | +from asr.doubao.result_processor import DoubaoResultProcessor, create_text_only_callback | ||
| 488 | + | ||
| 489 | +# 方式1: 使用处理器类 | ||
| 490 | +processor = DoubaoResultProcessor(text_only=True, enable_streaming_log=False) | ||
| 491 | +callback = processor.create_optimized_callback(lambda text: print(f"文本: {text}")) | ||
| 492 | + | ||
| 493 | +# 方式2: 使用便捷函数 | ||
| 494 | +callback = create_text_only_callback( | ||
| 495 | + user_callback=lambda text: print(f"文本: {text}"), | ||
| 496 | + enable_streaming_log=False | ||
| 497 | +) | ||
| 498 | + | ||
| 499 | +# 在ASR服务中使用 | ||
| 500 | +service = create_asr_service(...) | ||
| 501 | +await service.recognize_file("audio.wav", result_callback=callback) | ||
| 502 | +``` | ||
| 503 | + | ||
| 504 | +## 🧪 测试 | ||
| 505 | + | ||
| 506 | +运行测试套件: | ||
| 507 | + | ||
| 508 | +```bash | ||
| 509 | +python -m pytest test/test_doubao_asr.py -v | ||
| 510 | +``` | ||
| 511 | + | ||
| 512 | +或者直接运行测试文件: | ||
| 513 | + | ||
| 514 | +```bash | ||
| 515 | +python test/test_doubao_asr.py | ||
| 516 | +``` | ||
| 517 | + | ||
| 518 | +测试包括: | ||
| 519 | +- 单元测试 | ||
| 520 | +- 集成测试 | ||
| 521 | +- 性能测试 | ||
| 522 | +- 错误处理测试 | ||
| 523 | + | ||
| 524 | +## 🔧 故障排除 | ||
| 525 | + | ||
| 526 | +### 常见问题 | ||
| 527 | + | ||
| 528 | +#### 1. 认证失败 | ||
| 529 | + | ||
| 530 | +``` | ||
| 531 | +AuthenticationError: Invalid app_key or access_key | ||
| 532 | +``` | ||
| 533 | + | ||
| 534 | +**解决方案:** | ||
| 535 | +- 检查app_key和access_key是否正确 | ||
| 536 | +- 确认密钥是否已激活 | ||
| 537 | +- 检查网络连接 | ||
| 538 | + | ||
| 539 | +#### 2. 音频格式不支持 | ||
| 540 | + | ||
| 541 | +``` | ||
| 542 | +AudioFormatError: Unsupported audio format | ||
| 543 | +``` | ||
| 544 | + | ||
| 545 | +**解决方案:** | ||
| 546 | +- 确认音频格式为WAV、MP3或PCM | ||
| 547 | +- 检查音频文件是否损坏 | ||
| 548 | +- 转换音频格式到支持的格式 | ||
| 549 | + | ||
| 550 | +#### 3. 连接超时 | ||
| 551 | + | ||
| 552 | +``` | ||
| 553 | +ConnectionTimeoutError: Connection timeout | ||
| 554 | +``` | ||
| 555 | + | ||
| 556 | +**解决方案:** | ||
| 557 | +- 检查网络连接 | ||
| 558 | +- 增加timeout配置 | ||
| 559 | +- 检查防火墙设置 | ||
| 560 | + | ||
| 561 | +#### 4. 音频文件过大 | ||
| 562 | + | ||
| 563 | +``` | ||
| 564 | +FileSizeError: Audio file too large | ||
| 565 | +``` | ||
| 566 | + | ||
| 567 | +**解决方案:** | ||
| 568 | +- 分割音频文件 | ||
| 569 | +- 压缩音频质量 | ||
| 570 | +- 使用流式识别 | ||
| 571 | + | ||
| 572 | +### 调试模式 | ||
| 573 | + | ||
| 574 | +启用调试模式获取详细日志: | ||
| 575 | + | ||
| 576 | +```python | ||
| 577 | +service = create_asr_service( | ||
| 578 | + app_key="your_app_key", | ||
| 579 | + access_key="your_access_key", | ||
| 580 | + debug=True | ||
| 581 | +) | ||
| 582 | +``` | ||
| 583 | + | ||
| 584 | +或在配置文件中设置: | ||
| 585 | + | ||
| 586 | +```json | ||
| 587 | +{ | ||
| 588 | + "logging_config": { | ||
| 589 | + "enable_debug": true, | ||
| 590 | + "log_requests": true, | ||
| 591 | + "log_responses": true | ||
| 592 | + } | ||
| 593 | +} | ||
| 594 | +``` | ||
| 595 | + | ||
| 596 | +## 📈 性能优化 | ||
| 597 | + | ||
| 598 | +### 1. 连接复用 | ||
| 599 | + | ||
| 600 | +对于批量识别,使用服务实例复用连接: | ||
| 601 | + | ||
| 602 | +```python | ||
| 603 | +service = create_asr_service(...) | ||
| 604 | +try: | ||
| 605 | + for audio_file in audio_files: | ||
| 606 | + result = await service.recognize_file(audio_file) | ||
| 607 | +finally: | ||
| 608 | + await service.close() | ||
| 609 | +``` | ||
| 610 | + | ||
| 611 | +### 2. 并发处理 | ||
| 612 | + | ||
| 613 | +使用asyncio进行并发识别: | ||
| 614 | + | ||
| 615 | +```python | ||
| 616 | +import asyncio | ||
| 617 | + | ||
| 618 | +async def concurrent_recognition(audio_files): | ||
| 619 | + tasks = [] | ||
| 620 | + for audio_file in audio_files: | ||
| 621 | + task = recognize_file(audio_file, ...) | ||
| 622 | + tasks.append(task) | ||
| 623 | + | ||
| 624 | + results = await asyncio.gather(*tasks) | ||
| 625 | + return results | ||
| 626 | +``` | ||
| 627 | + | ||
| 628 | +### 3. 音频预处理 | ||
| 629 | + | ||
| 630 | +- 使用合适的音频格式和参数 | ||
| 631 | +- 预先分割大文件 | ||
| 632 | +- 去除静音段 | ||
| 633 | + | ||
| 634 | +## 🤝 贡献 | ||
| 635 | + | ||
| 636 | +欢迎提交Issue和Pull Request! | ||
| 637 | + | ||
| 638 | +### 开发环境设置 | ||
| 639 | + | ||
| 640 | +1. 克隆项目 | ||
| 641 | +2. 安装依赖:`pip install -r requirements.txt` | ||
| 642 | +3. 运行测试:`python -m pytest` | ||
| 643 | +4. 提交代码前请确保测试通过 | ||
| 644 | + | ||
| 645 | +## 📄 许可证 | ||
| 646 | + | ||
| 647 | +MIT License | ||
| 648 | + | ||
| 649 | +## 🔗 相关链接 | ||
| 650 | + | ||
| 651 | +- [豆包语音识别API文档](https://www.volcengine.com/docs/6561/1354869) | ||
| 652 | +- [豆包开放平台](https://www.volcengine.com/) | ||
| 653 | +- [WebSocket协议](https://tools.ietf.org/html/rfc6455) | ||
| 654 | + | ||
| 655 | +## 📞 支持 | ||
| 656 | + | ||
| 657 | +如有问题,请: | ||
| 658 | + | ||
| 659 | +1. 查看本文档的故障排除部分 | ||
| 660 | +2. 搜索已有的Issue | ||
| 661 | +3. 创建新的Issue并提供详细信息 | ||
| 662 | +4. 联系技术支持 | ||
| 663 | + | ||
| 664 | +--- | ||
| 665 | + | ||
| 666 | +**作者**: AIfeng | ||
| 667 | +**版本**: 1.0.0 | ||
| 668 | +**更新时间**: 2025-07-11 |
asr/doubao/__init__.py
0 → 100644
| 1 | +# AIfeng/2025-07-11 13:36:00 | ||
| 2 | +""" | ||
| 3 | +豆包ASR语音识别服务模块 | ||
| 4 | +提供完整的语音识别功能,支持流式和非流式识别 | ||
| 5 | +""" | ||
| 6 | + | ||
| 7 | +__version__ = "1.0.0" | ||
| 8 | +__author__ = "AIfeng" | ||
| 9 | +__description__ = "豆包ASR语音识别服务模块" | ||
| 10 | + | ||
| 11 | +# 导入核心类和函数 | ||
| 12 | +from .asr_client import DoubaoASRClient | ||
| 13 | +from .config_manager import ConfigManager | ||
| 14 | +from .service_factory import ( | ||
| 15 | + DoubaoASRService, | ||
| 16 | + create_asr_service, | ||
| 17 | + recognize_file, | ||
| 18 | + recognize_audio_data, | ||
| 19 | + run_recognition | ||
| 20 | +) | ||
| 21 | +from .protocol import DoubaoProtocol, MessageType, MessageFlags, SerializationMethod, CompressionType | ||
| 22 | +from .audio_utils import AudioProcessor | ||
| 23 | +from .result_processor import ( | ||
| 24 | + DoubaoResultProcessor, | ||
| 25 | + ASRResult, | ||
| 26 | + create_text_only_callback, | ||
| 27 | + extract_text_only | ||
| 28 | +) | ||
| 29 | + | ||
| 30 | +# 公共API | ||
| 31 | +__all__ = [ | ||
| 32 | + # 核心类 | ||
| 33 | + 'DoubaoASRClient', | ||
| 34 | + 'DoubaoASRService', | ||
| 35 | + 'ConfigManager', | ||
| 36 | + 'DoubaoProtocol', | ||
| 37 | + 'AudioProcessor', | ||
| 38 | + 'DoubaoResultProcessor', | ||
| 39 | + 'ASRResult', | ||
| 40 | + | ||
| 41 | + # 便捷函数 | ||
| 42 | + 'create_asr_service', | ||
| 43 | + 'recognize_file', | ||
| 44 | + 'recognize_audio_data', | ||
| 45 | + 'run_recognition', | ||
| 46 | + 'create_text_only_callback', | ||
| 47 | + 'extract_text_only', | ||
| 48 | + | ||
| 49 | + # 协议常量 | ||
| 50 | + 'MessageType', | ||
| 51 | + 'MessageFlags', | ||
| 52 | + 'SerializationMethod', | ||
| 53 | + 'CompressionType', | ||
| 54 | + | ||
| 55 | + # 版本信息 | ||
| 56 | + '__version__', | ||
| 57 | + '__author__', | ||
| 58 | + '__description__' | ||
| 59 | +] | ||
| 60 | + | ||
| 61 | + | ||
| 62 | +# 快速开始示例 | ||
| 63 | +def get_quick_start_example() -> str: | ||
| 64 | + """ | ||
| 65 | + 获取快速开始示例代码 | ||
| 66 | + | ||
| 67 | + Returns: | ||
| 68 | + str: 示例代码 | ||
| 69 | + """ | ||
| 70 | + return ''' | ||
| 71 | +# 豆包ASR快速开始示例 | ||
| 72 | + | ||
| 73 | +import asyncio | ||
| 74 | +from asr.doubao import recognize_file, create_asr_service | ||
| 75 | + | ||
| 76 | +# 方式1: 使用便捷函数(推荐用于简单场景) | ||
| 77 | +async def simple_recognition(): | ||
| 78 | + result = await recognize_file( | ||
| 79 | + audio_path="path/to/your/audio.wav", | ||
| 80 | + app_key="your_app_key", | ||
| 81 | + access_key="your_access_key", | ||
| 82 | + streaming=True | ||
| 83 | + ) | ||
| 84 | + print(result) | ||
| 85 | + | ||
| 86 | +# 方式2: 使用服务实例(推荐用于复杂场景) | ||
| 87 | +async def advanced_recognition(): | ||
| 88 | + # 创建服务实例 | ||
| 89 | + service = create_asr_service( | ||
| 90 | + app_key="your_app_key", | ||
| 91 | + access_key="your_access_key", | ||
| 92 | + streaming=True, | ||
| 93 | + debug=True | ||
| 94 | + ) | ||
| 95 | + | ||
| 96 | + # 定义结果回调函数 | ||
| 97 | + def on_result(result): | ||
| 98 | + if result.get('payload_msg'): | ||
| 99 | + print(f"实时结果: {result['payload_msg']}") | ||
| 100 | + | ||
| 101 | + try: | ||
| 102 | + # 执行识别 | ||
| 103 | + result = await service.recognize_file( | ||
| 104 | + "path/to/your/audio.wav", | ||
| 105 | + result_callback=on_result | ||
| 106 | + ) | ||
| 107 | + print(f"最终结果: {result}") | ||
| 108 | + finally: | ||
| 109 | + await service.close() | ||
| 110 | + | ||
| 111 | +# 方式3: 使用配置文件 | ||
| 112 | +async def config_based_recognition(): | ||
| 113 | + result = await recognize_file( | ||
| 114 | + audio_path="path/to/your/audio.wav", | ||
| 115 | + config_path="path/to/config.json" | ||
| 116 | + ) | ||
| 117 | + print(result) | ||
| 118 | + | ||
| 119 | +# 同步方式(简单场景) | ||
| 120 | +def sync_recognition(): | ||
| 121 | + from asr.doubao import run_recognition | ||
| 122 | + | ||
| 123 | + result = run_recognition( | ||
| 124 | + audio_path="path/to/your/audio.wav", | ||
| 125 | + app_key="your_app_key", | ||
| 126 | + access_key="your_access_key" | ||
| 127 | + ) | ||
| 128 | + print(result) | ||
| 129 | + | ||
| 130 | +# 运行示例 | ||
| 131 | +if __name__ == "__main__": | ||
| 132 | + # 选择一种方式运行 | ||
| 133 | + asyncio.run(simple_recognition()) | ||
| 134 | + # asyncio.run(advanced_recognition()) | ||
| 135 | + # asyncio.run(config_based_recognition()) | ||
| 136 | + # sync_recognition() | ||
| 137 | +''' | ||
| 138 | + | ||
| 139 | + | ||
| 140 | +def get_config_template() -> str: | ||
| 141 | + """ | ||
| 142 | + 获取配置文件模板 | ||
| 143 | + | ||
| 144 | + Returns: | ||
| 145 | + str: 配置文件模板 | ||
| 146 | + """ | ||
| 147 | + return ''' | ||
| 148 | +{ | ||
| 149 | + "asr_config": { | ||
| 150 | + "ws_url": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel", | ||
| 151 | + "ws_url_nostream": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream", | ||
| 152 | + "resource_id": "volc.bigasr.sauc.duration", | ||
| 153 | + "model_name": "bigmodel", | ||
| 154 | + "enable_punc": true, | ||
| 155 | + "streaming_mode": true, | ||
| 156 | + "seg_duration": 200, | ||
| 157 | + "mp3_seg_size": 1000 | ||
| 158 | + }, | ||
| 159 | + "auth_config": { | ||
| 160 | + "app_key": "your_app_key_here", | ||
| 161 | + "access_key": "your_access_key_here" | ||
| 162 | + }, | ||
| 163 | + "audio_config": { | ||
| 164 | + "default_format": "wav", | ||
| 165 | + "default_rate": 16000, | ||
| 166 | + "default_bits": 16, | ||
| 167 | + "default_channel": 1, | ||
| 168 | + "default_codec": "raw", | ||
| 169 | + "supported_formats": ["wav", "mp3", "pcm"] | ||
| 170 | + }, | ||
| 171 | + "connection_config": { | ||
| 172 | + "max_size": 1000000000, | ||
| 173 | + "timeout": 30, | ||
| 174 | + "retry_times": 3, | ||
| 175 | + "retry_delay": 1 | ||
| 176 | + }, | ||
| 177 | + "logging_config": { | ||
| 178 | + "enable_debug": false, | ||
| 179 | + "log_requests": true, | ||
| 180 | + "log_responses": true | ||
| 181 | + } | ||
| 182 | +} | ||
| 183 | +''' | ||
| 184 | + | ||
| 185 | + | ||
| 186 | +def print_info(): | ||
| 187 | + """ | ||
| 188 | + 打印模块信息 | ||
| 189 | + """ | ||
| 190 | + print(f"豆包ASR语音识别服务模块 v{__version__}") | ||
| 191 | + print(f"作者: {__author__}") | ||
| 192 | + print(f"描述: {__description__}") | ||
| 193 | + print("\n支持的功能:") | ||
| 194 | + print("- 流式语音识别") | ||
| 195 | + print("- 非流式语音识别") | ||
| 196 | + print("- 多种音频格式支持 (WAV, MP3, PCM)") | ||
| 197 | + print("- 灵活的配置管理") | ||
| 198 | + print("- 异步和同步API") | ||
| 199 | + print("- 实时结果回调") | ||
| 200 | + print("\n快速开始:") | ||
| 201 | + print("from asr.doubao import recognize_file") | ||
| 202 | + print("result = await recognize_file('audio.wav', app_key='...', access_key='...')") | ||
| 203 | + | ||
| 204 | + | ||
| 205 | +if __name__ == "__main__": | ||
| 206 | + print_info() |
asr/doubao/__main__.py
0 → 100644
| 1 | +# AIfeng/2025-07-11 14:15:00 | ||
| 2 | +""" | ||
| 3 | +豆包ASR模块主入口 | ||
| 4 | +支持通过 python -m asr.doubao 运行示例 | ||
| 5 | +""" | ||
| 6 | + | ||
| 7 | +if __name__ == '__main__': | ||
| 8 | + from .example import run_all_examples | ||
| 9 | + import asyncio | ||
| 10 | + | ||
| 11 | + print("=== 豆包ASR语音识别服务示例 ===") | ||
| 12 | + print("正在运行所有示例...") | ||
| 13 | + | ||
| 14 | + try: | ||
| 15 | + asyncio.run(run_all_examples()) | ||
| 16 | + except KeyboardInterrupt: | ||
| 17 | + print("\n用户中断执行") | ||
| 18 | + except Exception as e: | ||
| 19 | + print(f"执行失败: {e}") | ||
| 20 | + print("请确保已设置环境变量: DOUBAO_APP_KEY, DOUBAO_ACCESS_KEY") | ||
| 21 | + print("并准备好测试音频文件") |
asr/doubao/asr_client.py
0 → 100644
| 1 | +# AIfeng/2025-07-11 13:36:00 | ||
| 2 | +""" | ||
| 3 | +豆包ASR客户端核心模块 | ||
| 4 | +提供完整的语音识别服务接口,支持流式和非流式识别 | ||
| 5 | +""" | ||
| 6 | + | ||
| 7 | +import asyncio | ||
| 8 | +import json | ||
| 9 | +import logging | ||
| 10 | +import time | ||
| 11 | +import uuid | ||
| 12 | +from pathlib import Path | ||
| 13 | +from typing import Dict, Any, Optional, Callable, AsyncGenerator | ||
| 14 | + | ||
| 15 | +import aiofiles | ||
| 16 | +import websockets | ||
| 17 | +from websockets.exceptions import ConnectionClosedError, WebSocketException | ||
| 18 | + | ||
| 19 | +from .protocol import DoubaoProtocol, MessageType | ||
| 20 | +from .audio_utils import AudioProcessor | ||
| 21 | + | ||
| 22 | + | ||
| 23 | +class DoubaoASRClient: | ||
| 24 | + """豆包ASR客户端""" | ||
| 25 | + | ||
| 26 | + def __init__(self, config: Dict[str, Any]): | ||
| 27 | + """ | ||
| 28 | + 初始化ASR客户端 | ||
| 29 | + | ||
| 30 | + Args: | ||
| 31 | + config: 配置字典 | ||
| 32 | + """ | ||
| 33 | + self.config = config | ||
| 34 | + self.asr_config = config.get('asr_config', {}) | ||
| 35 | + self.auth_config = config.get('auth_config', {}) | ||
| 36 | + self.audio_config = config.get('audio_config', {}) | ||
| 37 | + self.connection_config = config.get('connection_config', {}) | ||
| 38 | + self.logging_config = config.get('logging_config', {}) | ||
| 39 | + | ||
| 40 | + # 设置日志 | ||
| 41 | + self.logger = self._setup_logger() | ||
| 42 | + | ||
| 43 | + # 协议处理器 | ||
| 44 | + self.protocol = DoubaoProtocol() | ||
| 45 | + | ||
| 46 | + # 音频处理器 | ||
| 47 | + self.audio_processor = AudioProcessor() | ||
| 48 | + | ||
| 49 | + # 连接状态 | ||
| 50 | + self.is_connected = False | ||
| 51 | + self.current_session_id = None | ||
| 52 | + | ||
| 53 | + def _setup_logger(self) -> logging.Logger: | ||
| 54 | + """设置日志记录器""" | ||
| 55 | + logger = logging.getLogger('doubao_asr') | ||
| 56 | + if not logger.handlers: | ||
| 57 | + handler = logging.StreamHandler() | ||
| 58 | + formatter = logging.Formatter( | ||
| 59 | + '%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||
| 60 | + ) | ||
| 61 | + handler.setFormatter(formatter) | ||
| 62 | + logger.addHandler(handler) | ||
| 63 | + | ||
| 64 | + if self.logging_config.get('enable_debug', False): | ||
| 65 | + logger.setLevel(logging.DEBUG) | ||
| 66 | + else: | ||
| 67 | + logger.setLevel(logging.INFO) | ||
| 68 | + | ||
| 69 | + return logger | ||
| 70 | + | ||
| 71 | + def _get_ws_url(self, streaming: bool = True) -> str: | ||
| 72 | + """获取WebSocket URL""" | ||
| 73 | + if streaming: | ||
| 74 | + return self.asr_config.get('ws_url', 'wss://openspeech.bytedance.com/api/v3/sauc/bigmodel') | ||
| 75 | + else: | ||
| 76 | + return self.asr_config.get('ws_url_nostream', 'wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream') | ||
| 77 | + | ||
| 78 | + def _build_auth_headers(self, request_id: str) -> Dict[str, str]: | ||
| 79 | + """构建认证头部""" | ||
| 80 | + headers = { | ||
| 81 | + 'X-Api-Resource-Id': self.asr_config.get('resource_id', 'volc.bigasr.sauc.duration'), | ||
| 82 | + 'X-Api-Access-Key': self.auth_config.get('access_key', ''), | ||
| 83 | + 'X-Api-App-Key': self.auth_config.get('app_key', ''), | ||
| 84 | + 'X-Api-Request-Id': request_id | ||
| 85 | + } | ||
| 86 | + return headers | ||
| 87 | + | ||
| 88 | + def _build_request_params( | ||
| 89 | + self, | ||
| 90 | + request_id: str, | ||
| 91 | + audio_format: str = 'wav', | ||
| 92 | + sample_rate: int = 16000, | ||
| 93 | + bits: int = 16, | ||
| 94 | + channels: int = 1, | ||
| 95 | + uid: str = 'default_user' | ||
| 96 | + ) -> Dict[str, Any]: | ||
| 97 | + """构建请求参数""" | ||
| 98 | + return { | ||
| 99 | + 'user': { | ||
| 100 | + 'uid': uid | ||
| 101 | + }, | ||
| 102 | + 'audio': { | ||
| 103 | + 'format': audio_format, | ||
| 104 | + 'sample_rate': sample_rate, | ||
| 105 | + 'bits': bits, | ||
| 106 | + 'channel': channels, | ||
| 107 | + 'codec': self.audio_config.get('default_codec', 'raw') | ||
| 108 | + }, | ||
| 109 | + 'request': { | ||
| 110 | + 'model_name': self.asr_config.get('model_name', 'bigmodel'), | ||
| 111 | + 'enable_punc': self.asr_config.get('enable_punc', True) | ||
| 112 | + } | ||
| 113 | + } | ||
| 114 | + | ||
| 115 | + async def recognize_file( | ||
| 116 | + self, | ||
| 117 | + audio_path: str, | ||
| 118 | + streaming: bool = True, | ||
| 119 | + result_callback: Optional[Callable[[Dict[str, Any]], None]] = None, | ||
| 120 | + **kwargs | ||
| 121 | + ) -> Dict[str, Any]: | ||
| 122 | + """ | ||
| 123 | + 识别音频文件 | ||
| 124 | + | ||
| 125 | + Args: | ||
| 126 | + audio_path: 音频文件路径 | ||
| 127 | + streaming: 是否使用流式识别 | ||
| 128 | + result_callback: 结果回调函数 | ||
| 129 | + **kwargs: 其他参数 | ||
| 130 | + | ||
| 131 | + Returns: | ||
| 132 | + Dict: 识别结果 | ||
| 133 | + """ | ||
| 134 | + try: | ||
| 135 | + # 读取音频文件 | ||
| 136 | + async with aiofiles.open(audio_path, mode='rb') as f: | ||
| 137 | + audio_data = await f.read() | ||
| 138 | + | ||
| 139 | + self.logger.info(f"开始识别音频文件: {audio_path}, 大小: {len(audio_data)} 字节") | ||
| 140 | + | ||
| 141 | + # 识别音频数据 | ||
| 142 | + return await self.recognize_audio_data( | ||
| 143 | + audio_data, | ||
| 144 | + streaming=streaming, | ||
| 145 | + result_callback=result_callback, | ||
| 146 | + **kwargs | ||
| 147 | + ) | ||
| 148 | + | ||
| 149 | + except Exception as e: | ||
| 150 | + self.logger.error(f"识别音频文件失败: {e}") | ||
| 151 | + return { | ||
| 152 | + 'success': False, | ||
| 153 | + 'error': str(e), | ||
| 154 | + 'audio_path': audio_path | ||
| 155 | + } | ||
| 156 | + | ||
| 157 | + async def recognize_audio_data( | ||
| 158 | + self, | ||
| 159 | + audio_data: bytes, | ||
| 160 | + streaming: bool = True, | ||
| 161 | + result_callback: Optional[Callable[[Dict[str, Any]], None]] = None, | ||
| 162 | + **kwargs | ||
| 163 | + ) -> Dict[str, Any]: | ||
| 164 | + """ | ||
| 165 | + 识别音频数据 | ||
| 166 | + | ||
| 167 | + Args: | ||
| 168 | + audio_data: 音频数据 | ||
| 169 | + streaming: 是否使用流式识别 | ||
| 170 | + result_callback: 结果回调函数 | ||
| 171 | + **kwargs: 其他参数 | ||
| 172 | + | ||
| 173 | + Returns: | ||
| 174 | + Dict: 识别结果 | ||
| 175 | + """ | ||
| 176 | + request_id = str(uuid.uuid4()) | ||
| 177 | + self.current_session_id = request_id | ||
| 178 | + | ||
| 179 | + try: | ||
| 180 | + # 准备音频数据 | ||
| 181 | + audio_format, segment_size, metadata = self.audio_processor.prepare_audio_for_recognition( | ||
| 182 | + audio_data, | ||
| 183 | + segment_duration_ms=self.asr_config.get('seg_duration', 200) | ||
| 184 | + ) | ||
| 185 | + | ||
| 186 | + self.logger.info(f"音频格式: {audio_format}, 分片大小: {segment_size}, 元数据: {metadata}") | ||
| 187 | + | ||
| 188 | + # 构建请求参数 | ||
| 189 | + request_params = self._build_request_params( | ||
| 190 | + request_id, | ||
| 191 | + audio_format=audio_format, | ||
| 192 | + sample_rate=metadata.get('sample_rate', 16000), | ||
| 193 | + bits=metadata.get('sample_width', 2) * 8, | ||
| 194 | + channels=metadata.get('channels', 1), | ||
| 195 | + uid=kwargs.get('uid', 'default_user') | ||
| 196 | + ) | ||
| 197 | + | ||
| 198 | + # 执行识别 | ||
| 199 | + if streaming: | ||
| 200 | + return await self._streaming_recognize( | ||
| 201 | + audio_data, | ||
| 202 | + request_params, | ||
| 203 | + segment_size, | ||
| 204 | + request_id, | ||
| 205 | + result_callback | ||
| 206 | + ) | ||
| 207 | + else: | ||
| 208 | + return await self._non_streaming_recognize( | ||
| 209 | + audio_data, | ||
| 210 | + request_params, | ||
| 211 | + request_id | ||
| 212 | + ) | ||
| 213 | + | ||
| 214 | + except Exception as e: | ||
| 215 | + self.logger.error(f"识别音频数据失败: {e}") | ||
| 216 | + return { | ||
| 217 | + 'success': False, | ||
| 218 | + 'error': str(e), | ||
| 219 | + 'request_id': request_id | ||
| 220 | + } | ||
| 221 | + | ||
| 222 | + async def _streaming_recognize( | ||
| 223 | + self, | ||
| 224 | + audio_data: bytes, | ||
| 225 | + request_params: Dict[str, Any], | ||
| 226 | + segment_size: int, | ||
| 227 | + request_id: str, | ||
| 228 | + result_callback: Optional[Callable[[Dict[str, Any]], None]] = None | ||
| 229 | + ) -> Dict[str, Any]: | ||
| 230 | + """流式识别处理""" | ||
| 231 | + ws_url = self._get_ws_url(streaming=True) | ||
| 232 | + headers = self._build_auth_headers(request_id) | ||
| 233 | + | ||
| 234 | + results = [] | ||
| 235 | + final_result = None | ||
| 236 | + | ||
| 237 | + try: | ||
| 238 | + # 兼容不同版本的websockets库 | ||
| 239 | + connect_kwargs = { | ||
| 240 | + 'uri': ws_url, | ||
| 241 | + 'max_size': self.connection_config.get('max_size', 1000000000) | ||
| 242 | + } | ||
| 243 | + | ||
| 244 | + # 尝试使用新版本的additional_headers参数 | ||
| 245 | + try: | ||
| 246 | + async with websockets.connect( | ||
| 247 | + **connect_kwargs, | ||
| 248 | + additional_headers=headers | ||
| 249 | + ) as ws: | ||
| 250 | + await self._handle_streaming_connection(ws, audio_data, request_params, segment_size, request_id, result_callback, results, final_result) | ||
| 251 | + except TypeError: | ||
| 252 | + # 回退到旧版本的extra_headers参数 | ||
| 253 | + async with websockets.connect( | ||
| 254 | + **connect_kwargs, | ||
| 255 | + extra_headers=headers | ||
| 256 | + ) as ws: | ||
| 257 | + await self._handle_streaming_connection(ws, audio_data, request_params, segment_size, request_id, result_callback, results, final_result) | ||
| 258 | + | ||
| 259 | + return { | ||
| 260 | + 'success': True, | ||
| 261 | + 'request_id': request_id, | ||
| 262 | + 'results': results, | ||
| 263 | + 'final_result': final_result, | ||
| 264 | + 'total_results': len(results) | ||
| 265 | + } | ||
| 266 | + | ||
| 267 | + except ConnectionClosedError as e: | ||
| 268 | + self.logger.error(f"WebSocket连接关闭: {e.code} - {e.reason}") | ||
| 269 | + return { | ||
| 270 | + 'success': False, | ||
| 271 | + 'error': f"连接关闭: {e.reason}", | ||
| 272 | + 'error_code': e.code, | ||
| 273 | + 'request_id': request_id | ||
| 274 | + } | ||
| 275 | + | ||
| 276 | + except WebSocketException as e: | ||
| 277 | + self.logger.error(f"WebSocket异常: {e}") | ||
| 278 | + return { | ||
| 279 | + 'success': False, | ||
| 280 | + 'error': str(e), | ||
| 281 | + 'request_id': request_id | ||
| 282 | + } | ||
| 283 | + | ||
| 284 | + except Exception as e: | ||
| 285 | + self.logger.error(f"流式识别异常: {e}") | ||
| 286 | + return { | ||
| 287 | + 'success': False, | ||
| 288 | + 'error': str(e), | ||
| 289 | + 'request_id': request_id | ||
| 290 | + } | ||
| 291 | + | ||
| 292 | + finally: | ||
| 293 | + self.is_connected = False | ||
| 294 | + | ||
| 295 | + async def _handle_streaming_connection( | ||
| 296 | + self, | ||
| 297 | + ws, | ||
| 298 | + audio_data: bytes, | ||
| 299 | + request_params: Dict[str, Any], | ||
| 300 | + segment_size: int, | ||
| 301 | + request_id: str, | ||
| 302 | + result_callback: Optional[Callable[[Dict[str, Any]], None]], | ||
| 303 | + results: list, | ||
| 304 | + final_result: Any | ||
| 305 | + ): | ||
| 306 | + """处理流式连接的核心逻辑""" | ||
| 307 | + self.is_connected = True | ||
| 308 | + self.logger.info(f"WebSocket连接建立成功") | ||
| 309 | + | ||
| 310 | + # 发送初始请求 | ||
| 311 | + seq = 1 | ||
| 312 | + full_request = self.protocol.build_full_request(request_params, seq) | ||
| 313 | + await ws.send(full_request) | ||
| 314 | + | ||
| 315 | + # 接收初始响应 | ||
| 316 | + response = await ws.recv() | ||
| 317 | + result = self.protocol.parse_response(response) | ||
| 318 | + | ||
| 319 | + if self.logging_config.get('log_responses', True): | ||
| 320 | + self.logger.debug(f"初始响应: {result}") | ||
| 321 | + | ||
| 322 | + # 分片发送音频数据 | ||
| 323 | + for chunk, is_last in self.audio_processor.slice_audio_data(audio_data, segment_size): | ||
| 324 | + seq += 1 | ||
| 325 | + if is_last: | ||
| 326 | + seq = -seq | ||
| 327 | + | ||
| 328 | + start_time = time.time() | ||
| 329 | + | ||
| 330 | + # 构建音频请求 | ||
| 331 | + audio_request = self.protocol.build_audio_request( | ||
| 332 | + chunk, seq, is_last | ||
| 333 | + ) | ||
| 334 | + | ||
| 335 | + # 发送音频数据 | ||
| 336 | + await ws.send(audio_request) | ||
| 337 | + | ||
| 338 | + # 接收响应 | ||
| 339 | + response = await ws.recv() | ||
| 340 | + result = self.protocol.parse_response(response) | ||
| 341 | + | ||
| 342 | + # 处理结果 | ||
| 343 | + if result.get('payload_msg'): | ||
| 344 | + results.append(result) | ||
| 345 | + | ||
| 346 | + # 调用回调函数 | ||
| 347 | + if result_callback: | ||
| 348 | + try: | ||
| 349 | + result_callback(result) | ||
| 350 | + except Exception as e: | ||
| 351 | + self.logger.warning(f"回调函数执行失败: {e}") | ||
| 352 | + | ||
| 353 | + if result.get('is_last_package'): | ||
| 354 | + final_result = result | ||
| 355 | + break | ||
| 356 | + | ||
| 357 | + # 流式识别延时控制 | ||
| 358 | + if self.asr_config.get('streaming_mode', True): | ||
| 359 | + elapsed = time.time() - start_time | ||
| 360 | + sleep_time = max(0, (self.asr_config.get('seg_duration', 200) / 1000.0) - elapsed) | ||
| 361 | + if sleep_time > 0: | ||
| 362 | + await asyncio.sleep(sleep_time) | ||
| 363 | + | ||
| 364 | + async def _non_streaming_recognize( | ||
| 365 | + self, | ||
| 366 | + audio_data: bytes, | ||
| 367 | + request_params: Dict[str, Any], | ||
| 368 | + request_id: str | ||
| 369 | + ) -> Dict[str, Any]: | ||
| 370 | + """非流式识别处理""" | ||
| 371 | + ws_url = self._get_ws_url(streaming=False) | ||
| 372 | + headers = self._build_auth_headers(request_id) | ||
| 373 | + | ||
| 374 | + try: | ||
| 375 | + # 兼容不同版本的websockets库 | ||
| 376 | + connect_kwargs = { | ||
| 377 | + 'uri': ws_url, | ||
| 378 | + 'max_size': self.connection_config.get('max_size', 1000000000) | ||
| 379 | + } | ||
| 380 | + | ||
| 381 | + # 尝试使用新版本的additional_headers参数 | ||
| 382 | + try: | ||
| 383 | + async with websockets.connect( | ||
| 384 | + **connect_kwargs, | ||
| 385 | + additional_headers=headers | ||
| 386 | + ) as ws: | ||
| 387 | + return await self._handle_non_streaming_connection(ws, audio_data, request_params, request_id) | ||
| 388 | + except TypeError: | ||
| 389 | + # 回退到旧版本的extra_headers参数 | ||
| 390 | + async with websockets.connect( | ||
| 391 | + **connect_kwargs, | ||
| 392 | + extra_headers=headers | ||
| 393 | + ) as ws: | ||
| 394 | + return await self._handle_non_streaming_connection(ws, audio_data, request_params, request_id) | ||
| 395 | + | ||
| 396 | + except Exception as e: | ||
| 397 | + self.logger.error(f"非流式识别异常: {e}") | ||
| 398 | + return { | ||
| 399 | + 'success': False, | ||
| 400 | + 'error': str(e), | ||
| 401 | + 'request_id': request_id | ||
| 402 | + } | ||
| 403 | + | ||
| 404 | + finally: | ||
| 405 | + self.is_connected = False | ||
| 406 | + | ||
| 407 | + async def _handle_non_streaming_connection( | ||
| 408 | + self, | ||
| 409 | + ws, | ||
| 410 | + audio_data: bytes, | ||
| 411 | + request_params: Dict[str, Any], | ||
| 412 | + request_id: str | ||
| 413 | + ) -> Dict[str, Any]: | ||
| 414 | + """处理非流式连接的核心逻辑""" | ||
| 415 | + self.is_connected = True | ||
| 416 | + self.logger.info(f"WebSocket连接建立成功") | ||
| 417 | + | ||
| 418 | + # 发送完整请求(包含音频数据) | ||
| 419 | + full_request = self.protocol.build_full_request(request_params, 1) | ||
| 420 | + await ws.send(full_request) | ||
| 421 | + | ||
| 422 | + # 发送音频数据 | ||
| 423 | + audio_request = self.protocol.build_audio_request( | ||
| 424 | + audio_data, -1, is_last=True | ||
| 425 | + ) | ||
| 426 | + await ws.send(audio_request) | ||
| 427 | + | ||
| 428 | + # 接收最终结果 | ||
| 429 | + response = await ws.recv() | ||
| 430 | + result = self.protocol.parse_response(response) | ||
| 431 | + | ||
| 432 | + self.is_connected = False | ||
| 433 | + | ||
| 434 | + return { | ||
| 435 | + 'success': True, | ||
| 436 | + 'request_id': request_id, | ||
| 437 | + 'result': result | ||
| 438 | + } | ||
| 439 | + | ||
| 440 | + async def close(self): | ||
| 441 | + """关闭客户端""" | ||
| 442 | + self.is_connected = False | ||
| 443 | + self.current_session_id = None | ||
| 444 | + self.logger.info("ASR客户端已关闭") | ||
| 445 | + | ||
| 446 | + def get_status(self) -> Dict[str, Any]: | ||
| 447 | + """获取客户端状态""" | ||
| 448 | + return { | ||
| 449 | + 'is_connected': self.is_connected, | ||
| 450 | + 'current_session_id': self.current_session_id, | ||
| 451 | + 'config': { | ||
| 452 | + 'ws_url': self._get_ws_url(), | ||
| 453 | + 'model_name': self.asr_config.get('model_name'), | ||
| 454 | + 'streaming_mode': self.asr_config.get('streaming_mode') | ||
| 455 | + } | ||
| 456 | + } |
asr/doubao/audio_utils.py
0 → 100644
| 1 | +# AIfeng/2025-07-11 13:36:00 | ||
| 2 | +""" | ||
| 3 | +豆包ASR音频处理工具模块 | ||
| 4 | +提供音频格式检测、分片处理、元数据提取等功能 | ||
| 5 | +""" | ||
| 6 | + | ||
| 7 | +import wave | ||
| 8 | +from io import BytesIO | ||
| 9 | +from typing import Tuple, Generator, Dict, Any | ||
| 10 | + | ||
| 11 | + | ||
| 12 | +class AudioProcessor: | ||
| 13 | + """音频处理器""" | ||
| 14 | + | ||
| 15 | + @staticmethod | ||
| 16 | + def read_wav_info(audio_data: bytes) -> Tuple[int, int, int, int, bytes]: | ||
| 17 | + """ | ||
| 18 | + 读取WAV文件信息 | ||
| 19 | + | ||
| 20 | + Args: | ||
| 21 | + audio_data: WAV音频数据 | ||
| 22 | + | ||
| 23 | + Returns: | ||
| 24 | + Tuple: (声道数, 采样宽度, 采样率, 帧数, 音频字节数据) | ||
| 25 | + """ | ||
| 26 | + try: | ||
| 27 | + with BytesIO(audio_data) as audio_io: | ||
| 28 | + with wave.open(audio_io, 'rb') as wave_fp: | ||
| 29 | + nchannels, sampwidth, framerate, nframes = wave_fp.getparams()[:4] | ||
| 30 | + wave_bytes = wave_fp.readframes(nframes) | ||
| 31 | + return nchannels, sampwidth, framerate, nframes, wave_bytes | ||
| 32 | + except Exception as e: | ||
| 33 | + raise ValueError(f"读取WAV文件失败: {e}") | ||
| 34 | + | ||
| 35 | + @staticmethod | ||
| 36 | + def is_wav_format(audio_data: bytes) -> bool: | ||
| 37 | + """ | ||
| 38 | + 检查是否为WAV格式 | ||
| 39 | + | ||
| 40 | + Args: | ||
| 41 | + audio_data: 音频数据 | ||
| 42 | + | ||
| 43 | + Returns: | ||
| 44 | + bool: 是否为WAV格式 | ||
| 45 | + """ | ||
| 46 | + if len(audio_data) < 44: | ||
| 47 | + return False | ||
| 48 | + return audio_data[0:4] == b"RIFF" and audio_data[8:12] == b"WAVE" | ||
| 49 | + | ||
| 50 | + @staticmethod | ||
| 51 | + def detect_audio_format(audio_data: bytes) -> str: | ||
| 52 | + """ | ||
| 53 | + 检测音频格式 | ||
| 54 | + | ||
| 55 | + Args: | ||
| 56 | + audio_data: 音频数据 | ||
| 57 | + | ||
| 58 | + Returns: | ||
| 59 | + str: 音频格式 ('wav', 'mp3', 'pcm', 'unknown') | ||
| 60 | + """ | ||
| 61 | + if len(audio_data) < 4: | ||
| 62 | + return 'unknown' | ||
| 63 | + | ||
| 64 | + # 检查WAV格式 | ||
| 65 | + if AudioProcessor.is_wav_format(audio_data): | ||
| 66 | + return 'wav' | ||
| 67 | + | ||
| 68 | + # 检查MP3格式 | ||
| 69 | + if audio_data[0:3] == b"ID3" or audio_data[0:2] == b"\xff\xfb": | ||
| 70 | + return 'mp3' | ||
| 71 | + | ||
| 72 | + # 默认为PCM | ||
| 73 | + return 'pcm' | ||
| 74 | + | ||
| 75 | + @staticmethod | ||
| 76 | + def slice_audio_data( | ||
| 77 | + audio_data: bytes, | ||
| 78 | + chunk_size: int | ||
| 79 | + ) -> Generator[Tuple[bytes, bool], None, None]: | ||
| 80 | + """ | ||
| 81 | + 将音频数据分片 | ||
| 82 | + | ||
| 83 | + Args: | ||
| 84 | + audio_data: 音频数据 | ||
| 85 | + chunk_size: 分片大小 | ||
| 86 | + | ||
| 87 | + Yields: | ||
| 88 | + Tuple[bytes, bool]: (音频片段, 是否为最后一片) | ||
| 89 | + """ | ||
| 90 | + data_len = len(audio_data) | ||
| 91 | + offset = 0 | ||
| 92 | + | ||
| 93 | + while offset + chunk_size < data_len: | ||
| 94 | + yield audio_data[offset:offset + chunk_size], False | ||
| 95 | + offset += chunk_size | ||
| 96 | + | ||
| 97 | + # 最后一片 | ||
| 98 | + if offset < data_len: | ||
| 99 | + yield audio_data[offset:data_len], True | ||
| 100 | + | ||
| 101 | + @staticmethod | ||
| 102 | + def calculate_segment_size( | ||
| 103 | + audio_format: str, | ||
| 104 | + sample_rate: int = 16000, | ||
| 105 | + channels: int = 1, | ||
| 106 | + bits: int = 16, | ||
| 107 | + segment_duration_ms: int = 200, | ||
| 108 | + mp3_seg_size: int = 1000 | ||
| 109 | + ) -> int: | ||
| 110 | + """ | ||
| 111 | + 计算音频分片大小 | ||
| 112 | + | ||
| 113 | + Args: | ||
| 114 | + audio_format: 音频格式 | ||
| 115 | + sample_rate: 采样率 | ||
| 116 | + channels: 声道数 | ||
| 117 | + bits: 位深度 | ||
| 118 | + segment_duration_ms: 分片时长(毫秒) | ||
| 119 | + mp3_seg_size: MP3分片大小 | ||
| 120 | + | ||
| 121 | + Returns: | ||
| 122 | + int: 分片大小(字节) | ||
| 123 | + """ | ||
| 124 | + if audio_format == 'mp3': | ||
| 125 | + return mp3_seg_size | ||
| 126 | + elif audio_format == 'wav': | ||
| 127 | + # 计算每秒字节数 | ||
| 128 | + bytes_per_second = channels * (bits // 8) * sample_rate | ||
| 129 | + return int(bytes_per_second * segment_duration_ms / 1000) | ||
| 130 | + elif audio_format == 'pcm': | ||
| 131 | + # PCM格式计算 | ||
| 132 | + return int(sample_rate * (bits // 8) * channels * segment_duration_ms / 1000) | ||
| 133 | + else: | ||
| 134 | + raise ValueError(f"不支持的音频格式: {audio_format}") | ||
| 135 | + | ||
| 136 | + @staticmethod | ||
| 137 | + def extract_wav_metadata(audio_data: bytes) -> Dict[str, Any]: | ||
| 138 | + """ | ||
| 139 | + 提取WAV文件元数据 | ||
| 140 | + | ||
| 141 | + Args: | ||
| 142 | + audio_data: WAV音频数据 | ||
| 143 | + | ||
| 144 | + Returns: | ||
| 145 | + Dict: 音频元数据 | ||
| 146 | + """ | ||
| 147 | + try: | ||
| 148 | + nchannels, sampwidth, framerate, nframes, _ = AudioProcessor.read_wav_info(audio_data) | ||
| 149 | + duration = nframes / framerate | ||
| 150 | + | ||
| 151 | + return { | ||
| 152 | + 'format': 'wav', | ||
| 153 | + 'channels': nchannels, | ||
| 154 | + 'sample_width': sampwidth, | ||
| 155 | + 'sample_rate': framerate, | ||
| 156 | + 'frames': nframes, | ||
| 157 | + 'duration': duration, | ||
| 158 | + 'size': len(audio_data) | ||
| 159 | + } | ||
| 160 | + except Exception as e: | ||
| 161 | + return { | ||
| 162 | + 'format': 'wav', | ||
| 163 | + 'error': str(e), | ||
| 164 | + 'size': len(audio_data) | ||
| 165 | + } | ||
| 166 | + | ||
| 167 | + @staticmethod | ||
| 168 | + def validate_audio_params( | ||
| 169 | + audio_format: str, | ||
| 170 | + sample_rate: int, | ||
| 171 | + channels: int, | ||
| 172 | + bits: int | ||
| 173 | + ) -> bool: | ||
| 174 | + """ | ||
| 175 | + 验证音频参数 | ||
| 176 | + | ||
| 177 | + Args: | ||
| 178 | + audio_format: 音频格式 | ||
| 179 | + sample_rate: 采样率 | ||
| 180 | + channels: 声道数 | ||
| 181 | + bits: 位深度 | ||
| 182 | + | ||
| 183 | + Returns: | ||
| 184 | + bool: 参数是否有效 | ||
| 185 | + """ | ||
| 186 | + # 支持的格式 | ||
| 187 | + supported_formats = ['wav', 'mp3', 'pcm'] | ||
| 188 | + if audio_format not in supported_formats: | ||
| 189 | + return False | ||
| 190 | + | ||
| 191 | + # 采样率范围 | ||
| 192 | + if sample_rate < 8000 or sample_rate > 48000: | ||
| 193 | + return False | ||
| 194 | + | ||
| 195 | + # 声道数 | ||
| 196 | + if channels < 1 or channels > 2: | ||
| 197 | + return False | ||
| 198 | + | ||
| 199 | + # 位深度 | ||
| 200 | + if bits not in [8, 16, 24, 32]: | ||
| 201 | + return False | ||
| 202 | + | ||
| 203 | + return True | ||
| 204 | + | ||
| 205 | + @staticmethod | ||
| 206 | + def prepare_audio_for_recognition( | ||
| 207 | + audio_data: bytes, | ||
| 208 | + target_format: str = 'wav', | ||
| 209 | + segment_duration_ms: int = 200 | ||
| 210 | + ) -> Tuple[str, int, Dict[str, Any]]: | ||
| 211 | + """ | ||
| 212 | + 为识别准备音频数据 | ||
| 213 | + | ||
| 214 | + Args: | ||
| 215 | + audio_data: 原始音频数据 | ||
| 216 | + target_format: 目标格式 | ||
| 217 | + segment_duration_ms: 分片时长 | ||
| 218 | + | ||
| 219 | + Returns: | ||
| 220 | + Tuple: (检测到的格式, 分片大小, 音频元数据) | ||
| 221 | + """ | ||
| 222 | + # 检测音频格式 | ||
| 223 | + detected_format = AudioProcessor.detect_audio_format(audio_data) | ||
| 224 | + | ||
| 225 | + # 提取元数据 | ||
| 226 | + if detected_format == 'wav': | ||
| 227 | + metadata = AudioProcessor.extract_wav_metadata(audio_data) | ||
| 228 | + segment_size = AudioProcessor.calculate_segment_size( | ||
| 229 | + detected_format, | ||
| 230 | + metadata.get('sample_rate', 16000), | ||
| 231 | + metadata.get('channels', 1), | ||
| 232 | + metadata.get('sample_width', 2) * 8, | ||
| 233 | + segment_duration_ms | ||
| 234 | + ) | ||
| 235 | + else: | ||
| 236 | + # 对于非WAV格式,使用默认参数 | ||
| 237 | + metadata = { | ||
| 238 | + 'format': detected_format, | ||
| 239 | + 'size': len(audio_data) | ||
| 240 | + } | ||
| 241 | + segment_size = AudioProcessor.calculate_segment_size( | ||
| 242 | + detected_format, | ||
| 243 | + segment_duration_ms=segment_duration_ms | ||
| 244 | + ) | ||
| 245 | + | ||
| 246 | + return detected_format, segment_size, metadata |
asr/doubao/config.json
0 → 100644
| 1 | +{ | ||
| 2 | + "asr_config": { | ||
| 3 | + "ws_url": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel", | ||
| 4 | + "ws_url_nostream": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream", | ||
| 5 | + "resource_id": "volc.bigasr.sauc.duration", | ||
| 6 | + "resource_id_concurrent": "volc.bigasr.sauc.concurrent", | ||
| 7 | + "model_name": "bigmodel", | ||
| 8 | + "enable_punc": true, | ||
| 9 | + "streaming_mode": true, | ||
| 10 | + "seg_duration": 200, | ||
| 11 | + "mp3_seg_size": 1000 | ||
| 12 | + }, | ||
| 13 | + "auth_config": { | ||
| 14 | + "app_key": "1549099156", | ||
| 15 | + "access_key": "0GcKVco6j09bThrIgQWTWa3g1nA91_9C" | ||
| 16 | + }, | ||
| 17 | + "audio_config": { | ||
| 18 | + "default_format": "wav", | ||
| 19 | + "default_rate": 16000, | ||
| 20 | + "default_bits": 16, | ||
| 21 | + "default_channel": 1, | ||
| 22 | + "default_codec": "raw", | ||
| 23 | + "supported_formats": ["wav", "mp3", "pcm"] | ||
| 24 | + }, | ||
| 25 | + "connection_config": { | ||
| 26 | + "max_size": 1000000000, | ||
| 27 | + "timeout": 30, | ||
| 28 | + "retry_times": 3, | ||
| 29 | + "retry_delay": 1 | ||
| 30 | + }, | ||
| 31 | + "logging_config": { | ||
| 32 | + "enable_debug": false, | ||
| 33 | + "log_requests": true, | ||
| 34 | + "log_responses": true | ||
| 35 | + } | ||
| 36 | +} |
asr/doubao/config_manager.py
0 → 100644
| 1 | +# AIfeng/2025-07-11 13:36:00 | ||
| 2 | +""" | ||
| 3 | +豆包ASR配置管理模块 | ||
| 4 | +提供配置文件加载、验证、合并和环境变量支持 | ||
| 5 | +""" | ||
| 6 | + | ||
| 7 | +import json | ||
| 8 | +import os | ||
| 9 | +from pathlib import Path | ||
| 10 | +from typing import Dict, Any, Optional | ||
| 11 | + | ||
| 12 | + | ||
| 13 | +class ConfigManager: | ||
| 14 | + """配置管理器""" | ||
| 15 | + | ||
| 16 | + DEFAULT_CONFIG = { | ||
| 17 | + "asr_config": { | ||
| 18 | + "ws_url": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel", | ||
| 19 | + "ws_url_nostream": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream", | ||
| 20 | + "resource_id": "volc.bigasr.sauc.duration", | ||
| 21 | + "resource_id_concurrent": "volc.bigasr.sauc.concurrent", | ||
| 22 | + "model_name": "bigmodel", | ||
| 23 | + "enable_punc": True, | ||
| 24 | + "streaming_mode": True, | ||
| 25 | + "seg_duration": 200, | ||
| 26 | + "mp3_seg_size": 1000 | ||
| 27 | + }, | ||
| 28 | + "auth_config": { | ||
| 29 | + "app_key": "", | ||
| 30 | + "access_key": "" | ||
| 31 | + }, | ||
| 32 | + "audio_config": { | ||
| 33 | + "default_format": "wav", | ||
| 34 | + "default_rate": 16000, | ||
| 35 | + "default_bits": 16, | ||
| 36 | + "default_channel": 1, | ||
| 37 | + "default_codec": "raw", | ||
| 38 | + "supported_formats": ["wav", "mp3", "pcm"] | ||
| 39 | + }, | ||
| 40 | + "connection_config": { | ||
| 41 | + "max_size": 1000000000, | ||
| 42 | + "timeout": 30, | ||
| 43 | + "retry_times": 3, | ||
| 44 | + "retry_delay": 1 | ||
| 45 | + }, | ||
| 46 | + "logging_config": { | ||
| 47 | + "enable_debug": False, | ||
| 48 | + "log_requests": True, | ||
| 49 | + "log_responses": True | ||
| 50 | + } | ||
| 51 | + } | ||
| 52 | + | ||
| 53 | + def __init__(self, config_path: Optional[str] = None): | ||
| 54 | + """ | ||
| 55 | + 初始化配置管理器 | ||
| 56 | + | ||
| 57 | + Args: | ||
| 58 | + config_path: 配置文件路径 | ||
| 59 | + """ | ||
| 60 | + self.config_path = config_path | ||
| 61 | + self.config = self.DEFAULT_CONFIG.copy() | ||
| 62 | + | ||
| 63 | + if config_path: | ||
| 64 | + self.load_config(config_path) | ||
| 65 | + | ||
| 66 | + # 从环境变量加载配置 | ||
| 67 | + self._load_from_env() | ||
| 68 | + | ||
| 69 | + def load_config(self, config_path: str) -> Dict[str, Any]: | ||
| 70 | + """ | ||
| 71 | + 加载配置文件 | ||
| 72 | + | ||
| 73 | + Args: | ||
| 74 | + config_path: 配置文件路径 | ||
| 75 | + | ||
| 76 | + Returns: | ||
| 77 | + Dict: 配置字典 | ||
| 78 | + """ | ||
| 79 | + try: | ||
| 80 | + config_file = Path(config_path) | ||
| 81 | + if not config_file.exists(): | ||
| 82 | + raise FileNotFoundError(f"配置文件不存在: {config_path}") | ||
| 83 | + | ||
| 84 | + with open(config_file, 'r', encoding='utf-8') as f: | ||
| 85 | + file_config = json.load(f) | ||
| 86 | + | ||
| 87 | + # 合并配置 | ||
| 88 | + self.config = self._merge_config(self.config, file_config) | ||
| 89 | + | ||
| 90 | + # 验证配置 | ||
| 91 | + self._validate_config() | ||
| 92 | + | ||
| 93 | + return self.config | ||
| 94 | + | ||
| 95 | + except Exception as e: | ||
| 96 | + raise ValueError(f"加载配置文件失败: {e}") | ||
| 97 | + | ||
| 98 | + def _merge_config(self, base_config: Dict[str, Any], new_config: Dict[str, Any]) -> Dict[str, Any]: | ||
| 99 | + """ | ||
| 100 | + 合并配置字典 | ||
| 101 | + | ||
| 102 | + Args: | ||
| 103 | + base_config: 基础配置 | ||
| 104 | + new_config: 新配置 | ||
| 105 | + | ||
| 106 | + Returns: | ||
| 107 | + Dict: 合并后的配置 | ||
| 108 | + """ | ||
| 109 | + merged = base_config.copy() | ||
| 110 | + | ||
| 111 | + for key, value in new_config.items(): | ||
| 112 | + if key in merged and isinstance(merged[key], dict) and isinstance(value, dict): | ||
| 113 | + merged[key] = self._merge_config(merged[key], value) | ||
| 114 | + else: | ||
| 115 | + merged[key] = value | ||
| 116 | + | ||
| 117 | + return merged | ||
| 118 | + | ||
| 119 | + def _load_from_env(self): | ||
| 120 | + """从环境变量加载配置""" | ||
| 121 | + # ASR配置 | ||
| 122 | + if os.getenv('DOUBAO_WS_URL'): | ||
| 123 | + self.config['asr_config']['ws_url'] = os.getenv('DOUBAO_WS_URL') | ||
| 124 | + | ||
| 125 | + if os.getenv('DOUBAO_MODEL_NAME'): | ||
| 126 | + self.config['asr_config']['model_name'] = os.getenv('DOUBAO_MODEL_NAME') | ||
| 127 | + | ||
| 128 | + if os.getenv('DOUBAO_SEG_DURATION'): | ||
| 129 | + try: | ||
| 130 | + self.config['asr_config']['seg_duration'] = int(os.getenv('DOUBAO_SEG_DURATION')) | ||
| 131 | + except ValueError: | ||
| 132 | + pass | ||
| 133 | + | ||
| 134 | + # 认证配置 | ||
| 135 | + if os.getenv('DOUBAO_APP_KEY'): | ||
| 136 | + self.config['auth_config']['app_key'] = os.getenv('DOUBAO_APP_KEY') | ||
| 137 | + | ||
| 138 | + if os.getenv('DOUBAO_ACCESS_KEY'): | ||
| 139 | + self.config['auth_config']['access_key'] = os.getenv('DOUBAO_ACCESS_KEY') | ||
| 140 | + | ||
| 141 | + # 日志配置 | ||
| 142 | + if os.getenv('DOUBAO_DEBUG'): | ||
| 143 | + self.config['logging_config']['enable_debug'] = os.getenv('DOUBAO_DEBUG').lower() == 'true' | ||
| 144 | + | ||
| 145 | + def _validate_config(self): | ||
| 146 | + """验证配置""" | ||
| 147 | + # 验证必需的认证信息 | ||
| 148 | + auth_config = self.config.get('auth_config', {}) | ||
| 149 | + if not auth_config.get('app_key'): | ||
| 150 | + raise ValueError("缺少必需的配置: auth_config.app_key") | ||
| 151 | + | ||
| 152 | + if not auth_config.get('access_key'): | ||
| 153 | + raise ValueError("缺少必需的配置: auth_config.access_key") | ||
| 154 | + | ||
| 155 | + # 验证ASR配置 | ||
| 156 | + asr_config = self.config.get('asr_config', {}) | ||
| 157 | + if not asr_config.get('ws_url'): | ||
| 158 | + raise ValueError("缺少必需的配置: asr_config.ws_url") | ||
| 159 | + | ||
| 160 | + # 验证音频配置 | ||
| 161 | + audio_config = self.config.get('audio_config', {}) | ||
| 162 | + supported_formats = audio_config.get('supported_formats', []) | ||
| 163 | + default_format = audio_config.get('default_format') | ||
| 164 | + | ||
| 165 | + if default_format and default_format not in supported_formats: | ||
| 166 | + raise ValueError(f"默认音频格式 {default_format} 不在支持的格式列表中: {supported_formats}") | ||
| 167 | + | ||
| 168 | + # 验证数值范围 | ||
| 169 | + seg_duration = asr_config.get('seg_duration', 200) | ||
| 170 | + if not (50 <= seg_duration <= 1000): | ||
| 171 | + raise ValueError(f"分片时长必须在50-1000ms之间,当前值: {seg_duration}") | ||
| 172 | + | ||
| 173 | + sample_rate = audio_config.get('default_rate', 16000) | ||
| 174 | + if sample_rate not in [8000, 16000, 22050, 44100, 48000]: | ||
| 175 | + raise ValueError(f"不支持的采样率: {sample_rate}") | ||
| 176 | + | ||
| 177 | + def get_config(self) -> Dict[str, Any]: | ||
| 178 | + """ | ||
| 179 | + 获取完整配置 | ||
| 180 | + | ||
| 181 | + Returns: | ||
| 182 | + Dict: 配置字典 | ||
| 183 | + """ | ||
| 184 | + return self.config.copy() | ||
| 185 | + | ||
| 186 | + def get_asr_config(self) -> Dict[str, Any]: | ||
| 187 | + """ | ||
| 188 | + 获取ASR配置 | ||
| 189 | + | ||
| 190 | + Returns: | ||
| 191 | + Dict: ASR配置 | ||
| 192 | + """ | ||
| 193 | + return self.config.get('asr_config', {}).copy() | ||
| 194 | + | ||
| 195 | + def get_auth_config(self) -> Dict[str, Any]: | ||
| 196 | + """ | ||
| 197 | + 获取认证配置 | ||
| 198 | + | ||
| 199 | + Returns: | ||
| 200 | + Dict: 认证配置 | ||
| 201 | + """ | ||
| 202 | + return self.config.get('auth_config', {}).copy() | ||
| 203 | + | ||
| 204 | + def get_audio_config(self) -> Dict[str, Any]: | ||
| 205 | + """ | ||
| 206 | + 获取音频配置 | ||
| 207 | + | ||
| 208 | + Returns: | ||
| 209 | + Dict: 音频配置 | ||
| 210 | + """ | ||
| 211 | + return self.config.get('audio_config', {}).copy() | ||
| 212 | + | ||
| 213 | + def update_config(self, new_config: Dict[str, Any]): | ||
| 214 | + """ | ||
| 215 | + 更新配置 | ||
| 216 | + | ||
| 217 | + Args: | ||
| 218 | + new_config: 新配置 | ||
| 219 | + """ | ||
| 220 | + self.config = self._merge_config(self.config, new_config) | ||
| 221 | + self._validate_config() | ||
| 222 | + | ||
| 223 | + def save_config(self, output_path: Optional[str] = None): | ||
| 224 | + """ | ||
| 225 | + 保存配置到文件 | ||
| 226 | + | ||
| 227 | + Args: | ||
| 228 | + output_path: 输出文件路径,默认使用原配置文件路径 | ||
| 229 | + """ | ||
| 230 | + save_path = output_path or self.config_path | ||
| 231 | + if not save_path: | ||
| 232 | + raise ValueError("未指定保存路径") | ||
| 233 | + | ||
| 234 | + try: | ||
| 235 | + with open(save_path, 'w', encoding='utf-8') as f: | ||
| 236 | + json.dump(self.config, f, indent=2, ensure_ascii=False) | ||
| 237 | + except Exception as e: | ||
| 238 | + raise ValueError(f"保存配置文件失败: {e}") | ||
| 239 | + | ||
| 240 | + def create_default_config(self, output_path: str) -> Dict[str, Any]: | ||
| 241 | + """ | ||
| 242 | + 创建默认配置文件 | ||
| 243 | + | ||
| 244 | + Args: | ||
| 245 | + output_path: 输出文件路径 | ||
| 246 | + | ||
| 247 | + Returns: | ||
| 248 | + Dict[str, Any]: 默认配置字典 | ||
| 249 | + """ | ||
| 250 | + try: | ||
| 251 | + with open(output_path, 'w', encoding='utf-8') as f: | ||
| 252 | + json.dump(self.DEFAULT_CONFIG, f, indent=2, ensure_ascii=False) | ||
| 253 | + return self.DEFAULT_CONFIG.copy() | ||
| 254 | + except Exception as e: | ||
| 255 | + raise ValueError(f"创建默认配置文件失败: {e}") | ||
| 256 | + | ||
| 257 | + def get_env_template(self) -> str: | ||
| 258 | + """ | ||
| 259 | + 获取环境变量模板 | ||
| 260 | + | ||
| 261 | + Returns: | ||
| 262 | + str: 环境变量模板 | ||
| 263 | + """ | ||
| 264 | + template = """ | ||
| 265 | +# 豆包ASR环境变量配置模板 | ||
| 266 | + | ||
| 267 | +# ASR服务配置 | ||
| 268 | +DOUBAO_WS_URL=wss://openspeech.bytedance.com/api/v3/sauc/bigmodel | ||
| 269 | +DOUBAO_MODEL_NAME=bigmodel | ||
| 270 | +DOUBAO_SEG_DURATION=200 | ||
| 271 | + | ||
| 272 | +# 认证配置(必需) | ||
| 273 | +DOUBAO_APP_KEY=your_app_key_here | ||
| 274 | +DOUBAO_ACCESS_KEY=your_access_key_here | ||
| 275 | + | ||
| 276 | +# 调试配置 | ||
| 277 | +DOUBAO_DEBUG=false | ||
| 278 | +""" | ||
| 279 | + return template.strip() | ||
| 280 | + | ||
| 281 | + @classmethod | ||
| 282 | + def from_dict(cls, config_dict: Dict[str, Any]) -> 'ConfigManager': | ||
| 283 | + """ | ||
| 284 | + 从字典创建配置管理器 | ||
| 285 | + | ||
| 286 | + Args: | ||
| 287 | + config_dict: 配置字典 | ||
| 288 | + | ||
| 289 | + Returns: | ||
| 290 | + ConfigManager: 配置管理器实例 | ||
| 291 | + """ | ||
| 292 | + manager = cls() | ||
| 293 | + manager.config = manager._merge_config(manager.DEFAULT_CONFIG, config_dict) | ||
| 294 | + manager._validate_config() | ||
| 295 | + return manager | ||
| 296 | + | ||
| 297 | + @classmethod | ||
| 298 | + def from_env(cls) -> 'ConfigManager': | ||
| 299 | + """ | ||
| 300 | + 仅从环境变量创建配置管理器 | ||
| 301 | + | ||
| 302 | + Returns: | ||
| 303 | + ConfigManager: 配置管理器实例 | ||
| 304 | + """ | ||
| 305 | + manager = cls() | ||
| 306 | + return manager |
asr/doubao/example.py
0 → 100644
| 1 | +# AIfeng/2025-07-11 13:36:00 | ||
| 2 | +""" | ||
| 3 | +豆包ASR语音识别服务使用示例 | ||
| 4 | +演示各种使用场景和最佳实践 | ||
| 5 | +""" | ||
| 6 | + | ||
| 7 | +import asyncio | ||
| 8 | +import os | ||
| 9 | +import logging | ||
| 10 | +from pathlib import Path | ||
| 11 | +from typing import Dict, Any, Optional | ||
| 12 | + | ||
| 13 | +# 导入ASR服务 - 支持相对导入和绝对导入 | ||
| 14 | +try: | ||
| 15 | + # 尝试相对导入(作为包运行时) | ||
| 16 | + from . import ( | ||
| 17 | + recognize_file, | ||
| 18 | + recognize_audio_data, | ||
| 19 | + create_asr_service, | ||
| 20 | + run_recognition, | ||
| 21 | + ConfigManager | ||
| 22 | + ) | ||
| 23 | +except ImportError: | ||
| 24 | + # 回退到绝对导入(独立运行时) | ||
| 25 | + try: | ||
| 26 | + from asr_client import ( | ||
| 27 | + recognize_file, | ||
| 28 | + recognize_audio_data, | ||
| 29 | + create_asr_service, | ||
| 30 | + run_recognition | ||
| 31 | + ) | ||
| 32 | + from config_manager import ConfigManager | ||
| 33 | + except ImportError: | ||
| 34 | + # 最后尝试直接导入 | ||
| 35 | + import sys | ||
| 36 | + from pathlib import Path | ||
| 37 | + | ||
| 38 | + # 添加当前目录到路径 | ||
| 39 | + current_dir = Path(__file__).parent | ||
| 40 | + sys.path.insert(0, str(current_dir)) | ||
| 41 | + | ||
| 42 | + from asr_client import ( | ||
| 43 | + recognize_file, | ||
| 44 | + recognize_audio_data, | ||
| 45 | + create_asr_service, | ||
| 46 | + run_recognition | ||
| 47 | + ) | ||
| 48 | + from config_manager import ConfigManager | ||
| 49 | + | ||
| 50 | + | ||
| 51 | +# 配置日志 | ||
| 52 | +logging.basicConfig( | ||
| 53 | + level=logging.INFO, | ||
| 54 | + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||
| 55 | +) | ||
| 56 | +logger = logging.getLogger(__name__) | ||
| 57 | + | ||
| 58 | + | ||
| 59 | +class ASRExamples: | ||
| 60 | + """ | ||
| 61 | + ASR使用示例集合 | ||
| 62 | + """ | ||
| 63 | + | ||
| 64 | + def __init__(self, app_key: str, access_key: str): | ||
| 65 | + self.app_key = app_key | ||
| 66 | + self.access_key = access_key | ||
| 67 | + | ||
| 68 | + """ | ||
| 69 | + 示例1: 简单文件识别 | ||
| 70 | + """ | ||
| 71 | + async def example_1_simple_file_recognition(self, audio_path: str): | ||
| 72 | + | ||
| 73 | + logger.info("=== 示例1: 简单文件识别 ===") | ||
| 74 | + | ||
| 75 | + try: | ||
| 76 | + result = await recognize_file( | ||
| 77 | + audio_path=audio_path, | ||
| 78 | + app_key=self.app_key, | ||
| 79 | + access_key=self.access_key, | ||
| 80 | + streaming=True | ||
| 81 | + ) | ||
| 82 | + | ||
| 83 | + logger.info(f"识别结果: {result}") | ||
| 84 | + return result | ||
| 85 | + | ||
| 86 | + except Exception as e: | ||
| 87 | + logger.error(f"识别失败: {e}") | ||
| 88 | + return None | ||
| 89 | + | ||
| 90 | + """ | ||
| 91 | + 示例2: 流式识别with实时回调 - 简化版流式输出演示 | ||
| 92 | + """ | ||
| 93 | + async def example_2_streaming_with_callback(self, audio_path: str): | ||
| 94 | + | ||
| 95 | + logger.info("=== 示例2: 流式识别with实时回调 - 简化版流式输出演示 ===") | ||
| 96 | + | ||
| 97 | + # 流式输出状态 | ||
| 98 | + self.current_text = "" | ||
| 99 | + self.result_count = 0 | ||
| 100 | + | ||
| 101 | + def clear_line(): | ||
| 102 | + """清除当前行""" | ||
| 103 | + print('\r' + ' ' * 100 + '\r', end='', flush=True) | ||
| 104 | + | ||
| 105 | + def print_streaming_result(text: str, is_final: bool = False): | ||
| 106 | + """打印流式结果""" | ||
| 107 | + clear_line() | ||
| 108 | + status = "[最终]" if is_final else "[流式]" | ||
| 109 | + timestamp = f"[{self.result_count:02d}]" | ||
| 110 | + print(f"{timestamp}{status} {text}", end='', flush=True) | ||
| 111 | + if is_final: | ||
| 112 | + print() # 最终结果换行 | ||
| 113 | + | ||
| 114 | + # 定义结果回调函数 - 演示实时文本更新 | ||
| 115 | + def on_result(result: Dict[str, Any]): | ||
| 116 | + self.result_count += 1 | ||
| 117 | + | ||
| 118 | + if result.get('payload_msg'): | ||
| 119 | + payload = result['payload_msg'] | ||
| 120 | + | ||
| 121 | + # 检查是否有识别结果 | ||
| 122 | + if 'result' in payload and 'text' in payload['result']: | ||
| 123 | + new_text = payload['result']['text'] | ||
| 124 | + | ||
| 125 | + # 显示累积文本更新 | ||
| 126 | + print_streaming_result(new_text, False) | ||
| 127 | + self.current_text = new_text | ||
| 128 | + | ||
| 129 | + # 检查是否为最终结果 | ||
| 130 | + if result.get('is_last_package', False): | ||
| 131 | + print_streaming_result(self.current_text, True) | ||
| 132 | + logger.info(f"识别完成,共收到{self.result_count}次流式结果") | ||
| 133 | + | ||
| 134 | + print("\n观察流式文本累积更新效果:") | ||
| 135 | + print() | ||
| 136 | + | ||
| 137 | + # 创建服务实例 | ||
| 138 | + service = create_asr_service( | ||
| 139 | + app_key=self.app_key, | ||
| 140 | + access_key=self.access_key, | ||
| 141 | + streaming=True, | ||
| 142 | + debug=False # 关闭调试日志以便观察输出 | ||
| 143 | + ) | ||
| 144 | + | ||
| 145 | + try: | ||
| 146 | + result = await service.recognize_file( | ||
| 147 | + audio_path, | ||
| 148 | + result_callback=on_result | ||
| 149 | + ) | ||
| 150 | + | ||
| 151 | + print() | ||
| 152 | + logger.info(f"=== 识别结果摘要 ===") | ||
| 153 | + logger.info(f"最终文本: {self.current_text}") | ||
| 154 | + logger.info(f"流式更新次数: {self.result_count}") | ||
| 155 | + | ||
| 156 | + return result | ||
| 157 | + | ||
| 158 | + except Exception as e: | ||
| 159 | + logger.error(f"识别失败: {e}") | ||
| 160 | + return None | ||
| 161 | + finally: | ||
| 162 | + await service.close() | ||
| 163 | + | ||
| 164 | + | ||
| 165 | + """ | ||
| 166 | + 示例3: 非流式识别 | ||
| 167 | + """ | ||
| 168 | + async def example_3_non_streaming_recognition(self, audio_path: str): | ||
| 169 | + | ||
| 170 | + logger.info("=== 示例3: 非流式识别 ===") | ||
| 171 | + | ||
| 172 | + try: | ||
| 173 | + result = await recognize_file( | ||
| 174 | + audio_path=audio_path, | ||
| 175 | + app_key=self.app_key, | ||
| 176 | + access_key=self.access_key, | ||
| 177 | + streaming=False # 非流式 | ||
| 178 | + ) | ||
| 179 | + | ||
| 180 | + logger.info(f"识别结果: {result}") | ||
| 181 | + return result | ||
| 182 | + | ||
| 183 | + except Exception as e: | ||
| 184 | + logger.error(f"识别失败: {e}") | ||
| 185 | + return None | ||
| 186 | + | ||
| 187 | + """ | ||
| 188 | + 示例4: 音频数据识别 | ||
| 189 | + """ | ||
| 190 | + async def example_4_audio_data_recognition(self, audio_data: bytes, audio_format: str = "wav"): | ||
| 191 | + | ||
| 192 | + logger.info("=== 示例4: 音频数据识别 ===") | ||
| 193 | + | ||
| 194 | + try: | ||
| 195 | + result = await recognize_audio_data( | ||
| 196 | + audio_data=audio_data, | ||
| 197 | + audio_format=audio_format, | ||
| 198 | + app_key=self.app_key, | ||
| 199 | + access_key=self.access_key, | ||
| 200 | + streaming=True | ||
| 201 | + ) | ||
| 202 | + | ||
| 203 | + logger.info(f"识别结果: {result}") | ||
| 204 | + return result | ||
| 205 | + | ||
| 206 | + except Exception as e: | ||
| 207 | + logger.error(f"识别失败: {e}") | ||
| 208 | + return None | ||
| 209 | + | ||
| 210 | + """ | ||
| 211 | + 示例5: 基于配置文件的识别 | ||
| 212 | + """ | ||
| 213 | + async def example_5_config_based_recognition(self, audio_path: str, config_path: str): | ||
| 214 | + """ | ||
| 215 | + 示例5: 基于配置文件的识别 | ||
| 216 | + """ | ||
| 217 | + logger.info("=== 示例5: 基于配置文件的识别 ===") | ||
| 218 | + | ||
| 219 | + try: | ||
| 220 | + result = await recognize_file( | ||
| 221 | + audio_path=audio_path, | ||
| 222 | + config_path=config_path | ||
| 223 | + ) | ||
| 224 | + | ||
| 225 | + logger.info(f"识别结果: {result}") | ||
| 226 | + return result | ||
| 227 | + | ||
| 228 | + except Exception as e: | ||
| 229 | + logger.error(f"识别失败: {e}") | ||
| 230 | + return None | ||
| 231 | + | ||
| 232 | + """ | ||
| 233 | + 示例6: 批量识别 | ||
| 234 | + """ | ||
| 235 | + async def example_6_batch_recognition(self, audio_files: list): | ||
| 236 | + | ||
| 237 | + logger.info("=== 示例6: 批量识别 ===") | ||
| 238 | + | ||
| 239 | + results = [] | ||
| 240 | + | ||
| 241 | + # 创建服务实例(复用连接) | ||
| 242 | + service = create_asr_service( | ||
| 243 | + app_key=self.app_key, | ||
| 244 | + access_key=self.access_key, | ||
| 245 | + streaming=True | ||
| 246 | + ) | ||
| 247 | + | ||
| 248 | + try: | ||
| 249 | + for i, audio_file in enumerate(audio_files): | ||
| 250 | + logger.info(f"处理文件 {i+1}/{len(audio_files)}: {audio_file}") | ||
| 251 | + | ||
| 252 | + try: | ||
| 253 | + result = await service.recognize_file(audio_file) | ||
| 254 | + results.append({ | ||
| 255 | + 'file': audio_file, | ||
| 256 | + 'result': result, | ||
| 257 | + 'status': 'success' | ||
| 258 | + }) | ||
| 259 | + | ||
| 260 | + except Exception as e: | ||
| 261 | + logger.error(f"文件 {audio_file} 识别失败: {e}") | ||
| 262 | + results.append({ | ||
| 263 | + 'file': audio_file, | ||
| 264 | + 'result': None, | ||
| 265 | + 'status': 'failed', | ||
| 266 | + 'error': str(e) | ||
| 267 | + }) | ||
| 268 | + | ||
| 269 | + logger.info(f"批量识别完成,成功: {sum(1 for r in results if r['status'] == 'success')}/{len(results)}") | ||
| 270 | + return results | ||
| 271 | + | ||
| 272 | + finally: | ||
| 273 | + await service.close() | ||
| 274 | + | ||
| 275 | + """ | ||
| 276 | + 示例7: 同步识别(简单场景) | ||
| 277 | + """ | ||
| 278 | + def example_7_sync_recognition(self, audio_path: str): | ||
| 279 | + | ||
| 280 | + logger.info("=== 示例7: 同步识别 ===") | ||
| 281 | + | ||
| 282 | + try: | ||
| 283 | + result = run_recognition( | ||
| 284 | + audio_path=audio_path, | ||
| 285 | + app_key=self.app_key, | ||
| 286 | + access_key=self.access_key, | ||
| 287 | + streaming=True | ||
| 288 | + ) | ||
| 289 | + | ||
| 290 | + logger.info(f"识别结果: {result}") | ||
| 291 | + return result | ||
| 292 | + | ||
| 293 | + except Exception as e: | ||
| 294 | + logger.error(f"识别失败: {e}") | ||
| 295 | + return None | ||
| 296 | + | ||
| 297 | + """ | ||
| 298 | + 示例8: 自定义配置识别 | ||
| 299 | + """ | ||
| 300 | + async def example_8_custom_config_recognition(self, audio_path: str): | ||
| 301 | + | ||
| 302 | + logger.info("=== 示例8: 自定义配置识别 ===") | ||
| 303 | + | ||
| 304 | + # 自定义配置 | ||
| 305 | + custom_config = { | ||
| 306 | + 'asr_config': { | ||
| 307 | + 'enable_punc': True, | ||
| 308 | + 'seg_duration': 300, # 自定义分段时长 | ||
| 309 | + 'streaming_mode': True | ||
| 310 | + }, | ||
| 311 | + 'audio_config': { | ||
| 312 | + 'default_rate': 16000, | ||
| 313 | + 'default_bits': 16, | ||
| 314 | + 'default_channel': 1 | ||
| 315 | + }, | ||
| 316 | + 'connection_config': { | ||
| 317 | + 'timeout': 60, # 自定义超时时间 | ||
| 318 | + 'retry_times': 5 | ||
| 319 | + }, | ||
| 320 | + 'logging_config': { | ||
| 321 | + 'enable_debug': True | ||
| 322 | + } | ||
| 323 | + } | ||
| 324 | + | ||
| 325 | + service = create_asr_service( | ||
| 326 | + app_key=self.app_key, | ||
| 327 | + access_key=self.access_key, | ||
| 328 | + custom_config=custom_config | ||
| 329 | + ) | ||
| 330 | + | ||
| 331 | + try: | ||
| 332 | + result = await service.recognize_file(audio_path) | ||
| 333 | + logger.info(f"识别结果: {result}") | ||
| 334 | + return result | ||
| 335 | + | ||
| 336 | + except Exception as e: | ||
| 337 | + logger.error(f"识别失败: {e}") | ||
| 338 | + return None | ||
| 339 | + finally: | ||
| 340 | + await service.close() | ||
| 341 | + | ||
| 342 | + | ||
| 343 | +def create_sample_config(config_path: str, app_key: str, access_key: str): | ||
| 344 | + """ | ||
| 345 | + 创建示例配置文件 | ||
| 346 | + """ | ||
| 347 | + config_manager = ConfigManager() | ||
| 348 | + | ||
| 349 | + # 创建配置 | ||
| 350 | + config = config_manager.create_default_config(config_path) | ||
| 351 | + config['auth_config']['app_key'] = app_key | ||
| 352 | + config['auth_config']['access_key'] = access_key | ||
| 353 | + config['logging_config']['enable_debug'] = True | ||
| 354 | + | ||
| 355 | + # 更新配置管理器的配置并保存 | ||
| 356 | + config_manager.update_config(config) | ||
| 357 | + config_manager.save_config(config_path) | ||
| 358 | + logger.info(f"示例配置文件已创建: {config_path}") | ||
| 359 | + | ||
| 360 | + | ||
| 361 | +async def run_all_examples(): | ||
| 362 | + """ | ||
| 363 | + 运行所有示例 | ||
| 364 | + """ | ||
| 365 | + # 从环境变量获取密钥 | ||
| 366 | + app_key = os.getenv('DOUBAO_APP_KEY', '1549099156') | ||
| 367 | + access_key = os.getenv('DOUBAO_ACCESS_KEY', '0GcKVco6j09bThrIgQWTWa3g1nA91_9C') | ||
| 368 | + | ||
| 369 | + if app_key == 'your_app_key_here' or access_key == 'your_access_key_here': | ||
| 370 | + logger.warning("请设置环境变量 DOUBAO_APP_KEY 和 DOUBAO_ACCESS_KEY") | ||
| 371 | + logger.info("或者直接修改代码中的密钥") | ||
| 372 | + return | ||
| 373 | + | ||
| 374 | + # 示例音频文件路径(请替换为实际路径) | ||
| 375 | + audio_path = "E:\\fengyang\\eman_one\\speech.wav" | ||
| 376 | + | ||
| 377 | + if not Path(audio_path).exists(): | ||
| 378 | + logger.warning(f"音频文件不存在: {audio_path}") | ||
| 379 | + logger.info("请替换为实际的音频文件路径") | ||
| 380 | + return | ||
| 381 | + | ||
| 382 | + # 创建示例实例 | ||
| 383 | + examples = ASRExamples(app_key, access_key) | ||
| 384 | + | ||
| 385 | + # 创建示例配置文件 | ||
| 386 | + config_path = "example_config.json" | ||
| 387 | + create_sample_config(config_path, app_key, access_key) | ||
| 388 | + | ||
| 389 | + try: | ||
| 390 | + # 运行示例 | ||
| 391 | + # await examples.example_1_simple_file_recognition(audio_path) | ||
| 392 | + await examples.example_2_streaming_with_callback(audio_path) | ||
| 393 | + # await examples.example_3_non_streaming_recognition(audio_path) | ||
| 394 | + | ||
| 395 | + # 音频数据示例(需要实际音频数据) | ||
| 396 | + # with open(audio_path, 'rb') as f: | ||
| 397 | + # audio_data = f.read() | ||
| 398 | + # await examples.example_4_audio_data_recognition(audio_data) | ||
| 399 | + | ||
| 400 | + # await examples.example_5_config_based_recognition(audio_path, config_path) | ||
| 401 | + | ||
| 402 | + # 批量识别示例 | ||
| 403 | + # audio_files = [audio_path] # 添加更多文件 | ||
| 404 | + # await examples.example_6_batch_recognition(audio_files) | ||
| 405 | + | ||
| 406 | + # 同步识别示例 | ||
| 407 | + # examples.example_7_sync_recognition(audio_path) | ||
| 408 | + | ||
| 409 | + # await examples.example_8_custom_config_recognition(audio_path) | ||
| 410 | + | ||
| 411 | + except Exception as e: | ||
| 412 | + logger.error(f"示例运行失败: {e}") | ||
| 413 | + | ||
| 414 | + finally: | ||
| 415 | + # 清理示例配置文件 | ||
| 416 | + if Path(config_path).exists(): | ||
| 417 | + os.remove(config_path) | ||
| 418 | + logger.info(f"已清理示例配置文件: {config_path}") | ||
| 419 | + | ||
| 420 | + | ||
| 421 | +if __name__ == "__main__": | ||
| 422 | + # 运行所有示例 | ||
| 423 | + asyncio.run(run_all_examples()) | ||
| 424 | + | ||
| 425 | + # 或者运行单个示例 | ||
| 426 | + # app_key = "your_app_key" | ||
| 427 | + # access_key = "your_access_key" | ||
| 428 | + # audio_path = "path/to/audio.wav" | ||
| 429 | + # | ||
| 430 | + # examples = ASRExamples(app_key, access_key) | ||
| 431 | + # asyncio.run(examples.example_1_simple_file_recognition(audio_path)) |
asr/doubao/example_optimized.py
0 → 100644
| 1 | +# AIfeng/2025-07-17 13:58:00 | ||
| 2 | +""" | ||
| 3 | +豆包ASR优化使用示例 | ||
| 4 | +演示如何使用结果处理器来优化ASR输出,只获取文本内容 | ||
| 5 | +""" | ||
| 6 | + | ||
| 7 | +import asyncio | ||
| 8 | +import logging | ||
| 9 | +from pathlib import Path | ||
| 10 | +from typing import Dict, Any | ||
| 11 | + | ||
| 12 | +from .service_factory import create_asr_service | ||
| 13 | +from .result_processor import DoubaoResultProcessor, create_text_only_callback, extract_text_only | ||
| 14 | + | ||
| 15 | +# 设置日志 | ||
| 16 | +logging.basicConfig(level=logging.INFO) | ||
| 17 | +logger = logging.getLogger(__name__) | ||
| 18 | + | ||
| 19 | + | ||
| 20 | +async def example_optimized_streaming(): | ||
| 21 | + """优化的流式识别示例 - 只输出文本内容""" | ||
| 22 | + logger.info("=== 优化流式识别示例 ===") | ||
| 23 | + | ||
| 24 | + # 创建结果处理器 | ||
| 25 | + processor = DoubaoResultProcessor( | ||
| 26 | + text_only=True, | ||
| 27 | + enable_streaming_log=False # 关闭流式日志,减少输出 | ||
| 28 | + ) | ||
| 29 | + | ||
| 30 | + # 用户自定义文本处理函数 | ||
| 31 | + def on_text_result(text: str): | ||
| 32 | + """只处理文本内容的回调函数""" | ||
| 33 | + print(f"识别文本: {text}") | ||
| 34 | + # 这里可以添加你的业务逻辑 | ||
| 35 | + # 比如:发送到WebSocket、保存到数据库等 | ||
| 36 | + | ||
| 37 | + # 创建优化的回调函数 | ||
| 38 | + optimized_callback = processor.create_optimized_callback(on_text_result) | ||
| 39 | + | ||
| 40 | + # 创建ASR服务 | ||
| 41 | + service = create_asr_service( | ||
| 42 | + app_key="your_app_key", | ||
| 43 | + access_key="your_access_key", | ||
| 44 | + streaming=True, | ||
| 45 | + debug=False # 关闭调试模式,减少日志输出 | ||
| 46 | + ) | ||
| 47 | + | ||
| 48 | + try: | ||
| 49 | + # 识别音频文件 | ||
| 50 | + audio_path = "path/to/your/audio.wav" | ||
| 51 | + result = await service.recognize_file( | ||
| 52 | + audio_path, | ||
| 53 | + result_callback=optimized_callback | ||
| 54 | + ) | ||
| 55 | + | ||
| 56 | + # 获取最终状态 | ||
| 57 | + final_status = processor.get_current_result() | ||
| 58 | + logger.info(f"识别完成: {final_status}") | ||
| 59 | + | ||
| 60 | + finally: | ||
| 61 | + await service.close() | ||
| 62 | + | ||
| 63 | + | ||
| 64 | +async def example_simple_text_extraction(): | ||
| 65 | + """简单文本提取示例""" | ||
| 66 | + logger.info("=== 简单文本提取示例 ===") | ||
| 67 | + | ||
| 68 | + # 使用便捷函数创建只处理文本的回调 | ||
| 69 | + def print_text(text: str): | ||
| 70 | + print(f">> {text}") | ||
| 71 | + | ||
| 72 | + text_callback = create_text_only_callback( | ||
| 73 | + user_callback=print_text, | ||
| 74 | + enable_streaming_log=False | ||
| 75 | + ) | ||
| 76 | + | ||
| 77 | + service = create_asr_service( | ||
| 78 | + app_key="your_app_key", | ||
| 79 | + access_key="your_access_key" | ||
| 80 | + ) | ||
| 81 | + | ||
| 82 | + try: | ||
| 83 | + audio_path = "path/to/your/audio.wav" | ||
| 84 | + await service.recognize_file( | ||
| 85 | + audio_path, | ||
| 86 | + result_callback=text_callback | ||
| 87 | + ) | ||
| 88 | + finally: | ||
| 89 | + await service.close() | ||
| 90 | + | ||
| 91 | + | ||
| 92 | +async def example_manual_text_extraction(): | ||
| 93 | + """手动文本提取示例""" | ||
| 94 | + logger.info("=== 手动文本提取示例 ===") | ||
| 95 | + | ||
| 96 | + def manual_callback(result: Dict[str, Any]): | ||
| 97 | + """手动提取文本的回调函数""" | ||
| 98 | + # 使用便捷函数提取文本 | ||
| 99 | + text = extract_text_only(result) | ||
| 100 | + if text: | ||
| 101 | + print(f"提取的文本: {text}") | ||
| 102 | + | ||
| 103 | + # 检查是否为最终结果 | ||
| 104 | + is_final = result.get('is_last_package', False) | ||
| 105 | + if is_final: | ||
| 106 | + print(f"[最终结果] {text}") | ||
| 107 | + | ||
| 108 | + service = create_asr_service( | ||
| 109 | + app_key="your_app_key", | ||
| 110 | + access_key="your_access_key" | ||
| 111 | + ) | ||
| 112 | + | ||
| 113 | + try: | ||
| 114 | + audio_path = "path/to/your/audio.wav" | ||
| 115 | + await service.recognize_file( | ||
| 116 | + audio_path, | ||
| 117 | + result_callback=manual_callback | ||
| 118 | + ) | ||
| 119 | + finally: | ||
| 120 | + await service.close() | ||
| 121 | + | ||
| 122 | + | ||
| 123 | +async def example_streaming_with_websocket(): | ||
| 124 | + """流式识别结合WebSocket示例""" | ||
| 125 | + logger.info("=== 流式识别WebSocket示例 ===") | ||
| 126 | + | ||
| 127 | + # 模拟WebSocket连接 | ||
| 128 | + class MockWebSocket: | ||
| 129 | + def __init__(self): | ||
| 130 | + self.messages = [] | ||
| 131 | + | ||
| 132 | + async def send_text(self, text: str): | ||
| 133 | + self.messages.append(text) | ||
| 134 | + print(f"WebSocket发送: {text}") | ||
| 135 | + | ||
| 136 | + websocket = MockWebSocket() | ||
| 137 | + | ||
| 138 | + # 创建处理器 | ||
| 139 | + processor = DoubaoResultProcessor( | ||
| 140 | + text_only=True, | ||
| 141 | + enable_streaming_log=False | ||
| 142 | + ) | ||
| 143 | + | ||
| 144 | + async def websocket_handler(text: str): | ||
| 145 | + """WebSocket文本处理器""" | ||
| 146 | + await websocket.send_text(text) | ||
| 147 | + | ||
| 148 | + # 创建回调 | ||
| 149 | + callback = processor.create_optimized_callback( | ||
| 150 | + lambda text: asyncio.create_task(websocket_handler(text)) | ||
| 151 | + ) | ||
| 152 | + | ||
| 153 | + service = create_asr_service( | ||
| 154 | + app_key="your_app_key", | ||
| 155 | + access_key="your_access_key" | ||
| 156 | + ) | ||
| 157 | + | ||
| 158 | + try: | ||
| 159 | + audio_path = "path/to/your/audio.wav" | ||
| 160 | + await service.recognize_file( | ||
| 161 | + audio_path, | ||
| 162 | + result_callback=callback | ||
| 163 | + ) | ||
| 164 | + | ||
| 165 | + print(f"WebSocket消息历史: {websocket.messages}") | ||
| 166 | + | ||
| 167 | + finally: | ||
| 168 | + await service.close() | ||
| 169 | + | ||
| 170 | + | ||
| 171 | +async def example_comparison(): | ||
| 172 | + """对比示例:原始输出 vs 优化输出""" | ||
| 173 | + logger.info("=== 对比示例 ===") | ||
| 174 | + | ||
| 175 | + print("\n1. 原始完整输出:") | ||
| 176 | + def original_callback(result: Dict[str, Any]): | ||
| 177 | + print(f"完整结果: {result}") | ||
| 178 | + | ||
| 179 | + print("\n2. 优化文本输出:") | ||
| 180 | + def optimized_callback(text: str): | ||
| 181 | + print(f"文本: {text}") | ||
| 182 | + | ||
| 183 | + # 这里只是演示,实际使用时选择其中一种即可 | ||
| 184 | + text_callback = create_text_only_callback(optimized_callback) | ||
| 185 | + | ||
| 186 | + service = create_asr_service( | ||
| 187 | + app_key="your_app_key", | ||
| 188 | + access_key="your_access_key" | ||
| 189 | + ) | ||
| 190 | + | ||
| 191 | + try: | ||
| 192 | + audio_path = "path/to/your/audio.wav" | ||
| 193 | + | ||
| 194 | + # 使用优化回调 | ||
| 195 | + await service.recognize_file( | ||
| 196 | + audio_path, | ||
| 197 | + result_callback=text_callback | ||
| 198 | + ) | ||
| 199 | + | ||
| 200 | + finally: | ||
| 201 | + await service.close() | ||
| 202 | + | ||
| 203 | + | ||
| 204 | +def run_example(example_name: str = "optimized"): | ||
| 205 | + """运行指定示例""" | ||
| 206 | + examples = { | ||
| 207 | + "optimized": example_optimized_streaming, | ||
| 208 | + "simple": example_simple_text_extraction, | ||
| 209 | + "manual": example_manual_text_extraction, | ||
| 210 | + "websocket": example_streaming_with_websocket, | ||
| 211 | + "comparison": example_comparison | ||
| 212 | + } | ||
| 213 | + | ||
| 214 | + if example_name not in examples: | ||
| 215 | + print(f"可用示例: {list(examples.keys())}") | ||
| 216 | + return | ||
| 217 | + | ||
| 218 | + asyncio.run(examples[example_name]()) | ||
| 219 | + | ||
| 220 | + | ||
| 221 | +if __name__ == "__main__": | ||
| 222 | + # 运行优化示例 | ||
| 223 | + run_example("optimized") | ||
| 224 | + | ||
| 225 | + # 或者运行其他示例 | ||
| 226 | + # run_example("simple") | ||
| 227 | + # run_example("manual") | ||
| 228 | + # run_example("websocket") | ||
| 229 | + # run_example("comparison") |
asr/doubao/protocol.py
0 → 100644
| 1 | +# AIfeng/2025-07-11 13:36:00 | ||
| 2 | +""" | ||
| 3 | +豆包语音识别WebSocket协议处理模块 | ||
| 4 | +实现二进制协议的编解码、消息类型定义和数据包处理 | ||
| 5 | +""" | ||
| 6 | + | ||
| 7 | +import gzip | ||
| 8 | +import json | ||
| 9 | +from typing import Dict, Any, Tuple, Optional | ||
| 10 | + | ||
| 11 | + | ||
| 12 | +# 协议版本和头部大小 | ||
| 13 | +PROTOCOL_VERSION = 0b0001 | ||
| 14 | +DEFAULT_HEADER_SIZE = 0b0001 | ||
| 15 | + | ||
| 16 | +# 消息类型定义 | ||
| 17 | +class MessageType: | ||
| 18 | + FULL_CLIENT_REQUEST = 0b0001 | ||
| 19 | + AUDIO_ONLY_REQUEST = 0b0010 | ||
| 20 | + FULL_SERVER_RESPONSE = 0b1001 | ||
| 21 | + SERVER_ACK = 0b1011 | ||
| 22 | + SERVER_ERROR_RESPONSE = 0b1111 | ||
| 23 | + | ||
| 24 | +# 消息类型特定标志 | ||
| 25 | +class MessageFlags: | ||
| 26 | + NO_SEQUENCE = 0b0000 | ||
| 27 | + POS_SEQUENCE = 0b0001 | ||
| 28 | + NEG_SEQUENCE = 0b0010 | ||
| 29 | + NEG_WITH_SEQUENCE = 0b0011 | ||
| 30 | + | ||
| 31 | +# 序列化方法 | ||
| 32 | +class SerializationMethod: | ||
| 33 | + NO_SERIALIZATION = 0b0000 | ||
| 34 | + JSON = 0b0001 | ||
| 35 | + | ||
| 36 | +# 压缩方法 | ||
| 37 | +class CompressionType: | ||
| 38 | + NO_COMPRESSION = 0b0000 | ||
| 39 | + GZIP = 0b0001 | ||
| 40 | + | ||
| 41 | + | ||
| 42 | +class DoubaoProtocol: | ||
| 43 | + """豆包ASR WebSocket协议处理器""" | ||
| 44 | + | ||
| 45 | + @staticmethod | ||
| 46 | + def generate_header( | ||
| 47 | + message_type: int = MessageType.FULL_CLIENT_REQUEST, | ||
| 48 | + message_type_specific_flags: int = MessageFlags.NO_SEQUENCE, | ||
| 49 | + serial_method: int = SerializationMethod.JSON, | ||
| 50 | + compression_type: int = CompressionType.GZIP, | ||
| 51 | + reserved_data: int = 0x00 | ||
| 52 | + ) -> bytearray: | ||
| 53 | + """ | ||
| 54 | + 生成协议头部 | ||
| 55 | + | ||
| 56 | + Args: | ||
| 57 | + message_type: 消息类型 | ||
| 58 | + message_type_specific_flags: 消息类型特定标志 | ||
| 59 | + serial_method: 序列化方法 | ||
| 60 | + compression_type: 压缩类型 | ||
| 61 | + reserved_data: 保留字段 | ||
| 62 | + | ||
| 63 | + Returns: | ||
| 64 | + bytearray: 4字节协议头部 | ||
| 65 | + """ | ||
| 66 | + header = bytearray() | ||
| 67 | + header_size = 1 | ||
| 68 | + header.append((PROTOCOL_VERSION << 4) | header_size) | ||
| 69 | + header.append((message_type << 4) | message_type_specific_flags) | ||
| 70 | + header.append((serial_method << 4) | compression_type) | ||
| 71 | + header.append(reserved_data) | ||
| 72 | + return header | ||
| 73 | + | ||
| 74 | + @staticmethod | ||
| 75 | + def generate_sequence_payload(sequence: int) -> bytearray: | ||
| 76 | + """ | ||
| 77 | + 生成序列号载荷 | ||
| 78 | + | ||
| 79 | + Args: | ||
| 80 | + sequence: 序列号 | ||
| 81 | + | ||
| 82 | + Returns: | ||
| 83 | + bytearray: 4字节序列号数据 | ||
| 84 | + """ | ||
| 85 | + payload = bytearray() | ||
| 86 | + payload.extend(sequence.to_bytes(4, 'big', signed=True)) | ||
| 87 | + return payload | ||
| 88 | + | ||
| 89 | + @staticmethod | ||
| 90 | + def parse_response(response_data: bytes) -> Dict[str, Any]: | ||
| 91 | + """ | ||
| 92 | + 解析服务器响应数据 | ||
| 93 | + | ||
| 94 | + Args: | ||
| 95 | + response_data: 服务器响应的二进制数据 | ||
| 96 | + | ||
| 97 | + Returns: | ||
| 98 | + Dict: 解析后的响应数据 | ||
| 99 | + """ | ||
| 100 | + if len(response_data) < 4: | ||
| 101 | + raise ValueError("响应数据长度不足") | ||
| 102 | + | ||
| 103 | + # 解析头部 | ||
| 104 | + protocol_version = response_data[0] >> 4 | ||
| 105 | + header_size = response_data[0] & 0x0f | ||
| 106 | + message_type = response_data[1] >> 4 | ||
| 107 | + message_type_specific_flags = response_data[1] & 0x0f | ||
| 108 | + serialization_method = response_data[2] >> 4 | ||
| 109 | + message_compression = response_data[2] & 0x0f | ||
| 110 | + reserved = response_data[3] | ||
| 111 | + | ||
| 112 | + # 解析扩展头部和载荷 | ||
| 113 | + header_extensions = response_data[4:header_size * 4] | ||
| 114 | + payload = response_data[header_size * 4:] | ||
| 115 | + | ||
| 116 | + result = { | ||
| 117 | + 'protocol_version': protocol_version, | ||
| 118 | + 'header_size': header_size, | ||
| 119 | + 'message_type': message_type, | ||
| 120 | + 'message_type_specific_flags': message_type_specific_flags, | ||
| 121 | + 'serialization_method': serialization_method, | ||
| 122 | + 'message_compression': message_compression, | ||
| 123 | + 'is_last_package': False, | ||
| 124 | + 'payload_msg': None, | ||
| 125 | + 'payload_size': 0 | ||
| 126 | + } | ||
| 127 | + | ||
| 128 | + # 处理序列号 | ||
| 129 | + if message_type_specific_flags & 0x01: | ||
| 130 | + if len(payload) >= 4: | ||
| 131 | + seq = int.from_bytes(payload[:4], "big", signed=True) | ||
| 132 | + result['payload_sequence'] = seq | ||
| 133 | + payload = payload[4:] | ||
| 134 | + | ||
| 135 | + # 检查是否为最后一包 | ||
| 136 | + if message_type_specific_flags & 0x02: | ||
| 137 | + result['is_last_package'] = True | ||
| 138 | + | ||
| 139 | + # 根据消息类型解析载荷 | ||
| 140 | + payload_msg = None | ||
| 141 | + payload_size = 0 | ||
| 142 | + | ||
| 143 | + if message_type == MessageType.FULL_SERVER_RESPONSE: | ||
| 144 | + if len(payload) >= 4: | ||
| 145 | + payload_size = int.from_bytes(payload[:4], "big", signed=True) | ||
| 146 | + payload_msg = payload[4:] | ||
| 147 | + elif message_type == MessageType.SERVER_ACK: | ||
| 148 | + if len(payload) >= 4: | ||
| 149 | + seq = int.from_bytes(payload[:4], "big", signed=True) | ||
| 150 | + result['seq'] = seq | ||
| 151 | + if len(payload) >= 8: | ||
| 152 | + payload_size = int.from_bytes(payload[4:8], "big", signed=False) | ||
| 153 | + payload_msg = payload[8:] | ||
| 154 | + elif message_type == MessageType.SERVER_ERROR_RESPONSE: | ||
| 155 | + if len(payload) >= 8: | ||
| 156 | + code = int.from_bytes(payload[:4], "big", signed=False) | ||
| 157 | + result['code'] = code | ||
| 158 | + payload_size = int.from_bytes(payload[4:8], "big", signed=False) | ||
| 159 | + payload_msg = payload[8:] | ||
| 160 | + | ||
| 161 | + # 解压缩和反序列化载荷 | ||
| 162 | + if payload_msg is not None: | ||
| 163 | + if message_compression == CompressionType.GZIP: | ||
| 164 | + try: | ||
| 165 | + payload_msg = gzip.decompress(payload_msg) | ||
| 166 | + except Exception as e: | ||
| 167 | + result['decompress_error'] = str(e) | ||
| 168 | + return result | ||
| 169 | + | ||
| 170 | + if serialization_method == SerializationMethod.JSON: | ||
| 171 | + try: | ||
| 172 | + payload_msg = json.loads(payload_msg.decode('utf-8')) | ||
| 173 | + except Exception as e: | ||
| 174 | + result['json_parse_error'] = str(e) | ||
| 175 | + return result | ||
| 176 | + elif serialization_method != SerializationMethod.NO_SERIALIZATION: | ||
| 177 | + payload_msg = payload_msg.decode('utf-8') | ||
| 178 | + | ||
| 179 | + result['payload_msg'] = payload_msg | ||
| 180 | + result['payload_size'] = payload_size | ||
| 181 | + return result | ||
| 182 | + | ||
| 183 | + @staticmethod | ||
| 184 | + def build_full_request( | ||
| 185 | + request_params: Dict[str, Any], | ||
| 186 | + sequence: int = 1, | ||
| 187 | + compression: bool = True | ||
| 188 | + ) -> bytearray: | ||
| 189 | + """ | ||
| 190 | + 构建完整客户端请求 | ||
| 191 | + | ||
| 192 | + Args: | ||
| 193 | + request_params: 请求参数字典 | ||
| 194 | + sequence: 序列号 | ||
| 195 | + compression: 是否启用压缩 | ||
| 196 | + | ||
| 197 | + Returns: | ||
| 198 | + bytearray: 完整的请求数据包 | ||
| 199 | + """ | ||
| 200 | + # 序列化请求参数 | ||
| 201 | + payload_bytes = json.dumps(request_params).encode('utf-8') | ||
| 202 | + | ||
| 203 | + # 压缩载荷 | ||
| 204 | + compression_type = CompressionType.GZIP if compression else CompressionType.NO_COMPRESSION | ||
| 205 | + if compression: | ||
| 206 | + payload_bytes = gzip.compress(payload_bytes) | ||
| 207 | + | ||
| 208 | + # 生成头部 | ||
| 209 | + header = DoubaoProtocol.generate_header( | ||
| 210 | + message_type=MessageType.FULL_CLIENT_REQUEST, | ||
| 211 | + message_type_specific_flags=MessageFlags.POS_SEQUENCE, | ||
| 212 | + compression_type=compression_type | ||
| 213 | + ) | ||
| 214 | + | ||
| 215 | + # 构建完整请求 | ||
| 216 | + request = bytearray(header) | ||
| 217 | + request.extend(DoubaoProtocol.generate_sequence_payload(sequence)) | ||
| 218 | + request.extend(len(payload_bytes).to_bytes(4, 'big')) | ||
| 219 | + request.extend(payload_bytes) | ||
| 220 | + | ||
| 221 | + return request | ||
| 222 | + | ||
| 223 | + @staticmethod | ||
| 224 | + def build_audio_request( | ||
| 225 | + audio_data: bytes, | ||
| 226 | + sequence: int, | ||
| 227 | + is_last: bool = False, | ||
| 228 | + compression: bool = True | ||
| 229 | + ) -> bytearray: | ||
| 230 | + """ | ||
| 231 | + 构建音频数据请求 | ||
| 232 | + | ||
| 233 | + Args: | ||
| 234 | + audio_data: 音频数据 | ||
| 235 | + sequence: 序列号 | ||
| 236 | + is_last: 是否为最后一包 | ||
| 237 | + compression: 是否启用压缩 | ||
| 238 | + | ||
| 239 | + Returns: | ||
| 240 | + bytearray: 音频请求数据包 | ||
| 241 | + """ | ||
| 242 | + # 压缩音频数据 | ||
| 243 | + compression_type = CompressionType.GZIP if compression else CompressionType.NO_COMPRESSION | ||
| 244 | + payload_bytes = gzip.compress(audio_data) if compression else audio_data | ||
| 245 | + | ||
| 246 | + # 确定消息标志 | ||
| 247 | + if is_last: | ||
| 248 | + flags = MessageFlags.NEG_WITH_SEQUENCE | ||
| 249 | + sequence = -abs(sequence) | ||
| 250 | + else: | ||
| 251 | + flags = MessageFlags.POS_SEQUENCE | ||
| 252 | + | ||
| 253 | + # 生成头部 | ||
| 254 | + header = DoubaoProtocol.generate_header( | ||
| 255 | + message_type=MessageType.AUDIO_ONLY_REQUEST, | ||
| 256 | + message_type_specific_flags=flags, | ||
| 257 | + compression_type=compression_type | ||
| 258 | + ) | ||
| 259 | + | ||
| 260 | + # 构建音频请求 | ||
| 261 | + request = bytearray(header) | ||
| 262 | + request.extend(DoubaoProtocol.generate_sequence_payload(sequence)) | ||
| 263 | + request.extend(len(payload_bytes).to_bytes(4, 'big')) | ||
| 264 | + request.extend(payload_bytes) | ||
| 265 | + | ||
| 266 | + return request |
asr/doubao/result_processor.py
0 → 100644
| 1 | +# AIfeng/2025-07-17 13:58:00 | ||
| 2 | +""" | ||
| 3 | +豆包ASR识别结果处理器 | ||
| 4 | +专门处理豆包ASR流式识别结果,提取关键信息并优化日志输出 | ||
| 5 | +""" | ||
| 6 | + | ||
| 7 | +import logging | ||
| 8 | +from typing import Dict, Any, Optional, Callable | ||
| 9 | +from dataclasses import dataclass | ||
| 10 | +from datetime import datetime | ||
| 11 | + | ||
| 12 | + | ||
| 13 | +@dataclass | ||
| 14 | +class ASRResult: | ||
| 15 | + """ASR识别结果数据类""" | ||
| 16 | + text: str | ||
| 17 | + is_final: bool | ||
| 18 | + confidence: float = 0.0 | ||
| 19 | + timestamp: datetime = None | ||
| 20 | + sequence: int = 0 | ||
| 21 | + | ||
| 22 | + def __post_init__(self): | ||
| 23 | + if self.timestamp is None: | ||
| 24 | + self.timestamp = datetime.now() | ||
| 25 | + | ||
| 26 | + | ||
| 27 | +class DoubaoResultProcessor: | ||
| 28 | + """豆包ASR结果处理器""" | ||
| 29 | + | ||
| 30 | + def __init__(self, | ||
| 31 | + text_only: bool = True, | ||
| 32 | + log_level: str = 'INFO', | ||
| 33 | + enable_streaming_log: bool = False): | ||
| 34 | + """ | ||
| 35 | + 初始化结果处理器 | ||
| 36 | + | ||
| 37 | + Args: | ||
| 38 | + text_only: 是否只输出文本内容 | ||
| 39 | + log_level: 日志级别 | ||
| 40 | + enable_streaming_log: 是否启用流式日志(会频繁输出中间结果) | ||
| 41 | + """ | ||
| 42 | + self.text_only = text_only | ||
| 43 | + self.enable_streaming_log = enable_streaming_log | ||
| 44 | + self.logger = self._setup_logger(log_level) | ||
| 45 | + | ||
| 46 | + # 流式结果管理 | ||
| 47 | + self.current_text = "" | ||
| 48 | + self.last_sequence = 0 | ||
| 49 | + self.result_count = 0 | ||
| 50 | + | ||
| 51 | + def _setup_logger(self, log_level: str) -> logging.Logger: | ||
| 52 | + """设置日志记录器""" | ||
| 53 | + logger = logging.getLogger(f"DoubaoResultProcessor_{id(self)}") | ||
| 54 | + logger.setLevel(getattr(logging, log_level.upper())) | ||
| 55 | + | ||
| 56 | + if not logger.handlers: | ||
| 57 | + handler = logging.StreamHandler() | ||
| 58 | + formatter = logging.Formatter( | ||
| 59 | + '%(asctime)s - %(name)s - %(levelname)s - %(message)s' | ||
| 60 | + ) | ||
| 61 | + handler.setFormatter(formatter) | ||
| 62 | + logger.addHandler(handler) | ||
| 63 | + | ||
| 64 | + return logger | ||
| 65 | + | ||
| 66 | + def extract_text_from_result(self, result: Dict[str, Any]) -> Optional[ASRResult]: | ||
| 67 | + """ | ||
| 68 | + 从豆包ASR完整结果中提取文本信息 | ||
| 69 | + | ||
| 70 | + Args: | ||
| 71 | + result: 豆包ASR返回的完整结果字典 | ||
| 72 | + | ||
| 73 | + Returns: | ||
| 74 | + ASRResult: 提取的结果对象,如果无有效文本则返回None | ||
| 75 | + """ | ||
| 76 | + try: | ||
| 77 | + # 检查是否有payload_msg | ||
| 78 | + payload_msg = result.get('payload_msg') | ||
| 79 | + if not payload_msg: | ||
| 80 | + return None | ||
| 81 | + | ||
| 82 | + # 提取result字段 | ||
| 83 | + result_data = payload_msg.get('result') | ||
| 84 | + if not result_data: | ||
| 85 | + return None | ||
| 86 | + | ||
| 87 | + # 提取文本 | ||
| 88 | + text = result_data.get('text', '').strip() | ||
| 89 | + if not text: | ||
| 90 | + return None | ||
| 91 | + | ||
| 92 | + # 提取其他信息 | ||
| 93 | + is_final = result.get('is_last_package', False) | ||
| 94 | + confidence = result_data.get('confidence', 0.0) | ||
| 95 | + sequence = result.get('payload_sequence', 0) | ||
| 96 | + | ||
| 97 | + return ASRResult( | ||
| 98 | + text=text, | ||
| 99 | + is_final=is_final, | ||
| 100 | + confidence=confidence, | ||
| 101 | + sequence=sequence | ||
| 102 | + ) | ||
| 103 | + | ||
| 104 | + except Exception as e: | ||
| 105 | + self.logger.error(f"提取ASR结果文本失败: {e}") | ||
| 106 | + return None | ||
| 107 | + | ||
| 108 | + def process_streaming_result(self, result: Dict[str, Any]) -> Optional[str]: | ||
| 109 | + """ | ||
| 110 | + 处理流式识别结果 | ||
| 111 | + | ||
| 112 | + Args: | ||
| 113 | + result: 豆包ASR返回的完整结果字典 | ||
| 114 | + | ||
| 115 | + Returns: | ||
| 116 | + str: 当前识别文本,如果无变化则返回None | ||
| 117 | + """ | ||
| 118 | + asr_result = self.extract_text_from_result(result) | ||
| 119 | + if not asr_result: | ||
| 120 | + return None | ||
| 121 | + | ||
| 122 | + self.result_count += 1 | ||
| 123 | + | ||
| 124 | + # 流式结果:后一次覆盖前一次 | ||
| 125 | + previous_text = self.current_text | ||
| 126 | + self.current_text = asr_result.text | ||
| 127 | + self.last_sequence = asr_result.sequence | ||
| 128 | + | ||
| 129 | + # 根据配置决定是否记录日志 | ||
| 130 | + if asr_result.is_final: | ||
| 131 | + self.logger.info(f"[最终结果] {asr_result.text}") | ||
| 132 | + elif self.enable_streaming_log: | ||
| 133 | + self.logger.debug(f"[流式更新 #{self.result_count}] {asr_result.text}") | ||
| 134 | + | ||
| 135 | + # 返回文本(如果与上次不同) | ||
| 136 | + return asr_result.text if asr_result.text != previous_text else None | ||
| 137 | + | ||
| 138 | + def create_optimized_callback(self, | ||
| 139 | + user_callback: Optional[Callable[[str], None]] = None) -> Callable[[Dict[str, Any]], None]: | ||
| 140 | + """ | ||
| 141 | + 创建优化的回调函数 | ||
| 142 | + | ||
| 143 | + Args: | ||
| 144 | + user_callback: 用户自定义回调函数,接收文本参数 | ||
| 145 | + | ||
| 146 | + Returns: | ||
| 147 | + Callable: 优化后的回调函数 | ||
| 148 | + """ | ||
| 149 | + def optimized_callback(result: Dict[str, Any]): | ||
| 150 | + """优化的回调函数,只处理文本内容""" | ||
| 151 | + try: | ||
| 152 | + # 处理流式结果 | ||
| 153 | + text = self.process_streaming_result(result) | ||
| 154 | + | ||
| 155 | + # 如果有文本变化且用户提供了回调函数 | ||
| 156 | + if text and user_callback: | ||
| 157 | + user_callback(text) | ||
| 158 | + | ||
| 159 | + except Exception as e: | ||
| 160 | + self.logger.error(f"处理ASR回调失败: {e}") | ||
| 161 | + | ||
| 162 | + return optimized_callback | ||
| 163 | + | ||
| 164 | + def get_current_result(self) -> Dict[str, Any]: | ||
| 165 | + """ | ||
| 166 | + 获取当前识别状态 | ||
| 167 | + | ||
| 168 | + Returns: | ||
| 169 | + Dict: 当前状态信息 | ||
| 170 | + """ | ||
| 171 | + return { | ||
| 172 | + 'current_text': self.current_text, | ||
| 173 | + 'last_sequence': self.last_sequence, | ||
| 174 | + 'result_count': self.result_count, | ||
| 175 | + 'text_length': len(self.current_text) | ||
| 176 | + } | ||
| 177 | + | ||
| 178 | + def reset(self): | ||
| 179 | + """重置处理器状态""" | ||
| 180 | + self.current_text = "" | ||
| 181 | + self.last_sequence = 0 | ||
| 182 | + self.result_count = 0 | ||
| 183 | + self.logger.info("结果处理器已重置") | ||
| 184 | + | ||
| 185 | + | ||
| 186 | +# 便捷函数 | ||
| 187 | +def create_text_only_callback(user_callback: Optional[Callable[[str], None]] = None, | ||
| 188 | + enable_streaming_log: bool = False) -> Callable[[Dict[str, Any]], None]: | ||
| 189 | + """ | ||
| 190 | + 创建只处理文本的回调函数 | ||
| 191 | + | ||
| 192 | + Args: | ||
| 193 | + user_callback: 用户回调函数 | ||
| 194 | + enable_streaming_log: 是否启用流式日志 | ||
| 195 | + | ||
| 196 | + Returns: | ||
| 197 | + Callable: 优化的回调函数 | ||
| 198 | + """ | ||
| 199 | + processor = DoubaoResultProcessor( | ||
| 200 | + text_only=True, | ||
| 201 | + enable_streaming_log=enable_streaming_log | ||
| 202 | + ) | ||
| 203 | + return processor.create_optimized_callback(user_callback) | ||
| 204 | + | ||
| 205 | + | ||
| 206 | +def extract_text_only(result: Dict[str, Any]) -> Optional[str]: | ||
| 207 | + """ | ||
| 208 | + 从豆包ASR结果中只提取文本 | ||
| 209 | + | ||
| 210 | + Args: | ||
| 211 | + result: 豆包ASR完整结果 | ||
| 212 | + | ||
| 213 | + Returns: | ||
| 214 | + str: 提取的文本,如果无文本则返回None | ||
| 215 | + """ | ||
| 216 | + try: | ||
| 217 | + return result.get('payload_msg', {}).get('result', {}).get('text', '').strip() or None | ||
| 218 | + except Exception: | ||
| 219 | + return None |
asr/doubao/service_factory.py
0 → 100644
| 1 | +# AIfeng/2025-07-11 13:36:00 | ||
| 2 | +""" | ||
| 3 | +豆包ASR服务工厂模块 | ||
| 4 | +提供简化的API接口和服务实例管理 | ||
| 5 | +""" | ||
| 6 | + | ||
| 7 | +import asyncio | ||
| 8 | +from pathlib import Path | ||
| 9 | +from typing import Dict, Any, Optional, Callable, Union | ||
| 10 | + | ||
| 11 | +from .config_manager import ConfigManager | ||
| 12 | +from .asr_client import DoubaoASRClient | ||
| 13 | + | ||
| 14 | + | ||
| 15 | +class DoubaoASRService: | ||
| 16 | + """豆包ASR服务工厂""" | ||
| 17 | + | ||
| 18 | + _instances = {} | ||
| 19 | + | ||
| 20 | + def __init__(self, config: Union[str, Dict[str, Any], ConfigManager]): | ||
| 21 | + """ | ||
| 22 | + 初始化ASR服务 | ||
| 23 | + | ||
| 24 | + Args: | ||
| 25 | + config: 配置文件路径、配置字典或配置管理器实例 | ||
| 26 | + """ | ||
| 27 | + if isinstance(config, str): | ||
| 28 | + self.config_manager = ConfigManager(config) | ||
| 29 | + elif isinstance(config, dict): | ||
| 30 | + self.config_manager = ConfigManager.from_dict(config) | ||
| 31 | + elif isinstance(config, ConfigManager): | ||
| 32 | + self.config_manager = config | ||
| 33 | + else: | ||
| 34 | + raise ValueError("配置参数类型错误") | ||
| 35 | + | ||
| 36 | + self.client = DoubaoASRClient(self.config_manager.get_config()) | ||
| 37 | + | ||
| 38 | + async def recognize_file( | ||
| 39 | + self, | ||
| 40 | + audio_path: str, | ||
| 41 | + streaming: bool = True, | ||
| 42 | + result_callback: Optional[Callable[[Dict[str, Any]], None]] = None, | ||
| 43 | + **kwargs | ||
| 44 | + ) -> Dict[str, Any]: | ||
| 45 | + """ | ||
| 46 | + 识别音频文件 | ||
| 47 | + | ||
| 48 | + Args: | ||
| 49 | + audio_path: 音频文件路径 | ||
| 50 | + streaming: 是否使用流式识别 | ||
| 51 | + result_callback: 结果回调函数 | ||
| 52 | + **kwargs: 其他参数 | ||
| 53 | + | ||
| 54 | + Returns: | ||
| 55 | + Dict: 识别结果 | ||
| 56 | + """ | ||
| 57 | + return await self.client.recognize_file( | ||
| 58 | + audio_path, | ||
| 59 | + streaming=streaming, | ||
| 60 | + result_callback=result_callback, | ||
| 61 | + **kwargs | ||
| 62 | + ) | ||
| 63 | + | ||
| 64 | + async def recognize_audio_data( | ||
| 65 | + self, | ||
| 66 | + audio_data: bytes, | ||
| 67 | + streaming: bool = True, | ||
| 68 | + result_callback: Optional[Callable[[Dict[str, Any]], None]] = None, | ||
| 69 | + **kwargs | ||
| 70 | + ) -> Dict[str, Any]: | ||
| 71 | + """ | ||
| 72 | + 识别音频数据 | ||
| 73 | + | ||
| 74 | + Args: | ||
| 75 | + audio_data: 音频数据 | ||
| 76 | + streaming: 是否使用流式识别 | ||
| 77 | + result_callback: 结果回调函数 | ||
| 78 | + **kwargs: 其他参数 | ||
| 79 | + | ||
| 80 | + Returns: | ||
| 81 | + Dict: 识别结果 | ||
| 82 | + """ | ||
| 83 | + return await self.client.recognize_audio_data( | ||
| 84 | + audio_data, | ||
| 85 | + streaming=streaming, | ||
| 86 | + result_callback=result_callback, | ||
| 87 | + **kwargs | ||
| 88 | + ) | ||
| 89 | + | ||
| 90 | + def get_status(self) -> Dict[str, Any]: | ||
| 91 | + """ | ||
| 92 | + 获取服务状态 | ||
| 93 | + | ||
| 94 | + Returns: | ||
| 95 | + Dict: 服务状态 | ||
| 96 | + """ | ||
| 97 | + return self.client.get_status() | ||
| 98 | + | ||
| 99 | + async def close(self): | ||
| 100 | + """关闭服务""" | ||
| 101 | + await self.client.close() | ||
| 102 | + | ||
| 103 | + @classmethod | ||
| 104 | + def create_service( | ||
| 105 | + cls, | ||
| 106 | + config: Union[str, Dict[str, Any], ConfigManager], | ||
| 107 | + instance_name: str = 'default' | ||
| 108 | + ) -> 'DoubaoASRService': | ||
| 109 | + """ | ||
| 110 | + 创建或获取服务实例 | ||
| 111 | + | ||
| 112 | + Args: | ||
| 113 | + config: 配置 | ||
| 114 | + instance_name: 实例名称 | ||
| 115 | + | ||
| 116 | + Returns: | ||
| 117 | + DoubaoASRService: 服务实例 | ||
| 118 | + """ | ||
| 119 | + if instance_name not in cls._instances: | ||
| 120 | + cls._instances[instance_name] = cls(config) | ||
| 121 | + return cls._instances[instance_name] | ||
| 122 | + | ||
| 123 | + @classmethod | ||
| 124 | + def get_service(cls, instance_name: str = 'default') -> Optional['DoubaoASRService']: | ||
| 125 | + """ | ||
| 126 | + 获取已创建的服务实例 | ||
| 127 | + | ||
| 128 | + Args: | ||
| 129 | + instance_name: 实例名称 | ||
| 130 | + | ||
| 131 | + Returns: | ||
| 132 | + DoubaoASRService: 服务实例或None | ||
| 133 | + """ | ||
| 134 | + return cls._instances.get(instance_name) | ||
| 135 | + | ||
| 136 | + @classmethod | ||
| 137 | + async def close_all_services(cls): | ||
| 138 | + """关闭所有服务实例""" | ||
| 139 | + for service in cls._instances.values(): | ||
| 140 | + await service.close() | ||
| 141 | + cls._instances.clear() | ||
| 142 | + | ||
| 143 | + | ||
| 144 | +# 便捷函数 | ||
| 145 | +def create_asr_service( | ||
| 146 | + config_path: Optional[str] = None, | ||
| 147 | + app_key: Optional[str] = None, | ||
| 148 | + access_key: Optional[str] = None, | ||
| 149 | + **kwargs | ||
| 150 | +) -> DoubaoASRService: | ||
| 151 | + """ | ||
| 152 | + 创建ASR服务的便捷函数 | ||
| 153 | + | ||
| 154 | + Args: | ||
| 155 | + config_path: 配置文件路径 | ||
| 156 | + app_key: 应用密钥 | ||
| 157 | + access_key: 访问密钥 | ||
| 158 | + **kwargs: 其他配置参数 | ||
| 159 | + | ||
| 160 | + Returns: | ||
| 161 | + DoubaoASRService: ASR服务实例 | ||
| 162 | + """ | ||
| 163 | + if config_path: | ||
| 164 | + return DoubaoASRService(config_path) | ||
| 165 | + | ||
| 166 | + # 从参数构建配置 | ||
| 167 | + config = { | ||
| 168 | + 'auth_config': { | ||
| 169 | + 'app_key': app_key or '', | ||
| 170 | + 'access_key': access_key or '' | ||
| 171 | + } | ||
| 172 | + } | ||
| 173 | + | ||
| 174 | + # 添加其他配置参数 | ||
| 175 | + if kwargs: | ||
| 176 | + if 'asr_config' not in config: | ||
| 177 | + config['asr_config'] = {} | ||
| 178 | + if 'audio_config' not in config: | ||
| 179 | + config['audio_config'] = {} | ||
| 180 | + if 'connection_config' not in config: | ||
| 181 | + config['connection_config'] = {} | ||
| 182 | + if 'logging_config' not in config: | ||
| 183 | + config['logging_config'] = {} | ||
| 184 | + | ||
| 185 | + # 映射常用参数 | ||
| 186 | + param_mapping = { | ||
| 187 | + 'streaming': ('asr_config', 'streaming_mode'), | ||
| 188 | + 'seg_duration': ('asr_config', 'seg_duration'), | ||
| 189 | + 'model_name': ('asr_config', 'model_name'), | ||
| 190 | + 'enable_punc': ('asr_config', 'enable_punc'), | ||
| 191 | + 'sample_rate': ('audio_config', 'default_rate'), | ||
| 192 | + 'audio_format': ('audio_config', 'default_format'), | ||
| 193 | + 'timeout': ('connection_config', 'timeout'), | ||
| 194 | + 'debug': ('logging_config', 'enable_debug') | ||
| 195 | + } | ||
| 196 | + | ||
| 197 | + for param, (section, key) in param_mapping.items(): | ||
| 198 | + if param in kwargs: | ||
| 199 | + config[section][key] = kwargs[param] | ||
| 200 | + | ||
| 201 | + return DoubaoASRService(config) | ||
| 202 | + | ||
| 203 | + | ||
| 204 | +async def recognize_file( | ||
| 205 | + audio_path: str, | ||
| 206 | + config_path: Optional[str] = None, | ||
| 207 | + app_key: Optional[str] = None, | ||
| 208 | + access_key: Optional[str] = None, | ||
| 209 | + streaming: bool = True, | ||
| 210 | + result_callback: Optional[Callable[[Dict[str, Any]], None]] = None, | ||
| 211 | + **kwargs | ||
| 212 | +) -> Dict[str, Any]: | ||
| 213 | + """ | ||
| 214 | + 识别音频文件的便捷函数 | ||
| 215 | + | ||
| 216 | + Args: | ||
| 217 | + audio_path: 音频文件路径 | ||
| 218 | + config_path: 配置文件路径 | ||
| 219 | + app_key: 应用密钥 | ||
| 220 | + access_key: 访问密钥 | ||
| 221 | + streaming: 是否使用流式识别 | ||
| 222 | + result_callback: 结果回调函数 | ||
| 223 | + **kwargs: 其他参数 | ||
| 224 | + | ||
| 225 | + Returns: | ||
| 226 | + Dict: 识别结果 | ||
| 227 | + """ | ||
| 228 | + service = create_asr_service( | ||
| 229 | + config_path=config_path, | ||
| 230 | + app_key=app_key, | ||
| 231 | + access_key=access_key, | ||
| 232 | + **kwargs | ||
| 233 | + ) | ||
| 234 | + | ||
| 235 | + try: | ||
| 236 | + return await service.recognize_file( | ||
| 237 | + audio_path, | ||
| 238 | + streaming=streaming, | ||
| 239 | + result_callback=result_callback | ||
| 240 | + ) | ||
| 241 | + finally: | ||
| 242 | + await service.close() | ||
| 243 | + | ||
| 244 | + | ||
| 245 | +async def recognize_audio_data( | ||
| 246 | + audio_data: bytes, | ||
| 247 | + config_path: Optional[str] = None, | ||
| 248 | + app_key: Optional[str] = None, | ||
| 249 | + access_key: Optional[str] = None, | ||
| 250 | + streaming: bool = True, | ||
| 251 | + result_callback: Optional[Callable[[Dict[str, Any]], None]] = None, | ||
| 252 | + **kwargs | ||
| 253 | +) -> Dict[str, Any]: | ||
| 254 | + """ | ||
| 255 | + 识别音频数据的便捷函数 | ||
| 256 | + | ||
| 257 | + Args: | ||
| 258 | + audio_data: 音频数据 | ||
| 259 | + config_path: 配置文件路径 | ||
| 260 | + app_key: 应用密钥 | ||
| 261 | + access_key: 访问密钥 | ||
| 262 | + streaming: 是否使用流式识别 | ||
| 263 | + result_callback: 结果回调函数 | ||
| 264 | + **kwargs: 其他参数 | ||
| 265 | + | ||
| 266 | + Returns: | ||
| 267 | + Dict: 识别结果 | ||
| 268 | + """ | ||
| 269 | + service = create_asr_service( | ||
| 270 | + config_path=config_path, | ||
| 271 | + app_key=app_key, | ||
| 272 | + access_key=access_key, | ||
| 273 | + **kwargs | ||
| 274 | + ) | ||
| 275 | + | ||
| 276 | + try: | ||
| 277 | + return await service.recognize_audio_data( | ||
| 278 | + audio_data, | ||
| 279 | + streaming=streaming, | ||
| 280 | + result_callback=result_callback | ||
| 281 | + ) | ||
| 282 | + finally: | ||
| 283 | + await service.close() | ||
| 284 | + | ||
| 285 | + | ||
| 286 | +def run_recognition( | ||
| 287 | + audio_path: str, | ||
| 288 | + config_path: Optional[str] = None, | ||
| 289 | + app_key: Optional[str] = None, | ||
| 290 | + access_key: Optional[str] = None, | ||
| 291 | + streaming: bool = True, | ||
| 292 | + result_callback: Optional[Callable[[Dict[str, Any]], None]] = None, | ||
| 293 | + **kwargs | ||
| 294 | +) -> Dict[str, Any]: | ||
| 295 | + """ | ||
| 296 | + 同步方式识别音频文件 | ||
| 297 | + | ||
| 298 | + Args: | ||
| 299 | + audio_path: 音频文件路径 | ||
| 300 | + config_path: 配置文件路径 | ||
| 301 | + app_key: 应用密钥 | ||
| 302 | + access_key: 访问密钥 | ||
| 303 | + streaming: 是否使用流式识别 | ||
| 304 | + result_callback: 结果回调函数 | ||
| 305 | + **kwargs: 其他参数 | ||
| 306 | + | ||
| 307 | + Returns: | ||
| 308 | + Dict: 识别结果 | ||
| 309 | + """ | ||
| 310 | + return asyncio.run( | ||
| 311 | + recognize_file( | ||
| 312 | + audio_path, | ||
| 313 | + config_path=config_path, | ||
| 314 | + app_key=app_key, | ||
| 315 | + access_key=access_key, | ||
| 316 | + streaming=streaming, | ||
| 317 | + result_callback=result_callback, | ||
| 318 | + **kwargs | ||
| 319 | + ) | ||
| 320 | + ) |
This diff could not be displayed because it is too large.
funasr_asr_sync.py
0 → 100644
| 1 | +# -*- coding: utf-8 -*- | ||
| 2 | +""" | ||
| 3 | +AIfeng/2025-01-02 10:27:06 | ||
| 4 | +FunASR语音识别模块 - 同步版本 | ||
| 5 | +基于eman-Fay-main-copy项目的同步实现模式 | ||
| 6 | +""" | ||
| 7 | + | ||
| 8 | +from threading import Thread | ||
| 9 | +import websocket | ||
| 10 | +import json | ||
| 11 | +import time | ||
| 12 | +import ssl | ||
| 13 | +import _thread as thread | ||
| 14 | +import os | ||
| 15 | +import asyncio | ||
| 16 | +import threading | ||
| 17 | + | ||
| 18 | +from core import get_web_instance, get_instance | ||
| 19 | +from utils import config_util as cfg | ||
| 20 | +from utils import util | ||
| 21 | + | ||
| 22 | +class FunASRSync: | ||
| 23 | + """FunASR同步客户端 - 基于参考项目实现""" | ||
| 24 | + | ||
| 25 | + def __init__(self, username): | ||
| 26 | + self.__URL = "ws://{}:{}".format(cfg.local_asr_ip, cfg.local_asr_port) | ||
| 27 | + self.__ws = None | ||
| 28 | + self.__connected = False | ||
| 29 | + self.__frames = [] | ||
| 30 | + self.__state = 0 | ||
| 31 | + self.__closing = False | ||
| 32 | + self.__task_id = '' | ||
| 33 | + self.done = False | ||
| 34 | + self.finalResults = "" | ||
| 35 | + self.__reconnect_delay = 1 | ||
| 36 | + self.__reconnecting = False | ||
| 37 | + self.username = username | ||
| 38 | + self.started = True | ||
| 39 | + self.__result_callback = None # 添加结果回调 | ||
| 40 | + | ||
| 41 | + util.log(1, f"FunASR同步客户端初始化完成,用户: {username}") | ||
| 42 | + | ||
| 43 | + def on_message(self, ws, message): | ||
| 44 | + """收到websocket消息的处理""" | ||
| 45 | + try: | ||
| 46 | + util.log(1, f"收到FunASR消息: {message}") | ||
| 47 | + | ||
| 48 | + # 尝试解析JSON消息以区分状态消息和识别结果 | ||
| 49 | + try: | ||
| 50 | + import json | ||
| 51 | + parsed_message = json.loads(message) | ||
| 52 | + | ||
| 53 | + # 检查是否为状态消息(如分块准备消息) | ||
| 54 | + if isinstance(parsed_message, dict) and 'status' in parsed_message: | ||
| 55 | + status = parsed_message.get('status') | ||
| 56 | + if status == 'ready': | ||
| 57 | + util.log(1, f"收到分块准备状态: {parsed_message.get('message', '')}") | ||
| 58 | + return # 状态消息不触发回调 | ||
| 59 | + elif status in ['processing', 'chunk_received']: | ||
| 60 | + util.log(1, f"收到处理状态: {status}") | ||
| 61 | + return # 处理状态消息不触发回调 | ||
| 62 | + elif status == 'error': | ||
| 63 | + util.log(3, f"收到错误状态: {parsed_message.get('message', '')}") | ||
| 64 | + return | ||
| 65 | + | ||
| 66 | + # 如果是字典但不是状态消息,可能是结构化的识别结果 | ||
| 67 | + if isinstance(parsed_message, dict) and 'text' in parsed_message: | ||
| 68 | + # 结构化识别结果 | ||
| 69 | + recognition_text = parsed_message.get('text', '') | ||
| 70 | + if recognition_text.strip(): # 只有非空结果才处理 | ||
| 71 | + self.done = True | ||
| 72 | + self.finalResults = recognition_text | ||
| 73 | + util.log(1, f"收到结构化识别结果: {recognition_text}") | ||
| 74 | + self._trigger_result_callback() | ||
| 75 | + return | ||
| 76 | + | ||
| 77 | + except json.JSONDecodeError: | ||
| 78 | + # 不是JSON格式,可能是纯文本识别结果 | ||
| 79 | + pass | ||
| 80 | + | ||
| 81 | + # 处理纯文本识别结果 | ||
| 82 | + if isinstance(message, str) and message.strip(): | ||
| 83 | + # 过滤掉明显的状态消息 | ||
| 84 | + if any(keyword in message.lower() for keyword in ['status', 'ready', '准备接收', 'processing', 'chunk']): | ||
| 85 | + util.log(1, f"跳过状态消息: {message}") | ||
| 86 | + return | ||
| 87 | + | ||
| 88 | + # 这是真正的识别结果 | ||
| 89 | + self.done = True | ||
| 90 | + self.finalResults = message | ||
| 91 | + util.log(1, f"收到文本识别结果: {message}") | ||
| 92 | + self._trigger_result_callback() | ||
| 93 | + | ||
| 94 | + except Exception as e: | ||
| 95 | + util.log(3, f"处理识别结果时出错: {e}") | ||
| 96 | + | ||
| 97 | + if self.__closing: | ||
| 98 | + try: | ||
| 99 | + self.__ws.close() | ||
| 100 | + except Exception as e: | ||
| 101 | + util.log(2, f"关闭WebSocket时出错: {e}") | ||
| 102 | + | ||
| 103 | + def _trigger_result_callback(self): | ||
| 104 | + """触发结果回调函数""" | ||
| 105 | + if self.__result_callback: | ||
| 106 | + try: | ||
| 107 | + # 创建chat_message直接推送 | ||
| 108 | + chat_message = { | ||
| 109 | + "type":"chat_message", | ||
| 110 | + "sender":"回音", | ||
| 111 | + "text": self.finalResults, | ||
| 112 | + "Username": self.username, | ||
| 113 | + "model_info":"Funasr" | ||
| 114 | + } | ||
| 115 | + | ||
| 116 | + self.__result_callback(chat_message) | ||
| 117 | + util.log(1, f"已触发结果回调: {self.finalResults}") | ||
| 118 | + except Exception as e: | ||
| 119 | + util.log(3, f"调用结果回调时出错: {e}") | ||
| 120 | + | ||
| 121 | + # 发送到Web客户端(改进的异步调用方式) | ||
| 122 | + # try: | ||
| 123 | + # # 先检查WSA服务是否已初始化 | ||
| 124 | + # web_instance = get_web_instance() | ||
| 125 | + # if web_instance and web_instance.is_connected(self.username): | ||
| 126 | + # # 创建chat_message直接推送 | ||
| 127 | + # chat_message = { | ||
| 128 | + # "type":"chat_message", | ||
| 129 | + # "sender":"回音", | ||
| 130 | + # "content": self.finalResults, | ||
| 131 | + # "Username": self.username, | ||
| 132 | + # "model_info":"Funasr" | ||
| 133 | + # } | ||
| 134 | + # # 方案1: 使用add_cmd推送wsa_command类型数据 | ||
| 135 | + # # web_instance.add_cmd(chat_message) | ||
| 136 | + | ||
| 137 | + # util.log(1, f"FunASR识别结果已推送到Web客户端[{self.username}]: {self.finalResults}") | ||
| 138 | + # else: | ||
| 139 | + # util.log(2, f"用户{self.username}未连接到Web客户端,跳过推送") | ||
| 140 | + # except RuntimeError as e: | ||
| 141 | + # # WSA服务未初始化,这是正常情况(服务启动顺序问题) | ||
| 142 | + # util.log(2, f"WSA服务未初始化,跳过Web客户端通知: {e}") | ||
| 143 | + # except Exception as e: | ||
| 144 | + # util.log(3, f"发送到Web客户端时出错: {e}") | ||
| 145 | + | ||
| 146 | + # Human客户端通知改为日志记录(避免重复通知当前服务) | ||
| 147 | + # util.log(1, f"FunASR识别结果[{self.username}]: {self.finalResults}") | ||
| 148 | + | ||
| 149 | + if self.__closing: | ||
| 150 | + try: | ||
| 151 | + self.__ws.close() | ||
| 152 | + except Exception as e: | ||
| 153 | + util.log(2, f"关闭WebSocket时出错: {e}") | ||
| 154 | + | ||
| 155 | + def on_close(self, ws, code, msg): | ||
| 156 | + """收到websocket关闭的处理""" | ||
| 157 | + self.__connected = False | ||
| 158 | + util.log(2, f"FunASR连接关闭: {msg}") | ||
| 159 | + self.__ws = None | ||
| 160 | + | ||
| 161 | + def on_error(self, ws, error): | ||
| 162 | + """收到websocket错误的处理""" | ||
| 163 | + self.__connected = False | ||
| 164 | + util.log(3, f"FunASR连接错误: {error}") | ||
| 165 | + self.__ws = None | ||
| 166 | + | ||
| 167 | + def __attempt_reconnect(self): | ||
| 168 | + """重连机制""" | ||
| 169 | + if not self.__reconnecting: | ||
| 170 | + self.__reconnecting = True | ||
| 171 | + util.log(1, "尝试重连FunASR...") | ||
| 172 | + while not self.__connected: | ||
| 173 | + time.sleep(self.__reconnect_delay) | ||
| 174 | + self.start() | ||
| 175 | + self.__reconnect_delay *= 2 | ||
| 176 | + self.__reconnect_delay = 1 | ||
| 177 | + self.__reconnecting = False | ||
| 178 | + | ||
| 179 | + def on_open(self, ws): | ||
| 180 | + """收到websocket连接建立的处理""" | ||
| 181 | + self.__connected = True | ||
| 182 | + util.log(1, "FunASR WebSocket连接建立") | ||
| 183 | + | ||
| 184 | + def run(*args): | ||
| 185 | + while self.__connected: | ||
| 186 | + try: | ||
| 187 | + if len(self.__frames) > 0: | ||
| 188 | + frame = self.__frames[0] | ||
| 189 | + self.__frames.pop(0) | ||
| 190 | + | ||
| 191 | + if type(frame) == dict: | ||
| 192 | + ws.send(json.dumps(frame)) | ||
| 193 | + elif type(frame) == bytes: | ||
| 194 | + ws.send(frame, websocket.ABNF.OPCODE_BINARY) | ||
| 195 | + | ||
| 196 | + except Exception as e: | ||
| 197 | + util.log(3, f"发送帧数据时出错: {e}") | ||
| 198 | + # 优化发送间隔,从0.04秒减少到0.02秒提高效率 | ||
| 199 | + time.sleep(0.02) | ||
| 200 | + | ||
| 201 | + thread.start_new_thread(run, ()) | ||
| 202 | + | ||
| 203 | + def __connect(self): | ||
| 204 | + """建立WebSocket连接""" | ||
| 205 | + self.finalResults = "" | ||
| 206 | + self.done = False | ||
| 207 | + self.__frames.clear() | ||
| 208 | + websocket.enableTrace(False) | ||
| 209 | + | ||
| 210 | + self.__ws = websocket.WebSocketApp( | ||
| 211 | + self.__URL, | ||
| 212 | + on_message=self.on_message, | ||
| 213 | + on_close=self.on_close, | ||
| 214 | + on_error=self.on_error | ||
| 215 | + ) | ||
| 216 | + self.__ws.on_open = self.on_open | ||
| 217 | + self.__ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) | ||
| 218 | + | ||
| 219 | + def add_frame(self, frame): | ||
| 220 | + """添加帧到发送队列""" | ||
| 221 | + self.__frames.append(frame) | ||
| 222 | + | ||
| 223 | + def send(self, buf): | ||
| 224 | + """发送音频数据""" | ||
| 225 | + self.__frames.append(buf) | ||
| 226 | + | ||
| 227 | + def send_url(self, url): | ||
| 228 | + """发送音频文件URL""" | ||
| 229 | + # 确保使用绝对路径,相对路径对funasr服务无效 | ||
| 230 | + absolute_url = os.path.abspath(url) | ||
| 231 | + frame = {'url': absolute_url} | ||
| 232 | + if self.__ws and self.__connected: | ||
| 233 | + util.log(1, f"发送音频文件URL到FunASR: {absolute_url}") | ||
| 234 | + self.__ws.send(json.dumps(frame)) | ||
| 235 | + util.log(1, f"音频文件URL已发送: {frame}") | ||
| 236 | + else: | ||
| 237 | + util.log(2, f"WebSocket未连接,无法发送URL: {absolute_url}") | ||
| 238 | + | ||
| 239 | + def send_audio_data(self, audio_bytes, filename="audio.wav"): | ||
| 240 | + """发送音频数据(支持大文件分块)""" | ||
| 241 | + import base64 | ||
| 242 | + import math | ||
| 243 | + | ||
| 244 | + try: | ||
| 245 | + # 确保audio_bytes是bytes类型,避免memoryview缓冲区问题 | ||
| 246 | + if hasattr(audio_bytes, 'tobytes'): | ||
| 247 | + audio_bytes = bytes(audio_bytes.tobytes()) # Fix BufferError: memoryview has 1 exported buffer | ||
| 248 | + elif isinstance(audio_bytes, memoryview): | ||
| 249 | + audio_bytes = bytes(audio_bytes) | ||
| 250 | + | ||
| 251 | + total_size = len(audio_bytes) | ||
| 252 | + | ||
| 253 | + # 大文件阈值:1MB,超过则使用分块发送 | ||
| 254 | + large_file_threshold = 512 * 1024 # aiohttp限制默认1M,但再处理base64,会增加33% | ||
| 255 | + | ||
| 256 | + if total_size > large_file_threshold: | ||
| 257 | + util.log(1, f"检测到大文件({total_size} bytes),使用分块发送模式") | ||
| 258 | + return self._send_audio_data_chunked(audio_bytes, filename) | ||
| 259 | + else: | ||
| 260 | + # 小文件使用原有方式 | ||
| 261 | + return self._send_audio_data_simple(audio_bytes, filename) | ||
| 262 | + | ||
| 263 | + except Exception as e: | ||
| 264 | + util.log(3, f"发送音频数据时出错: {e}") | ||
| 265 | + return False | ||
| 266 | + | ||
| 267 | + def _send_audio_data_simple(self, audio_bytes, filename): | ||
| 268 | + """简单发送模式(小文件)""" | ||
| 269 | + import base64 | ||
| 270 | + | ||
| 271 | + try: | ||
| 272 | + # 将音频字节数据编码为Base64 | ||
| 273 | + audio_data_b64 = base64.b64encode(audio_bytes).decode('utf-8') | ||
| 274 | + | ||
| 275 | + # 构造发送格式,与funasr服务的process_audio_data函数兼容 | ||
| 276 | + frame = { | ||
| 277 | + 'audio_data': audio_data_b64, | ||
| 278 | + 'filename': filename | ||
| 279 | + } | ||
| 280 | + | ||
| 281 | + if self.__ws and self.__connected: | ||
| 282 | + util.log(1, f"发送音频数据到FunASR: {filename}, 大小: {len(audio_bytes)} bytes") | ||
| 283 | + success = self._send_frame_with_retry(frame) | ||
| 284 | + if success: | ||
| 285 | + util.log(1, f"音频数据已发送: {filename}") | ||
| 286 | + return True | ||
| 287 | + else: | ||
| 288 | + util.log(3, f"音频数据发送失败: {filename}") | ||
| 289 | + return False | ||
| 290 | + else: | ||
| 291 | + util.log(2, f"WebSocket未连接,无法发送音频数据: {filename}") | ||
| 292 | + return False | ||
| 293 | + | ||
| 294 | + except Exception as e: | ||
| 295 | + util.log(3, f"简单发送音频数据时出错: {e}") | ||
| 296 | + return False | ||
| 297 | + | ||
| 298 | + def _send_audio_data_chunked(self, audio_bytes, filename, chunk_size=512*1024): | ||
| 299 | + """分块发送音频数据(大文件)""" | ||
| 300 | + import base64 | ||
| 301 | + import math | ||
| 302 | + | ||
| 303 | + try: | ||
| 304 | + total_size = len(audio_bytes) | ||
| 305 | + total_chunks = math.ceil(total_size / chunk_size) | ||
| 306 | + | ||
| 307 | + util.log(1, f"开始分块发送: {filename}, 总大小: {total_size} bytes, 分块数: {total_chunks}") | ||
| 308 | + | ||
| 309 | + # 发送开始信号 | ||
| 310 | + start_frame = { | ||
| 311 | + 'type': 'audio_start', | ||
| 312 | + 'filename': filename, | ||
| 313 | + 'total_size': total_size, | ||
| 314 | + 'total_chunks': total_chunks, | ||
| 315 | + 'chunk_size': chunk_size | ||
| 316 | + } | ||
| 317 | + | ||
| 318 | + if not self._send_frame_with_retry(start_frame): | ||
| 319 | + util.log(3, f"发送开始信号失败: {filename}") | ||
| 320 | + return False | ||
| 321 | + | ||
| 322 | + # 分块发送 | ||
| 323 | + for i in range(total_chunks): | ||
| 324 | + start_pos = i * chunk_size | ||
| 325 | + end_pos = min(start_pos + chunk_size, total_size) | ||
| 326 | + chunk_data = audio_bytes[start_pos:end_pos] | ||
| 327 | + | ||
| 328 | + # Base64编码分块 | ||
| 329 | + chunk_b64 = base64.b64encode(chunk_data).decode('utf-8') | ||
| 330 | + | ||
| 331 | + chunk_frame = { | ||
| 332 | + 'type': 'audio_chunk', | ||
| 333 | + 'filename': filename, | ||
| 334 | + 'chunk_index': i, | ||
| 335 | + 'chunk_data': chunk_b64, | ||
| 336 | + 'is_last': (i == total_chunks - 1) | ||
| 337 | + } | ||
| 338 | + | ||
| 339 | + # 发送分块并检查结果 | ||
| 340 | + success = self._send_frame_with_retry(chunk_frame) | ||
| 341 | + if not success: | ||
| 342 | + util.log(3, f"分块 {i+1}/{total_chunks} 发送失败") | ||
| 343 | + return False | ||
| 344 | + | ||
| 345 | + # 进度日志 | ||
| 346 | + if (i + 1) % 10 == 0 or i == total_chunks - 1: | ||
| 347 | + progress = ((i + 1) / total_chunks) * 100 | ||
| 348 | + util.log(1, f"发送进度: {progress:.1f}% ({i+1}/{total_chunks})") | ||
| 349 | + | ||
| 350 | + # 流控延迟 | ||
| 351 | + time.sleep(0.01) | ||
| 352 | + | ||
| 353 | + # 发送结束信号 | ||
| 354 | + end_frame = { | ||
| 355 | + 'type': 'audio_end', | ||
| 356 | + 'filename': filename | ||
| 357 | + } | ||
| 358 | + | ||
| 359 | + if self._send_frame_with_retry(end_frame): | ||
| 360 | + util.log(1, f"音频数据分块发送完成: {filename}") | ||
| 361 | + return True | ||
| 362 | + else: | ||
| 363 | + util.log(3, f"发送结束信号失败: {filename}") | ||
| 364 | + return False | ||
| 365 | + | ||
| 366 | + except Exception as e: | ||
| 367 | + util.log(3, f"分块发送音频数据时出错: {e}") | ||
| 368 | + return False | ||
| 369 | + | ||
| 370 | + def _send_frame_with_retry(self, frame, max_retries=3, timeout=10): | ||
| 371 | + """带重试的帧发送""" | ||
| 372 | + for attempt in range(max_retries): | ||
| 373 | + try: | ||
| 374 | + if self.__ws and self.__connected: | ||
| 375 | + # 设置发送超时 | ||
| 376 | + start_time = time.time() | ||
| 377 | + self.__ws.send(json.dumps(frame)) | ||
| 378 | + | ||
| 379 | + # 简单的发送确认检查 | ||
| 380 | + time.sleep(0.05) # 等待发送完成 | ||
| 381 | + | ||
| 382 | + if time.time() - start_time < timeout: | ||
| 383 | + return True | ||
| 384 | + else: | ||
| 385 | + util.log(2, f"发送超时,尝试 {attempt + 1}/{max_retries}") | ||
| 386 | + else: | ||
| 387 | + util.log(2, f"连接不可用,尝试 {attempt + 1}/{max_retries}") | ||
| 388 | + | ||
| 389 | + except Exception as e: | ||
| 390 | + util.log(2, f"发送失败,尝试 {attempt + 1}/{max_retries}: {e}") | ||
| 391 | + | ||
| 392 | + if attempt < max_retries - 1: | ||
| 393 | + time.sleep(0.5 * (attempt + 1)) # 指数退避 | ||
| 394 | + | ||
| 395 | + return False | ||
| 396 | + | ||
| 397 | + def set_result_callback(self, callback): | ||
| 398 | + """设置结果回调函数""" | ||
| 399 | + self.__result_callback = callback | ||
| 400 | + util.log(1, f"已设置结果回调函数") | ||
| 401 | + | ||
| 402 | + | ||
| 403 | + def connect(self): | ||
| 404 | + """连接到FunASR服务(同步版本)""" | ||
| 405 | + try: | ||
| 406 | + if not self.__connected: | ||
| 407 | + self.start() # 调用现有的start方法 | ||
| 408 | + | ||
| 409 | + # 等待连接建立,最多等待30秒(针对大文件处理优化) | ||
| 410 | + max_wait_time = 30.0 | ||
| 411 | + wait_interval = 0.1 | ||
| 412 | + waited_time = 0.0 | ||
| 413 | + | ||
| 414 | + while not self.__connected and waited_time < max_wait_time: | ||
| 415 | + time.sleep(wait_interval) | ||
| 416 | + waited_time += wait_interval | ||
| 417 | + | ||
| 418 | + # 每5秒输出一次等待日志 | ||
| 419 | + if waited_time % 5.0 < wait_interval: | ||
| 420 | + util.log(1, f"等待FunASR连接中... {waited_time:.1f}s/{max_wait_time}s") | ||
| 421 | + | ||
| 422 | + if self.__connected: | ||
| 423 | + util.log(1, f"FunASR连接成功,耗时: {waited_time:.2f}秒") | ||
| 424 | + else: | ||
| 425 | + util.log(3, f"FunASR连接超时,等待了{waited_time:.2f}秒") | ||
| 426 | + | ||
| 427 | + return self.__connected | ||
| 428 | + return True | ||
| 429 | + except Exception as e: | ||
| 430 | + util.log(3, f"连接FunASR服务时出错: {e}") | ||
| 431 | + return False | ||
| 432 | + | ||
| 433 | + def start(self): | ||
| 434 | + """启动FunASR客户端""" | ||
| 435 | + Thread(target=self.__connect, args=[]).start() | ||
| 436 | + data = { | ||
| 437 | + 'vad_need': False, | ||
| 438 | + 'state': 'StartTranscription' | ||
| 439 | + } | ||
| 440 | + self.add_frame(data) | ||
| 441 | + util.log(1, "FunASR客户端启动") | ||
| 442 | + | ||
| 443 | + def is_connected(self): | ||
| 444 | + """检查连接状态""" | ||
| 445 | + return self.__connected | ||
| 446 | + | ||
| 447 | + def end(self): | ||
| 448 | + """结束FunASR客户端""" | ||
| 449 | + if self.__connected: | ||
| 450 | + try: | ||
| 451 | + # 发送剩余帧 | ||
| 452 | + for frame in self.__frames: | ||
| 453 | + self.__frames.pop(0) | ||
| 454 | + if type(frame) == dict: | ||
| 455 | + self.__ws.send(json.dumps(frame)) | ||
| 456 | + elif type(frame) == bytes: | ||
| 457 | + self.__ws.send(frame, websocket.ABNF.OPCODE_BINARY) | ||
| 458 | + | ||
| 459 | + self.__frames.clear() | ||
| 460 | + | ||
| 461 | + # 发送停止信号 | ||
| 462 | + frame = {'vad_need': False, 'state': 'StopTranscription'} | ||
| 463 | + self.__ws.send(json.dumps(frame)) | ||
| 464 | + | ||
| 465 | + except Exception as e: | ||
| 466 | + util.log(3, f"结束FunASR时出错: {e}") | ||
| 467 | + | ||
| 468 | + self.__closing = True | ||
| 469 | + self.__connected = False | ||
| 470 | + util.log(1, "FunASR客户端结束") | ||
| 471 | + | ||
| 472 | + |
scheduler/__init__.py
0 → 100644
scheduler/thread_manager.py
0 → 100644
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# -*- coding: utf-8 -*- | ||
| 3 | +""" | ||
| 4 | +AIfeng/2025-07-02 11:24:08 | ||
| 5 | +线程管理器 - 提供增强的线程功能 | ||
| 6 | +""" | ||
| 7 | + | ||
| 8 | +import threading | ||
| 9 | +import time | ||
| 10 | +import traceback | ||
| 11 | +from typing import Callable, Any, Optional | ||
| 12 | +from utils.util import log | ||
| 13 | + | ||
| 14 | + | ||
| 15 | +class MyThread(threading.Thread): | ||
| 16 | + """增强的线程类,提供更好的错误处理和监控功能""" | ||
| 17 | + | ||
| 18 | + def __init__(self, target: Optional[Callable] = None, name: Optional[str] = None, | ||
| 19 | + args: tuple = (), kwargs: dict = None, daemon: bool = True): | ||
| 20 | + """ | ||
| 21 | + 初始化线程 | ||
| 22 | + | ||
| 23 | + Args: | ||
| 24 | + target: 目标函数 | ||
| 25 | + name: 线程名称 | ||
| 26 | + args: 位置参数 | ||
| 27 | + kwargs: 关键字参数 | ||
| 28 | + daemon: 是否为守护线程 | ||
| 29 | + """ | ||
| 30 | + super().__init__(target=target, name=name, args=args, kwargs=kwargs or {}, daemon=daemon) | ||
| 31 | + self._target = target | ||
| 32 | + self._args = args | ||
| 33 | + self._kwargs = kwargs or {} | ||
| 34 | + self._result = None | ||
| 35 | + self._exception = None | ||
| 36 | + self._start_time = None | ||
| 37 | + self._end_time = None | ||
| 38 | + self._running = False | ||
| 39 | + | ||
| 40 | + def run(self): | ||
| 41 | + """重写run方法,添加错误处理和监控""" | ||
| 42 | + self._start_time = time.time() | ||
| 43 | + self._running = True | ||
| 44 | + | ||
| 45 | + try: | ||
| 46 | + if self._target: | ||
| 47 | + log(1, f"线程 {self.name} 开始执行") | ||
| 48 | + self._result = self._target(*self._args, **self._kwargs) | ||
| 49 | + log(1, f"线程 {self.name} 执行完成") | ||
| 50 | + except Exception as e: | ||
| 51 | + self._exception = e | ||
| 52 | + log(3, f"线程 {self.name} 执行出错: {e}") | ||
| 53 | + log(3, f"错误详情: {traceback.format_exc()}") | ||
| 54 | + finally: | ||
| 55 | + self._end_time = time.time() | ||
| 56 | + self._running = False | ||
| 57 | + duration = self._end_time - self._start_time | ||
| 58 | + log(1, f"线程 {self.name} 运行时长: {duration:.2f}秒") | ||
| 59 | + | ||
| 60 | + def get_result(self) -> Any: | ||
| 61 | + """获取线程执行结果""" | ||
| 62 | + if self.is_alive(): | ||
| 63 | + raise RuntimeError("线程仍在运行中") | ||
| 64 | + | ||
| 65 | + if self._exception: | ||
| 66 | + raise self._exception | ||
| 67 | + | ||
| 68 | + return self._result | ||
| 69 | + | ||
| 70 | + def get_exception(self) -> Optional[Exception]: | ||
| 71 | + """获取线程执行过程中的异常""" | ||
| 72 | + return self._exception | ||
| 73 | + | ||
| 74 | + def get_duration(self) -> Optional[float]: | ||
| 75 | + """获取线程运行时长(秒)""" | ||
| 76 | + if self._start_time is None: | ||
| 77 | + return None | ||
| 78 | + | ||
| 79 | + end_time = self._end_time or time.time() | ||
| 80 | + return end_time - self._start_time | ||
| 81 | + | ||
| 82 | + def is_running(self) -> bool: | ||
| 83 | + """检查线程是否正在运行""" | ||
| 84 | + return self._running and self.is_alive() | ||
| 85 | + | ||
| 86 | + def stop_gracefully(self, timeout: float = 5.0) -> bool: | ||
| 87 | + """优雅地停止线程 | ||
| 88 | + | ||
| 89 | + Args: | ||
| 90 | + timeout: 等待超时时间(秒) | ||
| 91 | + | ||
| 92 | + Returns: | ||
| 93 | + bool: 是否成功停止 | ||
| 94 | + """ | ||
| 95 | + if not self.is_alive(): | ||
| 96 | + return True | ||
| 97 | + | ||
| 98 | + log(1, f"正在停止线程 {self.name}") | ||
| 99 | + | ||
| 100 | + # 等待线程自然结束 | ||
| 101 | + self.join(timeout=timeout) | ||
| 102 | + | ||
| 103 | + if self.is_alive(): | ||
| 104 | + log(2, f"线程 {self.name} 在 {timeout} 秒内未能自然结束") | ||
| 105 | + return False | ||
| 106 | + else: | ||
| 107 | + log(1, f"线程 {self.name} 已成功停止") | ||
| 108 | + return True | ||
| 109 | + | ||
| 110 | + def __str__(self) -> str: | ||
| 111 | + """字符串表示""" | ||
| 112 | + status = "运行中" if self.is_running() else "已停止" | ||
| 113 | + duration = self.get_duration() | ||
| 114 | + duration_str = f", 运行时长: {duration:.2f}秒" if duration else "" | ||
| 115 | + return f"MyThread(name={self.name}, status={status}{duration_str})" | ||
| 116 | + | ||
| 117 | + def __repr__(self) -> str: | ||
| 118 | + """详细字符串表示""" | ||
| 119 | + return self.__str__() | ||
| 120 | + | ||
| 121 | + | ||
| 122 | +class ThreadManager: | ||
| 123 | + """线程管理器,用于管理多个线程""" | ||
| 124 | + | ||
| 125 | + def __init__(self): | ||
| 126 | + self._threads = {} | ||
| 127 | + self._lock = threading.Lock() | ||
| 128 | + | ||
| 129 | + def create_thread(self, name: str, target: Callable, args: tuple = (), | ||
| 130 | + kwargs: dict = None, daemon: bool = True) -> MyThread: | ||
| 131 | + """创建新线程 | ||
| 132 | + | ||
| 133 | + Args: | ||
| 134 | + name: 线程名称 | ||
| 135 | + target: 目标函数 | ||
| 136 | + args: 位置参数 | ||
| 137 | + kwargs: 关键字参数 | ||
| 138 | + daemon: 是否为守护线程 | ||
| 139 | + | ||
| 140 | + Returns: | ||
| 141 | + MyThread: 创建的线程对象 | ||
| 142 | + """ | ||
| 143 | + with self._lock: | ||
| 144 | + if name in self._threads: | ||
| 145 | + raise ValueError(f"线程名称 '{name}' 已存在") | ||
| 146 | + | ||
| 147 | + thread = MyThread(target=target, name=name, args=args, | ||
| 148 | + kwargs=kwargs or {}, daemon=daemon) | ||
| 149 | + self._threads[name] = thread | ||
| 150 | + return thread | ||
| 151 | + | ||
| 152 | + def start_thread(self, name: str) -> bool: | ||
| 153 | + """启动指定线程 | ||
| 154 | + | ||
| 155 | + Args: | ||
| 156 | + name: 线程名称 | ||
| 157 | + | ||
| 158 | + Returns: | ||
| 159 | + bool: 是否成功启动 | ||
| 160 | + """ | ||
| 161 | + with self._lock: | ||
| 162 | + if name not in self._threads: | ||
| 163 | + log(3, f"线程 '{name}' 不存在") | ||
| 164 | + return False | ||
| 165 | + | ||
| 166 | + thread = self._threads[name] | ||
| 167 | + if thread.is_alive(): | ||
| 168 | + log(2, f"线程 '{name}' 已在运行中") | ||
| 169 | + return False | ||
| 170 | + | ||
| 171 | + try: | ||
| 172 | + thread.start() | ||
| 173 | + log(1, f"线程 '{name}' 启动成功") | ||
| 174 | + return True | ||
| 175 | + except Exception as e: | ||
| 176 | + log(3, f"启动线程 '{name}' 失败: {e}") | ||
| 177 | + return False | ||
| 178 | + | ||
| 179 | + def stop_thread(self, name: str, timeout: float = 5.0) -> bool: | ||
| 180 | + """停止指定线程 | ||
| 181 | + | ||
| 182 | + Args: | ||
| 183 | + name: 线程名称 | ||
| 184 | + timeout: 等待超时时间 | ||
| 185 | + | ||
| 186 | + Returns: | ||
| 187 | + bool: 是否成功停止 | ||
| 188 | + """ | ||
| 189 | + with self._lock: | ||
| 190 | + if name not in self._threads: | ||
| 191 | + log(3, f"线程 '{name}' 不存在") | ||
| 192 | + return False | ||
| 193 | + | ||
| 194 | + thread = self._threads[name] | ||
| 195 | + return thread.stop_gracefully(timeout) | ||
| 196 | + | ||
| 197 | + def stop_all_threads(self, timeout: float = 5.0) -> bool: | ||
| 198 | + """停止所有线程 | ||
| 199 | + | ||
| 200 | + Args: | ||
| 201 | + timeout: 每个线程的等待超时时间 | ||
| 202 | + | ||
| 203 | + Returns: | ||
| 204 | + bool: 是否所有线程都成功停止 | ||
| 205 | + """ | ||
| 206 | + log(1, "正在停止所有线程...") | ||
| 207 | + success = True | ||
| 208 | + | ||
| 209 | + with self._lock: | ||
| 210 | + for name, thread in self._threads.items(): | ||
| 211 | + if thread.is_alive(): | ||
| 212 | + if not thread.stop_gracefully(timeout): | ||
| 213 | + success = False | ||
| 214 | + | ||
| 215 | + if success: | ||
| 216 | + log(1, "所有线程已成功停止") | ||
| 217 | + else: | ||
| 218 | + log(2, "部分线程未能在指定时间内停止") | ||
| 219 | + | ||
| 220 | + return success | ||
| 221 | + | ||
| 222 | + def get_thread_status(self) -> dict: | ||
| 223 | + """获取所有线程状态 | ||
| 224 | + | ||
| 225 | + Returns: | ||
| 226 | + dict: 线程状态信息 | ||
| 227 | + """ | ||
| 228 | + status = {} | ||
| 229 | + with self._lock: | ||
| 230 | + for name, thread in self._threads.items(): | ||
| 231 | + status[name] = { | ||
| 232 | + 'alive': thread.is_alive(), | ||
| 233 | + 'running': thread.is_running(), | ||
| 234 | + 'duration': thread.get_duration(), | ||
| 235 | + 'exception': str(thread.get_exception()) if thread.get_exception() else None | ||
| 236 | + } | ||
| 237 | + return status | ||
| 238 | + | ||
| 239 | + def cleanup_finished_threads(self): | ||
| 240 | + """清理已完成的线程""" | ||
| 241 | + with self._lock: | ||
| 242 | + finished_threads = [name for name, thread in self._threads.items() | ||
| 243 | + if not thread.is_alive()] | ||
| 244 | + | ||
| 245 | + for name in finished_threads: | ||
| 246 | + del self._threads[name] | ||
| 247 | + log(1, f"已清理完成的线程: {name}") | ||
| 248 | + | ||
| 249 | + def __len__(self) -> int: | ||
| 250 | + """返回线程数量""" | ||
| 251 | + with self._lock: | ||
| 252 | + return len(self._threads) | ||
| 253 | + | ||
| 254 | + def __contains__(self, name: str) -> bool: | ||
| 255 | + """检查是否包含指定名称的线程""" | ||
| 256 | + with self._lock: | ||
| 257 | + return name in self._threads | ||
| 258 | + | ||
| 259 | + | ||
| 260 | +# 全局线程管理器实例 | ||
| 261 | +thread_manager = ThreadManager() |
-
Please register or login to post a comment