Showing
3 changed files
with
402 additions
and
39 deletions
| @@ -8,7 +8,7 @@ import os | @@ -8,7 +8,7 @@ import os | ||
| 8 | from pathlib import Path | 8 | from pathlib import Path |
| 9 | from uuid import uuid4 | 9 | from uuid import uuid4 |
| 10 | from datetime import datetime | 10 | from datetime import datetime |
| 11 | -from typing import Optional, Dict, Any, List | 11 | +from typing import Optional, Dict, Any, List, Callable |
| 12 | 12 | ||
| 13 | from loguru import logger | 13 | from loguru import logger |
| 14 | 14 | ||
| @@ -23,6 +23,7 @@ from .llms import LLMClient | @@ -23,6 +23,7 @@ from .llms import LLMClient | ||
| 23 | from .nodes import ( | 23 | from .nodes import ( |
| 24 | TemplateSelectionNode, | 24 | TemplateSelectionNode, |
| 25 | ChapterGenerationNode, | 25 | ChapterGenerationNode, |
| 26 | + ChapterJsonParseError, | ||
| 26 | DocumentLayoutNode, | 27 | DocumentLayoutNode, |
| 27 | WordBudgetNode, | 28 | WordBudgetNode, |
| 28 | ) | 29 | ) |
| @@ -205,10 +206,11 @@ class ReportAgent: | @@ -205,10 +206,11 @@ class ReportAgent: | ||
| 205 | ) | 206 | ) |
| 206 | 207 | ||
| 207 | def generate_report(self, query: str, reports: List[Any], forum_logs: str = "", | 208 | def generate_report(self, query: str, reports: List[Any], forum_logs: str = "", |
| 208 | - custom_template: str = "", save_report: bool = True) -> str: | 209 | + custom_template: str = "", save_report: bool = True, |
| 210 | + stream_handler: Optional[Callable[[str, Dict[str, Any]], None]] = None) -> str: | ||
| 209 | """ | 211 | """ |
| 210 | 生成综合报告(章节JSON → IR → HTML) | 212 | 生成综合报告(章节JSON → IR → HTML) |
| 211 | - | 213 | + |
| 212 | Returns: | 214 | Returns: |
| 213 | dict: HTML内容以及保存的文件路径信息 | 215 | dict: HTML内容以及保存的文件路径信息 |
| 214 | """ | 216 | """ |
| @@ -220,15 +222,32 @@ class ReportAgent: | @@ -220,15 +222,32 @@ class ReportAgent: | ||
| 220 | self.state.mark_processing() | 222 | self.state.mark_processing() |
| 221 | 223 | ||
| 222 | normalized_reports = self._normalize_reports(reports) | 224 | normalized_reports = self._normalize_reports(reports) |
| 225 | + | ||
| 226 | + def emit(event_type: str, payload: Dict[str, Any]): | ||
| 227 | + """面向Report Engine流通道的事件分发器,保证错误不外泄。""" | ||
| 228 | + if not stream_handler: | ||
| 229 | + return | ||
| 230 | + try: | ||
| 231 | + stream_handler(event_type, payload) | ||
| 232 | + except Exception as callback_error: # pragma: no cover - 仅记录 | ||
| 233 | + logger.warning(f"流式事件回调失败: {callback_error}") | ||
| 234 | + | ||
| 223 | logger.info(f"开始生成报告 {report_id}: {query}") | 235 | logger.info(f"开始生成报告 {report_id}: {query}") |
| 224 | logger.info(f"输入数据 - 报告数量: {len(reports)}, 论坛日志长度: {len(str(forum_logs))}") | 236 | logger.info(f"输入数据 - 报告数量: {len(reports)}, 论坛日志长度: {len(str(forum_logs))}") |
| 237 | + emit('stage', {'stage': 'agent_start', 'report_id': report_id, 'query': query}) | ||
| 225 | 238 | ||
| 226 | try: | 239 | try: |
| 227 | template_result = self._select_template(query, reports, forum_logs, custom_template) | 240 | template_result = self._select_template(query, reports, forum_logs, custom_template) |
| 228 | self.state.metadata.template_used = template_result.get('template_name', '') | 241 | self.state.metadata.template_used = template_result.get('template_name', '') |
| 242 | + emit('stage', { | ||
| 243 | + 'stage': 'template_selected', | ||
| 244 | + 'template': template_result.get('template_name'), | ||
| 245 | + 'reason': template_result.get('selection_reason') | ||
| 246 | + }) | ||
| 229 | sections = self._slice_template(template_result.get('template_content', '')) | 247 | sections = self._slice_template(template_result.get('template_content', '')) |
| 230 | if not sections: | 248 | if not sections: |
| 231 | raise ValueError("模板无法解析出章节,请检查模板内容。") | 249 | raise ValueError("模板无法解析出章节,请检查模板内容。") |
| 250 | + emit('stage', {'stage': 'template_sliced', 'section_count': len(sections)}) | ||
| 232 | 251 | ||
| 233 | template_text = template_result.get('template_content', '') | 252 | template_text = template_result.get('template_content', '') |
| 234 | template_overview = self._build_template_overview(template_text, sections) | 253 | template_overview = self._build_template_overview(template_text, sections) |
| @@ -241,6 +260,11 @@ class ReportAgent: | @@ -241,6 +260,11 @@ class ReportAgent: | ||
| 241 | query, | 260 | query, |
| 242 | template_overview, | 261 | template_overview, |
| 243 | ) | 262 | ) |
| 263 | + emit('stage', { | ||
| 264 | + 'stage': 'layout_designed', | ||
| 265 | + 'title': layout_design.get('title'), | ||
| 266 | + 'toc': layout_design.get('tocTitle') | ||
| 267 | + }) | ||
| 244 | # 使用刚生成的设计稿对全书进行篇幅规划,约束各章字数与重点 | 268 | # 使用刚生成的设计稿对全书进行篇幅规划,约束各章字数与重点 |
| 245 | word_plan = self.word_budget_node.run( | 269 | word_plan = self.word_budget_node.run( |
| 246 | sections, | 270 | sections, |
| @@ -250,6 +274,10 @@ class ReportAgent: | @@ -250,6 +274,10 @@ class ReportAgent: | ||
| 250 | query, | 274 | query, |
| 251 | template_overview, | 275 | template_overview, |
| 252 | ) | 276 | ) |
| 277 | + emit('stage', { | ||
| 278 | + 'stage': 'word_plan_ready', | ||
| 279 | + 'chapter_targets': len(word_plan.get('chapters', [])) | ||
| 280 | + }) | ||
| 253 | # 记录每个章节的目标字数/强调点,后续传给章节LLM | 281 | # 记录每个章节的目标字数/强调点,后续传给章节LLM |
| 254 | chapter_targets = { | 282 | chapter_targets = { |
| 255 | entry.get("chapterId"): entry | 283 | entry.get("chapterId"): entry |
| @@ -296,23 +324,97 @@ class ReportAgent: | @@ -296,23 +324,97 @@ class ReportAgent: | ||
| 296 | # 初始化章节输出目录并写入manifest,方便流式存盘 | 324 | # 初始化章节输出目录并写入manifest,方便流式存盘 |
| 297 | run_dir = self.chapter_storage.start_session(report_id, manifest_meta) | 325 | run_dir = self.chapter_storage.start_session(report_id, manifest_meta) |
| 298 | self._persist_planning_artifacts(run_dir, layout_design, word_plan, template_overview) | 326 | self._persist_planning_artifacts(run_dir, layout_design, word_plan, template_overview) |
| 327 | + emit('stage', {'stage': 'storage_ready', 'run_dir': str(run_dir)}) | ||
| 299 | 328 | ||
| 300 | chapters = [] | 329 | chapters = [] |
| 330 | + chapter_max_attempts = max(1, self.config.CHAPTER_JSON_MAX_ATTEMPTS) | ||
| 301 | for section in sections: | 331 | for section in sections: |
| 302 | logger.info(f"生成章节: {section.title}") | 332 | logger.info(f"生成章节: {section.title}") |
| 303 | - chapter = self.chapter_generation_node.run( | ||
| 304 | - section, | ||
| 305 | - generation_context, | ||
| 306 | - run_dir | ||
| 307 | - ) | ||
| 308 | - chapters.append(chapter) | 333 | + emit('chapter_status', { |
| 334 | + 'chapterId': section.chapter_id, | ||
| 335 | + 'title': section.title, | ||
| 336 | + 'status': 'running' | ||
| 337 | + }) | ||
| 338 | + # 章节流式回调:把LLM返回的delta透传给SSE,便于前端实时渲染 | ||
| 339 | + def chunk_callback(delta: str, meta: Dict[str, Any], section_ref: TemplateSection = section): | ||
| 340 | + emit('chapter_chunk', { | ||
| 341 | + 'chapterId': meta.get('chapterId') or section_ref.chapter_id, | ||
| 342 | + 'title': meta.get('title') or section_ref.title, | ||
| 343 | + 'delta': delta | ||
| 344 | + }) | ||
| 345 | + | ||
| 346 | + chapter_payload: Dict[str, Any] | None = None | ||
| 347 | + attempt = 1 | ||
| 348 | + while attempt <= chapter_max_attempts: | ||
| 349 | + try: | ||
| 350 | + chapter_payload = self.chapter_generation_node.run( | ||
| 351 | + section, | ||
| 352 | + generation_context, | ||
| 353 | + run_dir, | ||
| 354 | + stream_callback=chunk_callback | ||
| 355 | + ) | ||
| 356 | + break | ||
| 357 | + except ChapterJsonParseError as parse_error: | ||
| 358 | + logger.warning( | ||
| 359 | + "章节 %s JSON解析失败(第 %s/%s 次尝试): %s", | ||
| 360 | + section.title, | ||
| 361 | + attempt, | ||
| 362 | + chapter_max_attempts, | ||
| 363 | + parse_error, | ||
| 364 | + ) | ||
| 365 | + emit('chapter_status', { | ||
| 366 | + 'chapterId': section.chapter_id, | ||
| 367 | + 'title': section.title, | ||
| 368 | + 'status': 'retrying' if attempt < chapter_max_attempts else 'error', | ||
| 369 | + 'attempt': attempt, | ||
| 370 | + 'error': str(parse_error), | ||
| 371 | + }) | ||
| 372 | + if attempt >= chapter_max_attempts: | ||
| 373 | + raise | ||
| 374 | + attempt += 1 | ||
| 375 | + continue | ||
| 376 | + except Exception as chapter_error: | ||
| 377 | + if not self._should_retry_inappropriate_content_error(chapter_error): | ||
| 378 | + raise | ||
| 379 | + logger.warning( | ||
| 380 | + "章节 %s 触发内容安全限制(第 %s/%s 次尝试),准备重新生成: %s", | ||
| 381 | + section.title, | ||
| 382 | + attempt, | ||
| 383 | + chapter_max_attempts, | ||
| 384 | + chapter_error, | ||
| 385 | + ) | ||
| 386 | + emit('chapter_status', { | ||
| 387 | + 'chapterId': section.chapter_id, | ||
| 388 | + 'title': section.title, | ||
| 389 | + 'status': 'retrying' if attempt < chapter_max_attempts else 'error', | ||
| 390 | + 'attempt': attempt, | ||
| 391 | + 'error': str(chapter_error), | ||
| 392 | + 'reason': 'content_filter' | ||
| 393 | + }) | ||
| 394 | + if attempt >= chapter_max_attempts: | ||
| 395 | + raise | ||
| 396 | + attempt += 1 | ||
| 397 | + continue | ||
| 398 | + if chapter_payload is None: | ||
| 399 | + raise ChapterJsonParseError( | ||
| 400 | + f"{section.title} 章节JSON在 {chapter_max_attempts} 次尝试后仍无法解析" | ||
| 401 | + ) | ||
| 402 | + chapters.append(chapter_payload) | ||
| 403 | + emit('chapter_status', { | ||
| 404 | + 'chapterId': section.chapter_id, | ||
| 405 | + 'title': section.title, | ||
| 406 | + 'status': 'completed', | ||
| 407 | + 'attempt': attempt, | ||
| 408 | + }) | ||
| 309 | 409 | ||
| 310 | document_ir = self.document_composer.build_document( | 410 | document_ir = self.document_composer.build_document( |
| 311 | report_id, | 411 | report_id, |
| 312 | manifest_meta, | 412 | manifest_meta, |
| 313 | chapters | 413 | chapters |
| 314 | ) | 414 | ) |
| 415 | + emit('stage', {'stage': 'chapters_compiled', 'chapter_count': len(chapters)}) | ||
| 315 | html_report = self.renderer.render(document_ir) | 416 | html_report = self.renderer.render(document_ir) |
| 417 | + emit('stage', {'stage': 'html_rendered', 'html_length': len(html_report)}) | ||
| 316 | 418 | ||
| 317 | self.state.html_content = html_report | 419 | self.state.html_content = html_report |
| 318 | self.state.mark_completed() | 420 | self.state.mark_completed() |
| @@ -320,10 +422,12 @@ class ReportAgent: | @@ -320,10 +422,12 @@ class ReportAgent: | ||
| 320 | saved_files = {} | 422 | saved_files = {} |
| 321 | if save_report: | 423 | if save_report: |
| 322 | saved_files = self._save_report(html_report, document_ir, report_id) | 424 | saved_files = self._save_report(html_report, document_ir, report_id) |
| 425 | + emit('stage', {'stage': 'report_saved', 'files': saved_files}) | ||
| 323 | 426 | ||
| 324 | generation_time = (datetime.now() - start_time).total_seconds() | 427 | generation_time = (datetime.now() - start_time).total_seconds() |
| 325 | self.state.metadata.generation_time = generation_time | 428 | self.state.metadata.generation_time = generation_time |
| 326 | logger.info(f"报告生成完成,耗时: {generation_time:.2f} 秒") | 429 | logger.info(f"报告生成完成,耗时: {generation_time:.2f} 秒") |
| 430 | + emit('metrics', {'generation_seconds': generation_time}) | ||
| 327 | return { | 431 | return { |
| 328 | 'html_content': html_report, | 432 | 'html_content': html_report, |
| 329 | 'report_id': report_id, | 433 | 'report_id': report_id, |
| @@ -333,6 +437,7 @@ class ReportAgent: | @@ -333,6 +437,7 @@ class ReportAgent: | ||
| 333 | except Exception as e: | 437 | except Exception as e: |
| 334 | self.state.mark_failed(str(e)) | 438 | self.state.mark_failed(str(e)) |
| 335 | logger.exception(f"报告生成过程中发生错误: {str(e)}") | 439 | logger.exception(f"报告生成过程中发生错误: {str(e)}") |
| 440 | + emit('error', {'stage': 'agent_failed', 'message': str(e)}) | ||
| 336 | raise | 441 | raise |
| 337 | 442 | ||
| 338 | def _select_template(self, query: str, reports: List[Any], forum_logs: str, custom_template: str): | 443 | def _select_template(self, query: str, reports: List[Any], forum_logs: str, custom_template: str): |
| @@ -444,6 +549,22 @@ class ReportAgent: | @@ -444,6 +549,22 @@ class ReportAgent: | ||
| 444 | normalized[key] = self._stringify(value) | 549 | normalized[key] = self._stringify(value) |
| 445 | return normalized | 550 | return normalized |
| 446 | 551 | ||
| 552 | + def _should_retry_inappropriate_content_error(self, error: Exception) -> bool: | ||
| 553 | + """ | ||
| 554 | + 判断LLM异常是否由内容安全/不当内容导致,满足时允许重新生成整章。 | ||
| 555 | + """ | ||
| 556 | + message = str(error) if error else "" | ||
| 557 | + if not message: | ||
| 558 | + return False | ||
| 559 | + normalized = message.lower() | ||
| 560 | + keywords = [ | ||
| 561 | + "inappropriate content", | ||
| 562 | + "content violation", | ||
| 563 | + "content moderation", | ||
| 564 | + "model-studio/error-code", | ||
| 565 | + ] | ||
| 566 | + return any(keyword in normalized for keyword in keywords) | ||
| 567 | + | ||
| 447 | def _stringify(self, value: Any) -> str: | 568 | def _stringify(self, value: Any) -> str: |
| 448 | """安全地将对象转成字符串""" | 569 | """安全地将对象转成字符串""" |
| 449 | if value is None: | 570 | if value is None: |
| @@ -7,9 +7,11 @@ import os | @@ -7,9 +7,11 @@ import os | ||
| 7 | import json | 7 | import json |
| 8 | import threading | 8 | import threading |
| 9 | import time | 9 | import time |
| 10 | +from collections import deque, defaultdict | ||
| 10 | from datetime import datetime | 11 | from datetime import datetime |
| 11 | -from flask import Blueprint, request, jsonify, Response, send_file | ||
| 12 | -from typing import Dict, Any | 12 | +from queue import Queue, Empty |
| 13 | +from flask import Blueprint, request, jsonify, Response, send_file, stream_with_context | ||
| 14 | +from typing import Dict, Any, List, Optional | ||
| 13 | from loguru import logger | 15 | from loguru import logger |
| 14 | from .agent import ReportAgent, create_agent | 16 | from .agent import ReportAgent, create_agent |
| 15 | from .utils.config import settings | 17 | from .utils.config import settings |
| @@ -23,6 +25,69 @@ report_agent = None | @@ -23,6 +25,69 @@ report_agent = None | ||
| 23 | current_task = None | 25 | current_task = None |
| 24 | task_lock = threading.Lock() | 26 | task_lock = threading.Lock() |
| 25 | 27 | ||
| 28 | +# ====== 流式推送与任务历史管理 ====== | ||
| 29 | +# 通过有界deque缓存最近的事件,方便SSE断线后快速补发 | ||
| 30 | +MAX_TASK_HISTORY = 5 | ||
| 31 | +STREAM_HEARTBEAT_INTERVAL = 15 # 心跳间隔秒 | ||
| 32 | +stream_lock = threading.Lock() | ||
| 33 | +stream_subscribers = defaultdict(list) | ||
| 34 | +tasks_registry: Dict[str, 'ReportTask'] = {} | ||
| 35 | + | ||
| 36 | + | ||
| 37 | +def _register_stream(task_id: str) -> Queue: | ||
| 38 | + """为指定任务注册一个事件队列,供SSE监听器消费。""" | ||
| 39 | + queue = Queue() | ||
| 40 | + with stream_lock: | ||
| 41 | + stream_subscribers[task_id].append(queue) | ||
| 42 | + return queue | ||
| 43 | + | ||
| 44 | + | ||
| 45 | +def _unregister_stream(task_id: str, queue: Queue): | ||
| 46 | + """安全移除事件队列,避免内存泄漏。""" | ||
| 47 | + with stream_lock: | ||
| 48 | + listeners = stream_subscribers.get(task_id, []) | ||
| 49 | + if queue in listeners: | ||
| 50 | + listeners.remove(queue) | ||
| 51 | + if not listeners and task_id in stream_subscribers: | ||
| 52 | + stream_subscribers.pop(task_id, None) | ||
| 53 | + | ||
| 54 | + | ||
| 55 | +def _broadcast_event(task_id: str, event: Dict[str, Any]): | ||
| 56 | + """将事件推送给所有监听者,失败时做好异常捕获。""" | ||
| 57 | + with stream_lock: | ||
| 58 | + listeners = list(stream_subscribers.get(task_id, [])) | ||
| 59 | + for queue in listeners: | ||
| 60 | + try: | ||
| 61 | + queue.put(event, timeout=0.1) | ||
| 62 | + except Exception: | ||
| 63 | + logger.exception("推送流式事件失败,跳过当前监听队列") | ||
| 64 | + | ||
| 65 | + | ||
| 66 | +def _prune_task_history_locked(): | ||
| 67 | + """在task_lock持有期间调用,清理过多的历史任务以控制内存。""" | ||
| 68 | + if len(tasks_registry) <= MAX_TASK_HISTORY: | ||
| 69 | + return | ||
| 70 | + # 按创建时间排序,移除最旧的任务 | ||
| 71 | + sorted_tasks = sorted(tasks_registry.values(), key=lambda t: t.created_at) | ||
| 72 | + for task in sorted_tasks[:-MAX_TASK_HISTORY]: | ||
| 73 | + tasks_registry.pop(task.task_id, None) | ||
| 74 | + | ||
| 75 | + | ||
| 76 | +def _get_task(task_id: str) -> Optional['ReportTask']: | ||
| 77 | + """统一的任务查找方法,优先返回当前任务。""" | ||
| 78 | + with task_lock: | ||
| 79 | + if current_task and current_task.task_id == task_id: | ||
| 80 | + return current_task | ||
| 81 | + return tasks_registry.get(task_id) | ||
| 82 | + | ||
| 83 | + | ||
| 84 | +def _format_sse(event: Dict[str, Any]) -> str: | ||
| 85 | + """按SSE协议格式化消息。""" | ||
| 86 | + payload = json.dumps(event, ensure_ascii=False) | ||
| 87 | + event_id = event.get('id', 0) | ||
| 88 | + event_type = event.get('type', 'message') | ||
| 89 | + return f"id: {event_id}\nevent: {event_type}\ndata: {payload}\n\n" | ||
| 90 | + | ||
| 26 | 91 | ||
| 27 | def initialize_report_engine(): | 92 | def initialize_report_engine(): |
| 28 | """初始化Report Engine""" | 93 | """初始化Report Engine""" |
| @@ -63,6 +128,11 @@ class ReportTask: | @@ -63,6 +128,11 @@ class ReportTask: | ||
| 63 | self.report_file_name = "" | 128 | self.report_file_name = "" |
| 64 | self.state_file_path = "" | 129 | self.state_file_path = "" |
| 65 | self.state_file_relative_path = "" | 130 | self.state_file_relative_path = "" |
| 131 | + # ====== 流式事件缓存与并发保护 ====== | ||
| 132 | + # 使用deque保存最近的事件,结合锁保证多线程下的安全访问 | ||
| 133 | + self.event_history: deque = deque(maxlen=1000) | ||
| 134 | + self._event_lock = threading.Lock() | ||
| 135 | + self.last_event_id = 0 | ||
| 66 | 136 | ||
| 67 | def update_status(self, status: str, progress: int = None, error_message: str = ""): | 137 | def update_status(self, status: str, progress: int = None, error_message: str = ""): |
| 68 | """更新任务状态""" | 138 | """更新任务状态""" |
| @@ -72,6 +142,17 @@ class ReportTask: | @@ -72,6 +142,17 @@ class ReportTask: | ||
| 72 | if error_message: | 142 | if error_message: |
| 73 | self.error_message = error_message | 143 | self.error_message = error_message |
| 74 | self.updated_at = datetime.now() | 144 | self.updated_at = datetime.now() |
| 145 | + # 推送状态变更事件,方便前端实时刷新 | ||
| 146 | + self.publish_event( | ||
| 147 | + 'status', | ||
| 148 | + { | ||
| 149 | + 'status': self.status, | ||
| 150 | + 'progress': self.progress, | ||
| 151 | + 'error_message': self.error_message, | ||
| 152 | + 'hint': error_message or '', | ||
| 153 | + 'task': self.to_dict(), | ||
| 154 | + } | ||
| 155 | + ) | ||
| 75 | 156 | ||
| 76 | def to_dict(self) -> Dict[str, Any]: | 157 | def to_dict(self) -> Dict[str, Any]: |
| 77 | """转换为字典格式""" | 158 | """转换为字典格式""" |
| @@ -91,6 +172,29 @@ class ReportTask: | @@ -91,6 +172,29 @@ class ReportTask: | ||
| 91 | 'state_file_path': self.state_file_relative_path or self.state_file_path | 172 | 'state_file_path': self.state_file_relative_path or self.state_file_path |
| 92 | } | 173 | } |
| 93 | 174 | ||
| 175 | + def publish_event(self, event_type: str, payload: Dict[str, Any]) -> None: | ||
| 176 | + """将任意事件放入缓存并广播,所有新增逻辑均配套中文说明。""" | ||
| 177 | + timestamp = datetime.utcnow().isoformat() + 'Z' | ||
| 178 | + event: Dict[str, Any] = { | ||
| 179 | + 'id': 0, | ||
| 180 | + 'type': event_type, | ||
| 181 | + 'task_id': self.task_id, | ||
| 182 | + 'timestamp': timestamp, | ||
| 183 | + 'payload': payload, | ||
| 184 | + } | ||
| 185 | + with self._event_lock: | ||
| 186 | + self.last_event_id += 1 | ||
| 187 | + event['id'] = self.last_event_id | ||
| 188 | + self.event_history.append(event) | ||
| 189 | + _broadcast_event(self.task_id, event) | ||
| 190 | + | ||
| 191 | + def history_since(self, last_event_id: Optional[int]) -> List[Dict[str, Any]]: | ||
| 192 | + """根据Last-Event-ID补发历史事件,确保断线重连无遗漏。""" | ||
| 193 | + with self._event_lock: | ||
| 194 | + if last_event_id is None: | ||
| 195 | + return list(self.event_history) | ||
| 196 | + return [evt for evt in self.event_history if evt['id'] > last_event_id] | ||
| 197 | + | ||
| 94 | 198 | ||
| 95 | def check_engines_ready() -> Dict[str, Any]: | 199 | def check_engines_ready() -> Dict[str, Any]: |
| 96 | """检查三个子引擎是否都有新文件""" | 200 | """检查三个子引擎是否都有新文件""" |
| @@ -121,7 +225,13 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = " | @@ -121,7 +225,13 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = " | ||
| 121 | global current_task | 225 | global current_task |
| 122 | 226 | ||
| 123 | try: | 227 | try: |
| 228 | + # 在局部闭包内封装推送逻辑,便于传递给ReportAgent | ||
| 229 | + def stream_handler(event_type: str, payload: Dict[str, Any]): | ||
| 230 | + """所有阶段事件都通过同一个接口分发,保证日志一致。""" | ||
| 231 | + task.publish_event(event_type, payload) | ||
| 232 | + | ||
| 124 | task.update_status("running", 10) | 233 | task.update_status("running", 10) |
| 234 | + task.publish_event('stage', {'message': '任务已启动,正在检查输入文件', 'stage': 'prepare'}) | ||
| 125 | 235 | ||
| 126 | # 检查输入文件 | 236 | # 检查输入文件 |
| 127 | check_result = check_engines_ready() | 237 | check_result = check_engines_ready() |
| @@ -129,21 +239,54 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = " | @@ -129,21 +239,54 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = " | ||
| 129 | task.update_status("error", 0, f"输入文件未准备就绪: {check_result.get('missing_files', [])}") | 239 | task.update_status("error", 0, f"输入文件未准备就绪: {check_result.get('missing_files', [])}") |
| 130 | return | 240 | return |
| 131 | 241 | ||
| 242 | + task.publish_event('stage', { | ||
| 243 | + 'message': '输入文件检查通过,准备载入内容', | ||
| 244 | + 'stage': 'io_ready', | ||
| 245 | + 'files': check_result.get('latest_files', {}) | ||
| 246 | + }) | ||
| 247 | + | ||
| 132 | task.update_status("running", 30) | 248 | task.update_status("running", 30) |
| 133 | 249 | ||
| 134 | # 加载输入文件 | 250 | # 加载输入文件 |
| 135 | content = report_agent.load_input_files(check_result['latest_files']) | 251 | content = report_agent.load_input_files(check_result['latest_files']) |
| 252 | + task.publish_event('stage', {'message': '源数据加载完成,启动生成流程', 'stage': 'data_loaded'}) | ||
| 136 | 253 | ||
| 137 | task.update_status("running", 50) | 254 | task.update_status("running", 50) |
| 138 | 255 | ||
| 139 | - # 生成报告 | ||
| 140 | - generation_result = report_agent.generate_report( | ||
| 141 | - query=query, | ||
| 142 | - reports=content['reports'], | ||
| 143 | - forum_logs=content['forum_logs'], | ||
| 144 | - custom_template=custom_template, | ||
| 145 | - save_report=True | ||
| 146 | - ) | 256 | + # 生成报告(附带兜底重试,缓解瞬时网络抖动) |
| 257 | + for attempt in range(1, 3): | ||
| 258 | + try: | ||
| 259 | + task.publish_event('stage', { | ||
| 260 | + 'message': f'正在调用ReportAgent生成报告(第{attempt}次尝试)', | ||
| 261 | + 'stage': 'agent_running', | ||
| 262 | + 'attempt': attempt | ||
| 263 | + }) | ||
| 264 | + generation_result = report_agent.generate_report( | ||
| 265 | + query=query, | ||
| 266 | + reports=content['reports'], | ||
| 267 | + forum_logs=content['forum_logs'], | ||
| 268 | + custom_template=custom_template, | ||
| 269 | + save_report=True, | ||
| 270 | + stream_handler=stream_handler | ||
| 271 | + ) | ||
| 272 | + break | ||
| 273 | + except Exception as err: | ||
| 274 | + # 将错误即时推送至前端,方便观察重试策略 | ||
| 275 | + task.publish_event('warning', { | ||
| 276 | + 'message': f'ReportAgent执行失败: {str(err)}', | ||
| 277 | + 'stage': 'agent_running', | ||
| 278 | + 'attempt': attempt | ||
| 279 | + }) | ||
| 280 | + if attempt == 2: | ||
| 281 | + raise | ||
| 282 | + # 简单的指数退避,防止频繁触发限流(单位秒) | ||
| 283 | + backoff = min(5 * attempt, 15) | ||
| 284 | + task.publish_event('stage', { | ||
| 285 | + 'message': f'{backoff} 秒后重试生成任务', | ||
| 286 | + 'stage': 'retry_wait', | ||
| 287 | + 'wait_seconds': backoff | ||
| 288 | + }) | ||
| 289 | + time.sleep(backoff) | ||
| 147 | 290 | ||
| 148 | if isinstance(generation_result, dict): | 291 | if isinstance(generation_result, dict): |
| 149 | html_report = generation_result.get('html_content', '') | 292 | html_report = generation_result.get('html_content', '') |
| @@ -151,6 +294,7 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = " | @@ -151,6 +294,7 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = " | ||
| 151 | html_report = generation_result | 294 | html_report = generation_result |
| 152 | 295 | ||
| 153 | task.update_status("running", 90) | 296 | task.update_status("running", 90) |
| 297 | + task.publish_event('stage', {'message': '报告生成完毕,准备持久化', 'stage': 'persist'}) | ||
| 154 | 298 | ||
| 155 | # 保存结果 | 299 | # 保存结果 |
| 156 | task.html_content = html_report | 300 | task.html_content = html_report |
| @@ -160,11 +304,28 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = " | @@ -160,11 +304,28 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = " | ||
| 160 | task.report_file_name = generation_result.get('report_filename', '') | 304 | task.report_file_name = generation_result.get('report_filename', '') |
| 161 | task.state_file_path = generation_result.get('state_filepath', '') | 305 | task.state_file_path = generation_result.get('state_filepath', '') |
| 162 | task.state_file_relative_path = generation_result.get('state_relative_path', '') | 306 | task.state_file_relative_path = generation_result.get('state_relative_path', '') |
| 307 | + task.publish_event('html_ready', { | ||
| 308 | + 'message': 'HTML渲染完成,可刷新预览', | ||
| 309 | + 'report_file': task.report_file_relative_path or task.report_file_path, | ||
| 310 | + 'state_file': task.state_file_relative_path or task.state_file_path, | ||
| 311 | + 'task': task.to_dict(), | ||
| 312 | + }) | ||
| 163 | task.update_status("completed", 100) | 313 | task.update_status("completed", 100) |
| 314 | + task.publish_event('completed', { | ||
| 315 | + 'message': '任务完成', | ||
| 316 | + 'duration_seconds': (task.updated_at - task.created_at).total_seconds(), | ||
| 317 | + 'report_file': task.report_file_relative_path or task.report_file_path, | ||
| 318 | + 'task': task.to_dict(), | ||
| 319 | + }) | ||
| 164 | 320 | ||
| 165 | except Exception as e: | 321 | except Exception as e: |
| 166 | logger.exception(f"报告生成过程中发生错误: {str(e)}") | 322 | logger.exception(f"报告生成过程中发生错误: {str(e)}") |
| 167 | task.update_status("error", 0, str(e)) | 323 | task.update_status("error", 0, str(e)) |
| 324 | + task.publish_event('error', { | ||
| 325 | + 'message': str(e), | ||
| 326 | + 'stage': 'failed', | ||
| 327 | + 'task': task.to_dict(), | ||
| 328 | + }) | ||
| 168 | # 只在出错时清理任务 | 329 | # 只在出错时清理任务 |
| 169 | with task_lock: | 330 | with task_lock: |
| 170 | if current_task and current_task.task_id == task.task_id: | 331 | if current_task and current_task.task_id == task.task_id: |
| @@ -242,6 +403,19 @@ def generate_report(): | @@ -242,6 +403,19 @@ def generate_report(): | ||
| 242 | 403 | ||
| 243 | with task_lock: | 404 | with task_lock: |
| 244 | current_task = task | 405 | current_task = task |
| 406 | + tasks_registry[task_id] = task | ||
| 407 | + _prune_task_history_locked() | ||
| 408 | + | ||
| 409 | + # 通过主动推送pending事件告知前端任务已经排队 | ||
| 410 | + task.publish_event( | ||
| 411 | + 'status', | ||
| 412 | + { | ||
| 413 | + 'status': task.status, | ||
| 414 | + 'progress': task.progress, | ||
| 415 | + 'message': '任务已排队,等待资源空闲', | ||
| 416 | + 'task': task.to_dict(), | ||
| 417 | + } | ||
| 418 | + ) | ||
| 245 | 419 | ||
| 246 | # 在后台线程中运行报告生成 | 420 | # 在后台线程中运行报告生成 |
| 247 | thread = threading.Thread( | 421 | thread = threading.Thread( |
| @@ -255,7 +429,8 @@ def generate_report(): | @@ -255,7 +429,8 @@ def generate_report(): | ||
| 255 | 'success': True, | 429 | 'success': True, |
| 256 | 'task_id': task_id, | 430 | 'task_id': task_id, |
| 257 | 'message': '报告生成已启动', | 431 | 'message': '报告生成已启动', |
| 258 | - 'task': task.to_dict() | 432 | + 'task': task.to_dict(), |
| 433 | + 'stream_url': f"/api/report/stream/{task_id}" | ||
| 259 | }) | 434 | }) |
| 260 | 435 | ||
| 261 | except Exception as e: | 436 | except Exception as e: |
| @@ -270,9 +445,9 @@ def generate_report(): | @@ -270,9 +445,9 @@ def generate_report(): | ||
| 270 | def get_progress(task_id: str): | 445 | def get_progress(task_id: str): |
| 271 | """获取报告生成进度""" | 446 | """获取报告生成进度""" |
| 272 | try: | 447 | try: |
| 273 | - if not current_task or current_task.task_id != task_id: | ||
| 274 | - # 如果任务不存在,可能是已经完成并被清理了 | ||
| 275 | - # 返回一个默认的完成状态而不是404 | 448 | + task = _get_task(task_id) |
| 449 | + if not task: | ||
| 450 | + # 如果任务不存在,可能是历史记录已被清理,回传一个完成态兜底 | ||
| 276 | return jsonify({ | 451 | return jsonify({ |
| 277 | 'success': True, | 452 | 'success': True, |
| 278 | 'task': { | 453 | 'task': { |
| @@ -291,7 +466,7 @@ def get_progress(task_id: str): | @@ -291,7 +466,7 @@ def get_progress(task_id: str): | ||
| 291 | 466 | ||
| 292 | return jsonify({ | 467 | return jsonify({ |
| 293 | 'success': True, | 468 | 'success': True, |
| 294 | - 'task': current_task.to_dict() | 469 | + 'task': task.to_dict() |
| 295 | }) | 470 | }) |
| 296 | 471 | ||
| 297 | except Exception as e: | 472 | except Exception as e: |
| @@ -302,25 +477,78 @@ def get_progress(task_id: str): | @@ -302,25 +477,78 @@ def get_progress(task_id: str): | ||
| 302 | }), 500 | 477 | }), 500 |
| 303 | 478 | ||
| 304 | 479 | ||
| 480 | +@report_bp.route('/stream/<task_id>', methods=['GET']) | ||
| 481 | +def stream_task(task_id: str): | ||
| 482 | + """基于SSE的实时推送接口,向前端持续广播阶段事件。""" | ||
| 483 | + task = _get_task(task_id) | ||
| 484 | + if not task: | ||
| 485 | + return jsonify({'success': False, 'error': '任务不存在'}), 404 | ||
| 486 | + | ||
| 487 | + last_event_header = request.headers.get('Last-Event-ID') | ||
| 488 | + try: | ||
| 489 | + last_event_id = int(last_event_header) if last_event_header else None | ||
| 490 | + except ValueError: | ||
| 491 | + last_event_id = None | ||
| 492 | + | ||
| 493 | + def event_generator(): | ||
| 494 | + queue = _register_stream(task_id) | ||
| 495 | + try: | ||
| 496 | + # 断线重连场景下,先补发历史事件,保证界面状态一致 | ||
| 497 | + history = task.history_since(last_event_id) | ||
| 498 | + for event in history: | ||
| 499 | + yield _format_sse(event) | ||
| 500 | + | ||
| 501 | + finished = task.status in ("completed", "error", "cancelled") | ||
| 502 | + while True: | ||
| 503 | + if finished: | ||
| 504 | + break | ||
| 505 | + try: | ||
| 506 | + event = queue.get(timeout=STREAM_HEARTBEAT_INTERVAL) | ||
| 507 | + yield _format_sse(event) | ||
| 508 | + if event.get('type') in ("completed", "error"): | ||
| 509 | + finished = True | ||
| 510 | + except Empty: | ||
| 511 | + heartbeat = { | ||
| 512 | + 'id': f"hb-{int(time.time() * 1000)}", | ||
| 513 | + 'type': 'heartbeat', | ||
| 514 | + 'task_id': task_id, | ||
| 515 | + 'timestamp': datetime.utcnow().isoformat() + 'Z', | ||
| 516 | + 'payload': {'status': task.status} | ||
| 517 | + } | ||
| 518 | + yield _format_sse(heartbeat) | ||
| 519 | + finished = task.status in ("completed", "error", "cancelled") | ||
| 520 | + finally: | ||
| 521 | + _unregister_stream(task_id, queue) | ||
| 522 | + | ||
| 523 | + response = Response( | ||
| 524 | + stream_with_context(event_generator()), | ||
| 525 | + mimetype='text/event-stream' | ||
| 526 | + ) | ||
| 527 | + response.headers['Cache-Control'] = 'no-cache' | ||
| 528 | + response.headers['X-Accel-Buffering'] = 'no' | ||
| 529 | + return response | ||
| 530 | + | ||
| 531 | + | ||
| 305 | @report_bp.route('/result/<task_id>', methods=['GET']) | 532 | @report_bp.route('/result/<task_id>', methods=['GET']) |
| 306 | def get_result(task_id: str): | 533 | def get_result(task_id: str): |
| 307 | """获取报告生成结果""" | 534 | """获取报告生成结果""" |
| 308 | try: | 535 | try: |
| 309 | - if not current_task or current_task.task_id != task_id: | 536 | + task = _get_task(task_id) |
| 537 | + if not task: | ||
| 310 | return jsonify({ | 538 | return jsonify({ |
| 311 | 'success': False, | 539 | 'success': False, |
| 312 | 'error': '任务不存在' | 540 | 'error': '任务不存在' |
| 313 | }), 404 | 541 | }), 404 |
| 314 | 542 | ||
| 315 | - if current_task.status != "completed": | 543 | + if task.status != "completed": |
| 316 | return jsonify({ | 544 | return jsonify({ |
| 317 | 'success': False, | 545 | 'success': False, |
| 318 | 'error': '报告尚未完成', | 546 | 'error': '报告尚未完成', |
| 319 | - 'task': current_task.to_dict() | 547 | + 'task': task.to_dict() |
| 320 | }), 400 | 548 | }), 400 |
| 321 | 549 | ||
| 322 | return Response( | 550 | return Response( |
| 323 | - current_task.html_content, | 551 | + task.html_content, |
| 324 | mimetype='text/html' | 552 | mimetype='text/html' |
| 325 | ) | 553 | ) |
| 326 | 554 | ||
| @@ -336,23 +564,24 @@ def get_result(task_id: str): | @@ -336,23 +564,24 @@ def get_result(task_id: str): | ||
| 336 | def get_result_json(task_id: str): | 564 | def get_result_json(task_id: str): |
| 337 | """获取报告生成结果(JSON格式)""" | 565 | """获取报告生成结果(JSON格式)""" |
| 338 | try: | 566 | try: |
| 339 | - if not current_task or current_task.task_id != task_id: | 567 | + task = _get_task(task_id) |
| 568 | + if not task: | ||
| 340 | return jsonify({ | 569 | return jsonify({ |
| 341 | 'success': False, | 570 | 'success': False, |
| 342 | 'error': '任务不存在' | 571 | 'error': '任务不存在' |
| 343 | }), 404 | 572 | }), 404 |
| 344 | 573 | ||
| 345 | - if current_task.status != "completed": | 574 | + if task.status != "completed": |
| 346 | return jsonify({ | 575 | return jsonify({ |
| 347 | 'success': False, | 576 | 'success': False, |
| 348 | 'error': '报告尚未完成', | 577 | 'error': '报告尚未完成', |
| 349 | - 'task': current_task.to_dict() | 578 | + 'task': task.to_dict() |
| 350 | }), 400 | 579 | }), 400 |
| 351 | 580 | ||
| 352 | return jsonify({ | 581 | return jsonify({ |
| 353 | 'success': True, | 582 | 'success': True, |
| 354 | - 'task': current_task.to_dict(), | ||
| 355 | - 'html_content': current_task.html_content | 583 | + 'task': task.to_dict(), |
| 584 | + 'html_content': task.html_content | ||
| 356 | }) | 585 | }) |
| 357 | 586 | ||
| 358 | except Exception as e: | 587 | except Exception as e: |
| @@ -367,27 +596,28 @@ def get_result_json(task_id: str): | @@ -367,27 +596,28 @@ def get_result_json(task_id: str): | ||
| 367 | def download_report(task_id: str): | 596 | def download_report(task_id: str): |
| 368 | """下载已生成的报告HTML文件""" | 597 | """下载已生成的报告HTML文件""" |
| 369 | try: | 598 | try: |
| 370 | - if not current_task or current_task.task_id != task_id: | 599 | + task = _get_task(task_id) |
| 600 | + if not task: | ||
| 371 | return jsonify({ | 601 | return jsonify({ |
| 372 | 'success': False, | 602 | 'success': False, |
| 373 | 'error': '任务不存在' | 603 | 'error': '任务不存在' |
| 374 | }), 404 | 604 | }), 404 |
| 375 | 605 | ||
| 376 | - if current_task.status != "completed" or not current_task.report_file_path: | 606 | + if task.status != "completed" or not task.report_file_path: |
| 377 | return jsonify({ | 607 | return jsonify({ |
| 378 | 'success': False, | 608 | 'success': False, |
| 379 | 'error': '报告尚未完成或尚未保存' | 609 | 'error': '报告尚未完成或尚未保存' |
| 380 | }), 400 | 610 | }), 400 |
| 381 | 611 | ||
| 382 | - if not os.path.exists(current_task.report_file_path): | 612 | + if not os.path.exists(task.report_file_path): |
| 383 | return jsonify({ | 613 | return jsonify({ |
| 384 | 'success': False, | 614 | 'success': False, |
| 385 | 'error': '报告文件不存在或已被删除' | 615 | 'error': '报告文件不存在或已被删除' |
| 386 | }), 404 | 616 | }), 404 |
| 387 | 617 | ||
| 388 | - download_name = current_task.report_file_name or os.path.basename(current_task.report_file_path) | 618 | + download_name = task.report_file_name or os.path.basename(task.report_file_path) |
| 389 | return send_file( | 619 | return send_file( |
| 390 | - current_task.report_file_path, | 620 | + task.report_file_path, |
| 391 | mimetype='text/html', | 621 | mimetype='text/html', |
| 392 | as_attachment=True, | 622 | as_attachment=True, |
| 393 | download_name=download_name | 623 | download_name=download_name |
| @@ -411,7 +641,18 @@ def cancel_task(task_id: str): | @@ -411,7 +641,18 @@ def cancel_task(task_id: str): | ||
| 411 | if current_task and current_task.task_id == task_id: | 641 | if current_task and current_task.task_id == task_id: |
| 412 | if current_task.status == "running": | 642 | if current_task.status == "running": |
| 413 | current_task.update_status("cancelled", 0, "用户取消任务") | 643 | current_task.update_status("cancelled", 0, "用户取消任务") |
| 644 | + current_task.publish_event('cancelled', { | ||
| 645 | + 'message': '任务被用户主动终止', | ||
| 646 | + 'task': current_task.to_dict(), | ||
| 647 | + }) | ||
| 414 | current_task = None | 648 | current_task = None |
| 649 | + task = tasks_registry.get(task_id) | ||
| 650 | + if task and task.status == 'running': | ||
| 651 | + task.update_status("cancelled", task.progress, "用户取消任务") | ||
| 652 | + task.publish_event('cancelled', { | ||
| 653 | + 'message': '任务被用户主动终止', | ||
| 654 | + 'task': task.to_dict(), | ||
| 655 | + }) | ||
| 415 | 656 | ||
| 416 | return jsonify({ | 657 | return jsonify({ |
| 417 | 'success': True, | 658 | 'success': True, |
| @@ -5,7 +5,7 @@ Report Engine节点处理模块 | @@ -5,7 +5,7 @@ Report Engine节点处理模块 | ||
| 5 | 5 | ||
| 6 | from .base_node import BaseNode, StateMutationNode | 6 | from .base_node import BaseNode, StateMutationNode |
| 7 | from .template_selection_node import TemplateSelectionNode | 7 | from .template_selection_node import TemplateSelectionNode |
| 8 | -from .chapter_generation_node import ChapterGenerationNode | 8 | +from .chapter_generation_node import ChapterGenerationNode, ChapterJsonParseError |
| 9 | from .document_layout_node import DocumentLayoutNode | 9 | from .document_layout_node import DocumentLayoutNode |
| 10 | from .word_budget_node import WordBudgetNode | 10 | from .word_budget_node import WordBudgetNode |
| 11 | 11 | ||
| @@ -14,6 +14,7 @@ __all__ = [ | @@ -14,6 +14,7 @@ __all__ = [ | ||
| 14 | "StateMutationNode", | 14 | "StateMutationNode", |
| 15 | "TemplateSelectionNode", | 15 | "TemplateSelectionNode", |
| 16 | "ChapterGenerationNode", | 16 | "ChapterGenerationNode", |
| 17 | + "ChapterJsonParseError", | ||
| 17 | "DocumentLayoutNode", | 18 | "DocumentLayoutNode", |
| 18 | "WordBudgetNode", | 19 | "WordBudgetNode", |
| 19 | ] | 20 | ] |
-
Please register or login to post a comment