test_report_service.py 8.96 KB
from __future__ import annotations

from services.application.report import (
    ReportEngineUnavailableError,
    ReportOperationConflictError,
    ReportService,
    ReportTaskNotFoundError,
    ReportValidationError,
)


class _FakeReportTask:
    def __init__(self, query: str, task_id: str, custom_template: str = "", research_task_id: str = "") -> None:
        self.query = query
        self.task_id = task_id
        self.custom_template = custom_template
        self.research_task_id = research_task_id
        self.status = "pending"
        self.progress = 0
        self.persist_calls = 0
        self.published_events: list[dict[str, object]] = []

    def persist_snapshot(self) -> None:
        self.persist_calls += 1

    def publish_event(self, event_type: str, payload: dict[str, object]) -> None:
        self.published_events.append({"event_type": event_type, "payload": payload})

    def update_status(self, status: str, progress: int | None = None, error_message: str = "") -> None:
        self.status = status
        if progress is not None:
            self.progress = progress
        self.persist_snapshot()

    def to_dict(self) -> dict[str, object]:
        return {
            "task_id": self.task_id,
            "query": self.query,
            "research_task_id": self.research_task_id,
            "status": self.status,
            "progress": self.progress,
        }


def _build_service(*, current_task: _FakeReportTask | None = None, engine_ready: bool = True, engines_status=None):
    state = {
        "current_task": current_task,
        "set_calls": [],
        "register_calls": [],
        "clear_log_calls": 0,
        "engine_checks": 0,
        "task_factory_calls": [],
        "starter_calls": [],
        "sync_calls": [],
    }

    def get_current_task():
        return state["current_task"]

    def set_current_task(task):
        state["current_task"] = task
        state["set_calls"].append(task.task_id if task else None)

    def register_task(task):
        state["current_task"] = task
        state["register_calls"].append(task)

    def clear_report_log():
        state["clear_log_calls"] += 1

    def check_engines_ready():
        state["engine_checks"] += 1
        return dict(engines_status or {"ready": True, "missing_files": [], "files_found": []})

    def task_factory(query: str, task_id: str, custom_template: str, research_task_id: str):
        state["task_factory_calls"].append(
            {
                "query": query,
                "task_id": task_id,
                "custom_template": custom_template,
                "research_task_id": research_task_id,
            }
        )
        return _FakeReportTask(query, task_id, custom_template, research_task_id)

    def generation_starter(task, query: str, custom_template: str):
        state["starter_calls"].append(
            {
                "task_id": task.task_id,
                "query": query,
                "custom_template": custom_template,
            }
        )

    def sync_report_runtime_status(research_task_id: str, **kwargs):
        state["sync_calls"].append({"research_task_id": research_task_id, **kwargs})

    service = ReportService(
        current_task_getter=get_current_task,
        current_task_setter=set_current_task,
        report_task_getter=lambda task_id: state["current_task"] if state["current_task"] and state["current_task"].task_id == task_id else None,
        register_task=register_task,
        clear_report_log=clear_report_log,
        report_engine_ready_getter=lambda: engine_ready,
        check_engines_ready=check_engines_ready,
        task_factory=task_factory,
        generation_starter=generation_starter,
        sync_report_runtime_status=sync_report_runtime_status,
        task_id_factory=lambda: "report-fixed-id",
    )
    return service, state


def test_generate_report_rejects_running_task_conflict():
    current_task = _FakeReportTask("existing", "report-running")
    current_task.status = "running"
    service, state = _build_service(current_task=current_task)

    try:
        service.generate_report(query="new report")
        assert False, "expected ReportOperationConflictError"
    except ReportOperationConflictError as exc:
        assert str(exc) == "已有报告生成任务正在运行中"
        assert exc.details["current_task"]["task_id"] == "report-running"

    assert state["clear_log_calls"] == 0
    assert state["register_calls"] == []
    assert state["starter_calls"] == []


def test_generate_report_rejects_when_engine_is_unavailable():
    completed_task = _FakeReportTask("existing", "report-completed")
    completed_task.status = "completed"
    service, state = _build_service(current_task=completed_task, engine_ready=False)

    try:
        service.generate_report(query="new report")
        assert False, "expected ReportEngineUnavailableError"
    except ReportEngineUnavailableError as exc:
        assert str(exc) == "Report Engine未初始化"

    assert state["set_calls"] == [None]
    assert state["clear_log_calls"] == 1
    assert state["register_calls"] == []
    assert state["starter_calls"] == []


def test_generate_report_rejects_when_input_files_are_missing():
    service, state = _build_service(
        engines_status={"ready": False, "missing_files": ["query.md"], "files_found": []}
    )

    try:
        service.generate_report(query="new report")
        assert False, "expected ReportValidationError"
    except ReportValidationError as exc:
        assert str(exc) == "输入文件未准备就绪"
        assert exc.details["missing_files"] == ["query.md"]

    assert state["clear_log_calls"] == 1
    assert state["engine_checks"] == 1
    assert state["register_calls"] == []
    assert state["starter_calls"] == []


def test_generate_report_starts_background_task_and_syncs_research_task():
    service, state = _build_service()

    payload = service.generate_report(
        query="Shanghai museum report",
        custom_template="custom section",
        research_task_id="task-123",
    )

    assert payload["task_id"] == "report-fixed-id"
    assert payload["message"] == "报告生成已启动"
    assert payload["stream_url"] == "/api/report/stream/report-fixed-id"
    assert payload["task"]["research_task_id"] == "task-123"
    assert state["clear_log_calls"] == 1
    assert state["engine_checks"] == 1
    assert state["task_factory_calls"] == [
        {
            "query": "Shanghai museum report",
            "task_id": "report-fixed-id",
            "custom_template": "custom section",
            "research_task_id": "task-123",
        }
    ]
    created_task = state["register_calls"][0]
    assert created_task.persist_calls == 1
    assert created_task.published_events[0]["event_type"] == "status"
    assert created_task.published_events[0]["payload"]["status"] == "pending"
    assert state["sync_calls"] == [
        {
            "research_task_id": "task-123",
            "report_status": "pending",
            "report_job_id": "report-fixed-id",
            "last_action": "Report task queued",
        }
    ]
    assert state["starter_calls"] == [
        {
            "task_id": "report-fixed-id",
            "query": "Shanghai museum report",
            "custom_template": "custom section",
        }
    ]


def test_cancel_report_marks_running_task_cancelled_and_clears_current_task():
    current_task = _FakeReportTask("existing", "report-running")
    current_task.status = "running"
    current_task.progress = 48
    service, state = _build_service(current_task=current_task)

    payload = service.cancel_report("report-running")

    assert payload["message"] == "任务已取消"
    assert payload["task"]["task_id"] == "report-running"
    assert payload["task"]["status"] == "cancelled"
    assert payload["task"]["progress"] == 48
    assert state["current_task"] is None
    assert state["set_calls"] == [None]
    assert current_task.persist_calls == 1
    assert current_task.published_events[0]["event_type"] == "cancelled"


def test_cancel_report_rejects_missing_or_non_running_task():
    current_task = _FakeReportTask("existing", "report-done")
    current_task.status = "completed"
    service, state = _build_service(current_task=current_task)

    try:
        service.cancel_report("report-done")
        assert False, "expected ReportTaskNotFoundError"
    except ReportTaskNotFoundError as exc:
        assert str(exc) == "任务不存在或无法取消"
        assert exc.details["task_id"] == "report-done"

    assert state["set_calls"] == []


def test_cancel_report_allows_pending_task():
    current_task = _FakeReportTask("existing", "report-pending")
    current_task.status = "pending"
    service, _state = _build_service(current_task=current_task)

    payload = service.cancel_report("report-pending")

    assert payload["task"]["status"] == "cancelled"


def test_clear_report_log_returns_message_payload():
    service, state = _build_service()

    payload = service.clear_report_log()

    assert payload == {"message": "日志已清空"}
    assert state["clear_log_calls"] == 1