formatting_node.py
5.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
报告格式化节点
负责将最终研究结果格式化为美观的Markdown报告
"""
import json
from typing import List, Dict, Any
from .base_node import BaseNode
from loguru import logger
from ..prompts import SYSTEM_PROMPT_REPORT_FORMATTING
from ..utils.text_processing import (
remove_reasoning_from_output,
clean_markdown_tags
)
class ReportFormattingNode(BaseNode):
"""格式化最终报告的节点"""
def __init__(self, llm_client):
"""
初始化报告格式化节点
Args:
llm_client: LLM客户端
"""
super().__init__(llm_client, "ReportFormattingNode")
def validate_input(self, input_data: Any) -> bool:
"""验证输入数据"""
if isinstance(input_data, str):
try:
data = json.loads(input_data)
return isinstance(data, list) and all(
isinstance(item, dict) and "title" in item and "paragraph_latest_state" in item
for item in data
)
except:
return False
elif isinstance(input_data, list):
return all(
isinstance(item, dict) and "title" in item and "paragraph_latest_state" in item
for item in input_data
)
return False
def run(self, input_data: Any, **kwargs) -> str:
"""
调用LLM生成Markdown格式报告
Args:
input_data: 包含所有段落信息的列表
**kwargs: 额外参数
Returns:
格式化的Markdown报告
"""
try:
if not self.validate_input(input_data):
raise ValueError("输入数据格式错误,需要包含title和paragraph_latest_state的列表")
# 准备输入数据
if isinstance(input_data, str):
message = input_data
else:
message = json.dumps(input_data, ensure_ascii=False)
logger.info("正在格式化最终报告")
# 调用LLM生成Markdown格式(流式,安全拼接UTF-8)
response = self.llm_client.stream_invoke_to_string(
SYSTEM_PROMPT_REPORT_FORMATTING,
message,
)
# 处理响应
processed_response = self.process_output(response)
logger.info("成功生成格式化报告")
return processed_response
except Exception as e:
logger.exception(f"报告格式化失败: {str(e)}")
raise e
def process_output(self, output: str) -> str:
"""
处理LLM输出,清理Markdown格式
Args:
output: LLM原始输出
Returns:
清理后的Markdown报告
"""
try:
# 清理响应文本
cleaned_output = remove_reasoning_from_output(output)
cleaned_output = clean_markdown_tags(cleaned_output)
# 确保报告有基本结构
if not cleaned_output.strip():
return "# 报告生成失败\n\n无法生成有效的报告内容。"
# 如果没有标题,添加一个默认标题
if not cleaned_output.strip().startswith('#'):
cleaned_output = "# 深度研究报告\n\n" + cleaned_output
return cleaned_output.strip()
except Exception as e:
logger.exception(f"处理输出失败: {str(e)}")
return "# 报告处理失败\n\n报告格式化过程中发生错误。"
def format_report_manually(self, paragraphs_data: List[Dict[str, str]],
report_title: str = "深度研究报告") -> str:
"""
手动格式化报告(备用方法)
Args:
paragraphs_data: 段落数据列表
report_title: 报告标题
Returns:
格式化的Markdown报告
"""
try:
logger.info("使用手动格式化方法")
# 构建报告
report_lines = [
f"# {report_title}",
"",
"---",
""
]
# 添加各个段落
for i, paragraph in enumerate(paragraphs_data, 1):
title = paragraph.get("title", f"段落 {i}")
content = paragraph.get("paragraph_latest_state", "")
if content:
report_lines.extend([
f"## {title}",
"",
content,
"",
"---",
""
])
# 添加结论
if len(paragraphs_data) > 1:
report_lines.extend([
"## 结论",
"",
"本报告通过深度搜索和研究,对相关主题进行了全面分析。"
"以上各个方面的内容为理解该主题提供了重要参考。",
""
])
return "\n".join(report_lines)
except Exception as e:
logger.exception(f"手动格式化失败: {str(e)}")
return "# 报告生成失败\n\n无法完成报告格式化。"