test_report_stream_service.py 6.85 KB
from __future__ import annotations

from queue import Empty

import pytest

from services.application.report.stream_service import (
    ReportStreamService,
    ReportStreamTaskNotFoundError,
)


class _FakeReportTask:
    def __init__(
        self,
        *,
        task_id: str = "report-1",
        status: str = "running",
        history: list[dict[str, object]] | None = None,
    ) -> None:
        self.task_id = task_id
        self.status = status
        self._history = list(history or [])
        self.history_calls: list[int | None] = []

    def history_since(self, last_event_id: int | None) -> list[dict[str, object]]:
        self.history_calls.append(last_event_id)
        if last_event_id is None:
            return list(self._history)
        return [event for event in self._history if int(event["id"]) > last_event_id]


class _ScriptedSubscriber:
    def __init__(self, steps: list[object] | None = None) -> None:
        self._steps = list(steps or [])
        self.timeout_calls: list[float | None] = []

    def get(self, timeout: float | None = None):
        self.timeout_calls.append(timeout)
        if not self._steps:
            raise Empty()
        step = self._steps.pop(0)
        if isinstance(step, BaseException):
            raise step
        return step


class _TimeSequence:
    def __init__(self, values: list[float]) -> None:
        self._values = list(values)
        self._last = values[-1] if values else 0.0

    def __call__(self) -> float:
        if self._values:
            self._last = self._values.pop(0)
        return self._last


def _build_service(
    *,
    task: _FakeReportTask | None = None,
    subscriber: _ScriptedSubscriber | None = None,
    time_values: list[float] | None = None,
):
    state = {
        "register_calls": [],
        "unregister_calls": [],
        "heartbeat_calls": [],
    }
    subscriber = subscriber or _ScriptedSubscriber()

    def report_task_getter(task_id: str):
        if task is not None and task.task_id == task_id:
            return task
        return None

    def stream_register(task_id: str):
        state["register_calls"].append(task_id)
        return subscriber

    def stream_unregister(task_id: str, registered_subscriber):
        state["unregister_calls"].append((task_id, registered_subscriber))

    def event_formatter(event: dict[str, object]) -> str:
        return f"{event['type']}:{event['id']}"

    def heartbeat_builder(task_id: str, *, status: str = "") -> dict[str, object]:
        state["heartbeat_calls"].append((task_id, status))
        return {"id": "hb-1", "type": "heartbeat", "task_id": task_id, "payload": {"status": status}}

    service = ReportStreamService(
        report_task_getter=report_task_getter,
        stream_register=stream_register,
        stream_unregister=stream_unregister,
        event_formatter=event_formatter,
        heartbeat_builder=heartbeat_builder,
        time_getter=_TimeSequence(time_values or [0.0, 0.0, 0.0]),
        heartbeat_interval=3,
        idle_timeout=5,
        terminal_statuses={"completed", "error", "cancelled"},
    )
    return service, state, subscriber


@pytest.mark.parametrize(
    ("header_value", "expected"),
    [
        (None, None),
        ("", None),
        ("  42  ", 42),
        ("abc", None),
        ("hb-1", None),
    ],
)
def test_parse_last_event_id(header_value, expected):
    assert ReportStreamService.parse_last_event_id(header_value) == expected


def test_get_required_report_task_raises_when_missing():
    service, _state, _subscriber = _build_service(task=None)

    with pytest.raises(ReportStreamTaskNotFoundError) as exc_info:
        service.get_required_report_task("missing-task")

    assert exc_info.value.task_id == "missing-task"


def test_get_report_task_returns_none_for_blank_id():
    service, _state, _subscriber = _build_service(task=None)

    assert service.get_report_task("") is None
    assert service.get_report_task("   ") is None


def test_iter_encoded_events_replays_history_and_then_streams_new_events():
    task = _FakeReportTask(
        history=[
            {"id": 1, "type": "status"},
            {"id": 2, "type": "progress"},
        ]
    )
    subscriber = _ScriptedSubscriber(
        [
            {"id": 3, "type": "progress"},
            None,
        ]
    )
    service, state, _subscriber = _build_service(
        task=task,
        subscriber=subscriber,
        time_values=[0.0, 1.0, 2.0, 3.0],
    )

    events = list(service.iter_encoded_events(task_id="report-1", last_event_id=1))

    assert events == ["progress:2", "progress:3"]
    assert task.history_calls == [1]
    assert state["register_calls"] == ["report-1"]
    assert state["unregister_calls"] == [("report-1", subscriber)]


def test_iter_encoded_events_emits_heartbeat_while_task_is_active():
    task = _FakeReportTask(status="running")
    subscriber = _ScriptedSubscriber([Empty(), None])
    service, state, _subscriber = _build_service(
        task=task,
        subscriber=subscriber,
        time_values=[0.0, 1.0, 2.0, 3.0],
    )

    events = list(service.iter_encoded_events(task_id="report-1"))

    assert events == ["heartbeat:hb-1"]
    assert state["heartbeat_calls"] == [("report-1", "running")]
    assert subscriber.timeout_calls == [3, 3]


def test_iter_encoded_events_stops_after_terminal_idle_timeout():
    task = _FakeReportTask(status="completed")
    subscriber = _ScriptedSubscriber([Empty()])
    service, state, _subscriber = _build_service(
        task=task,
        subscriber=subscriber,
        time_values=[0.0, 2.0, 2.0, 7.0],
    )

    events = list(service.iter_encoded_events(task_id="report-1"))

    assert events == []
    assert state["heartbeat_calls"] == []
    assert state["unregister_calls"] == [("report-1", subscriber)]


def test_iter_encoded_events_stops_when_client_disconnects():
    task = _FakeReportTask(status="running")
    subscriber = _ScriptedSubscriber()
    service, state, _subscriber = _build_service(
        task=task,
        subscriber=subscriber,
        time_values=[0.0, 0.0],
    )

    events = list(
        service.iter_encoded_events(
            task_id="report-1",
            client_disconnected=lambda: True,
        )
    )

    assert events == []
    assert subscriber.timeout_calls == []
    assert state["unregister_calls"] == [("report-1", subscriber)]


def test_iter_encoded_events_unregisters_when_formatter_raises():
    task = _FakeReportTask(history=[{"id": 1, "type": "status"}])
    subscriber = _ScriptedSubscriber()
    service, state, _subscriber = _build_service(
        task=task,
        subscriber=subscriber,
        time_values=[0.0, 1.0],
    )
    service._event_formatter = lambda _event: (_ for _ in ()).throw(RuntimeError("format failed"))

    with pytest.raises(RuntimeError, match="format failed"):
        list(service.iter_encoded_events(task_id="report-1"))

    assert state["unregister_calls"] == [("report-1", subscriber)]