You need to sign in or sign up before continuing.
config.py 9.6 KB
"""
Configuration management module for the Insight Engine.
Handles environment variables and config file parameters.
"""

import os
from dataclasses import dataclass
from typing import Optional


def _get_value(source, key: str, default=None):
    """
    Helper to fetch a configuration value with environment fallback.
    """
    value = None
    if isinstance(source, dict):
        value = source.get(key)
    else:
        value = getattr(source, key, None)

    if value is None:
        value = os.getenv(key, default)
    return value if value not in ("", None) else default


@dataclass
class Config:
    """Insight Engine configuration."""

    # LLM configuration
    llm_api_key: Optional[str] = None
    llm_base_url: Optional[str] = None
    llm_model_name: Optional[str] = None
    llm_provider: Optional[str] = None  # kept for backward compatibility

    # Database configuration
    db_host: Optional[str] = None
    db_user: Optional[str] = None
    db_password: Optional[str] = None
    db_name: Optional[str] = None
    db_port: int = 3306
    db_charset: str = "utf8mb4"

    # Model behaviour configuration
    max_reflections: int = 3
    max_paragraphs: int = 6
    search_timeout: int = 240
    max_content_length: int = 500000

    # Search result limits
    default_search_hot_content_limit: int = 100
    default_search_topic_globally_limit_per_table: int = 50
    default_search_topic_by_date_limit_per_table: int = 100
    default_get_comments_for_topic_limit: int = 500
    default_search_topic_on_platform_limit: int = 200
    max_search_results_for_llm: int = 0
    max_high_confidence_sentiment_results: int = 0

    # Output configuration
    output_dir: str = "reports"
    save_intermediate_states: bool = True

    def __post_init__(self):
        if not self.llm_provider and self.llm_model_name:
            # Provider is no longer used, but keep the attribute for compatibility.
            self.llm_provider = self.llm_model_name

    def validate(self) -> bool:
        """Validate configuration."""
        if not self.llm_api_key:
            print("错误: Insight Engine LLM API Key 未设置 (INSIGHT_ENGINE_API_KEY)。")
            return False

        if not self.llm_model_name:
            print("错误: Insight Engine 模型名称未设置 (INSIGHT_ENGINE_MODEL_NAME)。")
            return False

        if not all([self.db_host, self.db_user, self.db_password, self.db_name]):
            print("错误: 数据库连接信息不完整,请检查 config.py 中的 DB_* 配置。")
            return False

        return True

    @classmethod
    def from_file(cls, config_file: str) -> "Config":
        """Create configuration from file."""
        if config_file.endswith(".py"):
            import importlib.util

            spec = importlib.util.spec_from_file_location("config", config_file)
            config_module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(config_module)

            return cls(
                llm_api_key=_get_value(config_module, "INSIGHT_ENGINE_API_KEY"),
                llm_base_url=_get_value(config_module, "INSIGHT_ENGINE_BASE_URL"),
                llm_model_name=_get_value(config_module, "INSIGHT_ENGINE_MODEL_NAME"),
                db_host=_get_value(config_module, "DB_HOST"),
                db_user=_get_value(config_module, "DB_USER"),
                db_password=_get_value(config_module, "DB_PASSWORD"),
                db_name=_get_value(config_module, "DB_NAME"),
                db_port=int(_get_value(config_module, "DB_PORT", 3306)),
                db_charset=_get_value(config_module, "DB_CHARSET", "utf8mb4"),
                max_reflections=int(_get_value(config_module, "MAX_REFLECTIONS", 3)),
                max_paragraphs=int(_get_value(config_module, "MAX_PARAGRAPHS", 6)),
                search_timeout=int(_get_value(config_module, "SEARCH_TIMEOUT", 240)),
                max_content_length=int(_get_value(config_module, "SEARCH_CONTENT_MAX_LENGTH", 500000)),
                default_search_hot_content_limit=int(
                    _get_value(config_module, "DEFAULT_SEARCH_HOT_CONTENT_LIMIT", 100)
                ),
                default_search_topic_globally_limit_per_table=int(
                    _get_value(config_module, "DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE", 50)
                ),
                default_search_topic_by_date_limit_per_table=int(
                    _get_value(config_module, "DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE", 100)
                ),
                default_get_comments_for_topic_limit=int(
                    _get_value(config_module, "DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT", 500)
                ),
                default_search_topic_on_platform_limit=int(
                    _get_value(config_module, "DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT", 200)
                ),
                max_search_results_for_llm=int(_get_value(config_module, "MAX_SEARCH_RESULTS_FOR_LLM", 0)),
                max_high_confidence_sentiment_results=int(
                    _get_value(config_module, "MAX_HIGH_CONFIDENCE_SENTIMENT_RESULTS", 0)
                ),
                output_dir=_get_value(config_module, "OUTPUT_DIR", "reports"),
                save_intermediate_states=str(
                    _get_value(config_module, "SAVE_INTERMEDIATE_STATES", "true")
                ).lower()
                in ("true", "1", "yes"),
            )

        # .env style configuration
        config_dict = {}
        if os.path.exists(config_file):
            with open(config_file, "r", encoding="utf-8") as f:
                for line in f:
                    line = line.strip()
                    if line and not line.startswith("#") and "=" in line:
                        key, value = line.split("=", 1)
                        config_dict[key.strip()] = value.strip()

        return cls(
            llm_api_key=_get_value(config_dict, "INSIGHT_ENGINE_API_KEY"),
            llm_base_url=_get_value(config_dict, "INSIGHT_ENGINE_BASE_URL"),
            llm_model_name=_get_value(config_dict, "INSIGHT_ENGINE_MODEL_NAME"),
            db_host=_get_value(config_dict, "DB_HOST"),
            db_user=_get_value(config_dict, "DB_USER"),
            db_password=_get_value(config_dict, "DB_PASSWORD"),
            db_name=_get_value(config_dict, "DB_NAME"),
            db_port=int(_get_value(config_dict, "DB_PORT", 3306)),
            db_charset=_get_value(config_dict, "DB_CHARSET", "utf8mb4"),
            max_reflections=int(_get_value(config_dict, "MAX_REFLECTIONS", 3)),
            max_paragraphs=int(_get_value(config_dict, "MAX_PARAGRAPHS", 6)),
            search_timeout=int(_get_value(config_dict, "SEARCH_TIMEOUT", 240)),
            max_content_length=int(_get_value(config_dict, "SEARCH_CONTENT_MAX_LENGTH", 500000)),
            default_search_hot_content_limit=int(
                _get_value(config_dict, "DEFAULT_SEARCH_HOT_CONTENT_LIMIT", 100)
            ),
            default_search_topic_globally_limit_per_table=int(
                _get_value(config_dict, "DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE", 50)
            ),
            default_search_topic_by_date_limit_per_table=int(
                _get_value(config_dict, "DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE", 100)
            ),
            default_get_comments_for_topic_limit=int(
                _get_value(config_dict, "DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT", 500)
            ),
            default_search_topic_on_platform_limit=int(
                _get_value(config_dict, "DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT", 200)
            ),
            max_search_results_for_llm=int(_get_value(config_dict, "MAX_SEARCH_RESULTS_FOR_LLM", 0)),
            max_high_confidence_sentiment_results=int(
                _get_value(config_dict, "MAX_HIGH_CONFIDENCE_SENTIMENT_RESULTS", 0)
            ),
            output_dir=_get_value(config_dict, "OUTPUT_DIR", "reports"),
            save_intermediate_states=str(
                _get_value(config_dict, "SAVE_INTERMEDIATE_STATES", "true")
            ).lower()
            in ("true", "1", "yes"),
        )


def load_config(config_file: Optional[str] = None) -> Config:
    """
    Load configuration.
    """
    if config_file:
        if not os.path.exists(config_file):
            raise FileNotFoundError(f"配置文件不存在: {config_file}")
        file_to_load = config_file
    else:
        for candidate in ("config.py", "config.env", ".env"):
            if os.path.exists(candidate):
                file_to_load = candidate
                print(f"已找到配置文件: {candidate}")
                break
        else:
            raise FileNotFoundError("未找到配置文件,请创建 config.py。")

    config = Config.from_file(file_to_load)

    if not config.validate():
        raise ValueError("配置校验失败,请检查 config.py 中的相关配置。")

    return config


def print_config(config: Config):
    """Print configuration (sensitive values masked)."""
    print("\n=== Insight Engine 配置 ===")
    print(f"LLM 模型: {config.llm_model_name}")
    print(f"LLM Base URL: {config.llm_base_url or '(默认)'}")
    print(f"搜索超时: {config.search_timeout} 秒")
    print(f"最长内容长度: {config.max_content_length}")
    print(f"最大反思次数: {config.max_reflections}")
    print(f"最大段落数: {config.max_paragraphs}")
    print(f"输出目录: {config.output_dir}")
    print(f"保存中间状态: {config.save_intermediate_states}")
    print(f"LLM API Key: {'已配置' if config.llm_api_key else '未配置'}")
    print(f"数据库连接: {'已配置' if all([config.db_host, config.db_user, config.db_password, config.db_name]) else '未配置'}")
    print("========================\n")