task_runtime_store.py 4.54 KB
"""Ephemeral task-scoped runtime snapshots for the web API."""

from __future__ import annotations

from datetime import datetime
from threading import Lock
from typing import Any

from pydantic import Field

from services.shared.models import EngineResult
from services.shared.models.common import ErrorInfo, ProgressInfo, SharedModel

_UNSET = object()


class TaskRuntimeSnapshot(SharedModel):
    """Task-centered runtime projection for in-flight orchestration state."""

    research_task_id: str
    status: str = "draft"
    generated_query: str = ""
    crawler_job_id: str | None = None
    analysis_run_id: str | None = None
    report_job_id: str | None = None
    last_action: str = ""
    progress: ProgressInfo = Field(default_factory=ProgressInfo)
    error: ErrorInfo | None = None
    engines: list[str] = Field(default_factory=list)
    partial_results: dict[str, Any] = Field(default_factory=dict)
    metrics: dict[str, Any] = Field(default_factory=dict)
    updated_at: datetime | None = None


class TaskRuntimeStore:
    """Thread-safe store for task runtime snapshots."""

    def __init__(self) -> None:
        self._entries: dict[str, TaskRuntimeSnapshot] = {}
        self._lock = Lock()

    def get_task(self, research_task_id: str) -> TaskRuntimeSnapshot | None:
        with self._lock:
            entry = self._entries.get(research_task_id)
            return entry.model_copy(deep=True) if entry else None

    def snapshot_all(self) -> dict[str, TaskRuntimeSnapshot]:
        with self._lock:
            return {
                task_id: entry.model_copy(deep=True)
                for task_id, entry in self._entries.items()
            }

    def upsert(
        self,
        research_task_id: str,
        *,
        status: str | None = None,
        generated_query: str | None = None,
        crawler_job_id: str | None | object = _UNSET,
        analysis_run_id: str | None | object = _UNSET,
        report_job_id: str | None | object = _UNSET,
        last_action: str | None = None,
        progress: ProgressInfo | None = None,
        error: ErrorInfo | None | object = _UNSET,
        engines: list[str] | None = None,
        partial_results: dict[str, Any] | None = None,
        metrics: dict[str, Any] | None = None,
    ) -> TaskRuntimeSnapshot:
        with self._lock:
            entry = self._entries.get(research_task_id)
            if entry is None:
                entry = TaskRuntimeSnapshot(research_task_id=research_task_id)

            if status is not None:
                entry.status = str(status)
            if generated_query is not None:
                entry.generated_query = str(generated_query)
            if crawler_job_id is not _UNSET:
                entry.crawler_job_id = self._normalize_linked_id(crawler_job_id)
            if analysis_run_id is not _UNSET:
                entry.analysis_run_id = self._normalize_linked_id(analysis_run_id)
            if report_job_id is not _UNSET:
                entry.report_job_id = self._normalize_linked_id(report_job_id)
            if last_action is not None:
                entry.last_action = str(last_action)
            if progress is not None:
                entry.progress = progress.model_copy(deep=True)
            if error is not _UNSET:
                entry.error = error.model_copy(deep=True) if error else None
            if engines is not None:
                entry.engines = [str(engine) for engine in engines if str(engine)]
            if partial_results is not None:
                entry.partial_results = {
                    str(engine): EngineResult.from_raw(
                        engine_name=str(engine),
                        payload=payload if isinstance(payload, dict) else {"success": False, "message": str(payload)},
                    ).to_runtime_payload()
                    for engine, payload in partial_results.items()
                }
            if metrics is not None:
                entry.metrics = dict(metrics)

            entry.updated_at = datetime.now()
            self._entries[research_task_id] = entry
            return entry.model_copy(deep=True)

    def clear_task(self, research_task_id: str) -> bool:
        with self._lock:
            return self._entries.pop(research_task_id, None) is not None

    @staticmethod
    def _normalize_linked_id(value: str | None | object) -> str | None:
        if value is None:
            return None
        normalized = str(value).strip()
        return normalized or None


TASK_RUNTIME_STORE = TaskRuntimeStore()


__all__ = [
    "TASK_RUNTIME_STORE",
    "TaskRuntimeSnapshot",
    "TaskRuntimeStore",
]