config_admin.py 3.78 KB
from __future__ import annotations

import importlib
import sys
from pathlib import Path
from typing import Any, Dict, Iterable

from loguru import logger

from services.shared.config import read_settings_values, resolve_settings_write_env_file

CONFIG_MODULE_NAME = "services.shared.config"
PROJECT_ROOT = Path(__file__).resolve().parent.parent

CONFIG_KEYS = [
    "HOST",
    "PORT",
    "DB_DIALECT",
    "DB_HOST",
    "DB_PORT",
    "DB_USER",
    "DB_PASSWORD",
    "DB_NAME",
    "DB_CHARSET",
    "INSIGHT_ENGINE_API_KEY",
    "INSIGHT_ENGINE_BASE_URL",
    "INSIGHT_ENGINE_MODEL_NAME",
    "MEDIA_ENGINE_API_KEY",
    "MEDIA_ENGINE_BASE_URL",
    "MEDIA_ENGINE_MODEL_NAME",
    "QUERY_ENGINE_API_KEY",
    "QUERY_ENGINE_BASE_URL",
    "QUERY_ENGINE_MODEL_NAME",
    "REPORT_ENGINE_API_KEY",
    "REPORT_ENGINE_BASE_URL",
    "REPORT_ENGINE_MODEL_NAME",
    "FORUM_HOST_API_KEY",
    "FORUM_HOST_BASE_URL",
    "FORUM_HOST_MODEL_NAME",
    "KEYWORD_OPTIMIZER_API_KEY",
    "KEYWORD_OPTIMIZER_BASE_URL",
    "KEYWORD_OPTIMIZER_MODEL_NAME",
    "TAVILY_API_KEY",
    "SEARCH_TOOL_TYPE",
    "BOCHA_WEB_SEARCH_API_KEY",
    "ANSPIRE_API_KEY",
]


def _load_config_module():
    """Load or reload the config module after .env changes."""
    importlib.invalidate_caches()
    module = sys.modules.get(CONFIG_MODULE_NAME)
    try:
        if module is None:
            module = importlib.import_module(CONFIG_MODULE_NAME)
        else:
            module = importlib.reload(module)
    except ModuleNotFoundError:
        return None
    return module


def read_config_values(keys: Iterable[str] = CONFIG_KEYS) -> Dict[str, str]:
    """Return the selected configuration keys as strings for the UI."""
    try:
        return read_settings_values(keys, reload=True)
    except Exception as exc:  # pragma: no cover - defensive fallback
        logger.exception(f"Failed to read config values: {exc}")
        return {}


def filter_config_updates(payload: Dict[str, Any], allowed_keys: Iterable[str] = CONFIG_KEYS) -> Dict[str, Any]:
    """Keep only supported keys from a config update payload."""
    allowed = set(allowed_keys)
    return {
        key: (value if value is not None else "")
        for key, value in payload.items()
        if key in allowed
    }


def write_config_values(updates: Dict[str, Any]) -> None:
    """Persist config updates to the shared .env file used by Settings."""
    env_file_path = resolve_settings_write_env_file(
        cwd=Path.cwd(),
        project_root=PROJECT_ROOT,
    )

    env_lines = []
    env_key_indices: Dict[str, int] = {}
    if env_file_path.exists():
        env_lines = env_file_path.read_text(encoding="utf-8").splitlines()
        for index, line in enumerate(env_lines):
            stripped = line.strip()
            if stripped and not stripped.startswith("#") and "=" in stripped:
                env_key_indices[stripped.split("=")[0].strip()] = index

    for key, raw_value in updates.items():
        if raw_value is None or raw_value == "":
            env_value = ""
        elif isinstance(raw_value, (int, float)):
            env_value = str(raw_value)
        elif isinstance(raw_value, bool):
            env_value = "True" if raw_value else "False"
        else:
            value_str = str(raw_value)
            if " " in value_str or "\n" in value_str or "#" in value_str:
                escaped = value_str.replace("\\", "\\\\").replace('"', '\\"')
                env_value = f'"{escaped}"'
            else:
                env_value = value_str

        if key in env_key_indices:
            env_lines[env_key_indices[key]] = f"{key}={env_value}"
        else:
            env_lines.append(f"{key}={env_value}")

    env_file_path.parent.mkdir(parents=True, exist_ok=True)
    env_file_path.write_text("\n".join(env_lines) + "\n", encoding="utf-8")
    _load_config_module()