template_selection_node.py
10.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
"""
模板选择节点。
综合用户查询、三引擎报告、论坛日志与本地模板库,
调用LLM挑选最合适的报告骨架。
"""
import os
import json
from typing import Dict, Any, List, Optional
from loguru import logger
from .base_node import BaseNode
from ..prompts import SYSTEM_PROMPT_TEMPLATE_SELECTION
from ..utils.json_parser import RobustJSONParser, JSONParseError
class TemplateSelectionNode(BaseNode):
"""
模板选择处理节点。
负责准备模板候选列表、构建提示词、解析LLM返回结果,
并在失败时回退到内置模板。
"""
def __init__(self, llm_client, template_dir: str = "ReportEngine/report_template"):
"""
初始化模板选择节点
Args:
llm_client: LLM客户端
template_dir: 模板目录路径
"""
super().__init__(llm_client, "TemplateSelectionNode")
self.template_dir = template_dir
# 初始化鲁棒JSON解析器,启用所有修复策略
self.json_parser = RobustJSONParser(
enable_json_repair=True,
enable_llm_repair=False,
max_repair_attempts=3,
)
def run(self, input_data: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""
执行模板选择。
Args:
input_data: 包含查询和报告内容的字典
- query: 原始查询
- reports: 三个子agent的报告列表
- forum_logs: 论坛日志内容
Returns:
选择的模板信息,包含名称、内容与选择理由
"""
logger.info("开始模板选择...")
query = input_data.get('query', '')
reports = input_data.get('reports', [])
forum_logs = input_data.get('forum_logs', '')
# 获取可用模板
available_templates = self._get_available_templates()
if not available_templates:
logger.info("未找到预设模板,使用内置默认模板")
return self._get_fallback_template()
# 使用LLM进行模板选择
try:
llm_result = self._llm_template_selection(query, reports, forum_logs, available_templates)
if llm_result:
return llm_result
except Exception as e:
logger.exception(f"LLM模板选择失败: {str(e)}")
# 如果LLM选择失败,使用备选方案
return self._get_fallback_template()
def _llm_template_selection(self, query: str, reports: List[Any], forum_logs: str,
available_templates: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""
使用LLM进行模板选择。
构造模板列表与报告摘要 → 调用LLM → 解析JSON →
验证模板是否存在并返回标准结构。
参数:
query: 用户输入的主题词。
reports: 多个分析引擎的报告内容。
forum_logs: 论坛日志,可能为空。
available_templates: 本地可用模板清单。
返回:
dict | None: 若LLM成功返回合法结果则包含模板信息,否则为None。
"""
logger.info("尝试使用LLM进行模板选择...")
# 构建模板列表
template_list = "\n".join([f"- {t['name']}: {t['description']}" for t in available_templates])
# 构建报告内容摘要
reports_summary = ""
if reports:
reports_summary = "\n\n=== 分析引擎报告内容 ===\n"
for i, report in enumerate(reports, 1):
# 获取报告内容,支持不同的数据格式
if isinstance(report, dict):
content = report.get('content', str(report))
elif hasattr(report, 'content'):
content = report.content
else:
content = str(report)
# 截断过长的内容,保留前1000个字符
if len(content) > 1000:
content = content[:1000] + "...(内容已截断)"
reports_summary += f"\n报告{i}内容:\n{content}\n"
# 构建论坛日志摘要
forum_summary = ""
if forum_logs and forum_logs.strip():
forum_summary = "\n\n=== 三个引擎的讨论内容 ===\n"
# 截断过长的日志内容,保留前800个字符
if len(forum_logs) > 800:
forum_content = forum_logs[:800] + "...(讨论内容已截断)"
else:
forum_content = forum_logs
forum_summary += forum_content
user_message = f"""查询内容: {query}
报告数量: {len(reports)} 个分析引擎报告
论坛日志: {'有' if forum_logs else '无'}
{reports_summary}{forum_summary}
可用模板:
{template_list}
请根据查询内容、报告内容和论坛日志的具体情况,选择最合适的模板。"""
# 调用LLM
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_TEMPLATE_SELECTION, user_message)
# 检查响应是否为空
if not response or not response.strip():
logger.error("LLM返回空响应")
return None
logger.info(f"LLM原始响应: {response}")
# 尝试解析JSON响应,使用鲁棒解析器
try:
result = self.json_parser.parse(
response,
context_name="模板选择",
expected_keys=["template_name", "selection_reason"],
)
# 验证选择的模板是否存在
selected_template_name = result.get('template_name', '')
for template in available_templates:
if template['name'] == selected_template_name or selected_template_name in template['name']:
logger.info(f"LLM选择模板: {selected_template_name}")
return {
'template_name': template['name'],
'template_content': template['content'],
'selection_reason': result.get('selection_reason', 'LLM智能选择')
}
logger.error(f"LLM选择的模板不存在: {selected_template_name}")
return None
except JSONParseError as e:
logger.error(f"JSON解析失败: {str(e)}")
# 尝试从文本响应中提取模板信息
return self._extract_template_from_text(response, available_templates)
def _extract_template_from_text(self, response: str, available_templates: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""
从文本响应中提取模板信息。
当LLM未输出合法JSON时,尝试匹配模板名称关键字做降级。
参数:
response: 非结构化的LLM文本。
available_templates: 可选模板列表。
返回:
dict | None: 匹配成功时返回模板详情,否则为None。
"""
logger.info("尝试从文本响应中提取模板信息")
# 查找响应中是否包含模板名称
for template in available_templates:
template_name_variants = [
template['name'],
template['name'].replace('.md', ''),
template['name'].replace('模板', ''),
]
for variant in template_name_variants:
if variant in response:
logger.info(f"在响应中找到模板: {template['name']}")
return {
'template_name': template['name'],
'template_content': template['content'],
'selection_reason': '从文本响应中提取'
}
return None
def _get_available_templates(self) -> List[Dict[str, Any]]:
"""
获取可用的模板列表。
枚举模板目录下的 `.md` 文件并读取内容与描述字段。
返回:
list[dict]: 每项包含 name/path/content/description。
"""
templates = []
if not os.path.exists(self.template_dir):
logger.error(f"模板目录不存在: {self.template_dir}")
return templates
# 查找所有markdown模板文件
for filename in os.listdir(self.template_dir):
if filename.endswith('.md'):
template_path = os.path.join(self.template_dir, filename)
try:
with open(template_path, 'r', encoding='utf-8') as f:
content = f.read()
template_name = filename.replace('.md', '')
description = self._extract_template_description(template_name)
templates.append({
'name': template_name,
'path': template_path,
'content': content,
'description': description
})
except Exception as e:
logger.exception(f"读取模板文件失败 {filename}: {str(e)}")
return templates
def _extract_template_description(self, template_name: str) -> str:
"""根据模板名称生成描述,方便LLM理解模板定位。"""
if '企业品牌' in template_name:
return "适用于企业品牌声誉和形象分析"
elif '市场竞争' in template_name:
return "适用于市场竞争格局和对手分析"
elif '日常' in template_name or '定期' in template_name:
return "适用于日常监测和定期汇报"
elif '政策' in template_name or '行业' in template_name:
return "适用于政策影响和行业动态分析"
elif '热点' in template_name or '社会' in template_name:
return "适用于社会热点和公共事件分析"
elif '突发' in template_name or '危机' in template_name:
return "适用于突发事件和危机公关"
return "通用报告模板"
def _get_fallback_template(self) -> Dict[str, Any]:
"""
获取备用默认模板(空模板,让LLM自行发挥)。
返回:
dict: 结构体字段与LLM返回一致,方便直接替换。
"""
logger.info("未找到合适模板,使用空模板让LLM自行发挥")
return {
'template_name': '自由发挥模板',
'template_content': '',
'selection_reason': '未找到合适的预设模板,让LLM根据内容自行设计报告结构'
}