config_util.py 8.1 KB
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AIfeng/2025-01-02 10:27:06
配置管理模块
提供系统配置的读取和管理功能
"""

import os
import json
from typing import Dict, Any, Optional

# 默认配置
DEFAULT_CONFIG = {
    'source': {
        'wake_word_enabled': False,
        'wake_word': '小助手,你好',
        'wake_word_type': 'common'  # common 或 front
    },
    'audio': {
        'sample_rate': 16000,
        'channels': 1,
        'chunk_size': 1024,
        'device_index': None
    },
    'asr': {
        'mode': 'funasr',
        'timeout': 30,
        'reconnect_delay': 1,
        'max_reconnect_attempts': 5
    },
    'server': {
        'host': '0.0.0.0',
        'port': 5050,
        'debug': False
    },
    'logging': {
        'level': 'INFO',
        'file': 'logs/system.log',
        'max_size': 10485760,  # 10MB
        'backup_count': 5
    }
}

# 配置文件路径
CONFIG_FILE = 'config.json'

# 全局配置对象
config = {}

# ASR相关配置
ASR_mode = "funasr"  # 默认使用FunASR
local_asr_ip = "127.0.0.1"
local_asr_port = 10197
asr_timeout = 30
asr_reconnect_delay = 1
asr_max_reconnect_attempts = 5

# 服务器配置
fay_url = "http://localhost:5050"

def load_config(config_file: str = CONFIG_FILE) -> Dict[str, Any]:
    """
    加载配置文件
    
    Args:
        config_file: 配置文件路径
        
    Returns:
        配置字典
    """
    global config
    
    try:
        if os.path.exists(config_file):
            with open(config_file, 'r', encoding='utf-8') as f:
                loaded_config = json.load(f)
                
            # 合并默认配置和加载的配置
            config = merge_config(DEFAULT_CONFIG, loaded_config)
            print(f"配置文件已加载: {config_file}")
        else:
            # 使用默认配置
            config = DEFAULT_CONFIG.copy()
            print(f"配置文件不存在,使用默认配置: {config_file}")
            
            # 保存默认配置到文件
            save_config(config_file)
            
    except Exception as e:
        print(f"加载配置文件时出错: {e}")
        config = DEFAULT_CONFIG.copy()
        
    # 更新全局变量
    update_global_vars()
    
    return config

def save_config(config_file: str = CONFIG_FILE, config_data: Optional[Dict[str, Any]] = None) -> bool:
    """
    保存配置到文件
    
    Args:
        config_file: 配置文件路径
        config_data: 要保存的配置数据,如果为None则使用全局config
        
    Returns:
        是否保存成功
    """
    try:
        data_to_save = config_data if config_data is not None else config
        
        with open(config_file, 'w', encoding='utf-8') as f:
            json.dump(data_to_save, f, ensure_ascii=False, indent=2)
            
        print(f"配置已保存到: {config_file}")
        return True
        
    except Exception as e:
        print(f"保存配置文件时出错: {e}")
        return False

def merge_config(default: Dict[str, Any], loaded: Dict[str, Any]) -> Dict[str, Any]:
    """
    合并配置字典
    
    Args:
        default: 默认配置
        loaded: 加载的配置
        
    Returns:
        合并后的配置
    """
    result = default.copy()
    
    for key, value in loaded.items():
        if key in result and isinstance(result[key], dict) and isinstance(value, dict):
            result[key] = merge_config(result[key], value)
        else:
            result[key] = value
            
    return result

def update_global_vars():
    """
    更新全局变量
    """
    global ASR_mode, local_asr_ip, local_asr_port
    global asr_timeout, asr_reconnect_delay, asr_max_reconnect_attempts
    global fay_url
    
    # ASR配置
    asr_config = config.get('asr', {})
    ASR_mode = asr_config.get('mode', 'funasr')
    asr_timeout = asr_config.get('timeout', 30)
    asr_reconnect_delay = asr_config.get('reconnect_delay', 1)
    asr_max_reconnect_attempts = asr_config.get('max_reconnect_attempts', 5)
    
    # 服务器配置
    server_config = config.get('server', {})
    server_port = server_config.get('port', 5050)
    fay_url = f"http://localhost:{server_port}"

def get_config(key_path: str, default_value: Any = None) -> Any:
    """
    获取配置值
    
    Args:
        key_path: 配置键路径,使用点号分隔,如 'audio.sample_rate'
        default_value: 默认值
        
    Returns:
        配置值
    """
    keys = key_path.split('.')
    value = config
    
    try:
        for key in keys:
            value = value[key]
        return value
    except (KeyError, TypeError):
        return default_value

def set_config(key_path: str, value: Any) -> bool:
    """
    设置配置值
    
    Args:
        key_path: 配置键路径,使用点号分隔
        value: 要设置的值
        
    Returns:
        是否设置成功
    """
    keys = key_path.split('.')
    
    try:
        current = config
        for key in keys[:-1]:
            if key not in current:
                current[key] = {}
            current = current[key]
            
        current[keys[-1]] = value
        
        # 更新全局变量
        update_global_vars()
        
        return True
        
    except Exception as e:
        print(f"设置配置值时出错: {e}")
        return False

def get_audio_config() -> Dict[str, Any]:
    """
    获取音频配置
    
    Returns:
        音频配置字典
    """
    return config.get('audio', DEFAULT_CONFIG['audio'])

def get_asr_config() -> Dict[str, Any]:
    """
    获取ASR配置
    
    Returns:
        ASR配置字典
    """
    return config.get('asr', DEFAULT_CONFIG['asr'])

def get_server_config() -> Dict[str, Any]:
    """
    获取服务器配置
    
    Returns:
        服务器配置字典
    """
    return config.get('server', DEFAULT_CONFIG['server'])

def is_wake_word_enabled() -> bool:
    """
    检查是否启用唤醒词
    
    Returns:
        是否启用唤醒词
    """
    return config.get('source', {}).get('wake_word_enabled', False)

def get_wake_words() -> list:
    """
    获取唤醒词列表
    
    Returns:
        唤醒词列表
    """
    wake_word = config.get('source', {}).get('wake_word', '小助手,你好')
    return [word.strip() for word in wake_word.split(',') if word.strip()]

def get_wake_word_type() -> str:
    """
    获取唤醒词类型
    
    Returns:
        唤醒词类型 ('common' 或 'front')
    """
    return config.get('source', {}).get('wake_word_type', 'common')

def validate_config() -> list:
    """
    验证配置的有效性
    
    Returns:
        错误信息列表
    """
    errors = []
    
    # 验证音频配置
    audio_config = get_audio_config()
    if audio_config.get('sample_rate', 0) <= 0:
        errors.append("音频采样率必须大于0")
        
    if audio_config.get('channels', 0) <= 0:
        errors.append("音频声道数必须大于0")
        
    if audio_config.get('chunk_size', 0) <= 0:
        errors.append("音频块大小必须大于0")
    
    # 验证ASR配置
    asr_config = get_asr_config()
    if asr_config.get('timeout', 0) <= 0:
        errors.append("ASR超时时间必须大于0")
        
    # 验证服务器配置
    server_config = get_server_config()
    port = server_config.get('port', 0)
    if not (1 <= port <= 65535):
        errors.append("服务器端口必须在1-65535范围内")
    
    return errors

def print_config():
    """
    打印当前配置
    """
    print("当前配置:")
    print(json.dumps(config, ensure_ascii=False, indent=2))
    
    print("\n全局变量:")
    print(f"ASR_mode: {ASR_mode}")
    print(f"local_asr_ip: {local_asr_ip}")
    print(f"local_asr_port: {local_asr_port}")
    print(f"asr_timeout: {asr_timeout}")
    print(f"fay_url: {fay_url}")

# 初始化配置
load_config()

if __name__ == "__main__":
    # 测试配置功能
    print("配置模块测试")
    print_config()
    
    # 验证配置
    errors = validate_config()
    if errors:
        print("\n配置验证错误:")
        for error in errors:
            print(f"  - {error}")
    else:
        print("\n配置验证通过")