马一丁

Streaming

... ... @@ -8,7 +8,7 @@ import os
from pathlib import Path
from uuid import uuid4
from datetime import datetime
from typing import Optional, Dict, Any, List
from typing import Optional, Dict, Any, List, Callable
from loguru import logger
... ... @@ -23,6 +23,7 @@ from .llms import LLMClient
from .nodes import (
TemplateSelectionNode,
ChapterGenerationNode,
ChapterJsonParseError,
DocumentLayoutNode,
WordBudgetNode,
)
... ... @@ -205,10 +206,11 @@ class ReportAgent:
)
def generate_report(self, query: str, reports: List[Any], forum_logs: str = "",
custom_template: str = "", save_report: bool = True) -> str:
custom_template: str = "", save_report: bool = True,
stream_handler: Optional[Callable[[str, Dict[str, Any]], None]] = None) -> str:
"""
生成综合报告(章节JSON → IR → HTML)
Returns:
dict: HTML内容以及保存的文件路径信息
"""
... ... @@ -220,15 +222,32 @@ class ReportAgent:
self.state.mark_processing()
normalized_reports = self._normalize_reports(reports)
def emit(event_type: str, payload: Dict[str, Any]):
"""面向Report Engine流通道的事件分发器,保证错误不外泄。"""
if not stream_handler:
return
try:
stream_handler(event_type, payload)
except Exception as callback_error: # pragma: no cover - 仅记录
logger.warning(f"流式事件回调失败: {callback_error}")
logger.info(f"开始生成报告 {report_id}: {query}")
logger.info(f"输入数据 - 报告数量: {len(reports)}, 论坛日志长度: {len(str(forum_logs))}")
emit('stage', {'stage': 'agent_start', 'report_id': report_id, 'query': query})
try:
template_result = self._select_template(query, reports, forum_logs, custom_template)
self.state.metadata.template_used = template_result.get('template_name', '')
emit('stage', {
'stage': 'template_selected',
'template': template_result.get('template_name'),
'reason': template_result.get('selection_reason')
})
sections = self._slice_template(template_result.get('template_content', ''))
if not sections:
raise ValueError("模板无法解析出章节,请检查模板内容。")
emit('stage', {'stage': 'template_sliced', 'section_count': len(sections)})
template_text = template_result.get('template_content', '')
template_overview = self._build_template_overview(template_text, sections)
... ... @@ -241,6 +260,11 @@ class ReportAgent:
query,
template_overview,
)
emit('stage', {
'stage': 'layout_designed',
'title': layout_design.get('title'),
'toc': layout_design.get('tocTitle')
})
# 使用刚生成的设计稿对全书进行篇幅规划,约束各章字数与重点
word_plan = self.word_budget_node.run(
sections,
... ... @@ -250,6 +274,10 @@ class ReportAgent:
query,
template_overview,
)
emit('stage', {
'stage': 'word_plan_ready',
'chapter_targets': len(word_plan.get('chapters', []))
})
# 记录每个章节的目标字数/强调点,后续传给章节LLM
chapter_targets = {
entry.get("chapterId"): entry
... ... @@ -296,23 +324,97 @@ class ReportAgent:
# 初始化章节输出目录并写入manifest,方便流式存盘
run_dir = self.chapter_storage.start_session(report_id, manifest_meta)
self._persist_planning_artifacts(run_dir, layout_design, word_plan, template_overview)
emit('stage', {'stage': 'storage_ready', 'run_dir': str(run_dir)})
chapters = []
chapter_max_attempts = max(1, self.config.CHAPTER_JSON_MAX_ATTEMPTS)
for section in sections:
logger.info(f"生成章节: {section.title}")
chapter = self.chapter_generation_node.run(
section,
generation_context,
run_dir
)
chapters.append(chapter)
emit('chapter_status', {
'chapterId': section.chapter_id,
'title': section.title,
'status': 'running'
})
# 章节流式回调:把LLM返回的delta透传给SSE,便于前端实时渲染
def chunk_callback(delta: str, meta: Dict[str, Any], section_ref: TemplateSection = section):
emit('chapter_chunk', {
'chapterId': meta.get('chapterId') or section_ref.chapter_id,
'title': meta.get('title') or section_ref.title,
'delta': delta
})
chapter_payload: Dict[str, Any] | None = None
attempt = 1
while attempt <= chapter_max_attempts:
try:
chapter_payload = self.chapter_generation_node.run(
section,
generation_context,
run_dir,
stream_callback=chunk_callback
)
break
except ChapterJsonParseError as parse_error:
logger.warning(
"章节 %s JSON解析失败(第 %s/%s 次尝试): %s",
section.title,
attempt,
chapter_max_attempts,
parse_error,
)
emit('chapter_status', {
'chapterId': section.chapter_id,
'title': section.title,
'status': 'retrying' if attempt < chapter_max_attempts else 'error',
'attempt': attempt,
'error': str(parse_error),
})
if attempt >= chapter_max_attempts:
raise
attempt += 1
continue
except Exception as chapter_error:
if not self._should_retry_inappropriate_content_error(chapter_error):
raise
logger.warning(
"章节 %s 触发内容安全限制(第 %s/%s 次尝试),准备重新生成: %s",
section.title,
attempt,
chapter_max_attempts,
chapter_error,
)
emit('chapter_status', {
'chapterId': section.chapter_id,
'title': section.title,
'status': 'retrying' if attempt < chapter_max_attempts else 'error',
'attempt': attempt,
'error': str(chapter_error),
'reason': 'content_filter'
})
if attempt >= chapter_max_attempts:
raise
attempt += 1
continue
if chapter_payload is None:
raise ChapterJsonParseError(
f"{section.title} 章节JSON在 {chapter_max_attempts} 次尝试后仍无法解析"
)
chapters.append(chapter_payload)
emit('chapter_status', {
'chapterId': section.chapter_id,
'title': section.title,
'status': 'completed',
'attempt': attempt,
})
document_ir = self.document_composer.build_document(
report_id,
manifest_meta,
chapters
)
emit('stage', {'stage': 'chapters_compiled', 'chapter_count': len(chapters)})
html_report = self.renderer.render(document_ir)
emit('stage', {'stage': 'html_rendered', 'html_length': len(html_report)})
self.state.html_content = html_report
self.state.mark_completed()
... ... @@ -320,10 +422,12 @@ class ReportAgent:
saved_files = {}
if save_report:
saved_files = self._save_report(html_report, document_ir, report_id)
emit('stage', {'stage': 'report_saved', 'files': saved_files})
generation_time = (datetime.now() - start_time).total_seconds()
self.state.metadata.generation_time = generation_time
logger.info(f"报告生成完成,耗时: {generation_time:.2f} 秒")
emit('metrics', {'generation_seconds': generation_time})
return {
'html_content': html_report,
'report_id': report_id,
... ... @@ -333,6 +437,7 @@ class ReportAgent:
except Exception as e:
self.state.mark_failed(str(e))
logger.exception(f"报告生成过程中发生错误: {str(e)}")
emit('error', {'stage': 'agent_failed', 'message': str(e)})
raise
def _select_template(self, query: str, reports: List[Any], forum_logs: str, custom_template: str):
... ... @@ -444,6 +549,22 @@ class ReportAgent:
normalized[key] = self._stringify(value)
return normalized
def _should_retry_inappropriate_content_error(self, error: Exception) -> bool:
"""
判断LLM异常是否由内容安全/不当内容导致,满足时允许重新生成整章。
"""
message = str(error) if error else ""
if not message:
return False
normalized = message.lower()
keywords = [
"inappropriate content",
"content violation",
"content moderation",
"model-studio/error-code",
]
return any(keyword in normalized for keyword in keywords)
def _stringify(self, value: Any) -> str:
"""安全地将对象转成字符串"""
if value is None:
... ...
... ... @@ -7,9 +7,11 @@ import os
import json
import threading
import time
from collections import deque, defaultdict
from datetime import datetime
from flask import Blueprint, request, jsonify, Response, send_file
from typing import Dict, Any
from queue import Queue, Empty
from flask import Blueprint, request, jsonify, Response, send_file, stream_with_context
from typing import Dict, Any, List, Optional
from loguru import logger
from .agent import ReportAgent, create_agent
from .utils.config import settings
... ... @@ -23,6 +25,69 @@ report_agent = None
current_task = None
task_lock = threading.Lock()
# ====== 流式推送与任务历史管理 ======
# 通过有界deque缓存最近的事件,方便SSE断线后快速补发
MAX_TASK_HISTORY = 5
STREAM_HEARTBEAT_INTERVAL = 15 # 心跳间隔秒
stream_lock = threading.Lock()
stream_subscribers = defaultdict(list)
tasks_registry: Dict[str, 'ReportTask'] = {}
def _register_stream(task_id: str) -> Queue:
"""为指定任务注册一个事件队列,供SSE监听器消费。"""
queue = Queue()
with stream_lock:
stream_subscribers[task_id].append(queue)
return queue
def _unregister_stream(task_id: str, queue: Queue):
"""安全移除事件队列,避免内存泄漏。"""
with stream_lock:
listeners = stream_subscribers.get(task_id, [])
if queue in listeners:
listeners.remove(queue)
if not listeners and task_id in stream_subscribers:
stream_subscribers.pop(task_id, None)
def _broadcast_event(task_id: str, event: Dict[str, Any]):
"""将事件推送给所有监听者,失败时做好异常捕获。"""
with stream_lock:
listeners = list(stream_subscribers.get(task_id, []))
for queue in listeners:
try:
queue.put(event, timeout=0.1)
except Exception:
logger.exception("推送流式事件失败,跳过当前监听队列")
def _prune_task_history_locked():
"""在task_lock持有期间调用,清理过多的历史任务以控制内存。"""
if len(tasks_registry) <= MAX_TASK_HISTORY:
return
# 按创建时间排序,移除最旧的任务
sorted_tasks = sorted(tasks_registry.values(), key=lambda t: t.created_at)
for task in sorted_tasks[:-MAX_TASK_HISTORY]:
tasks_registry.pop(task.task_id, None)
def _get_task(task_id: str) -> Optional['ReportTask']:
"""统一的任务查找方法,优先返回当前任务。"""
with task_lock:
if current_task and current_task.task_id == task_id:
return current_task
return tasks_registry.get(task_id)
def _format_sse(event: Dict[str, Any]) -> str:
"""按SSE协议格式化消息。"""
payload = json.dumps(event, ensure_ascii=False)
event_id = event.get('id', 0)
event_type = event.get('type', 'message')
return f"id: {event_id}\nevent: {event_type}\ndata: {payload}\n\n"
def initialize_report_engine():
"""初始化Report Engine"""
... ... @@ -63,6 +128,11 @@ class ReportTask:
self.report_file_name = ""
self.state_file_path = ""
self.state_file_relative_path = ""
# ====== 流式事件缓存与并发保护 ======
# 使用deque保存最近的事件,结合锁保证多线程下的安全访问
self.event_history: deque = deque(maxlen=1000)
self._event_lock = threading.Lock()
self.last_event_id = 0
def update_status(self, status: str, progress: int = None, error_message: str = ""):
"""更新任务状态"""
... ... @@ -72,6 +142,17 @@ class ReportTask:
if error_message:
self.error_message = error_message
self.updated_at = datetime.now()
# 推送状态变更事件,方便前端实时刷新
self.publish_event(
'status',
{
'status': self.status,
'progress': self.progress,
'error_message': self.error_message,
'hint': error_message or '',
'task': self.to_dict(),
}
)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
... ... @@ -91,6 +172,29 @@ class ReportTask:
'state_file_path': self.state_file_relative_path or self.state_file_path
}
def publish_event(self, event_type: str, payload: Dict[str, Any]) -> None:
"""将任意事件放入缓存并广播,所有新增逻辑均配套中文说明。"""
timestamp = datetime.utcnow().isoformat() + 'Z'
event: Dict[str, Any] = {
'id': 0,
'type': event_type,
'task_id': self.task_id,
'timestamp': timestamp,
'payload': payload,
}
with self._event_lock:
self.last_event_id += 1
event['id'] = self.last_event_id
self.event_history.append(event)
_broadcast_event(self.task_id, event)
def history_since(self, last_event_id: Optional[int]) -> List[Dict[str, Any]]:
"""根据Last-Event-ID补发历史事件,确保断线重连无遗漏。"""
with self._event_lock:
if last_event_id is None:
return list(self.event_history)
return [evt for evt in self.event_history if evt['id'] > last_event_id]
def check_engines_ready() -> Dict[str, Any]:
"""检查三个子引擎是否都有新文件"""
... ... @@ -121,7 +225,13 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = "
global current_task
try:
# 在局部闭包内封装推送逻辑,便于传递给ReportAgent
def stream_handler(event_type: str, payload: Dict[str, Any]):
"""所有阶段事件都通过同一个接口分发,保证日志一致。"""
task.publish_event(event_type, payload)
task.update_status("running", 10)
task.publish_event('stage', {'message': '任务已启动,正在检查输入文件', 'stage': 'prepare'})
# 检查输入文件
check_result = check_engines_ready()
... ... @@ -129,21 +239,54 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = "
task.update_status("error", 0, f"输入文件未准备就绪: {check_result.get('missing_files', [])}")
return
task.publish_event('stage', {
'message': '输入文件检查通过,准备载入内容',
'stage': 'io_ready',
'files': check_result.get('latest_files', {})
})
task.update_status("running", 30)
# 加载输入文件
content = report_agent.load_input_files(check_result['latest_files'])
task.publish_event('stage', {'message': '源数据加载完成,启动生成流程', 'stage': 'data_loaded'})
task.update_status("running", 50)
# 生成报告
generation_result = report_agent.generate_report(
query=query,
reports=content['reports'],
forum_logs=content['forum_logs'],
custom_template=custom_template,
save_report=True
)
# 生成报告(附带兜底重试,缓解瞬时网络抖动)
for attempt in range(1, 3):
try:
task.publish_event('stage', {
'message': f'正在调用ReportAgent生成报告(第{attempt}次尝试)',
'stage': 'agent_running',
'attempt': attempt
})
generation_result = report_agent.generate_report(
query=query,
reports=content['reports'],
forum_logs=content['forum_logs'],
custom_template=custom_template,
save_report=True,
stream_handler=stream_handler
)
break
except Exception as err:
# 将错误即时推送至前端,方便观察重试策略
task.publish_event('warning', {
'message': f'ReportAgent执行失败: {str(err)}',
'stage': 'agent_running',
'attempt': attempt
})
if attempt == 2:
raise
# 简单的指数退避,防止频繁触发限流(单位秒)
backoff = min(5 * attempt, 15)
task.publish_event('stage', {
'message': f'{backoff} 秒后重试生成任务',
'stage': 'retry_wait',
'wait_seconds': backoff
})
time.sleep(backoff)
if isinstance(generation_result, dict):
html_report = generation_result.get('html_content', '')
... ... @@ -151,6 +294,7 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = "
html_report = generation_result
task.update_status("running", 90)
task.publish_event('stage', {'message': '报告生成完毕,准备持久化', 'stage': 'persist'})
# 保存结果
task.html_content = html_report
... ... @@ -160,11 +304,28 @@ def run_report_generation(task: ReportTask, query: str, custom_template: str = "
task.report_file_name = generation_result.get('report_filename', '')
task.state_file_path = generation_result.get('state_filepath', '')
task.state_file_relative_path = generation_result.get('state_relative_path', '')
task.publish_event('html_ready', {
'message': 'HTML渲染完成,可刷新预览',
'report_file': task.report_file_relative_path or task.report_file_path,
'state_file': task.state_file_relative_path or task.state_file_path,
'task': task.to_dict(),
})
task.update_status("completed", 100)
task.publish_event('completed', {
'message': '任务完成',
'duration_seconds': (task.updated_at - task.created_at).total_seconds(),
'report_file': task.report_file_relative_path or task.report_file_path,
'task': task.to_dict(),
})
except Exception as e:
logger.exception(f"报告生成过程中发生错误: {str(e)}")
task.update_status("error", 0, str(e))
task.publish_event('error', {
'message': str(e),
'stage': 'failed',
'task': task.to_dict(),
})
# 只在出错时清理任务
with task_lock:
if current_task and current_task.task_id == task.task_id:
... ... @@ -242,6 +403,19 @@ def generate_report():
with task_lock:
current_task = task
tasks_registry[task_id] = task
_prune_task_history_locked()
# 通过主动推送pending事件告知前端任务已经排队
task.publish_event(
'status',
{
'status': task.status,
'progress': task.progress,
'message': '任务已排队,等待资源空闲',
'task': task.to_dict(),
}
)
# 在后台线程中运行报告生成
thread = threading.Thread(
... ... @@ -255,7 +429,8 @@ def generate_report():
'success': True,
'task_id': task_id,
'message': '报告生成已启动',
'task': task.to_dict()
'task': task.to_dict(),
'stream_url': f"/api/report/stream/{task_id}"
})
except Exception as e:
... ... @@ -270,9 +445,9 @@ def generate_report():
def get_progress(task_id: str):
"""获取报告生成进度"""
try:
if not current_task or current_task.task_id != task_id:
# 如果任务不存在,可能是已经完成并被清理了
# 返回一个默认的完成状态而不是404
task = _get_task(task_id)
if not task:
# 如果任务不存在,可能是历史记录已被清理,回传一个完成态兜底
return jsonify({
'success': True,
'task': {
... ... @@ -291,7 +466,7 @@ def get_progress(task_id: str):
return jsonify({
'success': True,
'task': current_task.to_dict()
'task': task.to_dict()
})
except Exception as e:
... ... @@ -302,25 +477,78 @@ def get_progress(task_id: str):
}), 500
@report_bp.route('/stream/<task_id>', methods=['GET'])
def stream_task(task_id: str):
"""基于SSE的实时推送接口,向前端持续广播阶段事件。"""
task = _get_task(task_id)
if not task:
return jsonify({'success': False, 'error': '任务不存在'}), 404
last_event_header = request.headers.get('Last-Event-ID')
try:
last_event_id = int(last_event_header) if last_event_header else None
except ValueError:
last_event_id = None
def event_generator():
queue = _register_stream(task_id)
try:
# 断线重连场景下,先补发历史事件,保证界面状态一致
history = task.history_since(last_event_id)
for event in history:
yield _format_sse(event)
finished = task.status in ("completed", "error", "cancelled")
while True:
if finished:
break
try:
event = queue.get(timeout=STREAM_HEARTBEAT_INTERVAL)
yield _format_sse(event)
if event.get('type') in ("completed", "error"):
finished = True
except Empty:
heartbeat = {
'id': f"hb-{int(time.time() * 1000)}",
'type': 'heartbeat',
'task_id': task_id,
'timestamp': datetime.utcnow().isoformat() + 'Z',
'payload': {'status': task.status}
}
yield _format_sse(heartbeat)
finished = task.status in ("completed", "error", "cancelled")
finally:
_unregister_stream(task_id, queue)
response = Response(
stream_with_context(event_generator()),
mimetype='text/event-stream'
)
response.headers['Cache-Control'] = 'no-cache'
response.headers['X-Accel-Buffering'] = 'no'
return response
@report_bp.route('/result/<task_id>', methods=['GET'])
def get_result(task_id: str):
"""获取报告生成结果"""
try:
if not current_task or current_task.task_id != task_id:
task = _get_task(task_id)
if not task:
return jsonify({
'success': False,
'error': '任务不存在'
}), 404
if current_task.status != "completed":
if task.status != "completed":
return jsonify({
'success': False,
'error': '报告尚未完成',
'task': current_task.to_dict()
'task': task.to_dict()
}), 400
return Response(
current_task.html_content,
task.html_content,
mimetype='text/html'
)
... ... @@ -336,23 +564,24 @@ def get_result(task_id: str):
def get_result_json(task_id: str):
"""获取报告生成结果(JSON格式)"""
try:
if not current_task or current_task.task_id != task_id:
task = _get_task(task_id)
if not task:
return jsonify({
'success': False,
'error': '任务不存在'
}), 404
if current_task.status != "completed":
if task.status != "completed":
return jsonify({
'success': False,
'error': '报告尚未完成',
'task': current_task.to_dict()
'task': task.to_dict()
}), 400
return jsonify({
'success': True,
'task': current_task.to_dict(),
'html_content': current_task.html_content
'task': task.to_dict(),
'html_content': task.html_content
})
except Exception as e:
... ... @@ -367,27 +596,28 @@ def get_result_json(task_id: str):
def download_report(task_id: str):
"""下载已生成的报告HTML文件"""
try:
if not current_task or current_task.task_id != task_id:
task = _get_task(task_id)
if not task:
return jsonify({
'success': False,
'error': '任务不存在'
}), 404
if current_task.status != "completed" or not current_task.report_file_path:
if task.status != "completed" or not task.report_file_path:
return jsonify({
'success': False,
'error': '报告尚未完成或尚未保存'
}), 400
if not os.path.exists(current_task.report_file_path):
if not os.path.exists(task.report_file_path):
return jsonify({
'success': False,
'error': '报告文件不存在或已被删除'
}), 404
download_name = current_task.report_file_name or os.path.basename(current_task.report_file_path)
download_name = task.report_file_name or os.path.basename(task.report_file_path)
return send_file(
current_task.report_file_path,
task.report_file_path,
mimetype='text/html',
as_attachment=True,
download_name=download_name
... ... @@ -411,7 +641,18 @@ def cancel_task(task_id: str):
if current_task and current_task.task_id == task_id:
if current_task.status == "running":
current_task.update_status("cancelled", 0, "用户取消任务")
current_task.publish_event('cancelled', {
'message': '任务被用户主动终止',
'task': current_task.to_dict(),
})
current_task = None
task = tasks_registry.get(task_id)
if task and task.status == 'running':
task.update_status("cancelled", task.progress, "用户取消任务")
task.publish_event('cancelled', {
'message': '任务被用户主动终止',
'task': task.to_dict(),
})
return jsonify({
'success': True,
... ...
... ... @@ -5,7 +5,7 @@ Report Engine节点处理模块
from .base_node import BaseNode, StateMutationNode
from .template_selection_node import TemplateSelectionNode
from .chapter_generation_node import ChapterGenerationNode
from .chapter_generation_node import ChapterGenerationNode, ChapterJsonParseError
from .document_layout_node import DocumentLayoutNode
from .word_budget_node import WordBudgetNode
... ... @@ -14,6 +14,7 @@ __all__ = [
"StateMutationNode",
"TemplateSelectionNode",
"ChapterGenerationNode",
"ChapterJsonParseError",
"DocumentLayoutNode",
"WordBudgetNode",
]
... ...