冯杨

ASR WebSocket服务实现:1.FunASR本地方案 2.豆包

1.已实现音频文件的处理,包括小文件直接转换以及大文件分割识别。
2.豆包接入流式识别,但封装仍需要修改
3.线程管理
4.websocket集中管理
... ... @@ -3,6 +3,8 @@ build/
*.egg-info/
*.so
*.mp4
*.mp3
*.wav
tmp*
trial*/
... ... @@ -12,7 +14,6 @@ data_utils/face_tracking/3DMM/*
data_utils/face_parsing/79999_iter.pth
pretrained
*.mp4
.DS_Store
workspace/log_ngp.txt
.idea
... ... @@ -21,3 +22,4 @@ models/
*.log
dist
.vscode/launch.json
/speech.wav
... ...
# AIfeng/2025-07-11 13:36:00
# 豆包ASR语音识别服务
基于豆包(Doubao)语音识别API的通用ASR服务,支持流式和非流式语音识别,提供简洁易用的Python接口。
## 🚀 特性
- **多种识别模式**: 支持流式和非流式语音识别
- **多格式支持**: 支持WAV、MP3、PCM等音频格式
- **灵活配置**: 支持配置文件、环境变量、代码配置等多种方式
- **异步支持**: 基于asyncio的异步API,支持高并发
- **实时回调**: 流式识别支持实时结果回调
- **错误处理**: 完善的错误处理和重试机制
- **易于集成**: 简洁的API设计,易于集成到现有项目
## 📦 安装
### 依赖要求
```bash
pip install websockets aiofiles
```
### 项目结构
```
asr/doubao/
├── __init__.py # 模块初始化和公共API
├── config.json # 默认配置文件
├── config_manager.py # 配置管理器
├── protocol.py # 豆包协议处理
├── audio_utils.py # 音频处理工具
├── asr_client.py # ASR客户端核心
├── service_factory.py # 服务工厂和便捷接口
├── example.py # 使用示例
└── README.md # 项目文档
```
## 🔧 配置
### 1. 获取API密钥
访问[豆包开放平台](https://www.volcengine.com/docs/6561/1354869)获取:
- `app_key`: 应用密钥
- `access_key`: 访问密钥
### 2. 配置方式
#### 方式1: 环境变量(推荐)
```bash
export DOUBAO_APP_KEY="your_app_key"
export DOUBAO_ACCESS_KEY="your_access_key"
```
#### 方式2: 配置文件
创建 `config.json`
```json
{
"auth_config": {
"app_key": "your_app_key",
"access_key": "your_access_key"
},
"asr_config": {
"streaming_mode": true,
"enable_punc": true,
"seg_duration": 200
}
}
```
#### 方式3: 代码配置
```python
service = create_asr_service(
app_key="your_app_key",
access_key="your_access_key"
)
```
## 🎯 快速开始
### 1. 简单文件识别
```python
import asyncio
from asr.doubao import recognize_file
async def simple_recognition():
result = await recognize_file(
audio_path="path/to/your/audio.wav",
app_key="your_app_key",
access_key="your_access_key",
streaming=True
)
print(f"识别结果: {result}")
# 运行
asyncio.run(simple_recognition())
```
### 2. 流式识别with实时回调
```python
import asyncio
from asr.doubao import create_asr_service
async def streaming_recognition():
# 定义结果回调函数
def on_result(result):
if result.get('payload_msg'):
print(f"实时结果: {result['payload_msg']}")
# 创建服务实例
service = create_asr_service(
app_key="your_app_key",
access_key="your_access_key",
streaming=True
)
try:
result = await service.recognize_file(
"path/to/your/audio.wav",
result_callback=on_result
)
print(f"最终结果: {result}")
finally:
await service.close()
# 运行
asyncio.run(streaming_recognition())
```
### 2.1 优化的文本输出(推荐)
针对豆包ASR输出完整报文但只需要文本的问题,提供了专门的结果处理器:
```python
import asyncio
from asr.doubao import create_asr_service
from asr.doubao.result_processor import create_text_only_callback
async def optimized_streaming():
# 只处理文本内容的回调函数
def on_text(text: str):
print(f"识别文本: {text}")
# 流式数据特点:后一次覆盖前一次,最终结果会不停刷新
# 创建优化的回调(自动提取text字段)
optimized_callback = create_text_only_callback(
user_callback=on_text,
enable_streaming_log=False # 关闭中间结果日志
)
service = create_asr_service(
app_key="your_app_key",
access_key="your_access_key",
streaming=True
)
try:
await service.recognize_file(
"path/to/your/audio.wav",
result_callback=optimized_callback
)
finally:
await service.close()
# 运行
asyncio.run(optimized_streaming())
```
### 3. 音频数据识别
```python
import asyncio
from asr.doubao import recognize_audio_data
async def data_recognition():
# 读取音频数据
with open("path/to/your/audio.wav", "rb") as f:
audio_data = f.read()
result = await recognize_audio_data(
audio_data=audio_data,
audio_format="wav",
app_key="your_app_key",
access_key="your_access_key"
)
print(f"识别结果: {result}")
# 运行
asyncio.run(data_recognition())
```
### 4. 同步方式(简单场景)
```python
from asr.doubao import run_recognition
# 同步识别
result = run_recognition(
audio_path="path/to/your/audio.wav",
app_key="your_app_key",
access_key="your_access_key"
)
print(f"识别结果: {result}")
```
## 📚 详细用法
### 服务实例管理
```python
from asr.doubao import DoubaoASRService, create_asr_service
# 创建服务实例
service = create_asr_service(
app_key="your_app_key",
access_key="your_access_key",
streaming=True,
debug=True
)
# 执行多次识别(复用连接)
result1 = await service.recognize_file("audio1.wav")
result2 = await service.recognize_file("audio2.wav")
# 关闭服务
await service.close()
```
### 批量识别
```python
async def batch_recognition(audio_files):
service = create_asr_service(
app_key="your_app_key",
access_key="your_access_key"
)
results = []
try:
for audio_file in audio_files:
result = await service.recognize_file(audio_file)
results.append({
'file': audio_file,
'result': result
})
finally:
await service.close()
return results
# 使用
audio_files = ["audio1.wav", "audio2.wav", "audio3.wav"]
results = await batch_recognition(audio_files)
```
### 自定义配置
```python
custom_config = {
'asr_config': {
'enable_punc': True,
'seg_duration': 300, # 自定义分段时长
'streaming_mode': True
},
'connection_config': {
'timeout': 60, # 自定义超时时间
'retry_times': 5
},
'logging_config': {
'enable_debug': True
}
}
service = create_asr_service(
app_key="your_app_key",
access_key="your_access_key",
custom_config=custom_config
)
```
### 配置文件使用
```python
# 使用配置文件
result = await recognize_file(
audio_path="audio.wav",
config_path="asr/doubao/config.json"
)
# 或者
service = create_asr_service(config_path="asr/doubao/config.json")
```
## 🔧 配置参数
### ASR配置 (asr_config)
| 参数 | 类型 | 默认值 | 说明 |
|------|------|--------|------|
| `ws_url` | str | wss://openspeech.bytedance.com/api/v3/sauc/bigmodel | 流式识别WebSocket URL |
| `ws_url_nostream` | str | wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream | 非流式识别WebSocket URL |
| `resource_id` | str | volc.bigasr.sauc.duration | 资源ID |
| `model_name` | str | bigmodel | 模型名称 |
| `enable_punc` | bool | true | 是否启用标点符号 |
| `streaming_mode` | bool | true | 是否启用流式模式 |
| `seg_duration` | int | 200 | 音频分段时长(ms) |
| `mp3_seg_size` | int | 1000 | MP3分段大小(bytes) |
### 认证配置 (auth_config)
| 参数 | 类型 | 说明 |
|------|------|------|
| `app_key` | str | 应用密钥 |
| `access_key` | str | 访问密钥 |
### 音频配置 (audio_config)
| 参数 | 类型 | 默认值 | 说明 |
|------|------|--------|------|
| `default_format` | str | wav | 默认音频格式 |
| `default_rate` | int | 16000 | 默认采样率 |
| `default_bits` | int | 16 | 默认位深度 |
| `default_channel` | int | 1 | 默认声道数 |
| `supported_formats` | list | ["wav", "mp3", "pcm"] | 支持的音频格式 |
### 连接配置 (connection_config)
| 参数 | 类型 | 默认值 | 说明 |
|------|------|--------|------|
| `max_size` | int | 1000000000 | 最大消息大小 |
| `timeout` | int | 30 | 连接超时时间(秒) |
| `retry_times` | int | 3 | 重试次数 |
| `retry_delay` | int | 1 | 重试延迟(秒) |
## 🎵 支持的音频格式
| 格式 | 扩展名 | 说明 |
|------|--------|------|
| WAV | .wav | 无损音频格式,推荐使用 |
| MP3 | .mp3 | 压缩音频格式 |
| PCM | .pcm | 原始音频数据 |
### 音频要求
- **采样率**: 推荐16kHz,支持8kHz、16kHz、24kHz、48kHz
- **位深度**: 推荐16bit
- **声道**: 推荐单声道(mono)
- **编码**: PCM编码
## 🔍 API参考
### 便捷函数
#### `recognize_file()`
```python
async def recognize_file(
audio_path: str,
app_key: str = None,
access_key: str = None,
config_path: str = None,
streaming: bool = True,
result_callback: callable = None,
**kwargs
) -> dict:
```
识别音频文件。
**参数:**
- `audio_path`: 音频文件路径
- `app_key`: 应用密钥
- `access_key`: 访问密钥
- `config_path`: 配置文件路径
- `streaming`: 是否使用流式识别
- `result_callback`: 结果回调函数
- `**kwargs`: 其他配置参数
**返回:** 识别结果字典
#### `recognize_audio_data()`
```python
async def recognize_audio_data(
audio_data: bytes,
audio_format: str,
app_key: str = None,
access_key: str = None,
config_path: str = None,
streaming: bool = True,
result_callback: callable = None,
**kwargs
) -> dict:
```
识别音频数据。
**参数:**
- `audio_data`: 音频数据(bytes)
- `audio_format`: 音频格式("wav", "mp3", "pcm")
- 其他参数同`recognize_file()`
#### `run_recognition()`
```python
def run_recognition(
audio_path: str = None,
audio_data: bytes = None,
audio_format: str = None,
**kwargs
) -> dict:
```
同步方式执行识别。
#### `create_asr_service()`
```python
def create_asr_service(
app_key: str = None,
access_key: str = None,
config_path: str = None,
custom_config: dict = None,
**kwargs
) -> DoubaoASRService:
```
创建ASR服务实例。
### 核心类
#### `DoubaoASRService`
主要服务类,提供高级API。
**方法:**
- `recognize_file(audio_path, result_callback=None)`: 识别文件
- `recognize_audio_data(audio_data, audio_format, result_callback=None)`: 识别音频数据
- `get_status()`: 获取服务状态
- `close()`: 关闭服务
#### `DoubaoASRClient`
底层客户端类,处理WebSocket通信。
#### `ConfigManager`
配置管理器,处理配置加载、验证、合并等。
**方法:**
- `load_config(config_path)`: 加载配置文件
- `save_config(config, config_path)`: 保存配置文件
- `validate_config(config)`: 验证配置
- `merge_configs(base_config, override_config)`: 合并配置
- `create_default_config()`: 创建默认配置
#### `DoubaoResultProcessor`
结果处理器,专门处理豆包ASR流式识别结果,解决输出完整报文但只需要文本的问题。
**特性:**
- 自动提取`payload_msg.result.text`字段
- 处理流式数据覆盖更新特性
- 可配置日志输出级别
- 支持自定义文本回调函数
**方法:**
- `extract_text_from_result(result)`: 从完整结果中提取文本
- `process_streaming_result(result)`: 处理流式结果
- `create_optimized_callback(user_callback)`: 创建优化的回调函数
- `get_current_result()`: 获取当前识别状态
- `reset()`: 重置处理器状态
**便捷函数:**
- `create_text_only_callback(user_callback, enable_streaming_log)`: 创建只处理文本的回调
- `extract_text_only(result)`: 快速提取文本内容
**使用示例:**
```python
from asr.doubao.result_processor import DoubaoResultProcessor, create_text_only_callback
# 方式1: 使用处理器类
processor = DoubaoResultProcessor(text_only=True, enable_streaming_log=False)
callback = processor.create_optimized_callback(lambda text: print(f"文本: {text}"))
# 方式2: 使用便捷函数
callback = create_text_only_callback(
user_callback=lambda text: print(f"文本: {text}"),
enable_streaming_log=False
)
# 在ASR服务中使用
service = create_asr_service(...)
await service.recognize_file("audio.wav", result_callback=callback)
```
## 🧪 测试
运行测试套件:
```bash
python -m pytest test/test_doubao_asr.py -v
```
或者直接运行测试文件:
```bash
python test/test_doubao_asr.py
```
测试包括:
- 单元测试
- 集成测试
- 性能测试
- 错误处理测试
## 🔧 故障排除
### 常见问题
#### 1. 认证失败
```
AuthenticationError: Invalid app_key or access_key
```
**解决方案:**
- 检查app_key和access_key是否正确
- 确认密钥是否已激活
- 检查网络连接
#### 2. 音频格式不支持
```
AudioFormatError: Unsupported audio format
```
**解决方案:**
- 确认音频格式为WAV、MP3或PCM
- 检查音频文件是否损坏
- 转换音频格式到支持的格式
#### 3. 连接超时
```
ConnectionTimeoutError: Connection timeout
```
**解决方案:**
- 检查网络连接
- 增加timeout配置
- 检查防火墙设置
#### 4. 音频文件过大
```
FileSizeError: Audio file too large
```
**解决方案:**
- 分割音频文件
- 压缩音频质量
- 使用流式识别
### 调试模式
启用调试模式获取详细日志:
```python
service = create_asr_service(
app_key="your_app_key",
access_key="your_access_key",
debug=True
)
```
或在配置文件中设置:
```json
{
"logging_config": {
"enable_debug": true,
"log_requests": true,
"log_responses": true
}
}
```
## 📈 性能优化
### 1. 连接复用
对于批量识别,使用服务实例复用连接:
```python
service = create_asr_service(...)
try:
for audio_file in audio_files:
result = await service.recognize_file(audio_file)
finally:
await service.close()
```
### 2. 并发处理
使用asyncio进行并发识别:
```python
import asyncio
async def concurrent_recognition(audio_files):
tasks = []
for audio_file in audio_files:
task = recognize_file(audio_file, ...)
tasks.append(task)
results = await asyncio.gather(*tasks)
return results
```
### 3. 音频预处理
- 使用合适的音频格式和参数
- 预先分割大文件
- 去除静音段
## 🤝 贡献
欢迎提交Issue和Pull Request!
### 开发环境设置
1. 克隆项目
2. 安装依赖:`pip install -r requirements.txt`
3. 运行测试:`python -m pytest`
4. 提交代码前请确保测试通过
## 📄 许可证
MIT License
## 🔗 相关链接
- [豆包语音识别API文档](https://www.volcengine.com/docs/6561/1354869)
- [豆包开放平台](https://www.volcengine.com/)
- [WebSocket协议](https://tools.ietf.org/html/rfc6455)
## 📞 支持
如有问题,请:
1. 查看本文档的故障排除部分
2. 搜索已有的Issue
3. 创建新的Issue并提供详细信息
4. 联系技术支持
---
**作者**: AIfeng
**版本**: 1.0.0
**更新时间**: 2025-07-11
\ No newline at end of file
... ...
# AIfeng/2025-07-11 13:36:00
"""
豆包ASR语音识别服务模块
提供完整的语音识别功能,支持流式和非流式识别
"""
__version__ = "1.0.0"
__author__ = "AIfeng"
__description__ = "豆包ASR语音识别服务模块"
# 导入核心类和函数
from .asr_client import DoubaoASRClient
from .config_manager import ConfigManager
from .service_factory import (
DoubaoASRService,
create_asr_service,
recognize_file,
recognize_audio_data,
run_recognition
)
from .protocol import DoubaoProtocol, MessageType, MessageFlags, SerializationMethod, CompressionType
from .audio_utils import AudioProcessor
from .result_processor import (
DoubaoResultProcessor,
ASRResult,
create_text_only_callback,
extract_text_only
)
# 公共API
__all__ = [
# 核心类
'DoubaoASRClient',
'DoubaoASRService',
'ConfigManager',
'DoubaoProtocol',
'AudioProcessor',
'DoubaoResultProcessor',
'ASRResult',
# 便捷函数
'create_asr_service',
'recognize_file',
'recognize_audio_data',
'run_recognition',
'create_text_only_callback',
'extract_text_only',
# 协议常量
'MessageType',
'MessageFlags',
'SerializationMethod',
'CompressionType',
# 版本信息
'__version__',
'__author__',
'__description__'
]
# 快速开始示例
def get_quick_start_example() -> str:
"""
获取快速开始示例代码
Returns:
str: 示例代码
"""
return '''
# 豆包ASR快速开始示例
import asyncio
from asr.doubao import recognize_file, create_asr_service
# 方式1: 使用便捷函数(推荐用于简单场景)
async def simple_recognition():
result = await recognize_file(
audio_path="path/to/your/audio.wav",
app_key="your_app_key",
access_key="your_access_key",
streaming=True
)
print(result)
# 方式2: 使用服务实例(推荐用于复杂场景)
async def advanced_recognition():
# 创建服务实例
service = create_asr_service(
app_key="your_app_key",
access_key="your_access_key",
streaming=True,
debug=True
)
# 定义结果回调函数
def on_result(result):
if result.get('payload_msg'):
print(f"实时结果: {result['payload_msg']}")
try:
# 执行识别
result = await service.recognize_file(
"path/to/your/audio.wav",
result_callback=on_result
)
print(f"最终结果: {result}")
finally:
await service.close()
# 方式3: 使用配置文件
async def config_based_recognition():
result = await recognize_file(
audio_path="path/to/your/audio.wav",
config_path="path/to/config.json"
)
print(result)
# 同步方式(简单场景)
def sync_recognition():
from asr.doubao import run_recognition
result = run_recognition(
audio_path="path/to/your/audio.wav",
app_key="your_app_key",
access_key="your_access_key"
)
print(result)
# 运行示例
if __name__ == "__main__":
# 选择一种方式运行
asyncio.run(simple_recognition())
# asyncio.run(advanced_recognition())
# asyncio.run(config_based_recognition())
# sync_recognition()
'''
def get_config_template() -> str:
"""
获取配置文件模板
Returns:
str: 配置文件模板
"""
return '''
{
"asr_config": {
"ws_url": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
"ws_url_nostream": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream",
"resource_id": "volc.bigasr.sauc.duration",
"model_name": "bigmodel",
"enable_punc": true,
"streaming_mode": true,
"seg_duration": 200,
"mp3_seg_size": 1000
},
"auth_config": {
"app_key": "your_app_key_here",
"access_key": "your_access_key_here"
},
"audio_config": {
"default_format": "wav",
"default_rate": 16000,
"default_bits": 16,
"default_channel": 1,
"default_codec": "raw",
"supported_formats": ["wav", "mp3", "pcm"]
},
"connection_config": {
"max_size": 1000000000,
"timeout": 30,
"retry_times": 3,
"retry_delay": 1
},
"logging_config": {
"enable_debug": false,
"log_requests": true,
"log_responses": true
}
}
'''
def print_info():
"""
打印模块信息
"""
print(f"豆包ASR语音识别服务模块 v{__version__}")
print(f"作者: {__author__}")
print(f"描述: {__description__}")
print("\n支持的功能:")
print("- 流式语音识别")
print("- 非流式语音识别")
print("- 多种音频格式支持 (WAV, MP3, PCM)")
print("- 灵活的配置管理")
print("- 异步和同步API")
print("- 实时结果回调")
print("\n快速开始:")
print("from asr.doubao import recognize_file")
print("result = await recognize_file('audio.wav', app_key='...', access_key='...')")
if __name__ == "__main__":
print_info()
\ No newline at end of file
... ...
# AIfeng/2025-07-11 14:15:00
"""
豆包ASR模块主入口
支持通过 python -m asr.doubao 运行示例
"""
if __name__ == '__main__':
from .example import run_all_examples
import asyncio
print("=== 豆包ASR语音识别服务示例 ===")
print("正在运行所有示例...")
try:
asyncio.run(run_all_examples())
except KeyboardInterrupt:
print("\n用户中断执行")
except Exception as e:
print(f"执行失败: {e}")
print("请确保已设置环境变量: DOUBAO_APP_KEY, DOUBAO_ACCESS_KEY")
print("并准备好测试音频文件")
\ No newline at end of file
... ...
# AIfeng/2025-07-11 13:36:00
"""
豆包ASR客户端核心模块
提供完整的语音识别服务接口,支持流式和非流式识别
"""
import asyncio
import json
import logging
import time
import uuid
from pathlib import Path
from typing import Dict, Any, Optional, Callable, AsyncGenerator
import aiofiles
import websockets
from websockets.exceptions import ConnectionClosedError, WebSocketException
from .protocol import DoubaoProtocol, MessageType
from .audio_utils import AudioProcessor
class DoubaoASRClient:
"""豆包ASR客户端"""
def __init__(self, config: Dict[str, Any]):
"""
初始化ASR客户端
Args:
config: 配置字典
"""
self.config = config
self.asr_config = config.get('asr_config', {})
self.auth_config = config.get('auth_config', {})
self.audio_config = config.get('audio_config', {})
self.connection_config = config.get('connection_config', {})
self.logging_config = config.get('logging_config', {})
# 设置日志
self.logger = self._setup_logger()
# 协议处理器
self.protocol = DoubaoProtocol()
# 音频处理器
self.audio_processor = AudioProcessor()
# 连接状态
self.is_connected = False
self.current_session_id = None
def _setup_logger(self) -> logging.Logger:
"""设置日志记录器"""
logger = logging.getLogger('doubao_asr')
if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
if self.logging_config.get('enable_debug', False):
logger.setLevel(logging.DEBUG)
else:
logger.setLevel(logging.INFO)
return logger
def _get_ws_url(self, streaming: bool = True) -> str:
"""获取WebSocket URL"""
if streaming:
return self.asr_config.get('ws_url', 'wss://openspeech.bytedance.com/api/v3/sauc/bigmodel')
else:
return self.asr_config.get('ws_url_nostream', 'wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream')
def _build_auth_headers(self, request_id: str) -> Dict[str, str]:
"""构建认证头部"""
headers = {
'X-Api-Resource-Id': self.asr_config.get('resource_id', 'volc.bigasr.sauc.duration'),
'X-Api-Access-Key': self.auth_config.get('access_key', ''),
'X-Api-App-Key': self.auth_config.get('app_key', ''),
'X-Api-Request-Id': request_id
}
return headers
def _build_request_params(
self,
request_id: str,
audio_format: str = 'wav',
sample_rate: int = 16000,
bits: int = 16,
channels: int = 1,
uid: str = 'default_user'
) -> Dict[str, Any]:
"""构建请求参数"""
return {
'user': {
'uid': uid
},
'audio': {
'format': audio_format,
'sample_rate': sample_rate,
'bits': bits,
'channel': channels,
'codec': self.audio_config.get('default_codec', 'raw')
},
'request': {
'model_name': self.asr_config.get('model_name', 'bigmodel'),
'enable_punc': self.asr_config.get('enable_punc', True)
}
}
async def recognize_file(
self,
audio_path: str,
streaming: bool = True,
result_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
**kwargs
) -> Dict[str, Any]:
"""
识别音频文件
Args:
audio_path: 音频文件路径
streaming: 是否使用流式识别
result_callback: 结果回调函数
**kwargs: 其他参数
Returns:
Dict: 识别结果
"""
try:
# 读取音频文件
async with aiofiles.open(audio_path, mode='rb') as f:
audio_data = await f.read()
self.logger.info(f"开始识别音频文件: {audio_path}, 大小: {len(audio_data)} 字节")
# 识别音频数据
return await self.recognize_audio_data(
audio_data,
streaming=streaming,
result_callback=result_callback,
**kwargs
)
except Exception as e:
self.logger.error(f"识别音频文件失败: {e}")
return {
'success': False,
'error': str(e),
'audio_path': audio_path
}
async def recognize_audio_data(
self,
audio_data: bytes,
streaming: bool = True,
result_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
**kwargs
) -> Dict[str, Any]:
"""
识别音频数据
Args:
audio_data: 音频数据
streaming: 是否使用流式识别
result_callback: 结果回调函数
**kwargs: 其他参数
Returns:
Dict: 识别结果
"""
request_id = str(uuid.uuid4())
self.current_session_id = request_id
try:
# 准备音频数据
audio_format, segment_size, metadata = self.audio_processor.prepare_audio_for_recognition(
audio_data,
segment_duration_ms=self.asr_config.get('seg_duration', 200)
)
self.logger.info(f"音频格式: {audio_format}, 分片大小: {segment_size}, 元数据: {metadata}")
# 构建请求参数
request_params = self._build_request_params(
request_id,
audio_format=audio_format,
sample_rate=metadata.get('sample_rate', 16000),
bits=metadata.get('sample_width', 2) * 8,
channels=metadata.get('channels', 1),
uid=kwargs.get('uid', 'default_user')
)
# 执行识别
if streaming:
return await self._streaming_recognize(
audio_data,
request_params,
segment_size,
request_id,
result_callback
)
else:
return await self._non_streaming_recognize(
audio_data,
request_params,
request_id
)
except Exception as e:
self.logger.error(f"识别音频数据失败: {e}")
return {
'success': False,
'error': str(e),
'request_id': request_id
}
async def _streaming_recognize(
self,
audio_data: bytes,
request_params: Dict[str, Any],
segment_size: int,
request_id: str,
result_callback: Optional[Callable[[Dict[str, Any]], None]] = None
) -> Dict[str, Any]:
"""流式识别处理"""
ws_url = self._get_ws_url(streaming=True)
headers = self._build_auth_headers(request_id)
results = []
final_result = None
try:
# 兼容不同版本的websockets库
connect_kwargs = {
'uri': ws_url,
'max_size': self.connection_config.get('max_size', 1000000000)
}
# 尝试使用新版本的additional_headers参数
try:
async with websockets.connect(
**connect_kwargs,
additional_headers=headers
) as ws:
await self._handle_streaming_connection(ws, audio_data, request_params, segment_size, request_id, result_callback, results, final_result)
except TypeError:
# 回退到旧版本的extra_headers参数
async with websockets.connect(
**connect_kwargs,
extra_headers=headers
) as ws:
await self._handle_streaming_connection(ws, audio_data, request_params, segment_size, request_id, result_callback, results, final_result)
return {
'success': True,
'request_id': request_id,
'results': results,
'final_result': final_result,
'total_results': len(results)
}
except ConnectionClosedError as e:
self.logger.error(f"WebSocket连接关闭: {e.code} - {e.reason}")
return {
'success': False,
'error': f"连接关闭: {e.reason}",
'error_code': e.code,
'request_id': request_id
}
except WebSocketException as e:
self.logger.error(f"WebSocket异常: {e}")
return {
'success': False,
'error': str(e),
'request_id': request_id
}
except Exception as e:
self.logger.error(f"流式识别异常: {e}")
return {
'success': False,
'error': str(e),
'request_id': request_id
}
finally:
self.is_connected = False
async def _handle_streaming_connection(
self,
ws,
audio_data: bytes,
request_params: Dict[str, Any],
segment_size: int,
request_id: str,
result_callback: Optional[Callable[[Dict[str, Any]], None]],
results: list,
final_result: Any
):
"""处理流式连接的核心逻辑"""
self.is_connected = True
self.logger.info(f"WebSocket连接建立成功")
# 发送初始请求
seq = 1
full_request = self.protocol.build_full_request(request_params, seq)
await ws.send(full_request)
# 接收初始响应
response = await ws.recv()
result = self.protocol.parse_response(response)
if self.logging_config.get('log_responses', True):
self.logger.debug(f"初始响应: {result}")
# 分片发送音频数据
for chunk, is_last in self.audio_processor.slice_audio_data(audio_data, segment_size):
seq += 1
if is_last:
seq = -seq
start_time = time.time()
# 构建音频请求
audio_request = self.protocol.build_audio_request(
chunk, seq, is_last
)
# 发送音频数据
await ws.send(audio_request)
# 接收响应
response = await ws.recv()
result = self.protocol.parse_response(response)
# 处理结果
if result.get('payload_msg'):
results.append(result)
# 调用回调函数
if result_callback:
try:
result_callback(result)
except Exception as e:
self.logger.warning(f"回调函数执行失败: {e}")
if result.get('is_last_package'):
final_result = result
break
# 流式识别延时控制
if self.asr_config.get('streaming_mode', True):
elapsed = time.time() - start_time
sleep_time = max(0, (self.asr_config.get('seg_duration', 200) / 1000.0) - elapsed)
if sleep_time > 0:
await asyncio.sleep(sleep_time)
async def _non_streaming_recognize(
self,
audio_data: bytes,
request_params: Dict[str, Any],
request_id: str
) -> Dict[str, Any]:
"""非流式识别处理"""
ws_url = self._get_ws_url(streaming=False)
headers = self._build_auth_headers(request_id)
try:
# 兼容不同版本的websockets库
connect_kwargs = {
'uri': ws_url,
'max_size': self.connection_config.get('max_size', 1000000000)
}
# 尝试使用新版本的additional_headers参数
try:
async with websockets.connect(
**connect_kwargs,
additional_headers=headers
) as ws:
return await self._handle_non_streaming_connection(ws, audio_data, request_params, request_id)
except TypeError:
# 回退到旧版本的extra_headers参数
async with websockets.connect(
**connect_kwargs,
extra_headers=headers
) as ws:
return await self._handle_non_streaming_connection(ws, audio_data, request_params, request_id)
except Exception as e:
self.logger.error(f"非流式识别异常: {e}")
return {
'success': False,
'error': str(e),
'request_id': request_id
}
finally:
self.is_connected = False
async def _handle_non_streaming_connection(
self,
ws,
audio_data: bytes,
request_params: Dict[str, Any],
request_id: str
) -> Dict[str, Any]:
"""处理非流式连接的核心逻辑"""
self.is_connected = True
self.logger.info(f"WebSocket连接建立成功")
# 发送完整请求(包含音频数据)
full_request = self.protocol.build_full_request(request_params, 1)
await ws.send(full_request)
# 发送音频数据
audio_request = self.protocol.build_audio_request(
audio_data, -1, is_last=True
)
await ws.send(audio_request)
# 接收最终结果
response = await ws.recv()
result = self.protocol.parse_response(response)
self.is_connected = False
return {
'success': True,
'request_id': request_id,
'result': result
}
async def close(self):
"""关闭客户端"""
self.is_connected = False
self.current_session_id = None
self.logger.info("ASR客户端已关闭")
def get_status(self) -> Dict[str, Any]:
"""获取客户端状态"""
return {
'is_connected': self.is_connected,
'current_session_id': self.current_session_id,
'config': {
'ws_url': self._get_ws_url(),
'model_name': self.asr_config.get('model_name'),
'streaming_mode': self.asr_config.get('streaming_mode')
}
}
\ No newline at end of file
... ...
# AIfeng/2025-07-11 13:36:00
"""
豆包ASR音频处理工具模块
提供音频格式检测、分片处理、元数据提取等功能
"""
import wave
from io import BytesIO
from typing import Tuple, Generator, Dict, Any
class AudioProcessor:
"""音频处理器"""
@staticmethod
def read_wav_info(audio_data: bytes) -> Tuple[int, int, int, int, bytes]:
"""
读取WAV文件信息
Args:
audio_data: WAV音频数据
Returns:
Tuple: (声道数, 采样宽度, 采样率, 帧数, 音频字节数据)
"""
try:
with BytesIO(audio_data) as audio_io:
with wave.open(audio_io, 'rb') as wave_fp:
nchannels, sampwidth, framerate, nframes = wave_fp.getparams()[:4]
wave_bytes = wave_fp.readframes(nframes)
return nchannels, sampwidth, framerate, nframes, wave_bytes
except Exception as e:
raise ValueError(f"读取WAV文件失败: {e}")
@staticmethod
def is_wav_format(audio_data: bytes) -> bool:
"""
检查是否为WAV格式
Args:
audio_data: 音频数据
Returns:
bool: 是否为WAV格式
"""
if len(audio_data) < 44:
return False
return audio_data[0:4] == b"RIFF" and audio_data[8:12] == b"WAVE"
@staticmethod
def detect_audio_format(audio_data: bytes) -> str:
"""
检测音频格式
Args:
audio_data: 音频数据
Returns:
str: 音频格式 ('wav', 'mp3', 'pcm', 'unknown')
"""
if len(audio_data) < 4:
return 'unknown'
# 检查WAV格式
if AudioProcessor.is_wav_format(audio_data):
return 'wav'
# 检查MP3格式
if audio_data[0:3] == b"ID3" or audio_data[0:2] == b"\xff\xfb":
return 'mp3'
# 默认为PCM
return 'pcm'
@staticmethod
def slice_audio_data(
audio_data: bytes,
chunk_size: int
) -> Generator[Tuple[bytes, bool], None, None]:
"""
将音频数据分片
Args:
audio_data: 音频数据
chunk_size: 分片大小
Yields:
Tuple[bytes, bool]: (音频片段, 是否为最后一片)
"""
data_len = len(audio_data)
offset = 0
while offset + chunk_size < data_len:
yield audio_data[offset:offset + chunk_size], False
offset += chunk_size
# 最后一片
if offset < data_len:
yield audio_data[offset:data_len], True
@staticmethod
def calculate_segment_size(
audio_format: str,
sample_rate: int = 16000,
channels: int = 1,
bits: int = 16,
segment_duration_ms: int = 200,
mp3_seg_size: int = 1000
) -> int:
"""
计算音频分片大小
Args:
audio_format: 音频格式
sample_rate: 采样率
channels: 声道数
bits: 位深度
segment_duration_ms: 分片时长(毫秒)
mp3_seg_size: MP3分片大小
Returns:
int: 分片大小(字节)
"""
if audio_format == 'mp3':
return mp3_seg_size
elif audio_format == 'wav':
# 计算每秒字节数
bytes_per_second = channels * (bits // 8) * sample_rate
return int(bytes_per_second * segment_duration_ms / 1000)
elif audio_format == 'pcm':
# PCM格式计算
return int(sample_rate * (bits // 8) * channels * segment_duration_ms / 1000)
else:
raise ValueError(f"不支持的音频格式: {audio_format}")
@staticmethod
def extract_wav_metadata(audio_data: bytes) -> Dict[str, Any]:
"""
提取WAV文件元数据
Args:
audio_data: WAV音频数据
Returns:
Dict: 音频元数据
"""
try:
nchannels, sampwidth, framerate, nframes, _ = AudioProcessor.read_wav_info(audio_data)
duration = nframes / framerate
return {
'format': 'wav',
'channels': nchannels,
'sample_width': sampwidth,
'sample_rate': framerate,
'frames': nframes,
'duration': duration,
'size': len(audio_data)
}
except Exception as e:
return {
'format': 'wav',
'error': str(e),
'size': len(audio_data)
}
@staticmethod
def validate_audio_params(
audio_format: str,
sample_rate: int,
channels: int,
bits: int
) -> bool:
"""
验证音频参数
Args:
audio_format: 音频格式
sample_rate: 采样率
channels: 声道数
bits: 位深度
Returns:
bool: 参数是否有效
"""
# 支持的格式
supported_formats = ['wav', 'mp3', 'pcm']
if audio_format not in supported_formats:
return False
# 采样率范围
if sample_rate < 8000 or sample_rate > 48000:
return False
# 声道数
if channels < 1 or channels > 2:
return False
# 位深度
if bits not in [8, 16, 24, 32]:
return False
return True
@staticmethod
def prepare_audio_for_recognition(
audio_data: bytes,
target_format: str = 'wav',
segment_duration_ms: int = 200
) -> Tuple[str, int, Dict[str, Any]]:
"""
为识别准备音频数据
Args:
audio_data: 原始音频数据
target_format: 目标格式
segment_duration_ms: 分片时长
Returns:
Tuple: (检测到的格式, 分片大小, 音频元数据)
"""
# 检测音频格式
detected_format = AudioProcessor.detect_audio_format(audio_data)
# 提取元数据
if detected_format == 'wav':
metadata = AudioProcessor.extract_wav_metadata(audio_data)
segment_size = AudioProcessor.calculate_segment_size(
detected_format,
metadata.get('sample_rate', 16000),
metadata.get('channels', 1),
metadata.get('sample_width', 2) * 8,
segment_duration_ms
)
else:
# 对于非WAV格式,使用默认参数
metadata = {
'format': detected_format,
'size': len(audio_data)
}
segment_size = AudioProcessor.calculate_segment_size(
detected_format,
segment_duration_ms=segment_duration_ms
)
return detected_format, segment_size, metadata
\ No newline at end of file
... ...
{
"asr_config": {
"ws_url": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
"ws_url_nostream": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream",
"resource_id": "volc.bigasr.sauc.duration",
"resource_id_concurrent": "volc.bigasr.sauc.concurrent",
"model_name": "bigmodel",
"enable_punc": true,
"streaming_mode": true,
"seg_duration": 200,
"mp3_seg_size": 1000
},
"auth_config": {
"app_key": "1549099156",
"access_key": "0GcKVco6j09bThrIgQWTWa3g1nA91_9C"
},
"audio_config": {
"default_format": "wav",
"default_rate": 16000,
"default_bits": 16,
"default_channel": 1,
"default_codec": "raw",
"supported_formats": ["wav", "mp3", "pcm"]
},
"connection_config": {
"max_size": 1000000000,
"timeout": 30,
"retry_times": 3,
"retry_delay": 1
},
"logging_config": {
"enable_debug": false,
"log_requests": true,
"log_responses": true
}
}
... ...
# AIfeng/2025-07-11 13:36:00
"""
豆包ASR配置管理模块
提供配置文件加载、验证、合并和环境变量支持
"""
import json
import os
from pathlib import Path
from typing import Dict, Any, Optional
class ConfigManager:
"""配置管理器"""
DEFAULT_CONFIG = {
"asr_config": {
"ws_url": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel",
"ws_url_nostream": "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream",
"resource_id": "volc.bigasr.sauc.duration",
"resource_id_concurrent": "volc.bigasr.sauc.concurrent",
"model_name": "bigmodel",
"enable_punc": True,
"streaming_mode": True,
"seg_duration": 200,
"mp3_seg_size": 1000
},
"auth_config": {
"app_key": "",
"access_key": ""
},
"audio_config": {
"default_format": "wav",
"default_rate": 16000,
"default_bits": 16,
"default_channel": 1,
"default_codec": "raw",
"supported_formats": ["wav", "mp3", "pcm"]
},
"connection_config": {
"max_size": 1000000000,
"timeout": 30,
"retry_times": 3,
"retry_delay": 1
},
"logging_config": {
"enable_debug": False,
"log_requests": True,
"log_responses": True
}
}
def __init__(self, config_path: Optional[str] = None):
"""
初始化配置管理器
Args:
config_path: 配置文件路径
"""
self.config_path = config_path
self.config = self.DEFAULT_CONFIG.copy()
if config_path:
self.load_config(config_path)
# 从环境变量加载配置
self._load_from_env()
def load_config(self, config_path: str) -> Dict[str, Any]:
"""
加载配置文件
Args:
config_path: 配置文件路径
Returns:
Dict: 配置字典
"""
try:
config_file = Path(config_path)
if not config_file.exists():
raise FileNotFoundError(f"配置文件不存在: {config_path}")
with open(config_file, 'r', encoding='utf-8') as f:
file_config = json.load(f)
# 合并配置
self.config = self._merge_config(self.config, file_config)
# 验证配置
self._validate_config()
return self.config
except Exception as e:
raise ValueError(f"加载配置文件失败: {e}")
def _merge_config(self, base_config: Dict[str, Any], new_config: Dict[str, Any]) -> Dict[str, Any]:
"""
合并配置字典
Args:
base_config: 基础配置
new_config: 新配置
Returns:
Dict: 合并后的配置
"""
merged = base_config.copy()
for key, value in new_config.items():
if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
merged[key] = self._merge_config(merged[key], value)
else:
merged[key] = value
return merged
def _load_from_env(self):
"""从环境变量加载配置"""
# ASR配置
if os.getenv('DOUBAO_WS_URL'):
self.config['asr_config']['ws_url'] = os.getenv('DOUBAO_WS_URL')
if os.getenv('DOUBAO_MODEL_NAME'):
self.config['asr_config']['model_name'] = os.getenv('DOUBAO_MODEL_NAME')
if os.getenv('DOUBAO_SEG_DURATION'):
try:
self.config['asr_config']['seg_duration'] = int(os.getenv('DOUBAO_SEG_DURATION'))
except ValueError:
pass
# 认证配置
if os.getenv('DOUBAO_APP_KEY'):
self.config['auth_config']['app_key'] = os.getenv('DOUBAO_APP_KEY')
if os.getenv('DOUBAO_ACCESS_KEY'):
self.config['auth_config']['access_key'] = os.getenv('DOUBAO_ACCESS_KEY')
# 日志配置
if os.getenv('DOUBAO_DEBUG'):
self.config['logging_config']['enable_debug'] = os.getenv('DOUBAO_DEBUG').lower() == 'true'
def _validate_config(self):
"""验证配置"""
# 验证必需的认证信息
auth_config = self.config.get('auth_config', {})
if not auth_config.get('app_key'):
raise ValueError("缺少必需的配置: auth_config.app_key")
if not auth_config.get('access_key'):
raise ValueError("缺少必需的配置: auth_config.access_key")
# 验证ASR配置
asr_config = self.config.get('asr_config', {})
if not asr_config.get('ws_url'):
raise ValueError("缺少必需的配置: asr_config.ws_url")
# 验证音频配置
audio_config = self.config.get('audio_config', {})
supported_formats = audio_config.get('supported_formats', [])
default_format = audio_config.get('default_format')
if default_format and default_format not in supported_formats:
raise ValueError(f"默认音频格式 {default_format} 不在支持的格式列表中: {supported_formats}")
# 验证数值范围
seg_duration = asr_config.get('seg_duration', 200)
if not (50 <= seg_duration <= 1000):
raise ValueError(f"分片时长必须在50-1000ms之间,当前值: {seg_duration}")
sample_rate = audio_config.get('default_rate', 16000)
if sample_rate not in [8000, 16000, 22050, 44100, 48000]:
raise ValueError(f"不支持的采样率: {sample_rate}")
def get_config(self) -> Dict[str, Any]:
"""
获取完整配置
Returns:
Dict: 配置字典
"""
return self.config.copy()
def get_asr_config(self) -> Dict[str, Any]:
"""
获取ASR配置
Returns:
Dict: ASR配置
"""
return self.config.get('asr_config', {}).copy()
def get_auth_config(self) -> Dict[str, Any]:
"""
获取认证配置
Returns:
Dict: 认证配置
"""
return self.config.get('auth_config', {}).copy()
def get_audio_config(self) -> Dict[str, Any]:
"""
获取音频配置
Returns:
Dict: 音频配置
"""
return self.config.get('audio_config', {}).copy()
def update_config(self, new_config: Dict[str, Any]):
"""
更新配置
Args:
new_config: 新配置
"""
self.config = self._merge_config(self.config, new_config)
self._validate_config()
def save_config(self, output_path: Optional[str] = None):
"""
保存配置到文件
Args:
output_path: 输出文件路径,默认使用原配置文件路径
"""
save_path = output_path or self.config_path
if not save_path:
raise ValueError("未指定保存路径")
try:
with open(save_path, 'w', encoding='utf-8') as f:
json.dump(self.config, f, indent=2, ensure_ascii=False)
except Exception as e:
raise ValueError(f"保存配置文件失败: {e}")
def create_default_config(self, output_path: str) -> Dict[str, Any]:
"""
创建默认配置文件
Args:
output_path: 输出文件路径
Returns:
Dict[str, Any]: 默认配置字典
"""
try:
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(self.DEFAULT_CONFIG, f, indent=2, ensure_ascii=False)
return self.DEFAULT_CONFIG.copy()
except Exception as e:
raise ValueError(f"创建默认配置文件失败: {e}")
def get_env_template(self) -> str:
"""
获取环境变量模板
Returns:
str: 环境变量模板
"""
template = """
# 豆包ASR环境变量配置模板
# ASR服务配置
DOUBAO_WS_URL=wss://openspeech.bytedance.com/api/v3/sauc/bigmodel
DOUBAO_MODEL_NAME=bigmodel
DOUBAO_SEG_DURATION=200
# 认证配置(必需)
DOUBAO_APP_KEY=your_app_key_here
DOUBAO_ACCESS_KEY=your_access_key_here
# 调试配置
DOUBAO_DEBUG=false
"""
return template.strip()
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> 'ConfigManager':
"""
从字典创建配置管理器
Args:
config_dict: 配置字典
Returns:
ConfigManager: 配置管理器实例
"""
manager = cls()
manager.config = manager._merge_config(manager.DEFAULT_CONFIG, config_dict)
manager._validate_config()
return manager
@classmethod
def from_env(cls) -> 'ConfigManager':
"""
仅从环境变量创建配置管理器
Returns:
ConfigManager: 配置管理器实例
"""
manager = cls()
return manager
\ No newline at end of file
... ...
# AIfeng/2025-07-11 13:36:00
"""
豆包ASR语音识别服务使用示例
演示各种使用场景和最佳实践
"""
import asyncio
import os
import logging
from pathlib import Path
from typing import Dict, Any, Optional
# 导入ASR服务 - 支持相对导入和绝对导入
try:
# 尝试相对导入(作为包运行时)
from . import (
recognize_file,
recognize_audio_data,
create_asr_service,
run_recognition,
ConfigManager
)
except ImportError:
# 回退到绝对导入(独立运行时)
try:
from asr_client import (
recognize_file,
recognize_audio_data,
create_asr_service,
run_recognition
)
from config_manager import ConfigManager
except ImportError:
# 最后尝试直接导入
import sys
from pathlib import Path
# 添加当前目录到路径
current_dir = Path(__file__).parent
sys.path.insert(0, str(current_dir))
from asr_client import (
recognize_file,
recognize_audio_data,
create_asr_service,
run_recognition
)
from config_manager import ConfigManager
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class ASRExamples:
"""
ASR使用示例集合
"""
def __init__(self, app_key: str, access_key: str):
self.app_key = app_key
self.access_key = access_key
"""
示例1: 简单文件识别
"""
async def example_1_simple_file_recognition(self, audio_path: str):
logger.info("=== 示例1: 简单文件识别 ===")
try:
result = await recognize_file(
audio_path=audio_path,
app_key=self.app_key,
access_key=self.access_key,
streaming=True
)
logger.info(f"识别结果: {result}")
return result
except Exception as e:
logger.error(f"识别失败: {e}")
return None
"""
示例2: 流式识别with实时回调 - 简化版流式输出演示
"""
async def example_2_streaming_with_callback(self, audio_path: str):
logger.info("=== 示例2: 流式识别with实时回调 - 简化版流式输出演示 ===")
# 流式输出状态
self.current_text = ""
self.result_count = 0
def clear_line():
"""清除当前行"""
print('\r' + ' ' * 100 + '\r', end='', flush=True)
def print_streaming_result(text: str, is_final: bool = False):
"""打印流式结果"""
clear_line()
status = "[最终]" if is_final else "[流式]"
timestamp = f"[{self.result_count:02d}]"
print(f"{timestamp}{status} {text}", end='', flush=True)
if is_final:
print() # 最终结果换行
# 定义结果回调函数 - 演示实时文本更新
def on_result(result: Dict[str, Any]):
self.result_count += 1
if result.get('payload_msg'):
payload = result['payload_msg']
# 检查是否有识别结果
if 'result' in payload and 'text' in payload['result']:
new_text = payload['result']['text']
# 显示累积文本更新
print_streaming_result(new_text, False)
self.current_text = new_text
# 检查是否为最终结果
if result.get('is_last_package', False):
print_streaming_result(self.current_text, True)
logger.info(f"识别完成,共收到{self.result_count}次流式结果")
print("\n观察流式文本累积更新效果:")
print()
# 创建服务实例
service = create_asr_service(
app_key=self.app_key,
access_key=self.access_key,
streaming=True,
debug=False # 关闭调试日志以便观察输出
)
try:
result = await service.recognize_file(
audio_path,
result_callback=on_result
)
print()
logger.info(f"=== 识别结果摘要 ===")
logger.info(f"最终文本: {self.current_text}")
logger.info(f"流式更新次数: {self.result_count}")
return result
except Exception as e:
logger.error(f"识别失败: {e}")
return None
finally:
await service.close()
"""
示例3: 非流式识别
"""
async def example_3_non_streaming_recognition(self, audio_path: str):
logger.info("=== 示例3: 非流式识别 ===")
try:
result = await recognize_file(
audio_path=audio_path,
app_key=self.app_key,
access_key=self.access_key,
streaming=False # 非流式
)
logger.info(f"识别结果: {result}")
return result
except Exception as e:
logger.error(f"识别失败: {e}")
return None
"""
示例4: 音频数据识别
"""
async def example_4_audio_data_recognition(self, audio_data: bytes, audio_format: str = "wav"):
logger.info("=== 示例4: 音频数据识别 ===")
try:
result = await recognize_audio_data(
audio_data=audio_data,
audio_format=audio_format,
app_key=self.app_key,
access_key=self.access_key,
streaming=True
)
logger.info(f"识别结果: {result}")
return result
except Exception as e:
logger.error(f"识别失败: {e}")
return None
"""
示例5: 基于配置文件的识别
"""
async def example_5_config_based_recognition(self, audio_path: str, config_path: str):
"""
示例5: 基于配置文件的识别
"""
logger.info("=== 示例5: 基于配置文件的识别 ===")
try:
result = await recognize_file(
audio_path=audio_path,
config_path=config_path
)
logger.info(f"识别结果: {result}")
return result
except Exception as e:
logger.error(f"识别失败: {e}")
return None
"""
示例6: 批量识别
"""
async def example_6_batch_recognition(self, audio_files: list):
logger.info("=== 示例6: 批量识别 ===")
results = []
# 创建服务实例(复用连接)
service = create_asr_service(
app_key=self.app_key,
access_key=self.access_key,
streaming=True
)
try:
for i, audio_file in enumerate(audio_files):
logger.info(f"处理文件 {i+1}/{len(audio_files)}: {audio_file}")
try:
result = await service.recognize_file(audio_file)
results.append({
'file': audio_file,
'result': result,
'status': 'success'
})
except Exception as e:
logger.error(f"文件 {audio_file} 识别失败: {e}")
results.append({
'file': audio_file,
'result': None,
'status': 'failed',
'error': str(e)
})
logger.info(f"批量识别完成,成功: {sum(1 for r in results if r['status'] == 'success')}/{len(results)}")
return results
finally:
await service.close()
"""
示例7: 同步识别(简单场景)
"""
def example_7_sync_recognition(self, audio_path: str):
logger.info("=== 示例7: 同步识别 ===")
try:
result = run_recognition(
audio_path=audio_path,
app_key=self.app_key,
access_key=self.access_key,
streaming=True
)
logger.info(f"识别结果: {result}")
return result
except Exception as e:
logger.error(f"识别失败: {e}")
return None
"""
示例8: 自定义配置识别
"""
async def example_8_custom_config_recognition(self, audio_path: str):
logger.info("=== 示例8: 自定义配置识别 ===")
# 自定义配置
custom_config = {
'asr_config': {
'enable_punc': True,
'seg_duration': 300, # 自定义分段时长
'streaming_mode': True
},
'audio_config': {
'default_rate': 16000,
'default_bits': 16,
'default_channel': 1
},
'connection_config': {
'timeout': 60, # 自定义超时时间
'retry_times': 5
},
'logging_config': {
'enable_debug': True
}
}
service = create_asr_service(
app_key=self.app_key,
access_key=self.access_key,
custom_config=custom_config
)
try:
result = await service.recognize_file(audio_path)
logger.info(f"识别结果: {result}")
return result
except Exception as e:
logger.error(f"识别失败: {e}")
return None
finally:
await service.close()
def create_sample_config(config_path: str, app_key: str, access_key: str):
"""
创建示例配置文件
"""
config_manager = ConfigManager()
# 创建配置
config = config_manager.create_default_config(config_path)
config['auth_config']['app_key'] = app_key
config['auth_config']['access_key'] = access_key
config['logging_config']['enable_debug'] = True
# 更新配置管理器的配置并保存
config_manager.update_config(config)
config_manager.save_config(config_path)
logger.info(f"示例配置文件已创建: {config_path}")
async def run_all_examples():
"""
运行所有示例
"""
# 从环境变量获取密钥
app_key = os.getenv('DOUBAO_APP_KEY', '1549099156')
access_key = os.getenv('DOUBAO_ACCESS_KEY', '0GcKVco6j09bThrIgQWTWa3g1nA91_9C')
if app_key == 'your_app_key_here' or access_key == 'your_access_key_here':
logger.warning("请设置环境变量 DOUBAO_APP_KEY 和 DOUBAO_ACCESS_KEY")
logger.info("或者直接修改代码中的密钥")
return
# 示例音频文件路径(请替换为实际路径)
audio_path = "E:\\fengyang\\eman_one\\speech.wav"
if not Path(audio_path).exists():
logger.warning(f"音频文件不存在: {audio_path}")
logger.info("请替换为实际的音频文件路径")
return
# 创建示例实例
examples = ASRExamples(app_key, access_key)
# 创建示例配置文件
config_path = "example_config.json"
create_sample_config(config_path, app_key, access_key)
try:
# 运行示例
# await examples.example_1_simple_file_recognition(audio_path)
await examples.example_2_streaming_with_callback(audio_path)
# await examples.example_3_non_streaming_recognition(audio_path)
# 音频数据示例(需要实际音频数据)
# with open(audio_path, 'rb') as f:
# audio_data = f.read()
# await examples.example_4_audio_data_recognition(audio_data)
# await examples.example_5_config_based_recognition(audio_path, config_path)
# 批量识别示例
# audio_files = [audio_path] # 添加更多文件
# await examples.example_6_batch_recognition(audio_files)
# 同步识别示例
# examples.example_7_sync_recognition(audio_path)
# await examples.example_8_custom_config_recognition(audio_path)
except Exception as e:
logger.error(f"示例运行失败: {e}")
finally:
# 清理示例配置文件
if Path(config_path).exists():
os.remove(config_path)
logger.info(f"已清理示例配置文件: {config_path}")
if __name__ == "__main__":
# 运行所有示例
asyncio.run(run_all_examples())
# 或者运行单个示例
# app_key = "your_app_key"
# access_key = "your_access_key"
# audio_path = "path/to/audio.wav"
#
# examples = ASRExamples(app_key, access_key)
# asyncio.run(examples.example_1_simple_file_recognition(audio_path))
\ No newline at end of file
... ...
# AIfeng/2025-07-17 13:58:00
"""
豆包ASR优化使用示例
演示如何使用结果处理器来优化ASR输出,只获取文本内容
"""
import asyncio
import logging
from pathlib import Path
from typing import Dict, Any
from .service_factory import create_asr_service
from .result_processor import DoubaoResultProcessor, create_text_only_callback, extract_text_only
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def example_optimized_streaming():
"""优化的流式识别示例 - 只输出文本内容"""
logger.info("=== 优化流式识别示例 ===")
# 创建结果处理器
processor = DoubaoResultProcessor(
text_only=True,
enable_streaming_log=False # 关闭流式日志,减少输出
)
# 用户自定义文本处理函数
def on_text_result(text: str):
"""只处理文本内容的回调函数"""
print(f"识别文本: {text}")
# 这里可以添加你的业务逻辑
# 比如:发送到WebSocket、保存到数据库等
# 创建优化的回调函数
optimized_callback = processor.create_optimized_callback(on_text_result)
# 创建ASR服务
service = create_asr_service(
app_key="your_app_key",
access_key="your_access_key",
streaming=True,
debug=False # 关闭调试模式,减少日志输出
)
try:
# 识别音频文件
audio_path = "path/to/your/audio.wav"
result = await service.recognize_file(
audio_path,
result_callback=optimized_callback
)
# 获取最终状态
final_status = processor.get_current_result()
logger.info(f"识别完成: {final_status}")
finally:
await service.close()
async def example_simple_text_extraction():
"""简单文本提取示例"""
logger.info("=== 简单文本提取示例 ===")
# 使用便捷函数创建只处理文本的回调
def print_text(text: str):
print(f">> {text}")
text_callback = create_text_only_callback(
user_callback=print_text,
enable_streaming_log=False
)
service = create_asr_service(
app_key="your_app_key",
access_key="your_access_key"
)
try:
audio_path = "path/to/your/audio.wav"
await service.recognize_file(
audio_path,
result_callback=text_callback
)
finally:
await service.close()
async def example_manual_text_extraction():
"""手动文本提取示例"""
logger.info("=== 手动文本提取示例 ===")
def manual_callback(result: Dict[str, Any]):
"""手动提取文本的回调函数"""
# 使用便捷函数提取文本
text = extract_text_only(result)
if text:
print(f"提取的文本: {text}")
# 检查是否为最终结果
is_final = result.get('is_last_package', False)
if is_final:
print(f"[最终结果] {text}")
service = create_asr_service(
app_key="your_app_key",
access_key="your_access_key"
)
try:
audio_path = "path/to/your/audio.wav"
await service.recognize_file(
audio_path,
result_callback=manual_callback
)
finally:
await service.close()
async def example_streaming_with_websocket():
"""流式识别结合WebSocket示例"""
logger.info("=== 流式识别WebSocket示例 ===")
# 模拟WebSocket连接
class MockWebSocket:
def __init__(self):
self.messages = []
async def send_text(self, text: str):
self.messages.append(text)
print(f"WebSocket发送: {text}")
websocket = MockWebSocket()
# 创建处理器
processor = DoubaoResultProcessor(
text_only=True,
enable_streaming_log=False
)
async def websocket_handler(text: str):
"""WebSocket文本处理器"""
await websocket.send_text(text)
# 创建回调
callback = processor.create_optimized_callback(
lambda text: asyncio.create_task(websocket_handler(text))
)
service = create_asr_service(
app_key="your_app_key",
access_key="your_access_key"
)
try:
audio_path = "path/to/your/audio.wav"
await service.recognize_file(
audio_path,
result_callback=callback
)
print(f"WebSocket消息历史: {websocket.messages}")
finally:
await service.close()
async def example_comparison():
"""对比示例:原始输出 vs 优化输出"""
logger.info("=== 对比示例 ===")
print("\n1. 原始完整输出:")
def original_callback(result: Dict[str, Any]):
print(f"完整结果: {result}")
print("\n2. 优化文本输出:")
def optimized_callback(text: str):
print(f"文本: {text}")
# 这里只是演示,实际使用时选择其中一种即可
text_callback = create_text_only_callback(optimized_callback)
service = create_asr_service(
app_key="your_app_key",
access_key="your_access_key"
)
try:
audio_path = "path/to/your/audio.wav"
# 使用优化回调
await service.recognize_file(
audio_path,
result_callback=text_callback
)
finally:
await service.close()
def run_example(example_name: str = "optimized"):
"""运行指定示例"""
examples = {
"optimized": example_optimized_streaming,
"simple": example_simple_text_extraction,
"manual": example_manual_text_extraction,
"websocket": example_streaming_with_websocket,
"comparison": example_comparison
}
if example_name not in examples:
print(f"可用示例: {list(examples.keys())}")
return
asyncio.run(examples[example_name]())
if __name__ == "__main__":
# 运行优化示例
run_example("optimized")
# 或者运行其他示例
# run_example("simple")
# run_example("manual")
# run_example("websocket")
# run_example("comparison")
\ No newline at end of file
... ...
# AIfeng/2025-07-11 13:36:00
"""
豆包语音识别WebSocket协议处理模块
实现二进制协议的编解码、消息类型定义和数据包处理
"""
import gzip
import json
from typing import Dict, Any, Tuple, Optional
# 协议版本和头部大小
PROTOCOL_VERSION = 0b0001
DEFAULT_HEADER_SIZE = 0b0001
# 消息类型定义
class MessageType:
FULL_CLIENT_REQUEST = 0b0001
AUDIO_ONLY_REQUEST = 0b0010
FULL_SERVER_RESPONSE = 0b1001
SERVER_ACK = 0b1011
SERVER_ERROR_RESPONSE = 0b1111
# 消息类型特定标志
class MessageFlags:
NO_SEQUENCE = 0b0000
POS_SEQUENCE = 0b0001
NEG_SEQUENCE = 0b0010
NEG_WITH_SEQUENCE = 0b0011
# 序列化方法
class SerializationMethod:
NO_SERIALIZATION = 0b0000
JSON = 0b0001
# 压缩方法
class CompressionType:
NO_COMPRESSION = 0b0000
GZIP = 0b0001
class DoubaoProtocol:
"""豆包ASR WebSocket协议处理器"""
@staticmethod
def generate_header(
message_type: int = MessageType.FULL_CLIENT_REQUEST,
message_type_specific_flags: int = MessageFlags.NO_SEQUENCE,
serial_method: int = SerializationMethod.JSON,
compression_type: int = CompressionType.GZIP,
reserved_data: int = 0x00
) -> bytearray:
"""
生成协议头部
Args:
message_type: 消息类型
message_type_specific_flags: 消息类型特定标志
serial_method: 序列化方法
compression_type: 压缩类型
reserved_data: 保留字段
Returns:
bytearray: 4字节协议头部
"""
header = bytearray()
header_size = 1
header.append((PROTOCOL_VERSION << 4) | header_size)
header.append((message_type << 4) | message_type_specific_flags)
header.append((serial_method << 4) | compression_type)
header.append(reserved_data)
return header
@staticmethod
def generate_sequence_payload(sequence: int) -> bytearray:
"""
生成序列号载荷
Args:
sequence: 序列号
Returns:
bytearray: 4字节序列号数据
"""
payload = bytearray()
payload.extend(sequence.to_bytes(4, 'big', signed=True))
return payload
@staticmethod
def parse_response(response_data: bytes) -> Dict[str, Any]:
"""
解析服务器响应数据
Args:
response_data: 服务器响应的二进制数据
Returns:
Dict: 解析后的响应数据
"""
if len(response_data) < 4:
raise ValueError("响应数据长度不足")
# 解析头部
protocol_version = response_data[0] >> 4
header_size = response_data[0] & 0x0f
message_type = response_data[1] >> 4
message_type_specific_flags = response_data[1] & 0x0f
serialization_method = response_data[2] >> 4
message_compression = response_data[2] & 0x0f
reserved = response_data[3]
# 解析扩展头部和载荷
header_extensions = response_data[4:header_size * 4]
payload = response_data[header_size * 4:]
result = {
'protocol_version': protocol_version,
'header_size': header_size,
'message_type': message_type,
'message_type_specific_flags': message_type_specific_flags,
'serialization_method': serialization_method,
'message_compression': message_compression,
'is_last_package': False,
'payload_msg': None,
'payload_size': 0
}
# 处理序列号
if message_type_specific_flags & 0x01:
if len(payload) >= 4:
seq = int.from_bytes(payload[:4], "big", signed=True)
result['payload_sequence'] = seq
payload = payload[4:]
# 检查是否为最后一包
if message_type_specific_flags & 0x02:
result['is_last_package'] = True
# 根据消息类型解析载荷
payload_msg = None
payload_size = 0
if message_type == MessageType.FULL_SERVER_RESPONSE:
if len(payload) >= 4:
payload_size = int.from_bytes(payload[:4], "big", signed=True)
payload_msg = payload[4:]
elif message_type == MessageType.SERVER_ACK:
if len(payload) >= 4:
seq = int.from_bytes(payload[:4], "big", signed=True)
result['seq'] = seq
if len(payload) >= 8:
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
payload_msg = payload[8:]
elif message_type == MessageType.SERVER_ERROR_RESPONSE:
if len(payload) >= 8:
code = int.from_bytes(payload[:4], "big", signed=False)
result['code'] = code
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
payload_msg = payload[8:]
# 解压缩和反序列化载荷
if payload_msg is not None:
if message_compression == CompressionType.GZIP:
try:
payload_msg = gzip.decompress(payload_msg)
except Exception as e:
result['decompress_error'] = str(e)
return result
if serialization_method == SerializationMethod.JSON:
try:
payload_msg = json.loads(payload_msg.decode('utf-8'))
except Exception as e:
result['json_parse_error'] = str(e)
return result
elif serialization_method != SerializationMethod.NO_SERIALIZATION:
payload_msg = payload_msg.decode('utf-8')
result['payload_msg'] = payload_msg
result['payload_size'] = payload_size
return result
@staticmethod
def build_full_request(
request_params: Dict[str, Any],
sequence: int = 1,
compression: bool = True
) -> bytearray:
"""
构建完整客户端请求
Args:
request_params: 请求参数字典
sequence: 序列号
compression: 是否启用压缩
Returns:
bytearray: 完整的请求数据包
"""
# 序列化请求参数
payload_bytes = json.dumps(request_params).encode('utf-8')
# 压缩载荷
compression_type = CompressionType.GZIP if compression else CompressionType.NO_COMPRESSION
if compression:
payload_bytes = gzip.compress(payload_bytes)
# 生成头部
header = DoubaoProtocol.generate_header(
message_type=MessageType.FULL_CLIENT_REQUEST,
message_type_specific_flags=MessageFlags.POS_SEQUENCE,
compression_type=compression_type
)
# 构建完整请求
request = bytearray(header)
request.extend(DoubaoProtocol.generate_sequence_payload(sequence))
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
return request
@staticmethod
def build_audio_request(
audio_data: bytes,
sequence: int,
is_last: bool = False,
compression: bool = True
) -> bytearray:
"""
构建音频数据请求
Args:
audio_data: 音频数据
sequence: 序列号
is_last: 是否为最后一包
compression: 是否启用压缩
Returns:
bytearray: 音频请求数据包
"""
# 压缩音频数据
compression_type = CompressionType.GZIP if compression else CompressionType.NO_COMPRESSION
payload_bytes = gzip.compress(audio_data) if compression else audio_data
# 确定消息标志
if is_last:
flags = MessageFlags.NEG_WITH_SEQUENCE
sequence = -abs(sequence)
else:
flags = MessageFlags.POS_SEQUENCE
# 生成头部
header = DoubaoProtocol.generate_header(
message_type=MessageType.AUDIO_ONLY_REQUEST,
message_type_specific_flags=flags,
compression_type=compression_type
)
# 构建音频请求
request = bytearray(header)
request.extend(DoubaoProtocol.generate_sequence_payload(sequence))
request.extend(len(payload_bytes).to_bytes(4, 'big'))
request.extend(payload_bytes)
return request
\ No newline at end of file
... ...
# AIfeng/2025-07-17 13:58:00
"""
豆包ASR识别结果处理器
专门处理豆包ASR流式识别结果,提取关键信息并优化日志输出
"""
import logging
from typing import Dict, Any, Optional, Callable
from dataclasses import dataclass
from datetime import datetime
@dataclass
class ASRResult:
"""ASR识别结果数据类"""
text: str
is_final: bool
confidence: float = 0.0
timestamp: datetime = None
sequence: int = 0
def __post_init__(self):
if self.timestamp is None:
self.timestamp = datetime.now()
class DoubaoResultProcessor:
"""豆包ASR结果处理器"""
def __init__(self,
text_only: bool = True,
log_level: str = 'INFO',
enable_streaming_log: bool = False):
"""
初始化结果处理器
Args:
text_only: 是否只输出文本内容
log_level: 日志级别
enable_streaming_log: 是否启用流式日志(会频繁输出中间结果)
"""
self.text_only = text_only
self.enable_streaming_log = enable_streaming_log
self.logger = self._setup_logger(log_level)
# 流式结果管理
self.current_text = ""
self.last_sequence = 0
self.result_count = 0
def _setup_logger(self, log_level: str) -> logging.Logger:
"""设置日志记录器"""
logger = logging.getLogger(f"DoubaoResultProcessor_{id(self)}")
logger.setLevel(getattr(logging, log_level.upper()))
if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def extract_text_from_result(self, result: Dict[str, Any]) -> Optional[ASRResult]:
"""
从豆包ASR完整结果中提取文本信息
Args:
result: 豆包ASR返回的完整结果字典
Returns:
ASRResult: 提取的结果对象,如果无有效文本则返回None
"""
try:
# 检查是否有payload_msg
payload_msg = result.get('payload_msg')
if not payload_msg:
return None
# 提取result字段
result_data = payload_msg.get('result')
if not result_data:
return None
# 提取文本
text = result_data.get('text', '').strip()
if not text:
return None
# 提取其他信息
is_final = result.get('is_last_package', False)
confidence = result_data.get('confidence', 0.0)
sequence = result.get('payload_sequence', 0)
return ASRResult(
text=text,
is_final=is_final,
confidence=confidence,
sequence=sequence
)
except Exception as e:
self.logger.error(f"提取ASR结果文本失败: {e}")
return None
def process_streaming_result(self, result: Dict[str, Any]) -> Optional[str]:
"""
处理流式识别结果
Args:
result: 豆包ASR返回的完整结果字典
Returns:
str: 当前识别文本,如果无变化则返回None
"""
asr_result = self.extract_text_from_result(result)
if not asr_result:
return None
self.result_count += 1
# 流式结果:后一次覆盖前一次
previous_text = self.current_text
self.current_text = asr_result.text
self.last_sequence = asr_result.sequence
# 根据配置决定是否记录日志
if asr_result.is_final:
self.logger.info(f"[最终结果] {asr_result.text}")
elif self.enable_streaming_log:
self.logger.debug(f"[流式更新 #{self.result_count}] {asr_result.text}")
# 返回文本(如果与上次不同)
return asr_result.text if asr_result.text != previous_text else None
def create_optimized_callback(self,
user_callback: Optional[Callable[[str], None]] = None) -> Callable[[Dict[str, Any]], None]:
"""
创建优化的回调函数
Args:
user_callback: 用户自定义回调函数,接收文本参数
Returns:
Callable: 优化后的回调函数
"""
def optimized_callback(result: Dict[str, Any]):
"""优化的回调函数,只处理文本内容"""
try:
# 处理流式结果
text = self.process_streaming_result(result)
# 如果有文本变化且用户提供了回调函数
if text and user_callback:
user_callback(text)
except Exception as e:
self.logger.error(f"处理ASR回调失败: {e}")
return optimized_callback
def get_current_result(self) -> Dict[str, Any]:
"""
获取当前识别状态
Returns:
Dict: 当前状态信息
"""
return {
'current_text': self.current_text,
'last_sequence': self.last_sequence,
'result_count': self.result_count,
'text_length': len(self.current_text)
}
def reset(self):
"""重置处理器状态"""
self.current_text = ""
self.last_sequence = 0
self.result_count = 0
self.logger.info("结果处理器已重置")
# 便捷函数
def create_text_only_callback(user_callback: Optional[Callable[[str], None]] = None,
enable_streaming_log: bool = False) -> Callable[[Dict[str, Any]], None]:
"""
创建只处理文本的回调函数
Args:
user_callback: 用户回调函数
enable_streaming_log: 是否启用流式日志
Returns:
Callable: 优化的回调函数
"""
processor = DoubaoResultProcessor(
text_only=True,
enable_streaming_log=enable_streaming_log
)
return processor.create_optimized_callback(user_callback)
def extract_text_only(result: Dict[str, Any]) -> Optional[str]:
"""
从豆包ASR结果中只提取文本
Args:
result: 豆包ASR完整结果
Returns:
str: 提取的文本,如果无文本则返回None
"""
try:
return result.get('payload_msg', {}).get('result', {}).get('text', '').strip() or None
except Exception:
return None
\ No newline at end of file
... ...
# AIfeng/2025-07-11 13:36:00
"""
豆包ASR服务工厂模块
提供简化的API接口和服务实例管理
"""
import asyncio
from pathlib import Path
from typing import Dict, Any, Optional, Callable, Union
from .config_manager import ConfigManager
from .asr_client import DoubaoASRClient
class DoubaoASRService:
"""豆包ASR服务工厂"""
_instances = {}
def __init__(self, config: Union[str, Dict[str, Any], ConfigManager]):
"""
初始化ASR服务
Args:
config: 配置文件路径、配置字典或配置管理器实例
"""
if isinstance(config, str):
self.config_manager = ConfigManager(config)
elif isinstance(config, dict):
self.config_manager = ConfigManager.from_dict(config)
elif isinstance(config, ConfigManager):
self.config_manager = config
else:
raise ValueError("配置参数类型错误")
self.client = DoubaoASRClient(self.config_manager.get_config())
async def recognize_file(
self,
audio_path: str,
streaming: bool = True,
result_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
**kwargs
) -> Dict[str, Any]:
"""
识别音频文件
Args:
audio_path: 音频文件路径
streaming: 是否使用流式识别
result_callback: 结果回调函数
**kwargs: 其他参数
Returns:
Dict: 识别结果
"""
return await self.client.recognize_file(
audio_path,
streaming=streaming,
result_callback=result_callback,
**kwargs
)
async def recognize_audio_data(
self,
audio_data: bytes,
streaming: bool = True,
result_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
**kwargs
) -> Dict[str, Any]:
"""
识别音频数据
Args:
audio_data: 音频数据
streaming: 是否使用流式识别
result_callback: 结果回调函数
**kwargs: 其他参数
Returns:
Dict: 识别结果
"""
return await self.client.recognize_audio_data(
audio_data,
streaming=streaming,
result_callback=result_callback,
**kwargs
)
def get_status(self) -> Dict[str, Any]:
"""
获取服务状态
Returns:
Dict: 服务状态
"""
return self.client.get_status()
async def close(self):
"""关闭服务"""
await self.client.close()
@classmethod
def create_service(
cls,
config: Union[str, Dict[str, Any], ConfigManager],
instance_name: str = 'default'
) -> 'DoubaoASRService':
"""
创建或获取服务实例
Args:
config: 配置
instance_name: 实例名称
Returns:
DoubaoASRService: 服务实例
"""
if instance_name not in cls._instances:
cls._instances[instance_name] = cls(config)
return cls._instances[instance_name]
@classmethod
def get_service(cls, instance_name: str = 'default') -> Optional['DoubaoASRService']:
"""
获取已创建的服务实例
Args:
instance_name: 实例名称
Returns:
DoubaoASRService: 服务实例或None
"""
return cls._instances.get(instance_name)
@classmethod
async def close_all_services(cls):
"""关闭所有服务实例"""
for service in cls._instances.values():
await service.close()
cls._instances.clear()
# 便捷函数
def create_asr_service(
config_path: Optional[str] = None,
app_key: Optional[str] = None,
access_key: Optional[str] = None,
**kwargs
) -> DoubaoASRService:
"""
创建ASR服务的便捷函数
Args:
config_path: 配置文件路径
app_key: 应用密钥
access_key: 访问密钥
**kwargs: 其他配置参数
Returns:
DoubaoASRService: ASR服务实例
"""
if config_path:
return DoubaoASRService(config_path)
# 从参数构建配置
config = {
'auth_config': {
'app_key': app_key or '',
'access_key': access_key or ''
}
}
# 添加其他配置参数
if kwargs:
if 'asr_config' not in config:
config['asr_config'] = {}
if 'audio_config' not in config:
config['audio_config'] = {}
if 'connection_config' not in config:
config['connection_config'] = {}
if 'logging_config' not in config:
config['logging_config'] = {}
# 映射常用参数
param_mapping = {
'streaming': ('asr_config', 'streaming_mode'),
'seg_duration': ('asr_config', 'seg_duration'),
'model_name': ('asr_config', 'model_name'),
'enable_punc': ('asr_config', 'enable_punc'),
'sample_rate': ('audio_config', 'default_rate'),
'audio_format': ('audio_config', 'default_format'),
'timeout': ('connection_config', 'timeout'),
'debug': ('logging_config', 'enable_debug')
}
for param, (section, key) in param_mapping.items():
if param in kwargs:
config[section][key] = kwargs[param]
return DoubaoASRService(config)
async def recognize_file(
audio_path: str,
config_path: Optional[str] = None,
app_key: Optional[str] = None,
access_key: Optional[str] = None,
streaming: bool = True,
result_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
**kwargs
) -> Dict[str, Any]:
"""
识别音频文件的便捷函数
Args:
audio_path: 音频文件路径
config_path: 配置文件路径
app_key: 应用密钥
access_key: 访问密钥
streaming: 是否使用流式识别
result_callback: 结果回调函数
**kwargs: 其他参数
Returns:
Dict: 识别结果
"""
service = create_asr_service(
config_path=config_path,
app_key=app_key,
access_key=access_key,
**kwargs
)
try:
return await service.recognize_file(
audio_path,
streaming=streaming,
result_callback=result_callback
)
finally:
await service.close()
async def recognize_audio_data(
audio_data: bytes,
config_path: Optional[str] = None,
app_key: Optional[str] = None,
access_key: Optional[str] = None,
streaming: bool = True,
result_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
**kwargs
) -> Dict[str, Any]:
"""
识别音频数据的便捷函数
Args:
audio_data: 音频数据
config_path: 配置文件路径
app_key: 应用密钥
access_key: 访问密钥
streaming: 是否使用流式识别
result_callback: 结果回调函数
**kwargs: 其他参数
Returns:
Dict: 识别结果
"""
service = create_asr_service(
config_path=config_path,
app_key=app_key,
access_key=access_key,
**kwargs
)
try:
return await service.recognize_audio_data(
audio_data,
streaming=streaming,
result_callback=result_callback
)
finally:
await service.close()
def run_recognition(
audio_path: str,
config_path: Optional[str] = None,
app_key: Optional[str] = None,
access_key: Optional[str] = None,
streaming: bool = True,
result_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
**kwargs
) -> Dict[str, Any]:
"""
同步方式识别音频文件
Args:
audio_path: 音频文件路径
config_path: 配置文件路径
app_key: 应用密钥
access_key: 访问密钥
streaming: 是否使用流式识别
result_callback: 结果回调函数
**kwargs: 其他参数
Returns:
Dict: 识别结果
"""
return asyncio.run(
recognize_file(
audio_path,
config_path=config_path,
app_key=app_key,
access_key=access_key,
streaming=streaming,
result_callback=result_callback,
**kwargs
)
)
\ No newline at end of file
... ...
This diff could not be displayed because it is too large.
# -*- coding: utf-8 -*-
"""
AIfeng/2025-01-02 10:27:06
FunASR语音识别模块 - 同步版本
基于eman-Fay-main-copy项目的同步实现模式
"""
from threading import Thread
import websocket
import json
import time
import ssl
import _thread as thread
import os
import asyncio
import threading
from core import get_web_instance, get_instance
from utils import config_util as cfg
from utils import util
class FunASRSync:
"""FunASR同步客户端 - 基于参考项目实现"""
def __init__(self, username):
self.__URL = "ws://{}:{}".format(cfg.local_asr_ip, cfg.local_asr_port)
self.__ws = None
self.__connected = False
self.__frames = []
self.__state = 0
self.__closing = False
self.__task_id = ''
self.done = False
self.finalResults = ""
self.__reconnect_delay = 1
self.__reconnecting = False
self.username = username
self.started = True
self.__result_callback = None # 添加结果回调
util.log(1, f"FunASR同步客户端初始化完成,用户: {username}")
def on_message(self, ws, message):
"""收到websocket消息的处理"""
try:
util.log(1, f"收到FunASR消息: {message}")
# 尝试解析JSON消息以区分状态消息和识别结果
try:
import json
parsed_message = json.loads(message)
# 检查是否为状态消息(如分块准备消息)
if isinstance(parsed_message, dict) and 'status' in parsed_message:
status = parsed_message.get('status')
if status == 'ready':
util.log(1, f"收到分块准备状态: {parsed_message.get('message', '')}")
return # 状态消息不触发回调
elif status in ['processing', 'chunk_received']:
util.log(1, f"收到处理状态: {status}")
return # 处理状态消息不触发回调
elif status == 'error':
util.log(3, f"收到错误状态: {parsed_message.get('message', '')}")
return
# 如果是字典但不是状态消息,可能是结构化的识别结果
if isinstance(parsed_message, dict) and 'text' in parsed_message:
# 结构化识别结果
recognition_text = parsed_message.get('text', '')
if recognition_text.strip(): # 只有非空结果才处理
self.done = True
self.finalResults = recognition_text
util.log(1, f"收到结构化识别结果: {recognition_text}")
self._trigger_result_callback()
return
except json.JSONDecodeError:
# 不是JSON格式,可能是纯文本识别结果
pass
# 处理纯文本识别结果
if isinstance(message, str) and message.strip():
# 过滤掉明显的状态消息
if any(keyword in message.lower() for keyword in ['status', 'ready', '准备接收', 'processing', 'chunk']):
util.log(1, f"跳过状态消息: {message}")
return
# 这是真正的识别结果
self.done = True
self.finalResults = message
util.log(1, f"收到文本识别结果: {message}")
self._trigger_result_callback()
except Exception as e:
util.log(3, f"处理识别结果时出错: {e}")
if self.__closing:
try:
self.__ws.close()
except Exception as e:
util.log(2, f"关闭WebSocket时出错: {e}")
def _trigger_result_callback(self):
"""触发结果回调函数"""
if self.__result_callback:
try:
# 创建chat_message直接推送
chat_message = {
"type":"chat_message",
"sender":"回音",
"text": self.finalResults,
"Username": self.username,
"model_info":"Funasr"
}
self.__result_callback(chat_message)
util.log(1, f"已触发结果回调: {self.finalResults}")
except Exception as e:
util.log(3, f"调用结果回调时出错: {e}")
# 发送到Web客户端(改进的异步调用方式)
# try:
# # 先检查WSA服务是否已初始化
# web_instance = get_web_instance()
# if web_instance and web_instance.is_connected(self.username):
# # 创建chat_message直接推送
# chat_message = {
# "type":"chat_message",
# "sender":"回音",
# "content": self.finalResults,
# "Username": self.username,
# "model_info":"Funasr"
# }
# # 方案1: 使用add_cmd推送wsa_command类型数据
# # web_instance.add_cmd(chat_message)
# util.log(1, f"FunASR识别结果已推送到Web客户端[{self.username}]: {self.finalResults}")
# else:
# util.log(2, f"用户{self.username}未连接到Web客户端,跳过推送")
# except RuntimeError as e:
# # WSA服务未初始化,这是正常情况(服务启动顺序问题)
# util.log(2, f"WSA服务未初始化,跳过Web客户端通知: {e}")
# except Exception as e:
# util.log(3, f"发送到Web客户端时出错: {e}")
# Human客户端通知改为日志记录(避免重复通知当前服务)
# util.log(1, f"FunASR识别结果[{self.username}]: {self.finalResults}")
if self.__closing:
try:
self.__ws.close()
except Exception as e:
util.log(2, f"关闭WebSocket时出错: {e}")
def on_close(self, ws, code, msg):
"""收到websocket关闭的处理"""
self.__connected = False
util.log(2, f"FunASR连接关闭: {msg}")
self.__ws = None
def on_error(self, ws, error):
"""收到websocket错误的处理"""
self.__connected = False
util.log(3, f"FunASR连接错误: {error}")
self.__ws = None
def __attempt_reconnect(self):
"""重连机制"""
if not self.__reconnecting:
self.__reconnecting = True
util.log(1, "尝试重连FunASR...")
while not self.__connected:
time.sleep(self.__reconnect_delay)
self.start()
self.__reconnect_delay *= 2
self.__reconnect_delay = 1
self.__reconnecting = False
def on_open(self, ws):
"""收到websocket连接建立的处理"""
self.__connected = True
util.log(1, "FunASR WebSocket连接建立")
def run(*args):
while self.__connected:
try:
if len(self.__frames) > 0:
frame = self.__frames[0]
self.__frames.pop(0)
if type(frame) == dict:
ws.send(json.dumps(frame))
elif type(frame) == bytes:
ws.send(frame, websocket.ABNF.OPCODE_BINARY)
except Exception as e:
util.log(3, f"发送帧数据时出错: {e}")
# 优化发送间隔,从0.04秒减少到0.02秒提高效率
time.sleep(0.02)
thread.start_new_thread(run, ())
def __connect(self):
"""建立WebSocket连接"""
self.finalResults = ""
self.done = False
self.__frames.clear()
websocket.enableTrace(False)
self.__ws = websocket.WebSocketApp(
self.__URL,
on_message=self.on_message,
on_close=self.on_close,
on_error=self.on_error
)
self.__ws.on_open = self.on_open
self.__ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def add_frame(self, frame):
"""添加帧到发送队列"""
self.__frames.append(frame)
def send(self, buf):
"""发送音频数据"""
self.__frames.append(buf)
def send_url(self, url):
"""发送音频文件URL"""
# 确保使用绝对路径,相对路径对funasr服务无效
absolute_url = os.path.abspath(url)
frame = {'url': absolute_url}
if self.__ws and self.__connected:
util.log(1, f"发送音频文件URL到FunASR: {absolute_url}")
self.__ws.send(json.dumps(frame))
util.log(1, f"音频文件URL已发送: {frame}")
else:
util.log(2, f"WebSocket未连接,无法发送URL: {absolute_url}")
def send_audio_data(self, audio_bytes, filename="audio.wav"):
"""发送音频数据(支持大文件分块)"""
import base64
import math
try:
# 确保audio_bytes是bytes类型,避免memoryview缓冲区问题
if hasattr(audio_bytes, 'tobytes'):
audio_bytes = bytes(audio_bytes.tobytes()) # Fix BufferError: memoryview has 1 exported buffer
elif isinstance(audio_bytes, memoryview):
audio_bytes = bytes(audio_bytes)
total_size = len(audio_bytes)
# 大文件阈值:1MB,超过则使用分块发送
large_file_threshold = 512 * 1024 # aiohttp限制默认1M,但再处理base64,会增加33%
if total_size > large_file_threshold:
util.log(1, f"检测到大文件({total_size} bytes),使用分块发送模式")
return self._send_audio_data_chunked(audio_bytes, filename)
else:
# 小文件使用原有方式
return self._send_audio_data_simple(audio_bytes, filename)
except Exception as e:
util.log(3, f"发送音频数据时出错: {e}")
return False
def _send_audio_data_simple(self, audio_bytes, filename):
"""简单发送模式(小文件)"""
import base64
try:
# 将音频字节数据编码为Base64
audio_data_b64 = base64.b64encode(audio_bytes).decode('utf-8')
# 构造发送格式,与funasr服务的process_audio_data函数兼容
frame = {
'audio_data': audio_data_b64,
'filename': filename
}
if self.__ws and self.__connected:
util.log(1, f"发送音频数据到FunASR: {filename}, 大小: {len(audio_bytes)} bytes")
success = self._send_frame_with_retry(frame)
if success:
util.log(1, f"音频数据已发送: {filename}")
return True
else:
util.log(3, f"音频数据发送失败: {filename}")
return False
else:
util.log(2, f"WebSocket未连接,无法发送音频数据: {filename}")
return False
except Exception as e:
util.log(3, f"简单发送音频数据时出错: {e}")
return False
def _send_audio_data_chunked(self, audio_bytes, filename, chunk_size=512*1024):
"""分块发送音频数据(大文件)"""
import base64
import math
try:
total_size = len(audio_bytes)
total_chunks = math.ceil(total_size / chunk_size)
util.log(1, f"开始分块发送: {filename}, 总大小: {total_size} bytes, 分块数: {total_chunks}")
# 发送开始信号
start_frame = {
'type': 'audio_start',
'filename': filename,
'total_size': total_size,
'total_chunks': total_chunks,
'chunk_size': chunk_size
}
if not self._send_frame_with_retry(start_frame):
util.log(3, f"发送开始信号失败: {filename}")
return False
# 分块发送
for i in range(total_chunks):
start_pos = i * chunk_size
end_pos = min(start_pos + chunk_size, total_size)
chunk_data = audio_bytes[start_pos:end_pos]
# Base64编码分块
chunk_b64 = base64.b64encode(chunk_data).decode('utf-8')
chunk_frame = {
'type': 'audio_chunk',
'filename': filename,
'chunk_index': i,
'chunk_data': chunk_b64,
'is_last': (i == total_chunks - 1)
}
# 发送分块并检查结果
success = self._send_frame_with_retry(chunk_frame)
if not success:
util.log(3, f"分块 {i+1}/{total_chunks} 发送失败")
return False
# 进度日志
if (i + 1) % 10 == 0 or i == total_chunks - 1:
progress = ((i + 1) / total_chunks) * 100
util.log(1, f"发送进度: {progress:.1f}% ({i+1}/{total_chunks})")
# 流控延迟
time.sleep(0.01)
# 发送结束信号
end_frame = {
'type': 'audio_end',
'filename': filename
}
if self._send_frame_with_retry(end_frame):
util.log(1, f"音频数据分块发送完成: {filename}")
return True
else:
util.log(3, f"发送结束信号失败: {filename}")
return False
except Exception as e:
util.log(3, f"分块发送音频数据时出错: {e}")
return False
def _send_frame_with_retry(self, frame, max_retries=3, timeout=10):
"""带重试的帧发送"""
for attempt in range(max_retries):
try:
if self.__ws and self.__connected:
# 设置发送超时
start_time = time.time()
self.__ws.send(json.dumps(frame))
# 简单的发送确认检查
time.sleep(0.05) # 等待发送完成
if time.time() - start_time < timeout:
return True
else:
util.log(2, f"发送超时,尝试 {attempt + 1}/{max_retries}")
else:
util.log(2, f"连接不可用,尝试 {attempt + 1}/{max_retries}")
except Exception as e:
util.log(2, f"发送失败,尝试 {attempt + 1}/{max_retries}: {e}")
if attempt < max_retries - 1:
time.sleep(0.5 * (attempt + 1)) # 指数退避
return False
def set_result_callback(self, callback):
"""设置结果回调函数"""
self.__result_callback = callback
util.log(1, f"已设置结果回调函数")
def connect(self):
"""连接到FunASR服务(同步版本)"""
try:
if not self.__connected:
self.start() # 调用现有的start方法
# 等待连接建立,最多等待30秒(针对大文件处理优化)
max_wait_time = 30.0
wait_interval = 0.1
waited_time = 0.0
while not self.__connected and waited_time < max_wait_time:
time.sleep(wait_interval)
waited_time += wait_interval
# 每5秒输出一次等待日志
if waited_time % 5.0 < wait_interval:
util.log(1, f"等待FunASR连接中... {waited_time:.1f}s/{max_wait_time}s")
if self.__connected:
util.log(1, f"FunASR连接成功,耗时: {waited_time:.2f}秒")
else:
util.log(3, f"FunASR连接超时,等待了{waited_time:.2f}秒")
return self.__connected
return True
except Exception as e:
util.log(3, f"连接FunASR服务时出错: {e}")
return False
def start(self):
"""启动FunASR客户端"""
Thread(target=self.__connect, args=[]).start()
data = {
'vad_need': False,
'state': 'StartTranscription'
}
self.add_frame(data)
util.log(1, "FunASR客户端启动")
def is_connected(self):
"""检查连接状态"""
return self.__connected
def end(self):
"""结束FunASR客户端"""
if self.__connected:
try:
# 发送剩余帧
for frame in self.__frames:
self.__frames.pop(0)
if type(frame) == dict:
self.__ws.send(json.dumps(frame))
elif type(frame) == bytes:
self.__ws.send(frame, websocket.ABNF.OPCODE_BINARY)
self.__frames.clear()
# 发送停止信号
frame = {'vad_need': False, 'state': 'StopTranscription'}
self.__ws.send(json.dumps(frame))
except Exception as e:
util.log(3, f"结束FunASR时出错: {e}")
self.__closing = True
self.__connected = False
util.log(1, "FunASR客户端结束")
\ No newline at end of file
... ...
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AIfeng/2025-07-02 11:24:08
Scheduler模块初始化文件
"""
from .thread_manager import MyThread
__all__ = ['MyThread']
\ No newline at end of file
... ...
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AIfeng/2025-07-02 11:24:08
线程管理器 - 提供增强的线程功能
"""
import threading
import time
import traceback
from typing import Callable, Any, Optional
from utils.util import log
class MyThread(threading.Thread):
"""增强的线程类,提供更好的错误处理和监控功能"""
def __init__(self, target: Optional[Callable] = None, name: Optional[str] = None,
args: tuple = (), kwargs: dict = None, daemon: bool = True):
"""
初始化线程
Args:
target: 目标函数
name: 线程名称
args: 位置参数
kwargs: 关键字参数
daemon: 是否为守护线程
"""
super().__init__(target=target, name=name, args=args, kwargs=kwargs or {}, daemon=daemon)
self._target = target
self._args = args
self._kwargs = kwargs or {}
self._result = None
self._exception = None
self._start_time = None
self._end_time = None
self._running = False
def run(self):
"""重写run方法,添加错误处理和监控"""
self._start_time = time.time()
self._running = True
try:
if self._target:
log(1, f"线程 {self.name} 开始执行")
self._result = self._target(*self._args, **self._kwargs)
log(1, f"线程 {self.name} 执行完成")
except Exception as e:
self._exception = e
log(3, f"线程 {self.name} 执行出错: {e}")
log(3, f"错误详情: {traceback.format_exc()}")
finally:
self._end_time = time.time()
self._running = False
duration = self._end_time - self._start_time
log(1, f"线程 {self.name} 运行时长: {duration:.2f}秒")
def get_result(self) -> Any:
"""获取线程执行结果"""
if self.is_alive():
raise RuntimeError("线程仍在运行中")
if self._exception:
raise self._exception
return self._result
def get_exception(self) -> Optional[Exception]:
"""获取线程执行过程中的异常"""
return self._exception
def get_duration(self) -> Optional[float]:
"""获取线程运行时长(秒)"""
if self._start_time is None:
return None
end_time = self._end_time or time.time()
return end_time - self._start_time
def is_running(self) -> bool:
"""检查线程是否正在运行"""
return self._running and self.is_alive()
def stop_gracefully(self, timeout: float = 5.0) -> bool:
"""优雅地停止线程
Args:
timeout: 等待超时时间(秒)
Returns:
bool: 是否成功停止
"""
if not self.is_alive():
return True
log(1, f"正在停止线程 {self.name}")
# 等待线程自然结束
self.join(timeout=timeout)
if self.is_alive():
log(2, f"线程 {self.name} 在 {timeout} 秒内未能自然结束")
return False
else:
log(1, f"线程 {self.name} 已成功停止")
return True
def __str__(self) -> str:
"""字符串表示"""
status = "运行中" if self.is_running() else "已停止"
duration = self.get_duration()
duration_str = f", 运行时长: {duration:.2f}秒" if duration else ""
return f"MyThread(name={self.name}, status={status}{duration_str})"
def __repr__(self) -> str:
"""详细字符串表示"""
return self.__str__()
class ThreadManager:
"""线程管理器,用于管理多个线程"""
def __init__(self):
self._threads = {}
self._lock = threading.Lock()
def create_thread(self, name: str, target: Callable, args: tuple = (),
kwargs: dict = None, daemon: bool = True) -> MyThread:
"""创建新线程
Args:
name: 线程名称
target: 目标函数
args: 位置参数
kwargs: 关键字参数
daemon: 是否为守护线程
Returns:
MyThread: 创建的线程对象
"""
with self._lock:
if name in self._threads:
raise ValueError(f"线程名称 '{name}' 已存在")
thread = MyThread(target=target, name=name, args=args,
kwargs=kwargs or {}, daemon=daemon)
self._threads[name] = thread
return thread
def start_thread(self, name: str) -> bool:
"""启动指定线程
Args:
name: 线程名称
Returns:
bool: 是否成功启动
"""
with self._lock:
if name not in self._threads:
log(3, f"线程 '{name}' 不存在")
return False
thread = self._threads[name]
if thread.is_alive():
log(2, f"线程 '{name}' 已在运行中")
return False
try:
thread.start()
log(1, f"线程 '{name}' 启动成功")
return True
except Exception as e:
log(3, f"启动线程 '{name}' 失败: {e}")
return False
def stop_thread(self, name: str, timeout: float = 5.0) -> bool:
"""停止指定线程
Args:
name: 线程名称
timeout: 等待超时时间
Returns:
bool: 是否成功停止
"""
with self._lock:
if name not in self._threads:
log(3, f"线程 '{name}' 不存在")
return False
thread = self._threads[name]
return thread.stop_gracefully(timeout)
def stop_all_threads(self, timeout: float = 5.0) -> bool:
"""停止所有线程
Args:
timeout: 每个线程的等待超时时间
Returns:
bool: 是否所有线程都成功停止
"""
log(1, "正在停止所有线程...")
success = True
with self._lock:
for name, thread in self._threads.items():
if thread.is_alive():
if not thread.stop_gracefully(timeout):
success = False
if success:
log(1, "所有线程已成功停止")
else:
log(2, "部分线程未能在指定时间内停止")
return success
def get_thread_status(self) -> dict:
"""获取所有线程状态
Returns:
dict: 线程状态信息
"""
status = {}
with self._lock:
for name, thread in self._threads.items():
status[name] = {
'alive': thread.is_alive(),
'running': thread.is_running(),
'duration': thread.get_duration(),
'exception': str(thread.get_exception()) if thread.get_exception() else None
}
return status
def cleanup_finished_threads(self):
"""清理已完成的线程"""
with self._lock:
finished_threads = [name for name, thread in self._threads.items()
if not thread.is_alive()]
for name in finished_threads:
del self._threads[name]
log(1, f"已清理完成的线程: {name}")
def __len__(self) -> int:
"""返回线程数量"""
with self._lock:
return len(self._threads)
def __contains__(self, name: str) -> bool:
"""检查是否包含指定名称的线程"""
with self._lock:
return name in self._threads
# 全局线程管理器实例
thread_manager = ThreadManager()
\ No newline at end of file
... ...