search_dispatch.py 12.8 KB
"""Runtime adapters for local search dispatch execution."""

from __future__ import annotations

import threading
from pathlib import Path
from typing import Any, Callable

from loguru import logger

from apps.web_api.runtime.engine_registry import ENGINE_RUNTIME_REGISTRY, EngineRuntimeRegistry
from apps.web_api.runtime.process_registry import PROCESS_RUNTIME_REGISTRY, ProcessRuntimeRegistry
from apps.web_api.runtime.task_runtime_store import TASK_RUNTIME_STORE, TaskRuntimeStore
from services.application.analysis import AnalysisExecutionContext, AnalysisService, SearchRequestSubmission
from services.shared.config.access import get_database_runtime_settings, get_settings
from services.shared.models import EngineContext, EngineExecutionError, EngineResult
from utils.runtime_paths import INSIGHT_REPORTS_DIR, MEDIA_REPORTS_DIR, QUERY_REPORTS_DIR

LogWriter = Callable[[Path, str, str], None]
SearchDispatcher = Callable[..., dict[str, Any]]
SearchQueryResolver = Callable[..., tuple[str, str]]
SearchRequestSubmitter = Callable[..., tuple[dict[str, Any], int]]

LOCAL_ENGINE_BUILDERS: dict[str, Callable[[], Any]] = {}


def build_query_agent():
    from services.engines.query import DeepSearchAgent as QueryAgent, Settings as QuerySettings

    settings = get_settings()
    config = QuerySettings(
        QUERY_ENGINE_API_KEY=settings.QUERY_ENGINE_API_KEY,
        QUERY_ENGINE_BASE_URL=settings.QUERY_ENGINE_BASE_URL,
        QUERY_ENGINE_MODEL_NAME=settings.QUERY_ENGINE_MODEL_NAME,
        SEARCH_TOOL_TYPE=settings.SEARCH_TOOL_TYPE or "AnspireAPI",
        ANSPIRE_API_KEY=settings.ANSPIRE_API_KEY,
        BOCHA_WEB_SEARCH_API_KEY=settings.BOCHA_WEB_SEARCH_API_KEY,
        TAVILY_API_KEY=settings.TAVILY_API_KEY,
        OUTPUT_DIR=str(QUERY_REPORTS_DIR),
    )
    return QueryAgent(config)


def build_media_agent():
    from services.engines.media import (
        AnspireSearchAgent as MediaAnspireAgent,
        DeepSearchAgent as MediaAgent,
        Settings as MediaSettings,
    )

    settings = get_settings()
    search_tool_type = settings.SEARCH_TOOL_TYPE or "AnspireAPI"
    config = MediaSettings(
        MEDIA_ENGINE_API_KEY=settings.MEDIA_ENGINE_API_KEY,
        MEDIA_ENGINE_BASE_URL=settings.MEDIA_ENGINE_BASE_URL,
        MEDIA_ENGINE_MODEL_NAME=settings.MEDIA_ENGINE_MODEL_NAME,
        SEARCH_TOOL_TYPE=search_tool_type,
        BOCHA_WEB_SEARCH_API_KEY=settings.BOCHA_WEB_SEARCH_API_KEY,
        ANSPIRE_API_KEY=settings.ANSPIRE_API_KEY,
        OUTPUT_DIR=str(MEDIA_REPORTS_DIR),
    )
    if search_tool_type == "BochaAPI":
        return MediaAgent(config)
    return MediaAnspireAgent(config)


def build_insight_agent():
    from services.engines.insight import DeepSearchAgent as InsightAgent, Settings as InsightSettings

    settings = get_settings()
    database_settings = get_database_runtime_settings()
    config = InsightSettings(
        INSIGHT_ENGINE_API_KEY=settings.INSIGHT_ENGINE_API_KEY,
        INSIGHT_ENGINE_BASE_URL=settings.INSIGHT_ENGINE_BASE_URL,
        INSIGHT_ENGINE_MODEL_NAME=settings.INSIGHT_ENGINE_MODEL_NAME,
        DB_HOST=database_settings.host,
        DB_USER=database_settings.user,
        DB_PASSWORD=database_settings.password,
        DB_NAME=database_settings.name,
        DB_PORT=database_settings.port,
        DB_CHARSET=database_settings.charset,
        DB_DIALECT=database_settings.dialect,
        OUTPUT_DIR=str(INSIGHT_REPORTS_DIR),
    )
    return InsightAgent(config)


LOCAL_ENGINE_BUILDERS.update(
    {
        "query": build_query_agent,
        "media": build_media_agent,
        "insight": build_insight_agent,
    }
)


def run_local_engine_research(
    context: EngineContext,
    *,
    engine_registry: EngineRuntimeRegistry | None = None,
) -> EngineResult:
    app_name = context.engine_name
    query = context.query
    try:
        runtime_engine_registry = (
            engine_registry if engine_registry is not None else ENGINE_RUNTIME_REGISTRY
        )
        local_runner = (
            runtime_engine_registry.get_local_runner(app_name)
            if runtime_engine_registry.contains(app_name)
            else None
        )
        if not local_runner:
            raise ValueError(f"Unknown engine: {app_name}")
        agent_builder = LOCAL_ENGINE_BUILDERS.get(local_runner)
        if agent_builder is None:
            raise ValueError(f"No local runner is registered for engine: {app_name}")
        agent = agent_builder()

        report = agent.research(query, save_report=True)
        metrics = _collect_engine_metrics(agent)
        summary = "research completed"
        return EngineResult(
            engine_name=app_name,
            status="completed",
            success=True,
            summary=summary,
            artifacts={
                "report_preview": str(report)[:300],
            },
            metrics=metrics,
        )
    except Exception as exc:
        logger.exception(f"{app_name} local research failed: {exc}")
        return EngineResult(
            engine_name=app_name,
            status="failed",
            success=False,
            summary=str(exc),
            error=EngineExecutionError(
                code="engine_execution_failed",
                message=str(exc),
                retryable=True,
                details={"engine_name": app_name},
            ),
        )


def _collect_engine_metrics(agent: Any) -> dict[str, Any]:
    get_progress_summary = getattr(agent, "get_progress_summary", None)
    if not callable(get_progress_summary):
        return {}

    try:
        summary = get_progress_summary()
    except Exception:
        logger.exception("Failed to collect engine progress summary")
        return {}

    return dict(summary) if isinstance(summary, dict) else {}


def build_analysis_service(
    *,
    research_task_service,
    engine_registry: EngineRuntimeRegistry | None = None,
    task_runtime_store: TaskRuntimeStore | None = None,
    thread_factory: Callable[..., Any] = threading.Thread,
) -> AnalysisService:
    """Build the analysis application service with runtime adapters injected."""
    runtime_engine_registry = (
        engine_registry if engine_registry is not None else ENGINE_RUNTIME_REGISTRY
    )

    def _engine_runner(context: EngineContext) -> EngineResult:
        return run_local_engine_research(
            context,
            engine_registry=runtime_engine_registry,
        )

    return AnalysisService(
        research_task_service,
        engine_runner=_engine_runner,
        task_runtime_store=task_runtime_store if task_runtime_store is not None else TASK_RUNTIME_STORE,
        thread_factory=thread_factory,
    )


def build_analysis_execution_context(
    *,
    process_registry: ProcessRuntimeRegistry | None = None,
    check_app_status: Callable[[], None],
    log_dir: Path,
    write_log: LogWriter,
) -> AnalysisExecutionContext:
    """Build the runtime execution context consumed by analysis dispatch wrappers."""

    return AnalysisExecutionContext(
        process_registry=(
            process_registry if process_registry is not None else PROCESS_RUNTIME_REGISTRY
        ),
        check_app_status=check_app_status,
        log_dir=log_dir,
        write_log=write_log,
    )


def build_search_request_submitter(
    *,
    research_task_service,
    resolve_search_query: SearchQueryResolver,
    dispatch_search_request: SearchDispatcher,
    check_app_status: Callable[[], None],
    log_dir: Path,
    write_log: LogWriter,
    process_registry: ProcessRuntimeRegistry | None = None,
    analysis_service: AnalysisService | None = None,
) -> SearchRequestSubmitter:
    """Build a fully bound search submitter for HTTP/runtime adapters."""

    service = _resolve_analysis_service(
        research_task_service=research_task_service,
        analysis_service=analysis_service,
    )
    execution_context = build_analysis_execution_context(
        process_registry=process_registry,
        check_app_status=check_app_status,
        log_dir=log_dir,
        write_log=write_log,
    )

    def _dispatch_search_request(
        *,
        research_task_id: str,
        query: str,
    ) -> dict[str, Any]:
        return dispatch_search_request(
            research_task_id=research_task_id,
            query=query,
            execution_context=execution_context,
            research_task_service=research_task_service,
        )

    def _submit_search_request(
        *,
        payload: dict[str, Any],
    ) -> tuple[dict[str, Any], int]:
        submission = service.submit_search_request(
            payload=payload,
            execution_context=execution_context,
            research_task_service=research_task_service,
            resolve_search_query=resolve_search_query,
            dispatch_search_request=_dispatch_search_request,
        )
        return _map_search_submission_to_http_response(submission)

    return _submit_search_request


def _map_search_submission_to_http_response(
    submission: SearchRequestSubmission,
) -> tuple[dict[str, Any], int]:
    status_code = 400 if submission.kind == "rejected" else 200
    return submission.payload, status_code


def _resolve_analysis_service(
    *,
    research_task_service,
    analysis_service: AnalysisService | None,
) -> AnalysisService:
    if analysis_service is not None:
        return analysis_service
    return build_analysis_service(
        research_task_service=research_task_service,
        thread_factory=threading.Thread,
    )


class SearchDispatchRuntime:
    """Object-oriented adapter around analysis search dispatch helpers."""

    def __init__(
        self,
        *,
        analysis_service: AnalysisService | None = None,
    ) -> None:
        self._analysis_service = analysis_service

    def _get_analysis_service(
        self,
        *,
        research_task_service,
        analysis_service: AnalysisService | None = None,
    ) -> AnalysisService:
        return _resolve_analysis_service(
            research_task_service=research_task_service,
            analysis_service=(
                analysis_service
                if analysis_service is not None
                else self._analysis_service
            ),
        )

    def execute_search_dispatch_async(
        self,
        *,
        research_task_id: str,
        query: str,
        running_apps: list[str],
        research_task_service,
        log_dir: Path,
        write_log: LogWriter,
        analysis_service: AnalysisService | None = None,
    ) -> None:
        service = self._get_analysis_service(
            research_task_service=research_task_service,
            analysis_service=analysis_service,
        )
        service.execute_search_dispatch_async(
            research_task_id=research_task_id,
            query=query,
            running_apps=running_apps,
            research_task_service=research_task_service,
            log_dir=log_dir,
            write_log=write_log,
        )

    def dispatch_search_request(
        self,
        *,
        research_task_id: str,
        query: str,
        process_registry: ProcessRuntimeRegistry | None = None,
        check_app_status: Callable[[], None],
        research_task_service,
        log_dir: Path,
        write_log: LogWriter,
        analysis_service: AnalysisService | None = None,
    ) -> dict[str, Any]:
        service = self._get_analysis_service(
            research_task_service=research_task_service,
            analysis_service=analysis_service,
        )
        return service.dispatch_search_request(
            research_task_id=research_task_id,
            query=query,
            execution_context=build_analysis_execution_context(
                process_registry=process_registry,
                check_app_status=check_app_status,
                log_dir=log_dir,
                write_log=write_log,
            ),
            research_task_service=research_task_service,
        )

    def resolve_search_query(
        self,
        *,
        payload: dict[str, Any],
        research_task_service,
        analysis_service: AnalysisService | None = None,
    ) -> tuple[str, str]:
        service = self._get_analysis_service(
            research_task_service=research_task_service,
            analysis_service=analysis_service,
        )
        return service.resolve_search_query(
            payload=payload,
            research_task_service=research_task_service,
        )


def build_search_dispatch_runtime(
    *,
    analysis_service: AnalysisService | None = None,
) -> SearchDispatchRuntime:
    """Build a search-dispatch runtime wrapper with an optional bound service."""

    return SearchDispatchRuntime(analysis_service=analysis_service)


__all__ = [
    "AnalysisService",
    "SearchDispatchRuntime",
    "build_analysis_execution_context",
    "build_search_request_submitter",
    "build_analysis_service",
    "build_insight_agent",
    "build_media_agent",
    "build_query_agent",
    "build_search_dispatch_runtime",
    "run_local_engine_research",
]