service_factory.py 8.55 KB
# AIfeng/2025-07-11 13:36:00
"""
豆包ASR服务工厂模块
提供简化的API接口和服务实例管理
"""

import asyncio
from pathlib import Path
from typing import Dict, Any, Optional, Callable, Union

from .config_manager import ConfigManager
from .asr_client import DoubaoASRClient


class DoubaoASRService:
    """豆包ASR服务工厂"""
    
    _instances = {}
    
    def __init__(self, config: Union[str, Dict[str, Any], ConfigManager]):
        """
        初始化ASR服务
        
        Args:
            config: 配置文件路径、配置字典或配置管理器实例
        """
        if isinstance(config, str):
            self.config_manager = ConfigManager(config)
        elif isinstance(config, dict):
            self.config_manager = ConfigManager.from_dict(config)
        elif isinstance(config, ConfigManager):
            self.config_manager = config
        else:
            raise ValueError("配置参数类型错误")
        
        self.client = DoubaoASRClient(self.config_manager.get_config())
    
    async def recognize_file(
        self,
        audio_path: str,
        streaming: bool = True,
        result_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
        **kwargs
    ) -> Dict[str, Any]:
        """
        识别音频文件
        
        Args:
            audio_path: 音频文件路径
            streaming: 是否使用流式识别
            result_callback: 结果回调函数
            **kwargs: 其他参数
            
        Returns:
            Dict: 识别结果
        """
        return await self.client.recognize_file(
            audio_path,
            streaming=streaming,
            result_callback=result_callback,
            **kwargs
        )
    
    async def recognize_audio_data(
        self,
        audio_data: bytes,
        streaming: bool = True,
        result_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
        **kwargs
    ) -> Dict[str, Any]:
        """
        识别音频数据
        
        Args:
            audio_data: 音频数据
            streaming: 是否使用流式识别
            result_callback: 结果回调函数
            **kwargs: 其他参数
            
        Returns:
            Dict: 识别结果
        """
        return await self.client.recognize_audio_data(
            audio_data,
            streaming=streaming,
            result_callback=result_callback,
            **kwargs
        )
    
    def get_status(self) -> Dict[str, Any]:
        """
        获取服务状态
        
        Returns:
            Dict: 服务状态
        """
        return self.client.get_status()
    
    async def close(self):
        """关闭服务"""
        await self.client.close()
    
    @classmethod
    def create_service(
        cls,
        config: Union[str, Dict[str, Any], ConfigManager],
        instance_name: str = 'default'
    ) -> 'DoubaoASRService':
        """
        创建或获取服务实例
        
        Args:
            config: 配置
            instance_name: 实例名称
            
        Returns:
            DoubaoASRService: 服务实例
        """
        if instance_name not in cls._instances:
            cls._instances[instance_name] = cls(config)
        return cls._instances[instance_name]
    
    @classmethod
    def get_service(cls, instance_name: str = 'default') -> Optional['DoubaoASRService']:
        """
        获取已创建的服务实例
        
        Args:
            instance_name: 实例名称
            
        Returns:
            DoubaoASRService: 服务实例或None
        """
        return cls._instances.get(instance_name)
    
    @classmethod
    async def close_all_services(cls):
        """关闭所有服务实例"""
        for service in cls._instances.values():
            await service.close()
        cls._instances.clear()


# 便捷函数
def create_asr_service(
    config_path: Optional[str] = None,
    app_key: Optional[str] = None,
    access_key: Optional[str] = None,
    **kwargs
) -> DoubaoASRService:
    """
    创建ASR服务的便捷函数
    
    Args:
        config_path: 配置文件路径
        app_key: 应用密钥
        access_key: 访问密钥
        **kwargs: 其他配置参数
        
    Returns:
        DoubaoASRService: ASR服务实例
    """
    if config_path:
        return DoubaoASRService(config_path)
    
    # 从参数构建配置
    config = {
        'auth_config': {
            'app_key': app_key or '',
            'access_key': access_key or ''
        }
    }
    
    # 添加其他配置参数
    if kwargs:
        if 'asr_config' not in config:
            config['asr_config'] = {}
        if 'audio_config' not in config:
            config['audio_config'] = {}
        if 'connection_config' not in config:
            config['connection_config'] = {}
        if 'logging_config' not in config:
            config['logging_config'] = {}
        
        # 映射常用参数
        param_mapping = {
            'streaming': ('asr_config', 'streaming_mode'),
            'seg_duration': ('asr_config', 'seg_duration'),
            'model_name': ('asr_config', 'model_name'),
            'enable_punc': ('asr_config', 'enable_punc'),
            'sample_rate': ('audio_config', 'default_rate'),
            'audio_format': ('audio_config', 'default_format'),
            'timeout': ('connection_config', 'timeout'),
            'debug': ('logging_config', 'enable_debug')
        }
        
        for param, (section, key) in param_mapping.items():
            if param in kwargs:
                config[section][key] = kwargs[param]
    
    return DoubaoASRService(config)


async def recognize_file(
    audio_path: str,
    config_path: Optional[str] = None,
    app_key: Optional[str] = None,
    access_key: Optional[str] = None,
    streaming: bool = True,
    result_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
    **kwargs
) -> Dict[str, Any]:
    """
    识别音频文件的便捷函数
    
    Args:
        audio_path: 音频文件路径
        config_path: 配置文件路径
        app_key: 应用密钥
        access_key: 访问密钥
        streaming: 是否使用流式识别
        result_callback: 结果回调函数
        **kwargs: 其他参数
        
    Returns:
        Dict: 识别结果
    """
    service = create_asr_service(
        config_path=config_path,
        app_key=app_key,
        access_key=access_key,
        **kwargs
    )
    
    try:
        return await service.recognize_file(
            audio_path,
            streaming=streaming,
            result_callback=result_callback
        )
    finally:
        await service.close()


async def recognize_audio_data(
    audio_data: bytes,
    config_path: Optional[str] = None,
    app_key: Optional[str] = None,
    access_key: Optional[str] = None,
    streaming: bool = True,
    result_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
    **kwargs
) -> Dict[str, Any]:
    """
    识别音频数据的便捷函数
    
    Args:
        audio_data: 音频数据
        config_path: 配置文件路径
        app_key: 应用密钥
        access_key: 访问密钥
        streaming: 是否使用流式识别
        result_callback: 结果回调函数
        **kwargs: 其他参数
        
    Returns:
        Dict: 识别结果
    """
    service = create_asr_service(
        config_path=config_path,
        app_key=app_key,
        access_key=access_key,
        **kwargs
    )
    
    try:
        return await service.recognize_audio_data(
            audio_data,
            streaming=streaming,
            result_callback=result_callback
        )
    finally:
        await service.close()


def run_recognition(
    audio_path: str,
    config_path: Optional[str] = None,
    app_key: Optional[str] = None,
    access_key: Optional[str] = None,
    streaming: bool = True,
    result_callback: Optional[Callable[[Dict[str, Any]], None]] = None,
    **kwargs
) -> Dict[str, Any]:
    """
    同步方式识别音频文件
    
    Args:
        audio_path: 音频文件路径
        config_path: 配置文件路径
        app_key: 应用密钥
        access_key: 访问密钥
        streaming: 是否使用流式识别
        result_callback: 结果回调函数
        **kwargs: 其他参数
        
    Returns:
        Dict: 识别结果
    """
    return asyncio.run(
        recognize_file(
            audio_path,
            config_path=config_path,
            app_key=app_key,
            access_key=access_key,
            streaming=streaming,
            result_callback=result_callback,
            **kwargs
        )
    )