ghmark675

feat(insight_agent): search results cluster

@@ -7,60 +7,75 @@ import json @@ -7,60 +7,75 @@ import json
7 import os 7 import os
8 import re 8 import re
9 from datetime import datetime 9 from datetime import datetime
10 -from typing import Optional, Dict, Any, List, Union 10 +from typing import Any, Dict, List, Optional, Union
  11 +
  12 +import numpy as np
11 from loguru import logger 13 from loguru import logger
  14 +from sentence_transformers import SentenceTransformer
  15 +from sklearn.cluster import KMeans
12 16
13 from .llms import LLMClient 17 from .llms import LLMClient
14 from .nodes import ( 18 from .nodes import (
15 - ReportStructureNode,  
16 - FirstSearchNode,  
17 - ReflectionNode, 19 + FirstSearchNode,
18 FirstSummaryNode, 20 FirstSummaryNode,
  21 + ReflectionNode,
19 ReflectionSummaryNode, 22 ReflectionSummaryNode,
20 - ReportFormattingNode 23 + ReportFormattingNode,
  24 + ReportStructureNode,
21 ) 25 )
22 from .state import State 26 from .state import State
23 -from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer, multilingual_sentiment_analyzer  
24 -from .utils.config import settings, Settings 27 +from .tools import (
  28 + DBResponse,
  29 + MediaCrawlerDB,
  30 + keyword_optimizer,
  31 + multilingual_sentiment_analyzer,
  32 +)
25 from .utils import format_search_results_for_prompt 33 from .utils import format_search_results_for_prompt
  34 +from .utils.config import Settings, settings
  35 +
  36 +ENABLE_CLUSTERING: bool = True # 是否启用聚类采样
  37 +MAX_CLUSTERED_RESULTS: int = 50 # 聚类后最大返回结果数
  38 +RESULTS_PER_CLUSTER: int = 5 # 每个聚类返回的结果数
26 39
27 40
28 class DeepSearchAgent: 41 class DeepSearchAgent:
29 """Deep Search Agent主类""" 42 """Deep Search Agent主类"""
30 - 43 +
31 def __init__(self, config: Optional[Settings] = None): 44 def __init__(self, config: Optional[Settings] = None):
32 """ 45 """
33 初始化Deep Search Agent 46 初始化Deep Search Agent
34 - 47 +
35 Args: 48 Args:
36 config: 可选配置对象(不填则用全局settings) 49 config: 可选配置对象(不填则用全局settings)
37 """ 50 """
38 self.config = config or settings 51 self.config = config or settings
39 - 52 +
40 # 初始化LLM客户端 53 # 初始化LLM客户端
41 self.llm_client = self._initialize_llm() 54 self.llm_client = self._initialize_llm()
42 -  
43 - 55 +
44 # 初始化搜索工具集 56 # 初始化搜索工具集
45 self.search_agency = MediaCrawlerDB() 57 self.search_agency = MediaCrawlerDB()
46 - 58 +
  59 + # 初始化聚类小模型(懒加载)
  60 + self._clustering_model = None
  61 +
47 # 初始化情感分析器 62 # 初始化情感分析器
48 self.sentiment_analyzer = multilingual_sentiment_analyzer 63 self.sentiment_analyzer = multilingual_sentiment_analyzer
49 - 64 +
50 # 初始化节点 65 # 初始化节点
51 self._initialize_nodes() 66 self._initialize_nodes()
52 - 67 +
53 # 状态 68 # 状态
54 self.state = State() 69 self.state = State()
55 - 70 +
56 # 确保输出目录存在 71 # 确保输出目录存在
57 os.makedirs(self.config.OUTPUT_DIR, exist_ok=True) 72 os.makedirs(self.config.OUTPUT_DIR, exist_ok=True)
58 - 73 +
59 logger.info(f"Insight Agent已初始化") 74 logger.info(f"Insight Agent已初始化")
60 logger.info(f"使用LLM: {self.llm_client.get_model_info()}") 75 logger.info(f"使用LLM: {self.llm_client.get_model_info()}")
61 logger.info(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)") 76 logger.info(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)")
62 logger.info(f"情感分析: WeiboMultilingualSentiment (支持22种语言的情感分析)") 77 logger.info(f"情感分析: WeiboMultilingualSentiment (支持22种语言的情感分析)")
63 - 78 +
64 def _initialize_llm(self) -> LLMClient: 79 def _initialize_llm(self) -> LLMClient:
65 """初始化LLM客户端""" 80 """初始化LLM客户端"""
66 return LLMClient( 81 return LLMClient(
@@ -68,7 +83,7 @@ class DeepSearchAgent: @@ -68,7 +83,7 @@ class DeepSearchAgent:
68 model_name=self.config.INSIGHT_ENGINE_MODEL_NAME, 83 model_name=self.config.INSIGHT_ENGINE_MODEL_NAME,
69 base_url=self.config.INSIGHT_ENGINE_BASE_URL, 84 base_url=self.config.INSIGHT_ENGINE_BASE_URL,
70 ) 85 )
71 - 86 +
72 def _initialize_nodes(self): 87 def _initialize_nodes(self):
73 """初始化处理节点""" 88 """初始化处理节点"""
74 self.first_search_node = FirstSearchNode(self.llm_client) 89 self.first_search_node = FirstSearchNode(self.llm_client)
@@ -76,36 +91,103 @@ class DeepSearchAgent: @@ -76,36 +91,103 @@ class DeepSearchAgent:
76 self.first_summary_node = FirstSummaryNode(self.llm_client) 91 self.first_summary_node = FirstSummaryNode(self.llm_client)
77 self.reflection_summary_node = ReflectionSummaryNode(self.llm_client) 92 self.reflection_summary_node = ReflectionSummaryNode(self.llm_client)
78 self.report_formatting_node = ReportFormattingNode(self.llm_client) 93 self.report_formatting_node = ReportFormattingNode(self.llm_client)
79 - 94 +
  95 + def _get_clustering_model(self):
  96 + """懒加载聚类模型"""
  97 + if self._clustering_model is None:
  98 + logger.info(" 加载聚类模型 (paraphrase-multilingual-MiniLM-L12-v2)...")
  99 + self._clustering_model = SentenceTransformer(
  100 + "paraphrase-multilingual-MiniLM-L12-v2"
  101 + )
  102 + return self._clustering_model
  103 +
80 def _validate_date_format(self, date_str: str) -> bool: 104 def _validate_date_format(self, date_str: str) -> bool:
81 """ 105 """
82 验证日期格式是否为YYYY-MM-DD 106 验证日期格式是否为YYYY-MM-DD
83 - 107 +
84 Args: 108 Args:
85 date_str: 日期字符串 109 date_str: 日期字符串
86 - 110 +
87 Returns: 111 Returns:
88 是否为有效格式 112 是否为有效格式
89 """ 113 """
90 if not date_str: 114 if not date_str:
91 return False 115 return False
92 - 116 +
93 # 检查格式 117 # 检查格式
94 - pattern = r'^\d{4}-\d{2}-\d{2}$' 118 + pattern = r"^\d{4}-\d{2}-\d{2}$"
95 if not re.match(pattern, date_str): 119 if not re.match(pattern, date_str):
96 return False 120 return False
97 - 121 +
98 # 检查日期是否有效 122 # 检查日期是否有效
99 try: 123 try:
100 - datetime.strptime(date_str, '%Y-%m-%d') 124 + datetime.strptime(date_str, "%Y-%m-%d")
101 return True 125 return True
102 except ValueError: 126 except ValueError:
103 return False 127 return False
104 - 128 +
  129 + def _cluster_and_sample_results(
  130 + self, results: List, max_results: int = 50, results_per_cluster: int = 5
  131 + ) -> List:
  132 + """
  133 + 对搜索结果进行聚类并采样
  134 +
  135 + Args:
  136 + results: 搜索结果列表
  137 + max_results: 最大返回结果数
  138 + results_per_cluster: 每个聚类返回的结果数
  139 +
  140 + Returns:
  141 + 采样后的结果列表
  142 + """
  143 + if len(results) <= max_results:
  144 + return results
  145 +
  146 + try:
  147 + # 提取文本
  148 + texts = [r.title_or_content[:500] for r in results]
  149 +
  150 + # 获取模型并编码
  151 + model = self._get_clustering_model()
  152 + embeddings = model.encode(texts, show_progress_bar=False)
  153 +
  154 + # 计算聚类数
  155 + n_clusters = min(max(2, max_results // results_per_cluster), len(results))
  156 +
  157 + # KMeans聚类
  158 + kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
  159 + labels = kmeans.fit_predict(embeddings)
  160 +
  161 + # 从每个聚类采样
  162 + sampled_results = []
  163 + for cluster_id in range(n_clusters):
  164 + cluster_indices = np.where(labels == cluster_id)[0]
  165 + cluster_results = [(results[i], i) for i in cluster_indices]
  166 + cluster_results.sort(
  167 + key=lambda x: x[0].hotness_score or 0, reverse=True
  168 + )
  169 +
  170 + for result, _ in cluster_results[:results_per_cluster]:
  171 + sampled_results.append(result)
  172 + if len(sampled_results) >= max_results:
  173 + break
  174 +
  175 + if len(sampled_results) >= max_results:
  176 + break
  177 +
  178 + logger.info(
  179 + f" 聚类完成: {len(results)} 条 -> {n_clusters} 个主题 -> {len(sampled_results)} 条代表性结果"
  180 + )
  181 + return sampled_results
  182 +
  183 + except Exception as e:
  184 + logger.warning(f" 聚类失败,返回前{max_results}条: {str(e)}")
  185 + return results[:max_results]
  186 +
105 def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse: 187 def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse:
106 """ 188 """
107 执行指定的数据库查询工具(集成关键词优化中间件和情感分析) 189 执行指定的数据库查询工具(集成关键词优化中间件和情感分析)
108 - 190 +
109 Args: 191 Args:
110 tool_name: 工具名称,可选值: 192 tool_name: 工具名称,可选值:
111 - "search_hot_content": 查找热点内容 193 - "search_hot_content": 查找热点内容
@@ -117,18 +199,20 @@ class DeepSearchAgent: @@ -117,18 +199,20 @@ class DeepSearchAgent:
117 query: 搜索关键词/话题 199 query: 搜索关键词/话题
118 **kwargs: 额外参数(如start_date, end_date, platform, limit, enable_sentiment等) 200 **kwargs: 额外参数(如start_date, end_date, platform, limit, enable_sentiment等)
119 enable_sentiment: 是否自动对搜索结果进行情感分析(默认True) 201 enable_sentiment: 是否自动对搜索结果进行情感分析(默认True)
120 - 202 +
121 Returns: 203 Returns:
122 DBResponse对象(可能包含情感分析结果) 204 DBResponse对象(可能包含情感分析结果)
123 """ 205 """
124 logger.info(f" → 执行数据库查询工具: {tool_name}") 206 logger.info(f" → 执行数据库查询工具: {tool_name}")
125 - 207 +
126 # 对于热点内容搜索,不需要关键词优化(因为不需要query参数) 208 # 对于热点内容搜索,不需要关键词优化(因为不需要query参数)
127 if tool_name == "search_hot_content": 209 if tool_name == "search_hot_content":
128 time_period = kwargs.get("time_period", "week") 210 time_period = kwargs.get("time_period", "week")
129 limit = kwargs.get("limit", 100) 211 limit = kwargs.get("limit", 100)
130 - response = self.search_agency.search_hot_content(time_period=time_period, limit=limit)  
131 - 212 + response = self.search_agency.search_hot_content(
  213 + time_period=time_period, limit=limit
  214 + )
  215 +
132 # 检查是否需要进行情感分析 216 # 检查是否需要进行情感分析
133 enable_sentiment = kwargs.get("enable_sentiment", True) 217 enable_sentiment = kwargs.get("enable_sentiment", True)
134 if enable_sentiment and response.results and len(response.results) > 0: 218 if enable_sentiment and response.results and len(response.results) > 0:
@@ -138,74 +222,101 @@ class DeepSearchAgent: @@ -138,74 +222,101 @@ class DeepSearchAgent:
138 # 将情感分析结果添加到响应的parameters中 222 # 将情感分析结果添加到响应的parameters中
139 response.parameters["sentiment_analysis"] = sentiment_analysis 223 response.parameters["sentiment_analysis"] = sentiment_analysis
140 logger.info(f" ✅ 情感分析完成") 224 logger.info(f" ✅ 情感分析完成")
141 - 225 +
142 return response 226 return response
143 - 227 +
144 # 独立情感分析工具 228 # 独立情感分析工具
145 if tool_name == "analyze_sentiment": 229 if tool_name == "analyze_sentiment":
146 texts = kwargs.get("texts", query) # 可以通过texts参数传递,或使用query 230 texts = kwargs.get("texts", query) # 可以通过texts参数传递,或使用query
147 sentiment_result = self.analyze_sentiment_only(texts) 231 sentiment_result = self.analyze_sentiment_only(texts)
148 - 232 +
149 # 构建DBResponse格式的响应 233 # 构建DBResponse格式的响应
150 return DBResponse( 234 return DBResponse(
151 tool_name="analyze_sentiment", 235 tool_name="analyze_sentiment",
152 parameters={ 236 parameters={
153 "texts": texts if isinstance(texts, list) else [texts], 237 "texts": texts if isinstance(texts, list) else [texts],
154 - **kwargs 238 + **kwargs,
155 }, 239 },
156 results=[], # 情感分析不返回搜索结果 240 results=[], # 情感分析不返回搜索结果
157 results_count=0, 241 results_count=0,
158 - metadata=sentiment_result 242 + metadata=sentiment_result,
159 ) 243 )
160 - 244 +
161 # 对于需要搜索词的工具,使用关键词优化中间件 245 # 对于需要搜索词的工具,使用关键词优化中间件
162 optimized_response = keyword_optimizer.optimize_keywords( 246 optimized_response = keyword_optimizer.optimize_keywords(
163 - original_query=query,  
164 - context=f"使用{tool_name}工具进行查询" 247 + original_query=query, context=f"使用{tool_name}工具进行查询"
165 ) 248 )
166 - 249 +
167 logger.info(f" 🔍 原始查询: '{query}'") 250 logger.info(f" 🔍 原始查询: '{query}'")
168 logger.info(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}") 251 logger.info(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}")
169 - 252 +
170 # 使用优化后的关键词进行多次查询并整合结果 253 # 使用优化后的关键词进行多次查询并整合结果
171 all_results = [] 254 all_results = []
172 total_count = 0 255 total_count = 0
173 - 256 +
174 for keyword in optimized_response.optimized_keywords: 257 for keyword in optimized_response.optimized_keywords:
175 logger.info(f" 查询关键词: '{keyword}'") 258 logger.info(f" 查询关键词: '{keyword}'")
176 - 259 +
177 try: 260 try:
178 if tool_name == "search_topic_globally": 261 if tool_name == "search_topic_globally":
179 # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 262 # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数
180 - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE  
181 - response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=limit_per_table) 263 + limit_per_table = (
  264 + self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
  265 + )
  266 + response = self.search_agency.search_topic_globally(
  267 + topic=keyword, limit_per_table=limit_per_table
  268 + )
182 elif tool_name == "search_topic_by_date": 269 elif tool_name == "search_topic_by_date":
183 start_date = kwargs.get("start_date") 270 start_date = kwargs.get("start_date")
184 end_date = kwargs.get("end_date") 271 end_date = kwargs.get("end_date")
185 # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 272 # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数
186 - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE 273 + limit_per_table = (
  274 + self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
  275 + )
187 if not start_date or not end_date: 276 if not start_date or not end_date:
188 - raise ValueError("search_topic_by_date工具需要start_date和end_date参数")  
189 - response = self.search_agency.search_topic_by_date(topic=keyword, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table) 277 + raise ValueError(
  278 + "search_topic_by_date工具需要start_date和end_date参数"
  279 + )
  280 + response = self.search_agency.search_topic_by_date(
  281 + topic=keyword,
  282 + start_date=start_date,
  283 + end_date=end_date,
  284 + limit_per_table=limit_per_table,
  285 + )
190 elif tool_name == "get_comments_for_topic": 286 elif tool_name == "get_comments_for_topic":
191 # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 287 # 使用配置文件中的默认值,按关键词数量分配,但保证最小值
192 - limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT // len(optimized_response.optimized_keywords) 288 + limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT // len(
  289 + optimized_response.optimized_keywords
  290 + )
193 limit = max(limit, 50) 291 limit = max(limit, 50)
194 - response = self.search_agency.get_comments_for_topic(topic=keyword, limit=limit) 292 + response = self.search_agency.get_comments_for_topic(
  293 + topic=keyword, limit=limit
  294 + )
195 elif tool_name == "search_topic_on_platform": 295 elif tool_name == "search_topic_on_platform":
196 platform = kwargs.get("platform") 296 platform = kwargs.get("platform")
197 start_date = kwargs.get("start_date") 297 start_date = kwargs.get("start_date")
198 end_date = kwargs.get("end_date") 298 end_date = kwargs.get("end_date")
199 # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 299 # 使用配置文件中的默认值,按关键词数量分配,但保证最小值
200 - limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT // len(optimized_response.optimized_keywords) 300 + limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT // len(
  301 + optimized_response.optimized_keywords
  302 + )
201 limit = max(limit, 30) 303 limit = max(limit, 30)
202 if not platform: 304 if not platform:
203 raise ValueError("search_topic_on_platform工具需要platform参数") 305 raise ValueError("search_topic_on_platform工具需要platform参数")
204 - response = self.search_agency.search_topic_on_platform(platform=platform, topic=keyword, start_date=start_date, end_date=end_date, limit=limit) 306 + response = self.search_agency.search_topic_on_platform(
  307 + platform=platform,
  308 + topic=keyword,
  309 + start_date=start_date,
  310 + end_date=end_date,
  311 + limit=limit,
  312 + )
205 else: 313 else:
206 logger.info(f" 未知的搜索工具: {tool_name},使用默认全局搜索") 314 logger.info(f" 未知的搜索工具: {tool_name},使用默认全局搜索")
207 - response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE)  
208 - 315 + response = self.search_agency.search_topic_globally(
  316 + topic=keyword,
  317 + limit_per_table=self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE,
  318 + )
  319 +
209 # 收集结果 320 # 收集结果
210 if response.results: 321 if response.results:
211 logger.info(f" 找到 {len(response.results)} 条结果") 322 logger.info(f" 找到 {len(response.results)} 条结果")
@@ -213,15 +324,22 @@ class DeepSearchAgent: @@ -213,15 +324,22 @@ class DeepSearchAgent:
213 total_count += len(response.results) 324 total_count += len(response.results)
214 else: 325 else:
215 logger.info(f" 未找到结果") 326 logger.info(f" 未找到结果")
216 - 327 +
217 except Exception as e: 328 except Exception as e:
218 logger.error(f" 查询'{keyword}'时出错: {str(e)}") 329 logger.error(f" 查询'{keyword}'时出错: {str(e)}")
219 continue 330 continue
220 - 331 +
221 # 去重和整合结果 332 # 去重和整合结果
222 unique_results = self._deduplicate_results(all_results) 333 unique_results = self._deduplicate_results(all_results)
223 logger.info(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)} 条") 334 logger.info(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)} 条")
224 - 335 +
  336 + if ENABLE_CLUSTERING:
  337 + unique_results = self._cluster_and_sample_results(
  338 + unique_results,
  339 + max_results=MAX_CLUSTERED_RESULTS,
  340 + results_per_cluster=RESULTS_PER_CLUSTER,
  341 + )
  342 +
225 # 构建整合后的响应 343 # 构建整合后的响应
226 integrated_response = DBResponse( 344 integrated_response = DBResponse(
227 tool_name=f"{tool_name}_optimized", 345 tool_name=f"{tool_name}_optimized",
@@ -229,12 +347,12 @@ class DeepSearchAgent: @@ -229,12 +347,12 @@ class DeepSearchAgent:
229 "original_query": query, 347 "original_query": query,
230 "optimized_keywords": optimized_response.optimized_keywords, 348 "optimized_keywords": optimized_response.optimized_keywords,
231 "optimization_reasoning": optimized_response.reasoning, 349 "optimization_reasoning": optimized_response.reasoning,
232 - **kwargs 350 + **kwargs,
233 }, 351 },
234 results=unique_results, 352 results=unique_results,
235 - results_count=len(unique_results) 353 + results_count=len(unique_results),
236 ) 354 )
237 - 355 +
238 # 检查是否需要进行情感分析 356 # 检查是否需要进行情感分析
239 enable_sentiment = kwargs.get("enable_sentiment", True) 357 enable_sentiment = kwargs.get("enable_sentiment", True)
240 if enable_sentiment and unique_results and len(unique_results) > 0: 358 if enable_sentiment and unique_results and len(unique_results) > 0:
@@ -242,40 +360,45 @@ class DeepSearchAgent: @@ -242,40 +360,45 @@ class DeepSearchAgent:
242 sentiment_analysis = self._perform_sentiment_analysis(unique_results) 360 sentiment_analysis = self._perform_sentiment_analysis(unique_results)
243 if sentiment_analysis: 361 if sentiment_analysis:
244 # 将情感分析结果添加到响应的parameters中 362 # 将情感分析结果添加到响应的parameters中
245 - integrated_response.parameters["sentiment_analysis"] = sentiment_analysis 363 + integrated_response.parameters["sentiment_analysis"] = (
  364 + sentiment_analysis
  365 + )
246 logger.info(f" ✅ 情感分析完成") 366 logger.info(f" ✅ 情感分析完成")
247 - 367 +
248 return integrated_response 368 return integrated_response
249 - 369 +
250 def _deduplicate_results(self, results: List) -> List: 370 def _deduplicate_results(self, results: List) -> List:
251 """ 371 """
252 去重搜索结果 372 去重搜索结果
253 """ 373 """
254 seen = set() 374 seen = set()
255 unique_results = [] 375 unique_results = []
256 - 376 +
257 for result in results: 377 for result in results:
258 # 使用URL或内容作为去重标识 378 # 使用URL或内容作为去重标识
259 identifier = result.url if result.url else result.title_or_content[:100] 379 identifier = result.url if result.url else result.title_or_content[:100]
260 if identifier not in seen: 380 if identifier not in seen:
261 seen.add(identifier) 381 seen.add(identifier)
262 unique_results.append(result) 382 unique_results.append(result)
263 - 383 +
264 return unique_results 384 return unique_results
265 - 385 +
266 def _perform_sentiment_analysis(self, results: List) -> Optional[Dict[str, Any]]: 386 def _perform_sentiment_analysis(self, results: List) -> Optional[Dict[str, Any]]:
267 """ 387 """
268 对搜索结果执行情感分析 388 对搜索结果执行情感分析
269 - 389 +
270 Args: 390 Args:
271 results: 搜索结果列表 391 results: 搜索结果列表
272 - 392 +
273 Returns: 393 Returns:
274 情感分析结果字典,如果失败则返回None 394 情感分析结果字典,如果失败则返回None
275 """ 395 """
276 try: 396 try:
277 # 初始化情感分析器(如果尚未初始化且未被禁用) 397 # 初始化情感分析器(如果尚未初始化且未被禁用)
278 - if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled: 398 + if (
  399 + not self.sentiment_analyzer.is_initialized
  400 + and not self.sentiment_analyzer.is_disabled
  401 + ):
279 logger.info(" 初始化情感分析模型...") 402 logger.info(" 初始化情感分析模型...")
280 if not self.sentiment_analyzer.initialize(): 403 if not self.sentiment_analyzer.initialize():
281 logger.info(" 情感分析模型初始化失败,将直接透传原始文本") 404 logger.info(" 情感分析模型初始化失败,将直接透传原始文本")
@@ -290,203 +413,222 @@ class DeepSearchAgent: @@ -290,203 +413,222 @@ class DeepSearchAgent:
290 "platform": result.platform, 413 "platform": result.platform,
291 "author": result.author_nickname, 414 "author": result.author_nickname,
292 "url": result.url, 415 "url": result.url,
293 - "publish_time": str(result.publish_time) if result.publish_time else None 416 + "publish_time": str(result.publish_time)
  417 + if result.publish_time
  418 + else None,
294 } 419 }
295 results_dict.append(result_dict) 420 results_dict.append(result_dict)
296 - 421 +
297 # 执行情感分析 422 # 执行情感分析
298 sentiment_analysis = self.sentiment_analyzer.analyze_query_results( 423 sentiment_analysis = self.sentiment_analyzer.analyze_query_results(
299 - query_results=results_dict,  
300 - text_field="content",  
301 - min_confidence=0.5 424 + query_results=results_dict, text_field="content", min_confidence=0.5
302 ) 425 )
303 - 426 +
304 return sentiment_analysis.get("sentiment_analysis") 427 return sentiment_analysis.get("sentiment_analysis")
305 - 428 +
306 except Exception as e: 429 except Exception as e:
307 logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}") 430 logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}")
308 return None 431 return None
309 - 432 +
310 def analyze_sentiment_only(self, texts: Union[str, List[str]]) -> Dict[str, Any]: 433 def analyze_sentiment_only(self, texts: Union[str, List[str]]) -> Dict[str, Any]:
311 """ 434 """
312 独立的情感分析工具 435 独立的情感分析工具
313 - 436 +
314 Args: 437 Args:
315 texts: 单个文本或文本列表 438 texts: 单个文本或文本列表
316 - 439 +
317 Returns: 440 Returns:
318 情感分析结果 441 情感分析结果
319 """ 442 """
320 logger.info(f" → 执行独立情感分析") 443 logger.info(f" → 执行独立情感分析")
321 - 444 +
322 try: 445 try:
323 # 初始化情感分析器(如果尚未初始化且未被禁用) 446 # 初始化情感分析器(如果尚未初始化且未被禁用)
324 - if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled: 447 + if (
  448 + not self.sentiment_analyzer.is_initialized
  449 + and not self.sentiment_analyzer.is_disabled
  450 + ):
325 logger.info(" 初始化情感分析模型...") 451 logger.info(" 初始化情感分析模型...")
326 if not self.sentiment_analyzer.initialize(): 452 if not self.sentiment_analyzer.initialize():
327 logger.info(" 情感分析模型初始化失败,将直接透传原始文本") 453 logger.info(" 情感分析模型初始化失败,将直接透传原始文本")
328 elif self.sentiment_analyzer.is_disabled: 454 elif self.sentiment_analyzer.is_disabled:
329 logger.warning(" 情感分析功能已禁用,直接透传原始文本") 455 logger.warning(" 情感分析功能已禁用,直接透传原始文本")
330 - 456 +
331 # 执行分析 457 # 执行分析
332 if isinstance(texts, str): 458 if isinstance(texts, str):
333 result = self.sentiment_analyzer.analyze_single_text(texts) 459 result = self.sentiment_analyzer.analyze_single_text(texts)
334 result_dict = result.__dict__ 460 result_dict = result.__dict__
335 response = { 461 response = {
336 "success": result.success and result.analysis_performed, 462 "success": result.success and result.analysis_performed,
337 - "total_analyzed": 1 if result.analysis_performed and result.success else 0,  
338 - "results": [result_dict] 463 + "total_analyzed": 1
  464 + if result.analysis_performed and result.success
  465 + else 0,
  466 + "results": [result_dict],
339 } 467 }
340 if not result.analysis_performed: 468 if not result.analysis_performed:
341 response["success"] = False 469 response["success"] = False
342 - response["warning"] = result.error_message or "情感分析功能不可用,已直接返回原始文本" 470 + response["warning"] = (
  471 + result.error_message or "情感分析功能不可用,已直接返回原始文本"
  472 + )
343 return response 473 return response
344 else: 474 else:
345 texts_list = list(texts) 475 texts_list = list(texts)
346 - batch_result = self.sentiment_analyzer.analyze_batch(texts_list, show_progress=True) 476 + batch_result = self.sentiment_analyzer.analyze_batch(
  477 + texts_list, show_progress=True
  478 + )
347 response = { 479 response = {
348 - "success": batch_result.analysis_performed and batch_result.success_count > 0,  
349 - "total_analyzed": batch_result.total_processed if batch_result.analysis_performed else 0, 480 + "success": batch_result.analysis_performed
  481 + and batch_result.success_count > 0,
  482 + "total_analyzed": batch_result.total_processed
  483 + if batch_result.analysis_performed
  484 + else 0,
350 "success_count": batch_result.success_count, 485 "success_count": batch_result.success_count,
351 "failed_count": batch_result.failed_count, 486 "failed_count": batch_result.failed_count,
352 - "average_confidence": batch_result.average_confidence if batch_result.analysis_performed else 0.0,  
353 - "results": [result.__dict__ for result in batch_result.results] 487 + "average_confidence": batch_result.average_confidence
  488 + if batch_result.analysis_performed
  489 + else 0.0,
  490 + "results": [result.__dict__ for result in batch_result.results],
354 } 491 }
355 if not batch_result.analysis_performed: 492 if not batch_result.analysis_performed:
356 warning = next( 493 warning = next(
357 - (r.error_message for r in batch_result.results if r.error_message),  
358 - "情感分析功能不可用,已直接返回原始文本" 494 + (
  495 + r.error_message
  496 + for r in batch_result.results
  497 + if r.error_message
  498 + ),
  499 + "情感分析功能不可用,已直接返回原始文本",
359 ) 500 )
360 response["success"] = False 501 response["success"] = False
361 response["warning"] = warning 502 response["warning"] = warning
362 return response 503 return response
363 - 504 +
364 except Exception as e: 505 except Exception as e:
365 logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}") 506 logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}")
366 - return {  
367 - "success": False,  
368 - "error": str(e),  
369 - "results": []  
370 - }  
371 - 507 + return {"success": False, "error": str(e), "results": []}
  508 +
372 def research(self, query: str, save_report: bool = True) -> str: 509 def research(self, query: str, save_report: bool = True) -> str:
373 """ 510 """
374 执行深度研究 511 执行深度研究
375 - 512 +
376 Args: 513 Args:
377 query: 研究查询 514 query: 研究查询
378 save_report: 是否保存报告到文件 515 save_report: 是否保存报告到文件
379 - 516 +
380 Returns: 517 Returns:
381 最终报告内容 518 最终报告内容
382 """ 519 """
383 - logger.info(f"\n{'='*60}") 520 + logger.info(f"\n{'=' * 60}")
384 logger.info(f"开始深度研究: {query}") 521 logger.info(f"开始深度研究: {query}")
385 - logger.info(f"{'='*60}")  
386 - 522 + logger.info(f"{'=' * 60}")
  523 +
387 try: 524 try:
388 # Step 1: 生成报告结构 525 # Step 1: 生成报告结构
389 self._generate_report_structure(query) 526 self._generate_report_structure(query)
390 - 527 +
391 # Step 2: 处理每个段落 528 # Step 2: 处理每个段落
392 self._process_paragraphs() 529 self._process_paragraphs()
393 - 530 +
394 # Step 3: 生成最终报告 531 # Step 3: 生成最终报告
395 final_report = self._generate_final_report() 532 final_report = self._generate_final_report()
396 - 533 +
397 # Step 4: 保存报告 534 # Step 4: 保存报告
398 if save_report: 535 if save_report:
399 self._save_report(final_report) 536 self._save_report(final_report)
400 537
401 logger.info("深度研究完成!") 538 logger.info("深度研究完成!")
402 - 539 +
403 return final_report 540 return final_report
404 - 541 +
405 except Exception as e: 542 except Exception as e:
406 logger.exception(f"研究过程中发生错误: {str(e)}") 543 logger.exception(f"研究过程中发生错误: {str(e)}")
407 raise e 544 raise e
408 - 545 +
409 def _generate_report_structure(self, query: str): 546 def _generate_report_structure(self, query: str):
410 """生成报告结构""" 547 """生成报告结构"""
411 logger.info(f"\n[步骤 1] 生成报告结构...") 548 logger.info(f"\n[步骤 1] 生成报告结构...")
412 - 549 +
413 # 创建报告结构节点 550 # 创建报告结构节点
414 report_structure_node = ReportStructureNode(self.llm_client, query) 551 report_structure_node = ReportStructureNode(self.llm_client, query)
415 - 552 +
416 # 生成结构并更新状态 553 # 生成结构并更新状态
417 self.state = report_structure_node.mutate_state(state=self.state) 554 self.state = report_structure_node.mutate_state(state=self.state)
418 - 555 +
419 _message = f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:" 556 _message = f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:"
420 for i, paragraph in enumerate(self.state.paragraphs, 1): 557 for i, paragraph in enumerate(self.state.paragraphs, 1):
421 _message += f"\n {i}. {paragraph.title}" 558 _message += f"\n {i}. {paragraph.title}"
422 logger.info(_message) 559 logger.info(_message)
423 - 560 +
424 def _process_paragraphs(self): 561 def _process_paragraphs(self):
425 """处理所有段落""" 562 """处理所有段落"""
426 total_paragraphs = len(self.state.paragraphs) 563 total_paragraphs = len(self.state.paragraphs)
427 - 564 +
428 for i in range(total_paragraphs): 565 for i in range(total_paragraphs):
429 - logger.info(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}") 566 + logger.info(
  567 + f"\n[步骤 2.{i + 1}] 处理段落: {self.state.paragraphs[i].title}"
  568 + )
430 logger.info("-" * 50) 569 logger.info("-" * 50)
431 - 570 +
432 # 初始搜索和总结 571 # 初始搜索和总结
433 self._initial_search_and_summary(i) 572 self._initial_search_and_summary(i)
434 - 573 +
435 # 反思循环 574 # 反思循环
436 self._reflection_loop(i) 575 self._reflection_loop(i)
437 - 576 +
438 # 标记段落完成 577 # 标记段落完成
439 self.state.paragraphs[i].research.mark_completed() 578 self.state.paragraphs[i].research.mark_completed()
440 - 579 +
441 progress = (i + 1) / total_paragraphs * 100 580 progress = (i + 1) / total_paragraphs * 100
442 logger.info(f"段落处理完成 ({progress:.1f}%)") 581 logger.info(f"段落处理完成 ({progress:.1f}%)")
443 - 582 +
444 def _initial_search_and_summary(self, paragraph_index: int): 583 def _initial_search_and_summary(self, paragraph_index: int):
445 """执行初始搜索和总结""" 584 """执行初始搜索和总结"""
446 paragraph = self.state.paragraphs[paragraph_index] 585 paragraph = self.state.paragraphs[paragraph_index]
447 - 586 +
448 # 准备搜索输入 587 # 准备搜索输入
449 - search_input = {  
450 - "title": paragraph.title,  
451 - "content": paragraph.content  
452 - }  
453 - 588 + search_input = {"title": paragraph.title, "content": paragraph.content}
  589 +
454 # 生成搜索查询和工具选择 590 # 生成搜索查询和工具选择
455 logger.info(" - 生成搜索查询...") 591 logger.info(" - 生成搜索查询...")
456 search_output = self.first_search_node.run(search_input) 592 search_output = self.first_search_node.run(search_input)
457 search_query = search_output["search_query"] 593 search_query = search_output["search_query"]
458 - search_tool = search_output.get("search_tool", "search_topic_globally") # 默认工具 594 + search_tool = search_output.get(
  595 + "search_tool", "search_topic_globally"
  596 + ) # 默认工具
459 reasoning = search_output["reasoning"] 597 reasoning = search_output["reasoning"]
460 - 598 +
461 logger.info(f" - 搜索查询: {search_query}") 599 logger.info(f" - 搜索查询: {search_query}")
462 logger.info(f" - 选择的工具: {search_tool}") 600 logger.info(f" - 选择的工具: {search_tool}")
463 logger.info(f" - 推理: {reasoning}") 601 logger.info(f" - 推理: {reasoning}")
464 - 602 +
465 # 执行搜索 603 # 执行搜索
466 logger.info(" - 执行数据库查询...") 604 logger.info(" - 执行数据库查询...")
467 - 605 +
468 # 处理特殊参数 606 # 处理特殊参数
469 search_kwargs = {} 607 search_kwargs = {}
470 - 608 +
471 # 处理需要日期的工具 609 # 处理需要日期的工具
472 if search_tool in ["search_topic_by_date", "search_topic_on_platform"]: 610 if search_tool in ["search_topic_by_date", "search_topic_on_platform"]:
473 start_date = search_output.get("start_date") 611 start_date = search_output.get("start_date")
474 end_date = search_output.get("end_date") 612 end_date = search_output.get("end_date")
475 - 613 +
476 if start_date and end_date: 614 if start_date and end_date:
477 # 验证日期格式 615 # 验证日期格式
478 - if self._validate_date_format(start_date) and self._validate_date_format(end_date): 616 + if self._validate_date_format(
  617 + start_date
  618 + ) and self._validate_date_format(end_date):
479 search_kwargs["start_date"] = start_date 619 search_kwargs["start_date"] = start_date
480 search_kwargs["end_date"] = end_date 620 search_kwargs["end_date"] = end_date
481 logger.info(f" - 时间范围: {start_date} 到 {end_date}") 621 logger.info(f" - 时间范围: {start_date} 到 {end_date}")
482 else: 622 else:
483 logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索") 623 logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
484 - logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}") 624 + logger.info(
  625 + f" 提供的日期: start_date={start_date}, end_date={end_date}"
  626 + )
485 search_tool = "search_topic_globally" 627 search_tool = "search_topic_globally"
486 elif search_tool == "search_topic_by_date": 628 elif search_tool == "search_topic_by_date":
487 logger.info(f" search_topic_by_date工具缺少时间参数,改用全局搜索") 629 logger.info(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
488 search_tool = "search_topic_globally" 630 search_tool = "search_topic_globally"
489 - 631 +
490 # 处理需要平台参数的工具 632 # 处理需要平台参数的工具
491 if search_tool == "search_topic_on_platform": 633 if search_tool == "search_topic_on_platform":
492 platform = search_output.get("platform") 634 platform = search_output.get("platform")
@@ -494,9 +636,11 @@ class DeepSearchAgent: @@ -494,9 +636,11 @@ class DeepSearchAgent:
494 search_kwargs["platform"] = platform 636 search_kwargs["platform"] = platform
495 logger.info(f" - 指定平台: {platform}") 637 logger.info(f" - 指定平台: {platform}")
496 else: 638 else:
497 - logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索") 639 + logger.warning(
  640 + f" search_topic_on_platform工具缺少平台参数,改用全局搜索"
  641 + )
498 search_tool = "search_topic_globally" 642 search_tool = "search_topic_globally"
499 - 643 +
500 # 处理限制参数,使用配置文件中的默认值而不是agent提供的参数 644 # 处理限制参数,使用配置文件中的默认值而不是agent提供的参数
501 if search_tool == "search_hot_content": 645 if search_tool == "search_hot_content":
502 time_period = search_output.get("time_period", "week") 646 time_period = search_output.get("time_period", "week")
@@ -505,9 +649,13 @@ class DeepSearchAgent: @@ -505,9 +649,13 @@ class DeepSearchAgent:
505 search_kwargs["limit"] = limit 649 search_kwargs["limit"] = limit
506 elif search_tool in ["search_topic_globally", "search_topic_by_date"]: 650 elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
507 if search_tool == "search_topic_globally": 651 if search_tool == "search_topic_globally":
508 - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE 652 + limit_per_table = (
  653 + self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
  654 + )
509 else: # search_topic_by_date 655 else: # search_topic_by_date
510 - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE 656 + limit_per_table = (
  657 + self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
  658 + )
511 search_kwargs["limit_per_table"] = limit_per_table 659 search_kwargs["limit_per_table"] = limit_per_table
512 elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: 660 elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
513 if search_tool == "get_comments_for_topic": 661 if search_tool == "get_comments_for_topic":
@@ -515,43 +663,55 @@ class DeepSearchAgent: @@ -515,43 +663,55 @@ class DeepSearchAgent:
515 else: # search_topic_on_platform 663 else: # search_topic_on_platform
516 limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT 664 limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT
517 search_kwargs["limit"] = limit 665 search_kwargs["limit"] = limit
518 -  
519 - search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)  
520 - 666 +
  667 + search_response = self.execute_search_tool(
  668 + search_tool, search_query, **search_kwargs
  669 + )
  670 +
521 # 转换为兼容格式 671 # 转换为兼容格式
522 search_results = [] 672 search_results = []
523 if search_response and search_response.results: 673 if search_response and search_response.results:
524 # 使用配置文件控制传递给LLM的结果数量,0表示不限制 674 # 使用配置文件控制传递给LLM的结果数量,0表示不限制
525 if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: 675 if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0:
526 - max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM) 676 + max_results = min(
  677 + len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM
  678 + )
527 else: 679 else:
528 max_results = len(search_response.results) # 不限制,传递所有结果 680 max_results = len(search_response.results) # 不限制,传递所有结果
529 for result in search_response.results[:max_results]: 681 for result in search_response.results[:max_results]:
530 - search_results.append({  
531 - 'title': result.title_or_content,  
532 - 'url': result.url or "",  
533 - 'content': result.title_or_content,  
534 - 'score': result.hotness_score,  
535 - 'raw_content': result.title_or_content,  
536 - 'published_date': result.publish_time.isoformat() if result.publish_time else None,  
537 - 'platform': result.platform,  
538 - 'content_type': result.content_type,  
539 - 'author': result.author_nickname,  
540 - 'engagement': result.engagement  
541 - })  
542 - 682 + search_results.append(
  683 + {
  684 + "title": result.title_or_content,
  685 + "url": result.url or "",
  686 + "content": result.title_or_content,
  687 + "score": result.hotness_score,
  688 + "raw_content": result.title_or_content,
  689 + "published_date": result.publish_time.isoformat()
  690 + if result.publish_time
  691 + else None,
  692 + "platform": result.platform,
  693 + "content_type": result.content_type,
  694 + "author": result.author_nickname,
  695 + "engagement": result.engagement,
  696 + }
  697 + )
  698 +
543 if search_results: 699 if search_results:
544 _message = f" - 找到 {len(search_results)} 个搜索结果" 700 _message = f" - 找到 {len(search_results)} 个搜索结果"
545 for j, result in enumerate(search_results, 1): 701 for j, result in enumerate(search_results, 1):
546 - date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" 702 + date_info = (
  703 + f" (发布于: {result.get('published_date', 'N/A')})"
  704 + if result.get("published_date")
  705 + else ""
  706 + )
547 _message += f"\n {j}. {result['title'][:50]}...{date_info}" 707 _message += f"\n {j}. {result['title'][:50]}...{date_info}"
548 logger.info(_message) 708 logger.info(_message)
549 else: 709 else:
550 logger.info(" - 未找到搜索结果") 710 logger.info(" - 未找到搜索结果")
551 - 711 +
552 # 更新状态中的搜索历史 712 # 更新状态中的搜索历史
553 paragraph.research.add_search_results(search_query, search_results) 713 paragraph.research.add_search_results(search_query, search_results)
554 - 714 +
555 # 生成初始总结 715 # 生成初始总结
556 logger.info(" - 生成初始总结...") 716 logger.info(" - 生成初始总结...")
557 summary_input = { 717 summary_input = {
@@ -560,63 +720,73 @@ class DeepSearchAgent: @@ -560,63 +720,73 @@ class DeepSearchAgent:
560 "search_query": search_query, 720 "search_query": search_query,
561 "search_results": format_search_results_for_prompt( 721 "search_results": format_search_results_for_prompt(
562 search_results, self.config.MAX_CONTENT_LENGTH 722 search_results, self.config.MAX_CONTENT_LENGTH
563 - ) 723 + ),
564 } 724 }
565 - 725 +
566 # 更新状态 726 # 更新状态
567 self.state = self.first_summary_node.mutate_state( 727 self.state = self.first_summary_node.mutate_state(
568 summary_input, self.state, paragraph_index 728 summary_input, self.state, paragraph_index
569 ) 729 )
570 - 730 +
571 logger.info(" - 初始总结完成") 731 logger.info(" - 初始总结完成")
572 - 732 +
573 def _reflection_loop(self, paragraph_index: int): 733 def _reflection_loop(self, paragraph_index: int):
574 """执行反思循环""" 734 """执行反思循环"""
575 paragraph = self.state.paragraphs[paragraph_index] 735 paragraph = self.state.paragraphs[paragraph_index]
576 - 736 +
577 for reflection_i in range(self.config.MAX_REFLECTIONS): 737 for reflection_i in range(self.config.MAX_REFLECTIONS):
578 logger.info(f" - 反思 {reflection_i + 1}/{self.config.MAX_REFLECTIONS}...") 738 logger.info(f" - 反思 {reflection_i + 1}/{self.config.MAX_REFLECTIONS}...")
579 - 739 +
580 # 准备反思输入 740 # 准备反思输入
581 reflection_input = { 741 reflection_input = {
582 "title": paragraph.title, 742 "title": paragraph.title,
583 "content": paragraph.content, 743 "content": paragraph.content,
584 - "paragraph_latest_state": paragraph.research.latest_summary 744 + "paragraph_latest_state": paragraph.research.latest_summary,
585 } 745 }
586 - 746 +
587 # 生成反思搜索查询 747 # 生成反思搜索查询
588 reflection_output = self.reflection_node.run(reflection_input) 748 reflection_output = self.reflection_node.run(reflection_input)
589 search_query = reflection_output["search_query"] 749 search_query = reflection_output["search_query"]
590 - search_tool = reflection_output.get("search_tool", "search_topic_globally") # 默认工具 750 + search_tool = reflection_output.get(
  751 + "search_tool", "search_topic_globally"
  752 + ) # 默认工具
591 reasoning = reflection_output["reasoning"] 753 reasoning = reflection_output["reasoning"]
592 - 754 +
593 logger.info(f" 反思查询: {search_query}") 755 logger.info(f" 反思查询: {search_query}")
594 logger.info(f" 选择的工具: {search_tool}") 756 logger.info(f" 选择的工具: {search_tool}")
595 logger.info(f" 反思推理: {reasoning}") 757 logger.info(f" 反思推理: {reasoning}")
596 - 758 +
597 # 执行反思搜索 759 # 执行反思搜索
598 # 处理特殊参数 760 # 处理特殊参数
599 search_kwargs = {} 761 search_kwargs = {}
600 - 762 +
601 # 处理需要日期的工具 763 # 处理需要日期的工具
602 if search_tool in ["search_topic_by_date", "search_topic_on_platform"]: 764 if search_tool in ["search_topic_by_date", "search_topic_on_platform"]:
603 start_date = reflection_output.get("start_date") 765 start_date = reflection_output.get("start_date")
604 end_date = reflection_output.get("end_date") 766 end_date = reflection_output.get("end_date")
605 - 767 +
606 if start_date and end_date: 768 if start_date and end_date:
607 # 验证日期格式 769 # 验证日期格式
608 - if self._validate_date_format(start_date) and self._validate_date_format(end_date): 770 + if self._validate_date_format(
  771 + start_date
  772 + ) and self._validate_date_format(end_date):
609 search_kwargs["start_date"] = start_date 773 search_kwargs["start_date"] = start_date
610 search_kwargs["end_date"] = end_date 774 search_kwargs["end_date"] = end_date
611 logger.info(f" 时间范围: {start_date} 到 {end_date}") 775 logger.info(f" 时间范围: {start_date} 到 {end_date}")
612 else: 776 else:
613 - logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")  
614 - logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}") 777 + logger.info(
  778 + f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索"
  779 + )
  780 + logger.info(
  781 + f" 提供的日期: start_date={start_date}, end_date={end_date}"
  782 + )
615 search_tool = "search_topic_globally" 783 search_tool = "search_topic_globally"
616 elif search_tool == "search_topic_by_date": 784 elif search_tool == "search_topic_by_date":
617 - logger.warning(f" search_topic_by_date工具缺少时间参数,改用全局搜索") 785 + logger.warning(
  786 + f" search_topic_by_date工具缺少时间参数,改用全局搜索"
  787 + )
618 search_tool = "search_topic_globally" 788 search_tool = "search_topic_globally"
619 - 789 +
620 # 处理需要平台参数的工具 790 # 处理需要平台参数的工具
621 if search_tool == "search_topic_on_platform": 791 if search_tool == "search_topic_on_platform":
622 platform = reflection_output.get("platform") 792 platform = reflection_output.get("platform")
@@ -624,9 +794,11 @@ class DeepSearchAgent: @@ -624,9 +794,11 @@ class DeepSearchAgent:
624 search_kwargs["platform"] = platform 794 search_kwargs["platform"] = platform
625 logger.info(f" 指定平台: {platform}") 795 logger.info(f" 指定平台: {platform}")
626 else: 796 else:
627 - logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索") 797 + logger.warning(
  798 + f" search_topic_on_platform工具缺少平台参数,改用全局搜索"
  799 + )
628 search_tool = "search_topic_globally" 800 search_tool = "search_topic_globally"
629 - 801 +
630 # 处理限制参数 802 # 处理限制参数
631 if search_tool == "search_hot_content": 803 if search_tool == "search_hot_content":
632 time_period = reflection_output.get("time_period", "week") 804 time_period = reflection_output.get("time_period", "week")
@@ -637,9 +809,13 @@ class DeepSearchAgent: @@ -637,9 +809,13 @@ class DeepSearchAgent:
637 elif search_tool in ["search_topic_globally", "search_topic_by_date"]: 809 elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
638 # 使用配置文件中的默认值,不允许agent控制limit_per_table参数 810 # 使用配置文件中的默认值,不允许agent控制limit_per_table参数
639 if search_tool == "search_topic_globally": 811 if search_tool == "search_topic_globally":
640 - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE 812 + limit_per_table = (
  813 + self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
  814 + )
641 else: # search_topic_by_date 815 else: # search_topic_by_date
642 - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE 816 + limit_per_table = (
  817 + self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
  818 + )
643 search_kwargs["limit_per_table"] = limit_per_table 819 search_kwargs["limit_per_table"] = limit_per_table
644 elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: 820 elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
645 # 使用配置文件中的默认值,不允许agent控制limit参数 821 # 使用配置文件中的默认值,不允许agent控制limit参数
@@ -648,43 +824,56 @@ class DeepSearchAgent: @@ -648,43 +824,56 @@ class DeepSearchAgent:
648 else: # search_topic_on_platform 824 else: # search_topic_on_platform
649 limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT 825 limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT
650 search_kwargs["limit"] = limit 826 search_kwargs["limit"] = limit
651 -  
652 - search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs)  
653 - 827 +
  828 + search_response = self.execute_search_tool(
  829 + search_tool, search_query, **search_kwargs
  830 + )
  831 +
654 # 转换为兼容格式 832 # 转换为兼容格式
655 search_results = [] 833 search_results = []
656 if search_response and search_response.results: 834 if search_response and search_response.results:
657 # 使用配置文件控制传递给LLM的结果数量,0表示不限制 835 # 使用配置文件控制传递给LLM的结果数量,0表示不限制
658 if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: 836 if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0:
659 - max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM) 837 + max_results = min(
  838 + len(search_response.results),
  839 + self.config.MAX_SEARCH_RESULTS_FOR_LLM,
  840 + )
660 else: 841 else:
661 max_results = len(search_response.results) # 不限制,传递所有结果 842 max_results = len(search_response.results) # 不限制,传递所有结果
662 for result in search_response.results[:max_results]: 843 for result in search_response.results[:max_results]:
663 - search_results.append({  
664 - 'title': result.title_or_content,  
665 - 'url': result.url or "",  
666 - 'content': result.title_or_content,  
667 - 'score': result.hotness_score,  
668 - 'raw_content': result.title_or_content,  
669 - 'published_date': result.publish_time.isoformat() if result.publish_time else None,  
670 - 'platform': result.platform,  
671 - 'content_type': result.content_type,  
672 - 'author': result.author_nickname,  
673 - 'engagement': result.engagement  
674 - })  
675 - 844 + search_results.append(
  845 + {
  846 + "title": result.title_or_content,
  847 + "url": result.url or "",
  848 + "content": result.title_or_content,
  849 + "score": result.hotness_score,
  850 + "raw_content": result.title_or_content,
  851 + "published_date": result.publish_time.isoformat()
  852 + if result.publish_time
  853 + else None,
  854 + "platform": result.platform,
  855 + "content_type": result.content_type,
  856 + "author": result.author_nickname,
  857 + "engagement": result.engagement,
  858 + }
  859 + )
  860 +
676 if search_results: 861 if search_results:
677 _message = f" 找到 {len(search_results)} 个反思搜索结果" 862 _message = f" 找到 {len(search_results)} 个反思搜索结果"
678 for j, result in enumerate(search_results, 1): 863 for j, result in enumerate(search_results, 1):
679 - date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" 864 + date_info = (
  865 + f" (发布于: {result.get('published_date', 'N/A')})"
  866 + if result.get("published_date")
  867 + else ""
  868 + )
680 _message += f"\n {j}. {result['title'][:50]}...{date_info}" 869 _message += f"\n {j}. {result['title'][:50]}...{date_info}"
681 logger.info(_message) 870 logger.info(_message)
682 else: 871 else:
683 logger.info(" 未找到反思搜索结果") 872 logger.info(" 未找到反思搜索结果")
684 - 873 +
685 # 更新搜索历史 874 # 更新搜索历史
686 paragraph.research.add_search_results(search_query, search_results) 875 paragraph.research.add_search_results(search_query, search_results)
687 - 876 +
688 # 生成反思总结 877 # 生成反思总结
689 reflection_summary_input = { 878 reflection_summary_input = {
690 "title": paragraph.title, 879 "title": paragraph.title,
@@ -693,28 +882,30 @@ class DeepSearchAgent: @@ -693,28 +882,30 @@ class DeepSearchAgent:
693 "search_results": format_search_results_for_prompt( 882 "search_results": format_search_results_for_prompt(
694 search_results, self.config.MAX_CONTENT_LENGTH 883 search_results, self.config.MAX_CONTENT_LENGTH
695 ), 884 ),
696 - "paragraph_latest_state": paragraph.research.latest_summary 885 + "paragraph_latest_state": paragraph.research.latest_summary,
697 } 886 }
698 - 887 +
699 # 更新状态 888 # 更新状态
700 self.state = self.reflection_summary_node.mutate_state( 889 self.state = self.reflection_summary_node.mutate_state(
701 reflection_summary_input, self.state, paragraph_index 890 reflection_summary_input, self.state, paragraph_index
702 ) 891 )
703 - 892 +
704 logger.info(f" 反思 {reflection_i + 1} 完成") 893 logger.info(f" 反思 {reflection_i + 1} 完成")
705 - 894 +
706 def _generate_final_report(self) -> str: 895 def _generate_final_report(self) -> str:
707 """生成最终报告""" 896 """生成最终报告"""
708 logger.info(f"\n[步骤 3] 生成最终报告...") 897 logger.info(f"\n[步骤 3] 生成最终报告...")
709 - 898 +
710 # 准备报告数据 899 # 准备报告数据
711 report_data = [] 900 report_data = []
712 for paragraph in self.state.paragraphs: 901 for paragraph in self.state.paragraphs:
713 - report_data.append({  
714 - "title": paragraph.title,  
715 - "paragraph_latest_state": paragraph.research.latest_summary  
716 - })  
717 - 902 + report_data.append(
  903 + {
  904 + "title": paragraph.title,
  905 + "paragraph_latest_state": paragraph.research.latest_summary,
  906 + }
  907 + )
  908 +
718 # 格式化报告 909 # 格式化报告
719 try: 910 try:
720 final_report = self.report_formatting_node.run(report_data) 911 final_report = self.report_formatting_node.run(report_data)
@@ -723,46 +914,48 @@ class DeepSearchAgent: @@ -723,46 +914,48 @@ class DeepSearchAgent:
723 final_report = self.report_formatting_node.format_report_manually( 914 final_report = self.report_formatting_node.format_report_manually(
724 report_data, self.state.report_title 915 report_data, self.state.report_title
725 ) 916 )
726 - 917 +
727 # 更新状态 918 # 更新状态
728 self.state.final_report = final_report 919 self.state.final_report = final_report
729 self.state.mark_completed() 920 self.state.mark_completed()
730 - 921 +
731 logger.info("最终报告生成完成") 922 logger.info("最终报告生成完成")
732 return final_report 923 return final_report
733 - 924 +
734 def _save_report(self, report_content: str): 925 def _save_report(self, report_content: str):
735 """保存报告到文件""" 926 """保存报告到文件"""
736 # 生成文件名 927 # 生成文件名
737 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 928 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
738 - query_safe = "".join(c for c in self.state.query if c.isalnum() or c in (' ', '-', '_')).rstrip()  
739 - query_safe = query_safe.replace(' ', '_')[:30]  
740 - 929 + query_safe = "".join(
  930 + c for c in self.state.query if c.isalnum() or c in (" ", "-", "_")
  931 + ).rstrip()
  932 + query_safe = query_safe.replace(" ", "_")[:30]
  933 +
741 filename = f"deep_search_report_{query_safe}_{timestamp}.md" 934 filename = f"deep_search_report_{query_safe}_{timestamp}.md"
742 filepath = os.path.join(self.config.OUTPUT_DIR, filename) 935 filepath = os.path.join(self.config.OUTPUT_DIR, filename)
743 - 936 +
744 # 保存报告 937 # 保存报告
745 - with open(filepath, 'w', encoding='utf-8') as f: 938 + with open(filepath, "w", encoding="utf-8") as f:
746 f.write(report_content) 939 f.write(report_content)
747 - 940 +
748 logger.info(f"报告已保存到: {filepath}") 941 logger.info(f"报告已保存到: {filepath}")
749 - 942 +
750 # 保存状态(如果配置允许) 943 # 保存状态(如果配置允许)
751 if self.config.SAVE_INTERMEDIATE_STATES: 944 if self.config.SAVE_INTERMEDIATE_STATES:
752 state_filename = f"state_{query_safe}_{timestamp}.json" 945 state_filename = f"state_{query_safe}_{timestamp}.json"
753 state_filepath = os.path.join(self.config.OUTPUT_DIR, state_filename) 946 state_filepath = os.path.join(self.config.OUTPUT_DIR, state_filename)
754 self.state.save_to_file(state_filepath) 947 self.state.save_to_file(state_filepath)
755 logger.info(f"状态已保存到: {state_filepath}") 948 logger.info(f"状态已保存到: {state_filepath}")
756 - 949 +
757 def get_progress_summary(self) -> Dict[str, Any]: 950 def get_progress_summary(self) -> Dict[str, Any]:
758 """获取进度摘要""" 951 """获取进度摘要"""
759 return self.state.get_progress_summary() 952 return self.state.get_progress_summary()
760 - 953 +
761 def load_state(self, filepath: str): 954 def load_state(self, filepath: str):
762 """从文件加载状态""" 955 """从文件加载状态"""
763 self.state = State.load_from_file(filepath) 956 self.state = State.load_from_file(filepath)
764 logger.info(f"状态已从 {filepath} 加载") 957 logger.info(f"状态已从 {filepath} 加载")
765 - 958 +
766 def save_state(self, filepath: str): 959 def save_state(self, filepath: str):
767 """保存状态到文件""" 960 """保存状态到文件"""
768 self.state.save_to_file(filepath) 961 self.state.save_to_file(filepath)
@@ -772,12 +965,12 @@ class DeepSearchAgent: @@ -772,12 +965,12 @@ class DeepSearchAgent:
772 def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent: 965 def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent:
773 """ 966 """
774 创建Deep Search Agent实例的便捷函数 967 创建Deep Search Agent实例的便捷函数
775 - 968 +
776 Args: 969 Args:
777 config_file: 配置文件路径 970 config_file: 配置文件路径
778 - 971 +
779 Returns: 972 Returns:
780 DeepSearchAgent实例 973 DeepSearchAgent实例
781 """ 974 """
782 - config = Settings() # 以空配置初始化,而从从环境变量初始化 975 + config = Settings() # 以空配置初始化,而从从环境变量初始化
783 return DeepSearchAgent(config) 976 return DeepSearchAgent(config)