马一丁
Committed by GitHub

Merge pull request #350 from 666ghj/feature/insight_agent_cluster

feat(insight_agent): search results cluster
@@ -7,22 +7,35 @@ import json @@ -7,22 +7,35 @@ import json
7 import os 7 import os
8 import re 8 import re
9 from datetime import datetime 9 from datetime import datetime
10 -from typing import Optional, Dict, Any, List, Union 10 +from typing import Any, Dict, List, Optional, Union
  11 +
  12 +import numpy as np
11 from loguru import logger 13 from loguru import logger
  14 +from sentence_transformers import SentenceTransformer
  15 +from sklearn.cluster import KMeans
12 16
13 from .llms import LLMClient 17 from .llms import LLMClient
14 from .nodes import ( 18 from .nodes import (
15 - ReportStructureNode,  
16 FirstSearchNode, 19 FirstSearchNode,
17 - ReflectionNode,  
18 FirstSummaryNode, 20 FirstSummaryNode,
  21 + ReflectionNode,
19 ReflectionSummaryNode, 22 ReflectionSummaryNode,
20 - ReportFormattingNode 23 + ReportFormattingNode,
  24 + ReportStructureNode,
21 ) 25 )
22 from .state import State 26 from .state import State
23 -from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer, multilingual_sentiment_analyzer  
24 -from .utils.config import settings, Settings 27 +from .tools import (
  28 + DBResponse,
  29 + MediaCrawlerDB,
  30 + keyword_optimizer,
  31 + multilingual_sentiment_analyzer,
  32 +)
25 from .utils import format_search_results_for_prompt 33 from .utils import format_search_results_for_prompt
  34 +from .utils.config import Settings, settings
  35 +
  36 +ENABLE_CLUSTERING: bool = True # 是否启用聚类采样
  37 +MAX_CLUSTERED_RESULTS: int = 50 # 聚类后最大返回结果数
  38 +RESULTS_PER_CLUSTER: int = 5 # 每个聚类返回的结果数
26 39
27 40
28 class DeepSearchAgent: 41 class DeepSearchAgent:
@@ -40,10 +53,12 @@ class DeepSearchAgent: @@ -40,10 +53,12 @@ class DeepSearchAgent:
40 # 初始化LLM客户端 53 # 初始化LLM客户端
41 self.llm_client = self._initialize_llm() 54 self.llm_client = self._initialize_llm()
42 55
43 -  
44 # 初始化搜索工具集 56 # 初始化搜索工具集
45 self.search_agency = MediaCrawlerDB() 57 self.search_agency = MediaCrawlerDB()
46 58
  59 + # 初始化聚类小模型(懒加载)
  60 + self._clustering_model = None
  61 +
47 # 初始化情感分析器 62 # 初始化情感分析器
48 self.sentiment_analyzer = multilingual_sentiment_analyzer 63 self.sentiment_analyzer = multilingual_sentiment_analyzer
49 64
@@ -77,6 +92,15 @@ class DeepSearchAgent: @@ -77,6 +92,15 @@ class DeepSearchAgent:
77 self.reflection_summary_node = ReflectionSummaryNode(self.llm_client) 92 self.reflection_summary_node = ReflectionSummaryNode(self.llm_client)
78 self.report_formatting_node = ReportFormattingNode(self.llm_client) 93 self.report_formatting_node = ReportFormattingNode(self.llm_client)
79 94
  95 + def _get_clustering_model(self):
  96 + """懒加载聚类模型"""
  97 + if self._clustering_model is None:
  98 + logger.info(" 加载聚类模型 (paraphrase-multilingual-MiniLM-L12-v2)...")
  99 + self._clustering_model = SentenceTransformer(
  100 + "paraphrase-multilingual-MiniLM-L12-v2"
  101 + )
  102 + return self._clustering_model
  103 +
80 def _validate_date_format(self, date_str: str) -> bool: 104 def _validate_date_format(self, date_str: str) -> bool:
81 """ 105 """
82 验证日期格式是否为YYYY-MM-DD 106 验证日期格式是否为YYYY-MM-DD
@@ -91,17 +115,78 @@ class DeepSearchAgent: @@ -91,17 +115,78 @@ class DeepSearchAgent:
91 return False 115 return False
92 116
93 # 检查格式 117 # 检查格式
94 - pattern = r'^\d{4}-\d{2}-\d{2}$' 118 + pattern = r"^\d{4}-\d{2}-\d{2}$"
95 if not re.match(pattern, date_str): 119 if not re.match(pattern, date_str):
96 return False 120 return False
97 121
98 # 检查日期是否有效 122 # 检查日期是否有效
99 try: 123 try:
100 - datetime.strptime(date_str, '%Y-%m-%d') 124 + datetime.strptime(date_str, "%Y-%m-%d")
101 return True 125 return True
102 except ValueError: 126 except ValueError:
103 return False 127 return False
104 128
  129 + def _cluster_and_sample_results(
  130 + self,
  131 + results: List,
  132 + max_results: int = MAX_CLUSTERED_RESULTS,
  133 + results_per_cluster: int = RESULTS_PER_CLUSTER,
  134 + ) -> List:
  135 + """
  136 + 对搜索结果进行聚类并采样
  137 +
  138 + Args:
  139 + results: 搜索结果列表
  140 + max_results: 最大返回结果数
  141 + results_per_cluster: 每个聚类返回的结果数
  142 +
  143 + Returns:
  144 + 采样后的结果列表
  145 + """
  146 + if len(results) <= max_results:
  147 + return results
  148 +
  149 + try:
  150 + # 提取文本
  151 + texts = [r.title_or_content[:500] for r in results]
  152 +
  153 + # 获取模型并编码
  154 + model = self._get_clustering_model()
  155 + embeddings = model.encode(texts, show_progress_bar=False)
  156 +
  157 + # 计算聚类数
  158 + n_clusters = min(max(2, max_results // results_per_cluster), len(results))
  159 +
  160 + # KMeans聚类
  161 + kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
  162 + labels = kmeans.fit_predict(embeddings)
  163 +
  164 + # 从每个聚类采样
  165 + sampled_results = []
  166 + for cluster_id in range(n_clusters):
  167 + cluster_indices = np.flatnonzero(labels == cluster_id)
  168 + cluster_results = [(results[i], i) for i in cluster_indices]
  169 + cluster_results.sort(
  170 + key=lambda x: x[0].hotness_score or 0, reverse=True
  171 + )
  172 +
  173 + for result, _ in cluster_results[:results_per_cluster]:
  174 + sampled_results.append(result)
  175 + if len(sampled_results) >= max_results:
  176 + break
  177 +
  178 + if len(sampled_results) >= max_results:
  179 + break
  180 +
  181 + logger.info(
  182 + f" 聚类完成: {len(results)} 条 -> {n_clusters} 个主题 -> {len(sampled_results)} 条代表性结果"
  183 + )
  184 + return sampled_results
  185 +
  186 + except Exception as e:
  187 + logger.warning(f" 聚类失败,返回前{max_results}条: {str(e)}")
  188 + return results[:max_results]
  189 +
105 def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse: 190 def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse:
106 """ 191 """
107 执行指定的数据库查询工具(集成关键词优化中间件和情感分析) 192 执行指定的数据库查询工具(集成关键词优化中间件和情感分析)
@@ -127,7 +212,9 @@ class DeepSearchAgent: @@ -127,7 +212,9 @@ class DeepSearchAgent:
127 if tool_name == "search_hot_content": 212 if tool_name == "search_hot_content":
128 time_period = kwargs.get("time_period", "week") 213 time_period = kwargs.get("time_period", "week")
129 limit = kwargs.get("limit", 100) 214 limit = kwargs.get("limit", 100)
130 - response = self.search_agency.search_hot_content(time_period=time_period, limit=limit) 215 + response = self.search_agency.search_hot_content(
  216 + time_period=time_period, limit=limit
  217 + )
131 218
132 # 检查是否需要进行情感分析 219 # 检查是否需要进行情感分析
133 enable_sentiment = kwargs.get("enable_sentiment", True) 220 enable_sentiment = kwargs.get("enable_sentiment", True)
@@ -151,17 +238,16 @@ class DeepSearchAgent: @@ -151,17 +238,16 @@ class DeepSearchAgent:
151 tool_name="analyze_sentiment", 238 tool_name="analyze_sentiment",
152 parameters={ 239 parameters={
153 "texts": texts if isinstance(texts, list) else [texts], 240 "texts": texts if isinstance(texts, list) else [texts],
154 - **kwargs 241 + **kwargs,
155 }, 242 },
156 results=[], # 情感分析不返回搜索结果 243 results=[], # 情感分析不返回搜索结果
157 results_count=0, 244 results_count=0,
158 - metadata=sentiment_result 245 + metadata=sentiment_result,
159 ) 246 )
160 247
161 # 对于需要搜索词的工具,使用关键词优化中间件 248 # 对于需要搜索词的工具,使用关键词优化中间件
162 optimized_response = keyword_optimizer.optimize_keywords( 249 optimized_response = keyword_optimizer.optimize_keywords(
163 - original_query=query,  
164 - context=f"使用{tool_name}工具进行查询" 250 + original_query=query, context=f"使用{tool_name}工具进行查询"
165 ) 251 )
166 252
167 logger.info(f" 🔍 原始查询: '{query}'") 253 logger.info(f" 🔍 原始查询: '{query}'")
@@ -177,34 +263,62 @@ class DeepSearchAgent: @@ -177,34 +263,62 @@ class DeepSearchAgent:
177 try: 263 try:
178 if tool_name == "search_topic_globally": 264 if tool_name == "search_topic_globally":
179 # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 265 # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数
180 - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE  
181 - response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=limit_per_table) 266 + limit_per_table = (
  267 + self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
  268 + )
  269 + response = self.search_agency.search_topic_globally(
  270 + topic=keyword, limit_per_table=limit_per_table
  271 + )
182 elif tool_name == "search_topic_by_date": 272 elif tool_name == "search_topic_by_date":
183 start_date = kwargs.get("start_date") 273 start_date = kwargs.get("start_date")
184 end_date = kwargs.get("end_date") 274 end_date = kwargs.get("end_date")
185 # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 275 # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数
186 - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE 276 + limit_per_table = (
  277 + self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
  278 + )
187 if not start_date or not end_date: 279 if not start_date or not end_date:
188 - raise ValueError("search_topic_by_date工具需要start_date和end_date参数")  
189 - response = self.search_agency.search_topic_by_date(topic=keyword, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table) 280 + raise ValueError(
  281 + "search_topic_by_date工具需要start_date和end_date参数"
  282 + )
  283 + response = self.search_agency.search_topic_by_date(
  284 + topic=keyword,
  285 + start_date=start_date,
  286 + end_date=end_date,
  287 + limit_per_table=limit_per_table,
  288 + )
190 elif tool_name == "get_comments_for_topic": 289 elif tool_name == "get_comments_for_topic":
191 # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 290 # 使用配置文件中的默认值,按关键词数量分配,但保证最小值
192 - limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT // len(optimized_response.optimized_keywords) 291 + limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT // len(
  292 + optimized_response.optimized_keywords
  293 + )
193 limit = max(limit, 50) 294 limit = max(limit, 50)
194 - response = self.search_agency.get_comments_for_topic(topic=keyword, limit=limit) 295 + response = self.search_agency.get_comments_for_topic(
  296 + topic=keyword, limit=limit
  297 + )
195 elif tool_name == "search_topic_on_platform": 298 elif tool_name == "search_topic_on_platform":
196 platform = kwargs.get("platform") 299 platform = kwargs.get("platform")
197 start_date = kwargs.get("start_date") 300 start_date = kwargs.get("start_date")
198 end_date = kwargs.get("end_date") 301 end_date = kwargs.get("end_date")
199 # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 302 # 使用配置文件中的默认值,按关键词数量分配,但保证最小值
200 - limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT // len(optimized_response.optimized_keywords) 303 + limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT // len(
  304 + optimized_response.optimized_keywords
  305 + )
201 limit = max(limit, 30) 306 limit = max(limit, 30)
202 if not platform: 307 if not platform:
203 raise ValueError("search_topic_on_platform工具需要platform参数") 308 raise ValueError("search_topic_on_platform工具需要platform参数")
204 - response = self.search_agency.search_topic_on_platform(platform=platform, topic=keyword, start_date=start_date, end_date=end_date, limit=limit) 309 + response = self.search_agency.search_topic_on_platform(
  310 + platform=platform,
  311 + topic=keyword,
  312 + start_date=start_date,
  313 + end_date=end_date,
  314 + limit=limit,
  315 + )
205 else: 316 else:
206 logger.info(f" 未知的搜索工具: {tool_name},使用默认全局搜索") 317 logger.info(f" 未知的搜索工具: {tool_name},使用默认全局搜索")
207 - response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE) 318 + response = self.search_agency.search_topic_globally(
  319 + topic=keyword,
  320 + limit_per_table=self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE,
  321 + )
208 322
209 # 收集结果 323 # 收集结果
210 if response.results: 324 if response.results:
@@ -222,6 +336,13 @@ class DeepSearchAgent: @@ -222,6 +336,13 @@ class DeepSearchAgent:
222 unique_results = self._deduplicate_results(all_results) 336 unique_results = self._deduplicate_results(all_results)
223 logger.info(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)} 条") 337 logger.info(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)} 条")
224 338
  339 + if ENABLE_CLUSTERING:
  340 + unique_results = self._cluster_and_sample_results(
  341 + unique_results,
  342 + max_results=MAX_CLUSTERED_RESULTS,
  343 + results_per_cluster=RESULTS_PER_CLUSTER,
  344 + )
  345 +
225 # 构建整合后的响应 346 # 构建整合后的响应
226 integrated_response = DBResponse( 347 integrated_response = DBResponse(
227 tool_name=f"{tool_name}_optimized", 348 tool_name=f"{tool_name}_optimized",
@@ -229,10 +350,10 @@ class DeepSearchAgent: @@ -229,10 +350,10 @@ class DeepSearchAgent:
229 "original_query": query, 350 "original_query": query,
230 "optimized_keywords": optimized_response.optimized_keywords, 351 "optimized_keywords": optimized_response.optimized_keywords,
231 "optimization_reasoning": optimized_response.reasoning, 352 "optimization_reasoning": optimized_response.reasoning,
232 - **kwargs 353 + **kwargs,
233 }, 354 },
234 results=unique_results, 355 results=unique_results,
235 - results_count=len(unique_results) 356 + results_count=len(unique_results),
236 ) 357 )
237 358
238 # 检查是否需要进行情感分析 359 # 检查是否需要进行情感分析
@@ -242,7 +363,9 @@ class DeepSearchAgent: @@ -242,7 +363,9 @@ class DeepSearchAgent:
242 sentiment_analysis = self._perform_sentiment_analysis(unique_results) 363 sentiment_analysis = self._perform_sentiment_analysis(unique_results)
243 if sentiment_analysis: 364 if sentiment_analysis:
244 # 将情感分析结果添加到响应的parameters中 365 # 将情感分析结果添加到响应的parameters中
245 - integrated_response.parameters["sentiment_analysis"] = sentiment_analysis 366 + integrated_response.parameters["sentiment_analysis"] = (
  367 + sentiment_analysis
  368 + )
246 logger.info(f" ✅ 情感分析完成") 369 logger.info(f" ✅ 情感分析完成")
247 370
248 return integrated_response 371 return integrated_response
@@ -275,7 +398,10 @@ class DeepSearchAgent: @@ -275,7 +398,10 @@ class DeepSearchAgent:
275 """ 398 """
276 try: 399 try:
277 # 初始化情感分析器(如果尚未初始化且未被禁用) 400 # 初始化情感分析器(如果尚未初始化且未被禁用)
278 - if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled: 401 + if (
  402 + not self.sentiment_analyzer.is_initialized
  403 + and not self.sentiment_analyzer.is_disabled
  404 + ):
279 logger.info(" 初始化情感分析模型...") 405 logger.info(" 初始化情感分析模型...")
280 if not self.sentiment_analyzer.initialize(): 406 if not self.sentiment_analyzer.initialize():
281 logger.info(" 情感分析模型初始化失败,将直接透传原始文本") 407 logger.info(" 情感分析模型初始化失败,将直接透传原始文本")
@@ -290,15 +416,15 @@ class DeepSearchAgent: @@ -290,15 +416,15 @@ class DeepSearchAgent:
290 "platform": result.platform, 416 "platform": result.platform,
291 "author": result.author_nickname, 417 "author": result.author_nickname,
292 "url": result.url, 418 "url": result.url,
293 - "publish_time": str(result.publish_time) if result.publish_time else None 419 + "publish_time": str(result.publish_time)
  420 + if result.publish_time
  421 + else None,
294 } 422 }
295 results_dict.append(result_dict) 423 results_dict.append(result_dict)
296 424
297 # 执行情感分析 425 # 执行情感分析
298 sentiment_analysis = self.sentiment_analyzer.analyze_query_results( 426 sentiment_analysis = self.sentiment_analyzer.analyze_query_results(
299 - query_results=results_dict,  
300 - text_field="content",  
301 - min_confidence=0.5 427 + query_results=results_dict, text_field="content", min_confidence=0.5
302 ) 428 )
303 429
304 return sentiment_analysis.get("sentiment_analysis") 430 return sentiment_analysis.get("sentiment_analysis")
@@ -321,7 +447,10 @@ class DeepSearchAgent: @@ -321,7 +447,10 @@ class DeepSearchAgent:
321 447
322 try: 448 try:
323 # 初始化情感分析器(如果尚未初始化且未被禁用) 449 # 初始化情感分析器(如果尚未初始化且未被禁用)
324 - if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled: 450 + if (
  451 + not self.sentiment_analyzer.is_initialized
  452 + and not self.sentiment_analyzer.is_disabled
  453 + ):
325 logger.info(" 初始化情感分析模型...") 454 logger.info(" 初始化情感分析模型...")
326 if not self.sentiment_analyzer.initialize(): 455 if not self.sentiment_analyzer.initialize():
327 logger.info(" 情感分析模型初始化失败,将直接透传原始文本") 456 logger.info(" 情感分析模型初始化失败,将直接透传原始文本")
@@ -334,28 +463,43 @@ class DeepSearchAgent: @@ -334,28 +463,43 @@ class DeepSearchAgent:
334 result_dict = result.__dict__ 463 result_dict = result.__dict__
335 response = { 464 response = {
336 "success": result.success and result.analysis_performed, 465 "success": result.success and result.analysis_performed,
337 - "total_analyzed": 1 if result.analysis_performed and result.success else 0,  
338 - "results": [result_dict] 466 + "total_analyzed": 1
  467 + if result.analysis_performed and result.success
  468 + else 0,
  469 + "results": [result_dict],
339 } 470 }
340 if not result.analysis_performed: 471 if not result.analysis_performed:
341 response["success"] = False 472 response["success"] = False
342 - response["warning"] = result.error_message or "情感分析功能不可用,已直接返回原始文本" 473 + response["warning"] = (
  474 + result.error_message or "情感分析功能不可用,已直接返回原始文本"
  475 + )
343 return response 476 return response
344 else: 477 else:
345 texts_list = list(texts) 478 texts_list = list(texts)
346 - batch_result = self.sentiment_analyzer.analyze_batch(texts_list, show_progress=True) 479 + batch_result = self.sentiment_analyzer.analyze_batch(
  480 + texts_list, show_progress=True
  481 + )
347 response = { 482 response = {
348 - "success": batch_result.analysis_performed and batch_result.success_count > 0,  
349 - "total_analyzed": batch_result.total_processed if batch_result.analysis_performed else 0, 483 + "success": batch_result.analysis_performed
  484 + and batch_result.success_count > 0,
  485 + "total_analyzed": batch_result.total_processed
  486 + if batch_result.analysis_performed
  487 + else 0,
350 "success_count": batch_result.success_count, 488 "success_count": batch_result.success_count,
351 "failed_count": batch_result.failed_count, 489 "failed_count": batch_result.failed_count,
352 - "average_confidence": batch_result.average_confidence if batch_result.analysis_performed else 0.0,  
353 - "results": [result.__dict__ for result in batch_result.results] 490 + "average_confidence": batch_result.average_confidence
  491 + if batch_result.analysis_performed
  492 + else 0.0,
  493 + "results": [result.__dict__ for result in batch_result.results],
354 } 494 }
355 if not batch_result.analysis_performed: 495 if not batch_result.analysis_performed:
356 warning = next( 496 warning = next(
357 - (r.error_message for r in batch_result.results if r.error_message),  
358 - "情感分析功能不可用,已直接返回原始文本" 497 + (
  498 + r.error_message
  499 + for r in batch_result.results
  500 + if r.error_message
  501 + ),
  502 + "情感分析功能不可用,已直接返回原始文本",
359 ) 503 )
360 response["success"] = False 504 response["success"] = False
361 response["warning"] = warning 505 response["warning"] = warning
@@ -363,11 +507,7 @@ class DeepSearchAgent: @@ -363,11 +507,7 @@ class DeepSearchAgent:
363 507
364 except Exception as e: 508 except Exception as e:
365 logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}") 509 logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}")
366 - return {  
367 - "success": False,  
368 - "error": str(e),  
369 - "results": []  
370 - } 510 + return {"success": False, "error": str(e), "results": []}
371 511
372 def research(self, query: str, save_report: bool = True) -> str: 512 def research(self, query: str, save_report: bool = True) -> str:
373 """ 513 """
@@ -380,9 +520,9 @@ class DeepSearchAgent: @@ -380,9 +520,9 @@ class DeepSearchAgent:
380 Returns: 520 Returns:
381 最终报告内容 521 最终报告内容
382 """ 522 """
383 - logger.info(f"\n{'='*60}") 523 + logger.info(f"\n{'=' * 60}")
384 logger.info(f"开始深度研究: {query}") 524 logger.info(f"开始深度研究: {query}")
385 - logger.info(f"{'='*60}") 525 + logger.info(f"{'=' * 60}")
386 526
387 try: 527 try:
388 # Step 1: 生成报告结构 528 # Step 1: 生成报告结构
@@ -426,7 +566,9 @@ class DeepSearchAgent: @@ -426,7 +566,9 @@ class DeepSearchAgent:
426 total_paragraphs = len(self.state.paragraphs) 566 total_paragraphs = len(self.state.paragraphs)
427 567
428 for i in range(total_paragraphs): 568 for i in range(total_paragraphs):
429 - logger.info(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}") 569 + logger.info(
  570 + f"\n[步骤 2.{i + 1}] 处理段落: {self.state.paragraphs[i].title}"
  571 + )
430 logger.info("-" * 50) 572 logger.info("-" * 50)
431 573
432 # 初始搜索和总结 574 # 初始搜索和总结
@@ -446,16 +588,15 @@ class DeepSearchAgent: @@ -446,16 +588,15 @@ class DeepSearchAgent:
446 paragraph = self.state.paragraphs[paragraph_index] 588 paragraph = self.state.paragraphs[paragraph_index]
447 589
448 # 准备搜索输入 590 # 准备搜索输入
449 - search_input = {  
450 - "title": paragraph.title,  
451 - "content": paragraph.content  
452 - } 591 + search_input = {"title": paragraph.title, "content": paragraph.content}
453 592
454 # 生成搜索查询和工具选择 593 # 生成搜索查询和工具选择
455 logger.info(" - 生成搜索查询...") 594 logger.info(" - 生成搜索查询...")
456 search_output = self.first_search_node.run(search_input) 595 search_output = self.first_search_node.run(search_input)
457 search_query = search_output["search_query"] 596 search_query = search_output["search_query"]
458 - search_tool = search_output.get("search_tool", "search_topic_globally") # 默认工具 597 + search_tool = search_output.get(
  598 + "search_tool", "search_topic_globally"
  599 + ) # 默认工具
459 reasoning = search_output["reasoning"] 600 reasoning = search_output["reasoning"]
460 601
461 logger.info(f" - 搜索查询: {search_query}") 602 logger.info(f" - 搜索查询: {search_query}")
@@ -475,13 +616,17 @@ class DeepSearchAgent: @@ -475,13 +616,17 @@ class DeepSearchAgent:
475 616
476 if start_date and end_date: 617 if start_date and end_date:
477 # 验证日期格式 618 # 验证日期格式
478 - if self._validate_date_format(start_date) and self._validate_date_format(end_date): 619 + if self._validate_date_format(
  620 + start_date
  621 + ) and self._validate_date_format(end_date):
479 search_kwargs["start_date"] = start_date 622 search_kwargs["start_date"] = start_date
480 search_kwargs["end_date"] = end_date 623 search_kwargs["end_date"] = end_date
481 logger.info(f" - 时间范围: {start_date} 到 {end_date}") 624 logger.info(f" - 时间范围: {start_date} 到 {end_date}")
482 else: 625 else:
483 logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索") 626 logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")
484 - logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}") 627 + logger.info(
  628 + f" 提供的日期: start_date={start_date}, end_date={end_date}"
  629 + )
485 search_tool = "search_topic_globally" 630 search_tool = "search_topic_globally"
486 elif search_tool == "search_topic_by_date": 631 elif search_tool == "search_topic_by_date":
487 logger.info(f" search_topic_by_date工具缺少时间参数,改用全局搜索") 632 logger.info(f" search_topic_by_date工具缺少时间参数,改用全局搜索")
@@ -494,7 +639,9 @@ class DeepSearchAgent: @@ -494,7 +639,9 @@ class DeepSearchAgent:
494 search_kwargs["platform"] = platform 639 search_kwargs["platform"] = platform
495 logger.info(f" - 指定平台: {platform}") 640 logger.info(f" - 指定平台: {platform}")
496 else: 641 else:
497 - logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索") 642 + logger.warning(
  643 + f" search_topic_on_platform工具缺少平台参数,改用全局搜索"
  644 + )
498 search_tool = "search_topic_globally" 645 search_tool = "search_topic_globally"
499 646
500 # 处理限制参数,使用配置文件中的默认值而不是agent提供的参数 647 # 处理限制参数,使用配置文件中的默认值而不是agent提供的参数
@@ -505,9 +652,13 @@ class DeepSearchAgent: @@ -505,9 +652,13 @@ class DeepSearchAgent:
505 search_kwargs["limit"] = limit 652 search_kwargs["limit"] = limit
506 elif search_tool in ["search_topic_globally", "search_topic_by_date"]: 653 elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
507 if search_tool == "search_topic_globally": 654 if search_tool == "search_topic_globally":
508 - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE 655 + limit_per_table = (
  656 + self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
  657 + )
509 else: # search_topic_by_date 658 else: # search_topic_by_date
510 - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE 659 + limit_per_table = (
  660 + self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
  661 + )
511 search_kwargs["limit_per_table"] = limit_per_table 662 search_kwargs["limit_per_table"] = limit_per_table
512 elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: 663 elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
513 if search_tool == "get_comments_for_topic": 664 if search_tool == "get_comments_for_topic":
@@ -516,34 +667,46 @@ class DeepSearchAgent: @@ -516,34 +667,46 @@ class DeepSearchAgent:
516 limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT 667 limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT
517 search_kwargs["limit"] = limit 668 search_kwargs["limit"] = limit
518 669
519 - search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) 670 + search_response = self.execute_search_tool(
  671 + search_tool, search_query, **search_kwargs
  672 + )
520 673
521 # 转换为兼容格式 674 # 转换为兼容格式
522 search_results = [] 675 search_results = []
523 if search_response and search_response.results: 676 if search_response and search_response.results:
524 # 使用配置文件控制传递给LLM的结果数量,0表示不限制 677 # 使用配置文件控制传递给LLM的结果数量,0表示不限制
525 if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: 678 if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0:
526 - max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM) 679 + max_results = min(
  680 + len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM
  681 + )
527 else: 682 else:
528 max_results = len(search_response.results) # 不限制,传递所有结果 683 max_results = len(search_response.results) # 不限制,传递所有结果
529 for result in search_response.results[:max_results]: 684 for result in search_response.results[:max_results]:
530 - search_results.append({  
531 - 'title': result.title_or_content,  
532 - 'url': result.url or "",  
533 - 'content': result.title_or_content,  
534 - 'score': result.hotness_score,  
535 - 'raw_content': result.title_or_content,  
536 - 'published_date': result.publish_time.isoformat() if result.publish_time else None,  
537 - 'platform': result.platform,  
538 - 'content_type': result.content_type,  
539 - 'author': result.author_nickname,  
540 - 'engagement': result.engagement  
541 - }) 685 + search_results.append(
  686 + {
  687 + "title": result.title_or_content,
  688 + "url": result.url or "",
  689 + "content": result.title_or_content,
  690 + "score": result.hotness_score,
  691 + "raw_content": result.title_or_content,
  692 + "published_date": result.publish_time.isoformat()
  693 + if result.publish_time
  694 + else None,
  695 + "platform": result.platform,
  696 + "content_type": result.content_type,
  697 + "author": result.author_nickname,
  698 + "engagement": result.engagement,
  699 + }
  700 + )
542 701
543 if search_results: 702 if search_results:
544 _message = f" - 找到 {len(search_results)} 个搜索结果" 703 _message = f" - 找到 {len(search_results)} 个搜索结果"
545 for j, result in enumerate(search_results, 1): 704 for j, result in enumerate(search_results, 1):
546 - date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" 705 + date_info = (
  706 + f" (发布于: {result.get('published_date', 'N/A')})"
  707 + if result.get("published_date")
  708 + else ""
  709 + )
547 _message += f"\n {j}. {result['title'][:50]}...{date_info}" 710 _message += f"\n {j}. {result['title'][:50]}...{date_info}"
548 logger.info(_message) 711 logger.info(_message)
549 else: 712 else:
@@ -560,7 +723,7 @@ class DeepSearchAgent: @@ -560,7 +723,7 @@ class DeepSearchAgent:
560 "search_query": search_query, 723 "search_query": search_query,
561 "search_results": format_search_results_for_prompt( 724 "search_results": format_search_results_for_prompt(
562 search_results, self.config.MAX_CONTENT_LENGTH 725 search_results, self.config.MAX_CONTENT_LENGTH
563 - ) 726 + ),
564 } 727 }
565 728
566 # 更新状态 729 # 更新状态
@@ -581,13 +744,15 @@ class DeepSearchAgent: @@ -581,13 +744,15 @@ class DeepSearchAgent:
581 reflection_input = { 744 reflection_input = {
582 "title": paragraph.title, 745 "title": paragraph.title,
583 "content": paragraph.content, 746 "content": paragraph.content,
584 - "paragraph_latest_state": paragraph.research.latest_summary 747 + "paragraph_latest_state": paragraph.research.latest_summary,
585 } 748 }
586 749
587 # 生成反思搜索查询 750 # 生成反思搜索查询
588 reflection_output = self.reflection_node.run(reflection_input) 751 reflection_output = self.reflection_node.run(reflection_input)
589 search_query = reflection_output["search_query"] 752 search_query = reflection_output["search_query"]
590 - search_tool = reflection_output.get("search_tool", "search_topic_globally") # 默认工具 753 + search_tool = reflection_output.get(
  754 + "search_tool", "search_topic_globally"
  755 + ) # 默认工具
591 reasoning = reflection_output["reasoning"] 756 reasoning = reflection_output["reasoning"]
592 757
593 logger.info(f" 反思查询: {search_query}") 758 logger.info(f" 反思查询: {search_query}")
@@ -605,16 +770,24 @@ class DeepSearchAgent: @@ -605,16 +770,24 @@ class DeepSearchAgent:
605 770
606 if start_date and end_date: 771 if start_date and end_date:
607 # 验证日期格式 772 # 验证日期格式
608 - if self._validate_date_format(start_date) and self._validate_date_format(end_date): 773 + if self._validate_date_format(
  774 + start_date
  775 + ) and self._validate_date_format(end_date):
609 search_kwargs["start_date"] = start_date 776 search_kwargs["start_date"] = start_date
610 search_kwargs["end_date"] = end_date 777 search_kwargs["end_date"] = end_date
611 logger.info(f" 时间范围: {start_date} 到 {end_date}") 778 logger.info(f" 时间范围: {start_date} 到 {end_date}")
612 else: 779 else:
613 - logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索")  
614 - logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}") 780 + logger.info(
  781 + f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索"
  782 + )
  783 + logger.info(
  784 + f" 提供的日期: start_date={start_date}, end_date={end_date}"
  785 + )
615 search_tool = "search_topic_globally" 786 search_tool = "search_topic_globally"
616 elif search_tool == "search_topic_by_date": 787 elif search_tool == "search_topic_by_date":
617 - logger.warning(f" search_topic_by_date工具缺少时间参数,改用全局搜索") 788 + logger.warning(
  789 + f" search_topic_by_date工具缺少时间参数,改用全局搜索"
  790 + )
618 search_tool = "search_topic_globally" 791 search_tool = "search_topic_globally"
619 792
620 # 处理需要平台参数的工具 793 # 处理需要平台参数的工具
@@ -624,7 +797,9 @@ class DeepSearchAgent: @@ -624,7 +797,9 @@ class DeepSearchAgent:
624 search_kwargs["platform"] = platform 797 search_kwargs["platform"] = platform
625 logger.info(f" 指定平台: {platform}") 798 logger.info(f" 指定平台: {platform}")
626 else: 799 else:
627 - logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索") 800 + logger.warning(
  801 + f" search_topic_on_platform工具缺少平台参数,改用全局搜索"
  802 + )
628 search_tool = "search_topic_globally" 803 search_tool = "search_topic_globally"
629 804
630 # 处理限制参数 805 # 处理限制参数
@@ -637,9 +812,13 @@ class DeepSearchAgent: @@ -637,9 +812,13 @@ class DeepSearchAgent:
637 elif search_tool in ["search_topic_globally", "search_topic_by_date"]: 812 elif search_tool in ["search_topic_globally", "search_topic_by_date"]:
638 # 使用配置文件中的默认值,不允许agent控制limit_per_table参数 813 # 使用配置文件中的默认值,不允许agent控制limit_per_table参数
639 if search_tool == "search_topic_globally": 814 if search_tool == "search_topic_globally":
640 - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE 815 + limit_per_table = (
  816 + self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE
  817 + )
641 else: # search_topic_by_date 818 else: # search_topic_by_date
642 - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE 819 + limit_per_table = (
  820 + self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE
  821 + )
643 search_kwargs["limit_per_table"] = limit_per_table 822 search_kwargs["limit_per_table"] = limit_per_table
644 elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: 823 elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]:
645 # 使用配置文件中的默认值,不允许agent控制limit参数 824 # 使用配置文件中的默认值,不允许agent控制limit参数
@@ -649,34 +828,47 @@ class DeepSearchAgent: @@ -649,34 +828,47 @@ class DeepSearchAgent:
649 limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT 828 limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT
650 search_kwargs["limit"] = limit 829 search_kwargs["limit"] = limit
651 830
652 - search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) 831 + search_response = self.execute_search_tool(
  832 + search_tool, search_query, **search_kwargs
  833 + )
653 834
654 # 转换为兼容格式 835 # 转换为兼容格式
655 search_results = [] 836 search_results = []
656 if search_response and search_response.results: 837 if search_response and search_response.results:
657 # 使用配置文件控制传递给LLM的结果数量,0表示不限制 838 # 使用配置文件控制传递给LLM的结果数量,0表示不限制
658 if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: 839 if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0:
659 - max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM) 840 + max_results = min(
  841 + len(search_response.results),
  842 + self.config.MAX_SEARCH_RESULTS_FOR_LLM,
  843 + )
660 else: 844 else:
661 max_results = len(search_response.results) # 不限制,传递所有结果 845 max_results = len(search_response.results) # 不限制,传递所有结果
662 for result in search_response.results[:max_results]: 846 for result in search_response.results[:max_results]:
663 - search_results.append({  
664 - 'title': result.title_or_content,  
665 - 'url': result.url or "",  
666 - 'content': result.title_or_content,  
667 - 'score': result.hotness_score,  
668 - 'raw_content': result.title_or_content,  
669 - 'published_date': result.publish_time.isoformat() if result.publish_time else None,  
670 - 'platform': result.platform,  
671 - 'content_type': result.content_type,  
672 - 'author': result.author_nickname,  
673 - 'engagement': result.engagement  
674 - }) 847 + search_results.append(
  848 + {
  849 + "title": result.title_or_content,
  850 + "url": result.url or "",
  851 + "content": result.title_or_content,
  852 + "score": result.hotness_score,
  853 + "raw_content": result.title_or_content,
  854 + "published_date": result.publish_time.isoformat()
  855 + if result.publish_time
  856 + else None,
  857 + "platform": result.platform,
  858 + "content_type": result.content_type,
  859 + "author": result.author_nickname,
  860 + "engagement": result.engagement,
  861 + }
  862 + )
675 863
676 if search_results: 864 if search_results:
677 _message = f" 找到 {len(search_results)} 个反思搜索结果" 865 _message = f" 找到 {len(search_results)} 个反思搜索结果"
678 for j, result in enumerate(search_results, 1): 866 for j, result in enumerate(search_results, 1):
679 - date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" 867 + date_info = (
  868 + f" (发布于: {result.get('published_date', 'N/A')})"
  869 + if result.get("published_date")
  870 + else ""
  871 + )
680 _message += f"\n {j}. {result['title'][:50]}...{date_info}" 872 _message += f"\n {j}. {result['title'][:50]}...{date_info}"
681 logger.info(_message) 873 logger.info(_message)
682 else: 874 else:
@@ -693,7 +885,7 @@ class DeepSearchAgent: @@ -693,7 +885,7 @@ class DeepSearchAgent:
693 "search_results": format_search_results_for_prompt( 885 "search_results": format_search_results_for_prompt(
694 search_results, self.config.MAX_CONTENT_LENGTH 886 search_results, self.config.MAX_CONTENT_LENGTH
695 ), 887 ),
696 - "paragraph_latest_state": paragraph.research.latest_summary 888 + "paragraph_latest_state": paragraph.research.latest_summary,
697 } 889 }
698 890
699 # 更新状态 891 # 更新状态
@@ -710,10 +902,12 @@ class DeepSearchAgent: @@ -710,10 +902,12 @@ class DeepSearchAgent:
710 # 准备报告数据 902 # 准备报告数据
711 report_data = [] 903 report_data = []
712 for paragraph in self.state.paragraphs: 904 for paragraph in self.state.paragraphs:
713 - report_data.append({ 905 + report_data.append(
  906 + {
714 "title": paragraph.title, 907 "title": paragraph.title,
715 - "paragraph_latest_state": paragraph.research.latest_summary  
716 - }) 908 + "paragraph_latest_state": paragraph.research.latest_summary,
  909 + }
  910 + )
717 911
718 # 格式化报告 912 # 格式化报告
719 try: 913 try:
@@ -735,14 +929,16 @@ class DeepSearchAgent: @@ -735,14 +929,16 @@ class DeepSearchAgent:
735 """保存报告到文件""" 929 """保存报告到文件"""
736 # 生成文件名 930 # 生成文件名
737 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 931 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
738 - query_safe = "".join(c for c in self.state.query if c.isalnum() or c in (' ', '-', '_')).rstrip()  
739 - query_safe = query_safe.replace(' ', '_')[:30] 932 + query_safe = "".join(
  933 + c for c in self.state.query if c.isalnum() or c in (" ", "-", "_")
  934 + ).rstrip()
  935 + query_safe = query_safe.replace(" ", "_")[:30]
740 936
741 filename = f"deep_search_report_{query_safe}_{timestamp}.md" 937 filename = f"deep_search_report_{query_safe}_{timestamp}.md"
742 filepath = os.path.join(self.config.OUTPUT_DIR, filename) 938 filepath = os.path.join(self.config.OUTPUT_DIR, filename)
743 939
744 # 保存报告 940 # 保存报告
745 - with open(filepath, 'w', encoding='utf-8') as f: 941 + with open(filepath, "w", encoding="utf-8") as f:
746 f.write(report_content) 942 f.write(report_content)
747 943
748 logger.info(f"报告已保存到: {filepath}") 944 logger.info(f"报告已保存到: {filepath}")
@@ -61,6 +61,7 @@ weasyprint>=60.0 # PDF导出,支持Python 3.9-3.13 @@ -61,6 +61,7 @@ weasyprint>=60.0 # PDF导出,支持Python 3.9-3.13
61 # ===== 机器学习(可选,用于情感分析,不安装也没事写了容错程序) ===== 61 # ===== 机器学习(可选,用于情感分析,不安装也没事写了容错程序) =====
62 torch>=2.0.0 # CPU版本 62 torch>=2.0.0 # CPU版本
63 transformers>=4.30.0 63 transformers>=4.30.0
  64 +sentence-transformers>=2.2.2
64 scikit-learn>=1.3.0 65 scikit-learn>=1.3.0
65 xgboost>=2.0.0 66 xgboost>=2.0.0
66 # NOTE:如果要安装GPU版本的torch,指令为pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126 67 # NOTE:如果要安装GPU版本的torch,指令为pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126