ghmark675

feat(insight_agent): search results cluster

@@ -7,22 +7,35 @@ import json @@ -7,22 +7,35 @@ import json
7 import os 7 import os
8 import re 8 import re
9 from datetime import datetime 9 from datetime import datetime
10 -from typing import Optional, Dict, Any, List, Union 10 +from typing import Any, Dict, List, Optional, Union
  11 +
  12 +import numpy as np
11 from loguru import logger 13 from loguru import logger
  14 +from sentence_transformers import SentenceTransformer
  15 +from sklearn.cluster import KMeans
12 16
13 from .llms import LLMClient 17 from .llms import LLMClient
14 from .nodes import ( 18 from .nodes import (
15 - ReportStructureNode,  
16 FirstSearchNode, 19 FirstSearchNode,
17 - ReflectionNode,  
18 FirstSummaryNode, 20 FirstSummaryNode,
  21 + ReflectionNode,
19 ReflectionSummaryNode, 22 ReflectionSummaryNode,
20 - ReportFormattingNode 23 + ReportFormattingNode,
  24 + ReportStructureNode,
21 ) 25 )
22 from .state import State 26 from .state import State
23 -from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer, multilingual_sentiment_analyzer  
24 -from .utils.config import settings, Settings 27 +from .tools import (
  28 + DBResponse,
  29 + MediaCrawlerDB,
  30 + keyword_optimizer,
  31 + multilingual_sentiment_analyzer,
  32 +)
25 from .utils import format_search_results_for_prompt 33 from .utils import format_search_results_for_prompt
  34 +from .utils.config import Settings, settings
  35 +
  36 +ENABLE_CLUSTERING: bool = True # 是否启用聚类采样
  37 +MAX_CLUSTERED_RESULTS: int = 50 # 聚类后最大返回结果数
  38 +RESULTS_PER_CLUSTER: int = 5 # 每个聚类返回的结果数
26 39
27 40
28 class DeepSearchAgent: 41 class DeepSearchAgent:
@@ -40,10 +53,12 @@ class DeepSearchAgent: @@ -40,10 +53,12 @@ class DeepSearchAgent:
40 # 初始化LLM客户端 53 # 初始化LLM客户端
41 self.llm_client = self._initialize_llm() 54 self.llm_client = self._initialize_llm()
42 55
43 -  
44 # 初始化搜索工具集 56 # 初始化搜索工具集
45 self.search_agency = MediaCrawlerDB() 57 self.search_agency = MediaCrawlerDB()
46 58
  59 + # 初始化聚类小模型(懒加载)
  60 + self._clustering_model = None
  61 +
47 # 初始化情感分析器 62 # 初始化情感分析器
48 self.sentiment_analyzer = multilingual_sentiment_analyzer 63 self.sentiment_analyzer = multilingual_sentiment_analyzer
49 64
@@ -77,6 +92,15 @@ class DeepSearchAgent: @@ -77,6 +92,15 @@ class DeepSearchAgent:
77 self.reflection_summary_node = ReflectionSummaryNode(self.llm_client) 92 self.reflection_summary_node = ReflectionSummaryNode(self.llm_client)
78 self.report_formatting_node = ReportFormattingNode(self.llm_client) 93 self.report_formatting_node = ReportFormattingNode(self.llm_client)
79 94
  95 + def _get_clustering_model(self):
  96 + """懒加载聚类模型"""
  97 + if self._clustering_model is None:
  98 + logger.info(" 加载聚类模型 (paraphrase-multilingual-MiniLM-L12-v2)...")
  99 + self._clustering_model = SentenceTransformer(
  100 + "paraphrase-multilingual-MiniLM-L12-v2"
  101 + )
  102 + return self._clustering_model
  103 +
80 def _validate_date_format(self, date_str: str) -> bool: 104 def _validate_date_format(self, date_str: str) -> bool:
81 """ 105 """
82 验证日期格式是否为YYYY-MM-DD 106 验证日期格式是否为YYYY-MM-DD
@@ -91,17 +115,75 @@ class DeepSearchAgent: @@ -91,17 +115,75 @@ class DeepSearchAgent:
91 return False 115 return False
92 116
93 # 检查格式 117 # 检查格式
94 - pattern = r'^\d{4}-\d{2}-\d{2}$' 118 + pattern = r"^\d{4}-\d{2}-\d{2}$"
95 if not re.match(pattern, date_str): 119 if not re.match(pattern, date_str):
96 return False 120 return False
97 121
98 # 检查日期是否有效 122 # 检查日期是否有效
99 try: 123 try:
100 - datetime.strptime(date_str, '%Y-%m-%d') 124 + datetime.strptime(date_str, "%Y-%m-%d")
101 return True 125 return True
102 except ValueError: 126 except ValueError:
103 return False 127 return False
104 128
  129 + def _cluster_and_sample_results(
  130 + self, results: List, max_results: int = 50, results_per_cluster: int = 5
  131 + ) -> List:
  132 + """
  133 + 对搜索结果进行聚类并采样
  134 +
  135 + Args:
  136 + results: 搜索结果列表
  137 + max_results: 最大返回结果数
  138 + results_per_cluster: 每个聚类返回的结果数
  139 +
  140 + Returns:
  141 + 采样后的结果列表
  142 + """
  143 + if len(results) <= max_results:
  144 + return results
  145 +
  146 + try:
  147 + # 提取文本
  148 + texts = [r.title_or_content[:500] for r in results]
  149 +
  150 + # 获取模型并编码
  151 + model = self._get_clustering_model()
  152 + embeddings = model.encode(texts, show_progress_bar=False)
  153 +
  154 + # 计算聚类数
  155 + n_clusters = min(max(2, max_results // results_per_cluster), len(results))
  156 +
  157 + # KMeans聚类
  158 + kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
  159 + labels = kmeans.fit_predict(embeddings)
  160 +
  161 + # 从每个聚类采样
  162 + sampled_results = []
  163 + for cluster_id in range(n_clusters):
  164 + cluster_indices = np.where(labels == cluster_id)[0]
  165 + cluster_results = [(results[i], i) for i in cluster_indices]
  166 + cluster_results.sort(
  167 + key=lambda x: x[0].hotness_score or 0, reverse=True
  168 + )
  169 +
  170 + for result, _ in cluster_results[:results_per_cluster]:
  171 + sampled_results.append(result)
  172 + if len(sampled_results) >= max_results:
  173 + break
  174 +
  175 + if len(sampled_results) >= max_results:
  176 + break
  177 +
  178 + logger.info(
  179 + f" 聚类完成: {len(results)} 条 -> {n_clusters} 个主题 -> {len(sampled_results)} 条代表性结果"
  180 + )
  181 + return sampled_results
  182 +
  183 + except Exception as e:
  184 + logger.warning(f" 聚类失败,返回前{max_results}条: {str(e)}")
  185 + return results[:max_results]
  186 +
105 def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse: 187 def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse:
106 """ 188 """
107 执行指定的数据库查询工具(集成关键词优化中间件和情感分析) 189 执行指定的数据库查询工具(集成关键词优化中间件和情感分析)
@@ -127,7 +209,9 @@ class DeepSearchAgent: @@ -127,7 +209,9 @@ class DeepSearchAgent:
127 if tool_name == "search_hot_content": 209 if tool_name == "search_hot_content":
128 time_period = kwargs.get("time_period", "week") 210 time_period = kwargs.get("time_period", "week")
129 limit = kwargs.get("limit", 100) 211 limit = kwargs.get("limit", 100)
130 - response = self.search_agency.search_hot_content(time_period=time_period, limit=limit) 212 + response = self.search_agency.search_hot_content(
  213 + time_period=time_period, limit=limit
  214 + )
131 215
132 # 检查是否需要进行情感分析 216 # 检查是否需要进行情感分析
133 enable_sentiment = kwargs.get("enable_sentiment", True) 217 enable_sentiment = kwargs.get("enable_sentiment", True)
@@ -151,17 +235,16 @@ class DeepSearchAgent: @@ -151,17 +235,16 @@ class DeepSearchAgent:
151 tool_name="analyze_sentiment", 235 tool_name="analyze_sentiment",
152 parameters={ 236 parameters={
153 "texts": texts if isinstance(texts, list) else [texts], 237 "texts": texts if isinstance(texts, list) else [texts],
154 - **kwargs 238 + **kwargs,
155 }, 239 },
156 results=[], # 情感分析不返回搜索结果 240 results=[], # 情感分析不返回搜索结果
157 results_count=0, 241 results_count=0,
158 - metadata=sentiment_result 242 + metadata=sentiment_result,
159 ) 243 )
160 244
161 # 对于需要搜索词的工具,使用关键词优化中间件 245 # 对于需要搜索词的工具,使用关键词优化中间件
162 optimized_response = keyword_optimizer.optimize_keywords( 246 optimized_response = keyword_optimizer.optimize_keywords(
163 - original_query=query,  
164 - context=f"使用{tool_name}工具进行查询" 247 + original_query=query, context=f"使用{tool_name}工具进行查询"
165 ) 248 )
166 249
167 logger.info(f" 🔍 原始查询: '{query}'") 250 logger.info(f" 🔍 原始查询: '{query}'")
@@ -177,34 +260,62 @@ class DeepSearchAgent: @@ -177,34 +260,62 @@ class DeepSearchAgent:
177 try: 260 try:
178 if tool_name == "search_topic_globally": 261 if tool_name == "search_topic_globally":
179 # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 262 # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数
180 - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE  
181 - response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=limit_per_table) 263 + limit_per_table = (
  264 + self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
  265 + )
  266 + response = self.search_agency.search_topic_globally(
  267 + topic=keyword, limit_per_table=limit_per_table
  268 + )
182 elif tool_name == "search_topic_by_date": 269 elif tool_name == "search_topic_by_date":
183 start_date = kwargs.get("start_date") 270 start_date = kwargs.get("start_date")
184 end_date = kwargs.get("end_date") 271 end_date = kwargs.get("end_date")
185 # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 272 # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数
186 - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE 273 + limit_per_table = (
  274 + self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
  275 + )
187 if not start_date or not end_date: 276 if not start_date or not end_date:
188 - raise ValueError("search_topic_by_date工具需要start_date和end_date参数")  
189 - response = self.search_agency.search_topic_by_date(topic=keyword, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table) 277 + raise ValueError(
  278 + "search_topic_by_date工具需要start_date和end_date参数"
  279 + )
  280 + response = self.search_agency.search_topic_by_date(
  281 + topic=keyword,
  282 + start_date=start_date,
  283 + end_date=end_date,
  284 + limit_per_table=limit_per_table,
  285 + )
190 elif tool_name == "get_comments_for_topic": 286 elif tool_name == "get_comments_for_topic":
191 # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 287 # 使用配置文件中的默认值,按关键词数量分配,但保证最小值
192 - limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT // len(optimized_response.optimized_keywords) 288 + limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT // len(
  289 + optimized_response.optimized_keywords
  290 + )
193 limit = max(limit, 50) 291 limit = max(limit, 50)
194 - response = self.search_agency.get_comments_for_topic(topic=keyword, limit=limit) 292 + response = self.search_agency.get_comments_for_topic(
  293 + topic=keyword, limit=limit
  294 + )
195 elif tool_name == "search_topic_on_platform": 295 elif tool_name == "search_topic_on_platform":
196 platform = kwargs.get("platform") 296 platform = kwargs.get("platform")
197 start_date = kwargs.get("start_date") 297 start_date = kwargs.get("start_date")
198 end_date = kwargs.get("end_date") 298 end_date = kwargs.get("end_date")
199 # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 299 # 使用配置文件中的默认值,按关键词数量分配,但保证最小值
200 - limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT // len(optimized_response.optimized_keywords) 300 + limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT // len(
  301 + optimized_response.optimized_keywords
  302 + )
201 limit = max(limit, 30) 303 limit = max(limit, 30)
202 if not platform: 304 if not platform:
203 raise ValueError("search_topic_on_platform工具需要platform参数") 305 raise ValueError("search_topic_on_platform工具需要platform参数")
204 - response = self.search_agency.search_topic_on_platform(platform=platform, topic=keyword, start_date=start_date, end_date=end_date, limit=limit) 306 + response = self.search_agency.search_topic_on_platform(
  307 + platform=platform,
  308 + topic=keyword,
  309 + start_date=start_date,
  310 + end_date=end_date,
  311 + limit=limit,
  312 + )
205 else: 313 else:
206 logger.info(f" 未知的搜索工具: {tool_name},使用默认全局搜索") 314 logger.info(f" 未知的搜索工具: {tool_name},使用默认全局搜索")
207 - response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE) 315 + response = self.search_agency.search_topic_globally(
  316 + topic=keyword,
  317 + limit_per_table=self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE,
  318 + )
208 319
209 # 收集结果 320 # 收集结果
210 if response.results: 321 if response.results:
@@ -222,6 +333,13 @@ class DeepSearchAgent: @@ -222,6 +333,13 @@ class DeepSearchAgent:
222 unique_results = self._deduplicate_results(all_results) 333 unique_results = self._deduplicate_results(all_results)
223 logger.info(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)} 条") 334 logger.info(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)} 条")
224 335
  336 + if ENABLE_CLUSTERING:
  337 + unique_results = self._cluster_and_sample_results(
  338 + unique_results,
  339 + max_results=MAX_CLUSTERED_RESULTS,
  340 + results_per_cluster=RESULTS_PER_CLUSTER,
  341 + )
  342 +
225 # 构建整合后的响应 343 # 构建整合后的响应
226 integrated_response = DBResponse( 344 integrated_response = DBResponse(
227 tool_name=f"{tool_name}_optimized", 345 tool_name=f"{tool_name}_optimized",
@@ -229,10 +347,10 @@ class DeepSearchAgent: @@ -229,10 +347,10 @@ class DeepSearchAgent:
229 "original_query": query, 347 "original_query": query,
230 "optimized_keywords": optimized_response.optimized_keywords, 348 "optimized_keywords": optimized_response.optimized_keywords,
231 "optimization_reasoning": optimized_response.reasoning, 349 "optimization_reasoning": optimized_response.reasoning,
232 - **kwargs 350 + **kwargs,
233 }, 351 },
234 results=unique_results, 352 results=unique_results,
235 - results_count=len(unique_results) 353 + results_count=len(unique_results),
236 ) 354 )
237 355
238 # 检查是否需要进行情感分析 356 # 检查是否需要进行情感分析
@@ -242,7 +360,9 @@ class DeepSearchAgent: @@ -242,7 +360,9 @@ class DeepSearchAgent:
242 sentiment_analysis = self._perform_sentiment_analysis(unique_results) 360 sentiment_analysis = self._perform_sentiment_analysis(unique_results)
243 if sentiment_analysis: 361 if sentiment_analysis:
244 # 将情感分析结果添加到响应的parameters中 362 # 将情感分析结果添加到响应的parameters中
245 - integrated_response.parameters["sentiment_analysis"] = sentiment_analysis 363 + integrated_response.parameters["sentiment_analysis"] = (
  364 + sentiment_analysis
  365 + )
246 logger.info(f" ✅ 情感分析完成") 366 logger.info(f" ✅ 情感分析完成")
247 367
248 return integrated_response 368 return integrated_response
@@ -275,7 +395,10 @@ class DeepSearchAgent: @@ -275,7 +395,10 @@ class DeepSearchAgent:
275 """ 395 """
276 try: 396 try:
277 # 初始化情感分析器(如果尚未初始化且未被禁用) 397 # 初始化情感分析器(如果尚未初始化且未被禁用)
278 - if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled: 398 + if (
  399 + not self.sentiment_analyzer.is_initialized
  400 + and not self.sentiment_analyzer.is_disabled
  401 + ):
279 logger.info(" 初始化情感分析模型...") 402 logger.info(" 初始化情感分析模型...")
280 if not self.sentiment_analyzer.initialize(): 403 if not self.sentiment_analyzer.initialize():
281 logger.info(" 情感分析模型初始化失败,将直接透传原始文本") 404 logger.info(" 情感分析模型初始化失败,将直接透传原始文本")
@@ -290,15 +413,15 @@ class DeepSearchAgent: @@ -290,15 +413,15 @@ class DeepSearchAgent:
290 "platform": result.platform, 413 "platform": result.platform,
291 "author": result.author_nickname, 414 "author": result.author_nickname,
292 "url": result.url, 415 "url": result.url,
293 - "publish_time": str(result.publish_time) if result.publish_time else None 416 + "publish_time": str(result.publish_time)
  417 + if result.publish_time
  418 + else None,
294 } 419 }
295 results_dict.append(result_dict) 420 results_dict.append(result_dict)
296 421
297 # 执行情感分析 422 # 执行情感分析
298 sentiment_analysis = self.sentiment_analyzer.analyze_query_results( 423 sentiment_analysis = self.sentiment_analyzer.analyze_query_results(
299 - query_results=results_dict,  
300 - text_field="content",  
301 - min_confidence=0.5 424 + query_results=results_dict, text_field="content", min_confidence=0.5
302 ) 425 )
303 426
304 return sentiment_analysis.get("sentiment_analysis") 427 return sentiment_analysis.get("sentiment_analysis")
@@ -321,7 +444,10 @@ class DeepSearchAgent: @@ -321,7 +444,10 @@ class DeepSearchAgent:
321 444
322 try: 445 try:
323 # 初始化情感分析器(如果尚未初始化且未被禁用) 446 # 初始化情感分析器(如果尚未初始化且未被禁用)
324 - if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled: 447 + if (
  448 + not self.sentiment_analyzer.is_initialized
  449 + and not self.sentiment_analyzer.is_disabled
  450 + ):
325 logger.info(" 初始化情感分析模型...") 451 logger.info(" 初始化情感分析模型...")
326 if not self.sentiment_analyzer.initialize(): 452 if not self.sentiment_analyzer.initialize():
327 logger.info(" 情感分析模型初始化失败,将直接透传原始文本") 453 logger.info(" 情感分析模型初始化失败,将直接透传原始文本")
@@ -334,28 +460,43 @@ class DeepSearchAgent: @@ -334,28 +460,43 @@ class DeepSearchAgent:
334 result_dict = result.__dict__ 460 result_dict = result.__dict__
335 response = { 461 response = {
336 "success": result.success and result.analysis_performed, 462 "success": result.success and result.analysis_performed,
337 - "total_analyzed": 1 if result.analysis_performed and result.success else 0,  
338 - "results": [result_dict] 463 + "total_analyzed": 1
  464 + if result.analysis_performed and result.success
  465 + else 0,
  466 + "results": [result_dict],
339 } 467 }
340 if not result.analysis_performed: 468 if not result.analysis_performed:
341 response["success"] = False 469 response["success"] = False
342 - response["warning"] = result.error_message or "情感分析功能不可用,已直接返回原始文本" 470 + response["warning"] = (
  471 + result.error_message or "情感分析功能不可用,已直接返回原始文本"
  472 + )
343 return response 473 return response
344 else: 474 else:
345 texts_list = list(texts) 475 texts_list = list(texts)
346 - batch_result = self.sentiment_analyzer.analyze_batch(texts_list, show_progress=True) 476 + batch_result = self.sentiment_analyzer.analyze_batch(
  477 + texts_list, show_progress=True
  478 + )
347 response = { 479 response = {
348 - "success": batch_result.analysis_performed and batch_result.success_count > 0,  
349 - "total_analyzed": batch_result.total_processed if batch_result.analysis_performed else 0, 480 + "success": batch_result.analysis_performed
  481 + and batch_result.success_count > 0,
  482 + "total_analyzed": batch_result.total_processed
  483 + if batch_result.analysis_performed
  484 + else 0,
350 "success_count": batch_result.success_count, 485 "success_count": batch_result.success_count,
351 "failed_count": batch_result.failed_count, 486 "failed_count": batch_result.failed_count,
352 - "average_confidence": batch_result.average_confidence if batch_result.analysis_performed else 0.0,  
353 - "results": [result.__dict__ for result in batch_result.results] 487 + "average_confidence": batch_result.average_confidence
  488 + if batch_result.analysis_performed
  489 + else 0.0,
  490 + "results": [result.__dict__ for result in batch_result.results],
354 } 491 }
355 if not batch_result.analysis_performed: 492 if not batch_result.analysis_performed:
356 warning = next( 493 warning = next(
357 - (r.error_message for r in batch_result.results if r.error_message),  
358 - "情感分析功能不可用,已直接返回原始文本" 494 + (
  495 + r.error_message
  496 + for r in batch_result.results
  497 + if r.error_message
  498 + ),
  499 + "情感分析功能不可用,已直接返回原始文本",
359 ) 500 )
360 response["success"] = False 501 response["success"] = False
361 response["warning"] = warning 502 response["warning"] = warning
@@ -363,11 +504,7 @@ class DeepSearchAgent: @@ -363,11 +504,7 @@ class DeepSearchAgent:
363 504
364 except Exception as e: 505 except Exception as e:
365 logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}") 506 logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}")
366 - return {  
367 - "success": False,  
368 - "error": str(e),  
369 - "results": []  
370 - } 507 + return {"success": False, "error": str(e), "results": []}
371 508
372 def research(self, query: str, save_report: bool = True) -> str: 509 def research(self, query: str, save_report: bool = True) -> str:
373 """ 510 """
@@ -380,9 +517,9 @@ class DeepSearchAgent: @@ -380,9 +517,9 @@ class DeepSearchAgent:
380 Returns: 517 Returns:
381 最终报告内容 518 最终报告内容
382 """ 519 """
383 - logger.info(f"\n{'='*60}") 520 + logger.info(f"\n{'=' * 60}")
384 logger.info(f"开始深度研究: {query}") 521 logger.info(f"开始深度研究: {query}")
385 - logger.info(f"{'='*60}") 522 + logger.info(f"{'=' * 60}")
386 523
387 try: 524 try:
388 # Step 1: 生成报告结构 525 # Step 1: 生成报告结构
@@ -426,7 +563,9 @@ class DeepSearchAgent: @@ -426,7 +563,9 @@ class DeepSearchAgent:
426 total_paragraphs = len(self.state.paragraphs) 563 total_paragraphs = len(self.state.paragraphs)
427 564
428 for i in range(total_paragraphs): 565 for i in range(total_paragraphs):
429 - logger.info(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}") 566 + logger.info(
  567 + f"\n[步骤 2.{i + 1}] 处理段落: {self.state.paragraphs[i].title}"
  568 + )
430 logger.info("-" * 50) 569 logger.info("-" * 50)
431 570
432 # 初始搜索和总结 571 # 初始搜索和总结
@@ -446,16 +585,15 @@ class DeepSearchAgent: @@ -446,16 +585,15 @@ class DeepSearchAgent:
446 paragraph = self.state.paragraphs[paragraph_index] 585 paragraph = self.state.paragraphs[paragraph_index]
447 586
448 # 准备搜索输入 587 # 准备搜索输入
449 - search_input = {  
450 - "title": paragraph.title,  
451 - "content": paragraph.content  
452 - } 588 + search_input = {"title": paragraph.title, "content": paragraph.content}
453 589
454 # 生成搜索查询和工具选择 590 # 生成搜索查询和工具选择
455 logger.info(" - 生成搜索查询...") 591 logger.info(" - 生成搜索查询...")
456 search_output = self.first_search_node.run(search_input) 592 search_output = self.first_search_node.run(search_input)
457 search_query = search_output["search_query"] 593 search_query = search_output["search_query"]
458 - search_tool = search_output.get("search_tool", "search_topic_globally") # 默认工具 594 + search_tool = search_output.get(
  595 + "search_tool", "search_topic_globally"
  596 + ) # 默认工具
459 reasoning = search_output["reasoning"] 597 reasoning = search_output["reasoning"]
460 598
461 logger.info(f" - 搜索查询: {search_query}") 599 logger.info(f" - 搜索查询: {search_query}")
@@ -475,13 +613,17 @@ class DeepSearchAgent: @@ -475,13 +613,17 @@ class DeepSearchAgent:
475 613
476 if start_date and end_date: 614 if start_date and end_date:
477 # 验证日期格式 615 # 验证日期格式
478 - if self._validate_date_format(start_date) and self._validate_date_format(end_date): 616 + if self._validate_date_format(
  617 + start_date
  618 + ) and self._validate_date_format(end_date):
479 search_kwargs["start_date"] = start_date 619 search_kwargs["start_date"] = start_date
480 search_kwargs["end_date"] = end_date 620 search_kwargs["end_date"] = end_date
481 logger.info(f" - 时间范围: {start_date} 到 {end_date}") 621 logger.info(f" - 时间范围: {start_date} 到 {end_date}")
482 else: 622 else:
483 logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索") 623 logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
484 - logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}") 624 + logger.info(
  625 + f" 提供的日期: start_date={start_date}, end_date={end_date}"
  626 + )
485 search_tool = "search_topic_globally" 627 search_tool = "search_topic_globally"
486 elif search_tool == "search_topic_by_date": 628 elif search_tool == "search_topic_by_date":
487 logger.info(f" search_topic_by_date工具缺少时间参数,改用全局搜索") 629 logger.info(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
@@ -494,7 +636,9 @@ class DeepSearchAgent: @@ -494,7 +636,9 @@ class DeepSearchAgent:
494 search_kwargs["platform"] = platform 636 search_kwargs["platform"] = platform
495 logger.info(f" - 指定平台: {platform}") 637 logger.info(f" - 指定平台: {platform}")
496 else: 638 else:
497 - logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索") 639 + logger.warning(
  640 + f" search_topic_on_platform工具缺少平台参数,改用全局搜索"
  641 + )
498 search_tool = "search_topic_globally" 642 search_tool = "search_topic_globally"
499 643
500 # 处理限制参数,使用配置文件中的默认值而不是agent提供的参数 644 # 处理限制参数,使用配置文件中的默认值而不是agent提供的参数
@@ -505,9 +649,13 @@ class DeepSearchAgent: @@ -505,9 +649,13 @@ class DeepSearchAgent:
505 search_kwargs["limit"] = limit 649 search_kwargs["limit"] = limit
506 elif search_tool in ["search_topic_globally", "search_topic_by_date"]: 650 elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
507 if search_tool == "search_topic_globally": 651 if search_tool == "search_topic_globally":
508 - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE 652 + limit_per_table = (
  653 + self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
  654 + )
509 else: # search_topic_by_date 655 else: # search_topic_by_date
510 - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE 656 + limit_per_table = (
  657 + self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
  658 + )
511 search_kwargs["limit_per_table"] = limit_per_table 659 search_kwargs["limit_per_table"] = limit_per_table
512 elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: 660 elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
513 if search_tool == "get_comments_for_topic": 661 if search_tool == "get_comments_for_topic":
@@ -516,34 +664,46 @@ class DeepSearchAgent: @@ -516,34 +664,46 @@ class DeepSearchAgent:
516 limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT 664 limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT
517 search_kwargs["limit"] = limit 665 search_kwargs["limit"] = limit
518 666
519 - search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) 667 + search_response = self.execute_search_tool(
  668 + search_tool, search_query, **search_kwargs
  669 + )
520 670
521 # 转换为兼容格式 671 # 转换为兼容格式
522 search_results = [] 672 search_results = []
523 if search_response and search_response.results: 673 if search_response and search_response.results:
524 # 使用配置文件控制传递给LLM的结果数量,0表示不限制 674 # 使用配置文件控制传递给LLM的结果数量,0表示不限制
525 if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: 675 if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0:
526 - max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM) 676 + max_results = min(
  677 + len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM
  678 + )
527 else: 679 else:
528 max_results = len(search_response.results) # 不限制,传递所有结果 680 max_results = len(search_response.results) # 不限制,传递所有结果
529 for result in search_response.results[:max_results]: 681 for result in search_response.results[:max_results]:
530 - search_results.append({  
531 - 'title': result.title_or_content,  
532 - 'url': result.url or "",  
533 - 'content': result.title_or_content,  
534 - 'score': result.hotness_score,  
535 - 'raw_content': result.title_or_content,  
536 - 'published_date': result.publish_time.isoformat() if result.publish_time else None,  
537 - 'platform': result.platform,  
538 - 'content_type': result.content_type,  
539 - 'author': result.author_nickname,  
540 - 'engagement': result.engagement  
541 - }) 682 + search_results.append(
  683 + {
  684 + "title": result.title_or_content,
  685 + "url": result.url or "",
  686 + "content": result.title_or_content,
  687 + "score": result.hotness_score,
  688 + "raw_content": result.title_or_content,
  689 + "published_date": result.publish_time.isoformat()
  690 + if result.publish_time
  691 + else None,
  692 + "platform": result.platform,
  693 + "content_type": result.content_type,
  694 + "author": result.author_nickname,
  695 + "engagement": result.engagement,
  696 + }
  697 + )
542 698
543 if search_results: 699 if search_results:
544 _message = f" - 找到 {len(search_results)} 个搜索结果" 700 _message = f" - 找到 {len(search_results)} 个搜索结果"
545 for j, result in enumerate(search_results, 1): 701 for j, result in enumerate(search_results, 1):
546 - date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" 702 + date_info = (
  703 + f" (发布于: {result.get('published_date', 'N/A')})"
  704 + if result.get("published_date")
  705 + else ""
  706 + )
547 _message += f"\n {j}. {result['title'][:50]}...{date_info}" 707 _message += f"\n {j}. {result['title'][:50]}...{date_info}"
548 logger.info(_message) 708 logger.info(_message)
549 else: 709 else:
@@ -560,7 +720,7 @@ class DeepSearchAgent: @@ -560,7 +720,7 @@ class DeepSearchAgent:
560 "search_query": search_query, 720 "search_query": search_query,
561 "search_results": format_search_results_for_prompt( 721 "search_results": format_search_results_for_prompt(
562 search_results, self.config.MAX_CONTENT_LENGTH 722 search_results, self.config.MAX_CONTENT_LENGTH
563 - ) 723 + ),
564 } 724 }
565 725
566 # 更新状态 726 # 更新状态
@@ -581,13 +741,15 @@ class DeepSearchAgent: @@ -581,13 +741,15 @@ class DeepSearchAgent:
581 reflection_input = { 741 reflection_input = {
582 "title": paragraph.title, 742 "title": paragraph.title,
583 "content": paragraph.content, 743 "content": paragraph.content,
584 - "paragraph_latest_state": paragraph.research.latest_summary 744 + "paragraph_latest_state": paragraph.research.latest_summary,
585 } 745 }
586 746
587 # 生成反思搜索查询 747 # 生成反思搜索查询
588 reflection_output = self.reflection_node.run(reflection_input) 748 reflection_output = self.reflection_node.run(reflection_input)
589 search_query = reflection_output["search_query"] 749 search_query = reflection_output["search_query"]
590 - search_tool = reflection_output.get("search_tool", "search_topic_globally") # 默认工具 750 + search_tool = reflection_output.get(
  751 + "search_tool", "search_topic_globally"
  752 + ) # 默认工具
591 reasoning = reflection_output["reasoning"] 753 reasoning = reflection_output["reasoning"]
592 754
593 logger.info(f" 反思查询: {search_query}") 755 logger.info(f" 反思查询: {search_query}")
@@ -605,16 +767,24 @@ class DeepSearchAgent: @@ -605,16 +767,24 @@ class DeepSearchAgent:
605 767
606 if start_date and end_date: 768 if start_date and end_date:
607 # 验证日期格式 769 # 验证日期格式
608 - if self._validate_date_format(start_date) and self._validate_date_format(end_date): 770 + if self._validate_date_format(
  771 + start_date
  772 + ) and self._validate_date_format(end_date):
609 search_kwargs["start_date"] = start_date 773 search_kwargs["start_date"] = start_date
610 search_kwargs["end_date"] = end_date 774 search_kwargs["end_date"] = end_date
611 logger.info(f" 时间范围: {start_date} 到 {end_date}") 775 logger.info(f" 时间范围: {start_date} 到 {end_date}")
612 else: 776 else:
613 - logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")  
614 - logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}") 777 + logger.info(
  778 + f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索"
  779 + )
  780 + logger.info(
  781 + f" 提供的日期: start_date={start_date}, end_date={end_date}"
  782 + )
615 search_tool = "search_topic_globally" 783 search_tool = "search_topic_globally"
616 elif search_tool == "search_topic_by_date": 784 elif search_tool == "search_topic_by_date":
617 - logger.warning(f" search_topic_by_date工具缺少时间参数,改用全局搜索") 785 + logger.warning(
  786 + f" search_topic_by_date工具缺少时间参数,改用全局搜索"
  787 + )
618 search_tool = "search_topic_globally" 788 search_tool = "search_topic_globally"
619 789
620 # 处理需要平台参数的工具 790 # 处理需要平台参数的工具
@@ -624,7 +794,9 @@ class DeepSearchAgent: @@ -624,7 +794,9 @@ class DeepSearchAgent:
624 search_kwargs["platform"] = platform 794 search_kwargs["platform"] = platform
625 logger.info(f" 指定平台: {platform}") 795 logger.info(f" 指定平台: {platform}")
626 else: 796 else:
627 - logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索") 797 + logger.warning(
  798 + f" search_topic_on_platform工具缺少平台参数,改用全局搜索"
  799 + )
628 search_tool = "search_topic_globally" 800 search_tool = "search_topic_globally"
629 801
630 # 处理限制参数 802 # 处理限制参数
@@ -637,9 +809,13 @@ class DeepSearchAgent: @@ -637,9 +809,13 @@ class DeepSearchAgent:
637 elif search_tool in ["search_topic_globally", "search_topic_by_date"]: 809 elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
638 # 使用配置文件中的默认值,不允许agent控制limit_per_table参数 810 # 使用配置文件中的默认值,不允许agent控制limit_per_table参数
639 if search_tool == "search_topic_globally": 811 if search_tool == "search_topic_globally":
640 - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE 812 + limit_per_table = (
  813 + self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
  814 + )
641 else: # search_topic_by_date 815 else: # search_topic_by_date
642 - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE 816 + limit_per_table = (
  817 + self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
  818 + )
643 search_kwargs["limit_per_table"] = limit_per_table 819 search_kwargs["limit_per_table"] = limit_per_table
644 elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: 820 elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
645 # 使用配置文件中的默认值,不允许agent控制limit参数 821 # 使用配置文件中的默认值,不允许agent控制limit参数
@@ -649,34 +825,47 @@ class DeepSearchAgent: @@ -649,34 +825,47 @@ class DeepSearchAgent:
649 limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT 825 limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT
650 search_kwargs["limit"] = limit 826 search_kwargs["limit"] = limit
651 827
652 - search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) 828 + search_response = self.execute_search_tool(
  829 + search_tool, search_query, **search_kwargs
  830 + )
653 831
654 # 转换为兼容格式 832 # 转换为兼容格式
655 search_results = [] 833 search_results = []
656 if search_response and search_response.results: 834 if search_response and search_response.results:
657 # 使用配置文件控制传递给LLM的结果数量,0表示不限制 835 # 使用配置文件控制传递给LLM的结果数量,0表示不限制
658 if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: 836 if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0:
659 - max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM) 837 + max_results = min(
  838 + len(search_response.results),
  839 + self.config.MAX_SEARCH_RESULTS_FOR_LLM,
  840 + )
660 else: 841 else:
661 max_results = len(search_response.results) # 不限制,传递所有结果 842 max_results = len(search_response.results) # 不限制,传递所有结果
662 for result in search_response.results[:max_results]: 843 for result in search_response.results[:max_results]:
663 - search_results.append({  
664 - 'title': result.title_or_content,  
665 - 'url': result.url or "",  
666 - 'content': result.title_or_content,  
667 - 'score': result.hotness_score,  
668 - 'raw_content': result.title_or_content,  
669 - 'published_date': result.publish_time.isoformat() if result.publish_time else None,  
670 - 'platform': result.platform,  
671 - 'content_type': result.content_type,  
672 - 'author': result.author_nickname,  
673 - 'engagement': result.engagement  
674 - }) 844 + search_results.append(
  845 + {
  846 + "title": result.title_or_content,
  847 + "url": result.url or "",
  848 + "content": result.title_or_content,
  849 + "score": result.hotness_score,
  850 + "raw_content": result.title_or_content,
  851 + "published_date": result.publish_time.isoformat()
  852 + if result.publish_time
  853 + else None,
  854 + "platform": result.platform,
  855 + "content_type": result.content_type,
  856 + "author": result.author_nickname,
  857 + "engagement": result.engagement,
  858 + }
  859 + )
675 860
676 if search_results: 861 if search_results:
677 _message = f" 找到 {len(search_results)} 个反思搜索结果" 862 _message = f" 找到 {len(search_results)} 个反思搜索结果"
678 for j, result in enumerate(search_results, 1): 863 for j, result in enumerate(search_results, 1):
679 - date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" 864 + date_info = (
  865 + f" (发布于: {result.get('published_date', 'N/A')})"
  866 + if result.get("published_date")
  867 + else ""
  868 + )
680 _message += f"\n {j}. {result['title'][:50]}...{date_info}" 869 _message += f"\n {j}. {result['title'][:50]}...{date_info}"
681 logger.info(_message) 870 logger.info(_message)
682 else: 871 else:
@@ -693,7 +882,7 @@ class DeepSearchAgent: @@ -693,7 +882,7 @@ class DeepSearchAgent:
693 "search_results": format_search_results_for_prompt( 882 "search_results": format_search_results_for_prompt(
694 search_results, self.config.MAX_CONTENT_LENGTH 883 search_results, self.config.MAX_CONTENT_LENGTH
695 ), 884 ),
696 - "paragraph_latest_state": paragraph.research.latest_summary 885 + "paragraph_latest_state": paragraph.research.latest_summary,
697 } 886 }
698 887
699 # 更新状态 888 # 更新状态
@@ -710,10 +899,12 @@ class DeepSearchAgent: @@ -710,10 +899,12 @@ class DeepSearchAgent:
710 # 准备报告数据 899 # 准备报告数据
711 report_data = [] 900 report_data = []
712 for paragraph in self.state.paragraphs: 901 for paragraph in self.state.paragraphs:
713 - report_data.append({ 902 + report_data.append(
  903 + {
714 "title": paragraph.title, 904 "title": paragraph.title,
715 - "paragraph_latest_state": paragraph.research.latest_summary  
716 - }) 905 + "paragraph_latest_state": paragraph.research.latest_summary,
  906 + }
  907 + )
717 908
718 # 格式化报告 909 # 格式化报告
719 try: 910 try:
@@ -735,14 +926,16 @@ class DeepSearchAgent: @@ -735,14 +926,16 @@ class DeepSearchAgent:
735 """保存报告到文件""" 926 """保存报告到文件"""
736 # 生成文件名 927 # 生成文件名
737 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 928 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
738 - query_safe = "".join(c for c in self.state.query if c.isalnum() or c in (' ', '-', '_')).rstrip()  
739 - query_safe = query_safe.replace(' ', '_')[:30] 929 + query_safe = "".join(
  930 + c for c in self.state.query if c.isalnum() or c in (" ", "-", "_")
  931 + ).rstrip()
  932 + query_safe = query_safe.replace(" ", "_")[:30]
740 933
741 filename = f"deep_search_report_{query_safe}_{timestamp}.md" 934 filename = f"deep_search_report_{query_safe}_{timestamp}.md"
742 filepath = os.path.join(self.config.OUTPUT_DIR, filename) 935 filepath = os.path.join(self.config.OUTPUT_DIR, filename)
743 936
744 # 保存报告 937 # 保存报告
745 - with open(filepath, 'w', encoding='utf-8') as f: 938 + with open(filepath, "w", encoding="utf-8") as f:
746 f.write(report_content) 939 f.write(report_content)
747 940
748 logger.info(f"报告已保存到: {filepath}") 941 logger.info(f"报告已保存到: {filepath}")