task_runtime_store.py
4.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Ephemeral task-scoped runtime snapshots for the web API."""
from __future__ import annotations
from datetime import datetime
from threading import Lock
from typing import Any
from pydantic import Field
from services.shared.models import EngineResult
from services.shared.models.common import ErrorInfo, ProgressInfo, SharedModel
_UNSET = object()
class TaskRuntimeSnapshot(SharedModel):
"""Task-centered runtime projection for in-flight orchestration state."""
research_task_id: str
status: str = "draft"
generated_query: str = ""
crawler_job_id: str | None = None
analysis_run_id: str | None = None
report_job_id: str | None = None
last_action: str = ""
progress: ProgressInfo = Field(default_factory=ProgressInfo)
error: ErrorInfo | None = None
engines: list[str] = Field(default_factory=list)
partial_results: dict[str, Any] = Field(default_factory=dict)
metrics: dict[str, Any] = Field(default_factory=dict)
updated_at: datetime | None = None
class TaskRuntimeStore:
"""Thread-safe store for task runtime snapshots."""
def __init__(self) -> None:
self._entries: dict[str, TaskRuntimeSnapshot] = {}
self._lock = Lock()
def get_task(self, research_task_id: str) -> TaskRuntimeSnapshot | None:
with self._lock:
entry = self._entries.get(research_task_id)
return entry.model_copy(deep=True) if entry else None
def snapshot_all(self) -> dict[str, TaskRuntimeSnapshot]:
with self._lock:
return {
task_id: entry.model_copy(deep=True)
for task_id, entry in self._entries.items()
}
def upsert(
self,
research_task_id: str,
*,
status: str | None = None,
generated_query: str | None = None,
crawler_job_id: str | None | object = _UNSET,
analysis_run_id: str | None | object = _UNSET,
report_job_id: str | None | object = _UNSET,
last_action: str | None = None,
progress: ProgressInfo | None = None,
error: ErrorInfo | None | object = _UNSET,
engines: list[str] | None = None,
partial_results: dict[str, Any] | None = None,
metrics: dict[str, Any] | None = None,
) -> TaskRuntimeSnapshot:
with self._lock:
entry = self._entries.get(research_task_id)
if entry is None:
entry = TaskRuntimeSnapshot(research_task_id=research_task_id)
if status is not None:
entry.status = str(status)
if generated_query is not None:
entry.generated_query = str(generated_query)
if crawler_job_id is not _UNSET:
entry.crawler_job_id = self._normalize_linked_id(crawler_job_id)
if analysis_run_id is not _UNSET:
entry.analysis_run_id = self._normalize_linked_id(analysis_run_id)
if report_job_id is not _UNSET:
entry.report_job_id = self._normalize_linked_id(report_job_id)
if last_action is not None:
entry.last_action = str(last_action)
if progress is not None:
entry.progress = progress.model_copy(deep=True)
if error is not _UNSET:
entry.error = error.model_copy(deep=True) if error else None
if engines is not None:
entry.engines = [str(engine) for engine in engines if str(engine)]
if partial_results is not None:
entry.partial_results = {
str(engine): EngineResult.from_raw(
engine_name=str(engine),
payload=payload if isinstance(payload, dict) else {"success": False, "message": str(payload)},
).to_runtime_payload()
for engine, payload in partial_results.items()
}
if metrics is not None:
entry.metrics = dict(metrics)
entry.updated_at = datetime.now()
self._entries[research_task_id] = entry
return entry.model_copy(deep=True)
def clear_task(self, research_task_id: str) -> bool:
with self._lock:
return self._entries.pop(research_task_id, None) is not None
@staticmethod
def _normalize_linked_id(value: str | None | object) -> str | None:
if value is None:
return None
normalized = str(value).strip()
return normalized or None
TASK_RUNTIME_STORE = TaskRuntimeStore()
__all__ = [
"TASK_RUNTIME_STORE",
"TaskRuntimeSnapshot",
"TaskRuntimeStore",
]