Committed by
GitHub
Merge pull request #350 from 666ghj/feature/insight_agent_cluster
feat(insight_agent): search results cluster
Showing
2 changed files
with
305 additions
and
108 deletions
| @@ -7,22 +7,35 @@ import json | @@ -7,22 +7,35 @@ import json | ||
| 7 | import os | 7 | import os |
| 8 | import re | 8 | import re |
| 9 | from datetime import datetime | 9 | from datetime import datetime |
| 10 | -from typing import Optional, Dict, Any, List, Union | 10 | +from typing import Any, Dict, List, Optional, Union |
| 11 | + | ||
| 12 | +import numpy as np | ||
| 11 | from loguru import logger | 13 | from loguru import logger |
| 14 | +from sentence_transformers import SentenceTransformer | ||
| 15 | +from sklearn.cluster import KMeans | ||
| 12 | 16 | ||
| 13 | from .llms import LLMClient | 17 | from .llms import LLMClient |
| 14 | from .nodes import ( | 18 | from .nodes import ( |
| 15 | - ReportStructureNode, | ||
| 16 | FirstSearchNode, | 19 | FirstSearchNode, |
| 17 | - ReflectionNode, | ||
| 18 | FirstSummaryNode, | 20 | FirstSummaryNode, |
| 21 | + ReflectionNode, | ||
| 19 | ReflectionSummaryNode, | 22 | ReflectionSummaryNode, |
| 20 | - ReportFormattingNode | 23 | + ReportFormattingNode, |
| 24 | + ReportStructureNode, | ||
| 21 | ) | 25 | ) |
| 22 | from .state import State | 26 | from .state import State |
| 23 | -from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer, multilingual_sentiment_analyzer | ||
| 24 | -from .utils.config import settings, Settings | 27 | +from .tools import ( |
| 28 | + DBResponse, | ||
| 29 | + MediaCrawlerDB, | ||
| 30 | + keyword_optimizer, | ||
| 31 | + multilingual_sentiment_analyzer, | ||
| 32 | +) | ||
| 25 | from .utils import format_search_results_for_prompt | 33 | from .utils import format_search_results_for_prompt |
| 34 | +from .utils.config import Settings, settings | ||
| 35 | + | ||
| 36 | +ENABLE_CLUSTERING: bool = True # 是否启用聚类采样 | ||
| 37 | +MAX_CLUSTERED_RESULTS: int = 50 # 聚类后最大返回结果数 | ||
| 38 | +RESULTS_PER_CLUSTER: int = 5 # 每个聚类返回的结果数 | ||
| 26 | 39 | ||
| 27 | 40 | ||
| 28 | class DeepSearchAgent: | 41 | class DeepSearchAgent: |
| @@ -40,10 +53,12 @@ class DeepSearchAgent: | @@ -40,10 +53,12 @@ class DeepSearchAgent: | ||
| 40 | # 初始化LLM客户端 | 53 | # 初始化LLM客户端 |
| 41 | self.llm_client = self._initialize_llm() | 54 | self.llm_client = self._initialize_llm() |
| 42 | 55 | ||
| 43 | - | ||
| 44 | # 初始化搜索工具集 | 56 | # 初始化搜索工具集 |
| 45 | self.search_agency = MediaCrawlerDB() | 57 | self.search_agency = MediaCrawlerDB() |
| 46 | 58 | ||
| 59 | + # 初始化聚类小模型(懒加载) | ||
| 60 | + self._clustering_model = None | ||
| 61 | + | ||
| 47 | # 初始化情感分析器 | 62 | # 初始化情感分析器 |
| 48 | self.sentiment_analyzer = multilingual_sentiment_analyzer | 63 | self.sentiment_analyzer = multilingual_sentiment_analyzer |
| 49 | 64 | ||
| @@ -77,6 +92,15 @@ class DeepSearchAgent: | @@ -77,6 +92,15 @@ class DeepSearchAgent: | ||
| 77 | self.reflection_summary_node = ReflectionSummaryNode(self.llm_client) | 92 | self.reflection_summary_node = ReflectionSummaryNode(self.llm_client) |
| 78 | self.report_formatting_node = ReportFormattingNode(self.llm_client) | 93 | self.report_formatting_node = ReportFormattingNode(self.llm_client) |
| 79 | 94 | ||
| 95 | + def _get_clustering_model(self): | ||
| 96 | + """懒加载聚类模型""" | ||
| 97 | + if self._clustering_model is None: | ||
| 98 | + logger.info(" 加载聚类模型 (paraphrase-multilingual-MiniLM-L12-v2)...") | ||
| 99 | + self._clustering_model = SentenceTransformer( | ||
| 100 | + "paraphrase-multilingual-MiniLM-L12-v2" | ||
| 101 | + ) | ||
| 102 | + return self._clustering_model | ||
| 103 | + | ||
| 80 | def _validate_date_format(self, date_str: str) -> bool: | 104 | def _validate_date_format(self, date_str: str) -> bool: |
| 81 | """ | 105 | """ |
| 82 | 验证日期格式是否为YYYY-MM-DD | 106 | 验证日期格式是否为YYYY-MM-DD |
| @@ -91,17 +115,78 @@ class DeepSearchAgent: | @@ -91,17 +115,78 @@ class DeepSearchAgent: | ||
| 91 | return False | 115 | return False |
| 92 | 116 | ||
| 93 | # 检查格式 | 117 | # 检查格式 |
| 94 | - pattern = r'^\d{4}-\d{2}-\d{2}$' | 118 | + pattern = r"^\d{4}-\d{2}-\d{2}$" |
| 95 | if not re.match(pattern, date_str): | 119 | if not re.match(pattern, date_str): |
| 96 | return False | 120 | return False |
| 97 | 121 | ||
| 98 | # 检查日期是否有效 | 122 | # 检查日期是否有效 |
| 99 | try: | 123 | try: |
| 100 | - datetime.strptime(date_str, '%Y-%m-%d') | 124 | + datetime.strptime(date_str, "%Y-%m-%d") |
| 101 | return True | 125 | return True |
| 102 | except ValueError: | 126 | except ValueError: |
| 103 | return False | 127 | return False |
| 104 | 128 | ||
| 129 | + def _cluster_and_sample_results( | ||
| 130 | + self, | ||
| 131 | + results: List, | ||
| 132 | + max_results: int = MAX_CLUSTERED_RESULTS, | ||
| 133 | + results_per_cluster: int = RESULTS_PER_CLUSTER, | ||
| 134 | + ) -> List: | ||
| 135 | + """ | ||
| 136 | + 对搜索结果进行聚类并采样 | ||
| 137 | + | ||
| 138 | + Args: | ||
| 139 | + results: 搜索结果列表 | ||
| 140 | + max_results: 最大返回结果数 | ||
| 141 | + results_per_cluster: 每个聚类返回的结果数 | ||
| 142 | + | ||
| 143 | + Returns: | ||
| 144 | + 采样后的结果列表 | ||
| 145 | + """ | ||
| 146 | + if len(results) <= max_results: | ||
| 147 | + return results | ||
| 148 | + | ||
| 149 | + try: | ||
| 150 | + # 提取文本 | ||
| 151 | + texts = [r.title_or_content[:500] for r in results] | ||
| 152 | + | ||
| 153 | + # 获取模型并编码 | ||
| 154 | + model = self._get_clustering_model() | ||
| 155 | + embeddings = model.encode(texts, show_progress_bar=False) | ||
| 156 | + | ||
| 157 | + # 计算聚类数 | ||
| 158 | + n_clusters = min(max(2, max_results // results_per_cluster), len(results)) | ||
| 159 | + | ||
| 160 | + # KMeans聚类 | ||
| 161 | + kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) | ||
| 162 | + labels = kmeans.fit_predict(embeddings) | ||
| 163 | + | ||
| 164 | + # 从每个聚类采样 | ||
| 165 | + sampled_results = [] | ||
| 166 | + for cluster_id in range(n_clusters): | ||
| 167 | + cluster_indices = np.flatnonzero(labels == cluster_id) | ||
| 168 | + cluster_results = [(results[i], i) for i in cluster_indices] | ||
| 169 | + cluster_results.sort( | ||
| 170 | + key=lambda x: x[0].hotness_score or 0, reverse=True | ||
| 171 | + ) | ||
| 172 | + | ||
| 173 | + for result, _ in cluster_results[:results_per_cluster]: | ||
| 174 | + sampled_results.append(result) | ||
| 175 | + if len(sampled_results) >= max_results: | ||
| 176 | + break | ||
| 177 | + | ||
| 178 | + if len(sampled_results) >= max_results: | ||
| 179 | + break | ||
| 180 | + | ||
| 181 | + logger.info( | ||
| 182 | + f" 聚类完成: {len(results)} 条 -> {n_clusters} 个主题 -> {len(sampled_results)} 条代表性结果" | ||
| 183 | + ) | ||
| 184 | + return sampled_results | ||
| 185 | + | ||
| 186 | + except Exception as e: | ||
| 187 | + logger.warning(f" 聚类失败,返回前{max_results}条: {str(e)}") | ||
| 188 | + return results[:max_results] | ||
| 189 | + | ||
| 105 | def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse: | 190 | def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse: |
| 106 | """ | 191 | """ |
| 107 | 执行指定的数据库查询工具(集成关键词优化中间件和情感分析) | 192 | 执行指定的数据库查询工具(集成关键词优化中间件和情感分析) |
| @@ -127,7 +212,9 @@ class DeepSearchAgent: | @@ -127,7 +212,9 @@ class DeepSearchAgent: | ||
| 127 | if tool_name == "search_hot_content": | 212 | if tool_name == "search_hot_content": |
| 128 | time_period = kwargs.get("time_period", "week") | 213 | time_period = kwargs.get("time_period", "week") |
| 129 | limit = kwargs.get("limit", 100) | 214 | limit = kwargs.get("limit", 100) |
| 130 | - response = self.search_agency.search_hot_content(time_period=time_period, limit=limit) | 215 | + response = self.search_agency.search_hot_content( |
| 216 | + time_period=time_period, limit=limit | ||
| 217 | + ) | ||
| 131 | 218 | ||
| 132 | # 检查是否需要进行情感分析 | 219 | # 检查是否需要进行情感分析 |
| 133 | enable_sentiment = kwargs.get("enable_sentiment", True) | 220 | enable_sentiment = kwargs.get("enable_sentiment", True) |
| @@ -151,17 +238,16 @@ class DeepSearchAgent: | @@ -151,17 +238,16 @@ class DeepSearchAgent: | ||
| 151 | tool_name="analyze_sentiment", | 238 | tool_name="analyze_sentiment", |
| 152 | parameters={ | 239 | parameters={ |
| 153 | "texts": texts if isinstance(texts, list) else [texts], | 240 | "texts": texts if isinstance(texts, list) else [texts], |
| 154 | - **kwargs | 241 | + **kwargs, |
| 155 | }, | 242 | }, |
| 156 | results=[], # 情感分析不返回搜索结果 | 243 | results=[], # 情感分析不返回搜索结果 |
| 157 | results_count=0, | 244 | results_count=0, |
| 158 | - metadata=sentiment_result | 245 | + metadata=sentiment_result, |
| 159 | ) | 246 | ) |
| 160 | 247 | ||
| 161 | # 对于需要搜索词的工具,使用关键词优化中间件 | 248 | # 对于需要搜索词的工具,使用关键词优化中间件 |
| 162 | optimized_response = keyword_optimizer.optimize_keywords( | 249 | optimized_response = keyword_optimizer.optimize_keywords( |
| 163 | - original_query=query, | ||
| 164 | - context=f"使用{tool_name}工具进行查询" | 250 | + original_query=query, context=f"使用{tool_name}工具进行查询" |
| 165 | ) | 251 | ) |
| 166 | 252 | ||
| 167 | logger.info(f" 🔍 原始查询: '{query}'") | 253 | logger.info(f" 🔍 原始查询: '{query}'") |
| @@ -177,34 +263,62 @@ class DeepSearchAgent: | @@ -177,34 +263,62 @@ class DeepSearchAgent: | ||
| 177 | try: | 263 | try: |
| 178 | if tool_name == "search_topic_globally": | 264 | if tool_name == "search_topic_globally": |
| 179 | # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 | 265 | # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 |
| 180 | - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE | ||
| 181 | - response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=limit_per_table) | 266 | + limit_per_table = ( |
| 267 | + self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE | ||
| 268 | + ) | ||
| 269 | + response = self.search_agency.search_topic_globally( | ||
| 270 | + topic=keyword, limit_per_table=limit_per_table | ||
| 271 | + ) | ||
| 182 | elif tool_name == "search_topic_by_date": | 272 | elif tool_name == "search_topic_by_date": |
| 183 | start_date = kwargs.get("start_date") | 273 | start_date = kwargs.get("start_date") |
| 184 | end_date = kwargs.get("end_date") | 274 | end_date = kwargs.get("end_date") |
| 185 | # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 | 275 | # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 |
| 186 | - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE | 276 | + limit_per_table = ( |
| 277 | + self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE | ||
| 278 | + ) | ||
| 187 | if not start_date or not end_date: | 279 | if not start_date or not end_date: |
| 188 | - raise ValueError("search_topic_by_date工具需要start_date和end_date参数") | ||
| 189 | - response = self.search_agency.search_topic_by_date(topic=keyword, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table) | 280 | + raise ValueError( |
| 281 | + "search_topic_by_date工具需要start_date和end_date参数" | ||
| 282 | + ) | ||
| 283 | + response = self.search_agency.search_topic_by_date( | ||
| 284 | + topic=keyword, | ||
| 285 | + start_date=start_date, | ||
| 286 | + end_date=end_date, | ||
| 287 | + limit_per_table=limit_per_table, | ||
| 288 | + ) | ||
| 190 | elif tool_name == "get_comments_for_topic": | 289 | elif tool_name == "get_comments_for_topic": |
| 191 | # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 | 290 | # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 |
| 192 | - limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT // len(optimized_response.optimized_keywords) | 291 | + limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT // len( |
| 292 | + optimized_response.optimized_keywords | ||
| 293 | + ) | ||
| 193 | limit = max(limit, 50) | 294 | limit = max(limit, 50) |
| 194 | - response = self.search_agency.get_comments_for_topic(topic=keyword, limit=limit) | 295 | + response = self.search_agency.get_comments_for_topic( |
| 296 | + topic=keyword, limit=limit | ||
| 297 | + ) | ||
| 195 | elif tool_name == "search_topic_on_platform": | 298 | elif tool_name == "search_topic_on_platform": |
| 196 | platform = kwargs.get("platform") | 299 | platform = kwargs.get("platform") |
| 197 | start_date = kwargs.get("start_date") | 300 | start_date = kwargs.get("start_date") |
| 198 | end_date = kwargs.get("end_date") | 301 | end_date = kwargs.get("end_date") |
| 199 | # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 | 302 | # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 |
| 200 | - limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT // len(optimized_response.optimized_keywords) | 303 | + limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT // len( |
| 304 | + optimized_response.optimized_keywords | ||
| 305 | + ) | ||
| 201 | limit = max(limit, 30) | 306 | limit = max(limit, 30) |
| 202 | if not platform: | 307 | if not platform: |
| 203 | raise ValueError("search_topic_on_platform工具需要platform参数") | 308 | raise ValueError("search_topic_on_platform工具需要platform参数") |
| 204 | - response = self.search_agency.search_topic_on_platform(platform=platform, topic=keyword, start_date=start_date, end_date=end_date, limit=limit) | 309 | + response = self.search_agency.search_topic_on_platform( |
| 310 | + platform=platform, | ||
| 311 | + topic=keyword, | ||
| 312 | + start_date=start_date, | ||
| 313 | + end_date=end_date, | ||
| 314 | + limit=limit, | ||
| 315 | + ) | ||
| 205 | else: | 316 | else: |
| 206 | logger.info(f" 未知的搜索工具: {tool_name},使用默认全局搜索") | 317 | logger.info(f" 未知的搜索工具: {tool_name},使用默认全局搜索") |
| 207 | - response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE) | 318 | + response = self.search_agency.search_topic_globally( |
| 319 | + topic=keyword, | ||
| 320 | + limit_per_table=self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE, | ||
| 321 | + ) | ||
| 208 | 322 | ||
| 209 | # 收集结果 | 323 | # 收集结果 |
| 210 | if response.results: | 324 | if response.results: |
| @@ -222,6 +336,13 @@ class DeepSearchAgent: | @@ -222,6 +336,13 @@ class DeepSearchAgent: | ||
| 222 | unique_results = self._deduplicate_results(all_results) | 336 | unique_results = self._deduplicate_results(all_results) |
| 223 | logger.info(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)} 条") | 337 | logger.info(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)} 条") |
| 224 | 338 | ||
| 339 | + if ENABLE_CLUSTERING: | ||
| 340 | + unique_results = self._cluster_and_sample_results( | ||
| 341 | + unique_results, | ||
| 342 | + max_results=MAX_CLUSTERED_RESULTS, | ||
| 343 | + results_per_cluster=RESULTS_PER_CLUSTER, | ||
| 344 | + ) | ||
| 345 | + | ||
| 225 | # 构建整合后的响应 | 346 | # 构建整合后的响应 |
| 226 | integrated_response = DBResponse( | 347 | integrated_response = DBResponse( |
| 227 | tool_name=f"{tool_name}_optimized", | 348 | tool_name=f"{tool_name}_optimized", |
| @@ -229,10 +350,10 @@ class DeepSearchAgent: | @@ -229,10 +350,10 @@ class DeepSearchAgent: | ||
| 229 | "original_query": query, | 350 | "original_query": query, |
| 230 | "optimized_keywords": optimized_response.optimized_keywords, | 351 | "optimized_keywords": optimized_response.optimized_keywords, |
| 231 | "optimization_reasoning": optimized_response.reasoning, | 352 | "optimization_reasoning": optimized_response.reasoning, |
| 232 | - **kwargs | 353 | + **kwargs, |
| 233 | }, | 354 | }, |
| 234 | results=unique_results, | 355 | results=unique_results, |
| 235 | - results_count=len(unique_results) | 356 | + results_count=len(unique_results), |
| 236 | ) | 357 | ) |
| 237 | 358 | ||
| 238 | # 检查是否需要进行情感分析 | 359 | # 检查是否需要进行情感分析 |
| @@ -242,7 +363,9 @@ class DeepSearchAgent: | @@ -242,7 +363,9 @@ class DeepSearchAgent: | ||
| 242 | sentiment_analysis = self._perform_sentiment_analysis(unique_results) | 363 | sentiment_analysis = self._perform_sentiment_analysis(unique_results) |
| 243 | if sentiment_analysis: | 364 | if sentiment_analysis: |
| 244 | # 将情感分析结果添加到响应的parameters中 | 365 | # 将情感分析结果添加到响应的parameters中 |
| 245 | - integrated_response.parameters["sentiment_analysis"] = sentiment_analysis | 366 | + integrated_response.parameters["sentiment_analysis"] = ( |
| 367 | + sentiment_analysis | ||
| 368 | + ) | ||
| 246 | logger.info(f" ✅ 情感分析完成") | 369 | logger.info(f" ✅ 情感分析完成") |
| 247 | 370 | ||
| 248 | return integrated_response | 371 | return integrated_response |
| @@ -275,7 +398,10 @@ class DeepSearchAgent: | @@ -275,7 +398,10 @@ class DeepSearchAgent: | ||
| 275 | """ | 398 | """ |
| 276 | try: | 399 | try: |
| 277 | # 初始化情感分析器(如果尚未初始化且未被禁用) | 400 | # 初始化情感分析器(如果尚未初始化且未被禁用) |
| 278 | - if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled: | 401 | + if ( |
| 402 | + not self.sentiment_analyzer.is_initialized | ||
| 403 | + and not self.sentiment_analyzer.is_disabled | ||
| 404 | + ): | ||
| 279 | logger.info(" 初始化情感分析模型...") | 405 | logger.info(" 初始化情感分析模型...") |
| 280 | if not self.sentiment_analyzer.initialize(): | 406 | if not self.sentiment_analyzer.initialize(): |
| 281 | logger.info(" 情感分析模型初始化失败,将直接透传原始文本") | 407 | logger.info(" 情感分析模型初始化失败,将直接透传原始文本") |
| @@ -290,15 +416,15 @@ class DeepSearchAgent: | @@ -290,15 +416,15 @@ class DeepSearchAgent: | ||
| 290 | "platform": result.platform, | 416 | "platform": result.platform, |
| 291 | "author": result.author_nickname, | 417 | "author": result.author_nickname, |
| 292 | "url": result.url, | 418 | "url": result.url, |
| 293 | - "publish_time": str(result.publish_time) if result.publish_time else None | 419 | + "publish_time": str(result.publish_time) |
| 420 | + if result.publish_time | ||
| 421 | + else None, | ||
| 294 | } | 422 | } |
| 295 | results_dict.append(result_dict) | 423 | results_dict.append(result_dict) |
| 296 | 424 | ||
| 297 | # 执行情感分析 | 425 | # 执行情感分析 |
| 298 | sentiment_analysis = self.sentiment_analyzer.analyze_query_results( | 426 | sentiment_analysis = self.sentiment_analyzer.analyze_query_results( |
| 299 | - query_results=results_dict, | ||
| 300 | - text_field="content", | ||
| 301 | - min_confidence=0.5 | 427 | + query_results=results_dict, text_field="content", min_confidence=0.5 |
| 302 | ) | 428 | ) |
| 303 | 429 | ||
| 304 | return sentiment_analysis.get("sentiment_analysis") | 430 | return sentiment_analysis.get("sentiment_analysis") |
| @@ -321,7 +447,10 @@ class DeepSearchAgent: | @@ -321,7 +447,10 @@ class DeepSearchAgent: | ||
| 321 | 447 | ||
| 322 | try: | 448 | try: |
| 323 | # 初始化情感分析器(如果尚未初始化且未被禁用) | 449 | # 初始化情感分析器(如果尚未初始化且未被禁用) |
| 324 | - if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled: | 450 | + if ( |
| 451 | + not self.sentiment_analyzer.is_initialized | ||
| 452 | + and not self.sentiment_analyzer.is_disabled | ||
| 453 | + ): | ||
| 325 | logger.info(" 初始化情感分析模型...") | 454 | logger.info(" 初始化情感分析模型...") |
| 326 | if not self.sentiment_analyzer.initialize(): | 455 | if not self.sentiment_analyzer.initialize(): |
| 327 | logger.info(" 情感分析模型初始化失败,将直接透传原始文本") | 456 | logger.info(" 情感分析模型初始化失败,将直接透传原始文本") |
| @@ -334,28 +463,43 @@ class DeepSearchAgent: | @@ -334,28 +463,43 @@ class DeepSearchAgent: | ||
| 334 | result_dict = result.__dict__ | 463 | result_dict = result.__dict__ |
| 335 | response = { | 464 | response = { |
| 336 | "success": result.success and result.analysis_performed, | 465 | "success": result.success and result.analysis_performed, |
| 337 | - "total_analyzed": 1 if result.analysis_performed and result.success else 0, | ||
| 338 | - "results": [result_dict] | 466 | + "total_analyzed": 1 |
| 467 | + if result.analysis_performed and result.success | ||
| 468 | + else 0, | ||
| 469 | + "results": [result_dict], | ||
| 339 | } | 470 | } |
| 340 | if not result.analysis_performed: | 471 | if not result.analysis_performed: |
| 341 | response["success"] = False | 472 | response["success"] = False |
| 342 | - response["warning"] = result.error_message or "情感分析功能不可用,已直接返回原始文本" | 473 | + response["warning"] = ( |
| 474 | + result.error_message or "情感分析功能不可用,已直接返回原始文本" | ||
| 475 | + ) | ||
| 343 | return response | 476 | return response |
| 344 | else: | 477 | else: |
| 345 | texts_list = list(texts) | 478 | texts_list = list(texts) |
| 346 | - batch_result = self.sentiment_analyzer.analyze_batch(texts_list, show_progress=True) | 479 | + batch_result = self.sentiment_analyzer.analyze_batch( |
| 480 | + texts_list, show_progress=True | ||
| 481 | + ) | ||
| 347 | response = { | 482 | response = { |
| 348 | - "success": batch_result.analysis_performed and batch_result.success_count > 0, | ||
| 349 | - "total_analyzed": batch_result.total_processed if batch_result.analysis_performed else 0, | 483 | + "success": batch_result.analysis_performed |
| 484 | + and batch_result.success_count > 0, | ||
| 485 | + "total_analyzed": batch_result.total_processed | ||
| 486 | + if batch_result.analysis_performed | ||
| 487 | + else 0, | ||
| 350 | "success_count": batch_result.success_count, | 488 | "success_count": batch_result.success_count, |
| 351 | "failed_count": batch_result.failed_count, | 489 | "failed_count": batch_result.failed_count, |
| 352 | - "average_confidence": batch_result.average_confidence if batch_result.analysis_performed else 0.0, | ||
| 353 | - "results": [result.__dict__ for result in batch_result.results] | 490 | + "average_confidence": batch_result.average_confidence |
| 491 | + if batch_result.analysis_performed | ||
| 492 | + else 0.0, | ||
| 493 | + "results": [result.__dict__ for result in batch_result.results], | ||
| 354 | } | 494 | } |
| 355 | if not batch_result.analysis_performed: | 495 | if not batch_result.analysis_performed: |
| 356 | warning = next( | 496 | warning = next( |
| 357 | - (r.error_message for r in batch_result.results if r.error_message), | ||
| 358 | - "情感分析功能不可用,已直接返回原始文本" | 497 | + ( |
| 498 | + r.error_message | ||
| 499 | + for r in batch_result.results | ||
| 500 | + if r.error_message | ||
| 501 | + ), | ||
| 502 | + "情感分析功能不可用,已直接返回原始文本", | ||
| 359 | ) | 503 | ) |
| 360 | response["success"] = False | 504 | response["success"] = False |
| 361 | response["warning"] = warning | 505 | response["warning"] = warning |
| @@ -363,11 +507,7 @@ class DeepSearchAgent: | @@ -363,11 +507,7 @@ class DeepSearchAgent: | ||
| 363 | 507 | ||
| 364 | except Exception as e: | 508 | except Exception as e: |
| 365 | logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}") | 509 | logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}") |
| 366 | - return { | ||
| 367 | - "success": False, | ||
| 368 | - "error": str(e), | ||
| 369 | - "results": [] | ||
| 370 | - } | 510 | + return {"success": False, "error": str(e), "results": []} |
| 371 | 511 | ||
| 372 | def research(self, query: str, save_report: bool = True) -> str: | 512 | def research(self, query: str, save_report: bool = True) -> str: |
| 373 | """ | 513 | """ |
| @@ -380,9 +520,9 @@ class DeepSearchAgent: | @@ -380,9 +520,9 @@ class DeepSearchAgent: | ||
| 380 | Returns: | 520 | Returns: |
| 381 | 最终报告内容 | 521 | 最终报告内容 |
| 382 | """ | 522 | """ |
| 383 | - logger.info(f"\n{'='*60}") | 523 | + logger.info(f"\n{'=' * 60}") |
| 384 | logger.info(f"开始深度研究: {query}") | 524 | logger.info(f"开始深度研究: {query}") |
| 385 | - logger.info(f"{'='*60}") | 525 | + logger.info(f"{'=' * 60}") |
| 386 | 526 | ||
| 387 | try: | 527 | try: |
| 388 | # Step 1: 生成报告结构 | 528 | # Step 1: 生成报告结构 |
| @@ -426,7 +566,9 @@ class DeepSearchAgent: | @@ -426,7 +566,9 @@ class DeepSearchAgent: | ||
| 426 | total_paragraphs = len(self.state.paragraphs) | 566 | total_paragraphs = len(self.state.paragraphs) |
| 427 | 567 | ||
| 428 | for i in range(total_paragraphs): | 568 | for i in range(total_paragraphs): |
| 429 | - logger.info(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}") | 569 | + logger.info( |
| 570 | + f"\n[步骤 2.{i + 1}] 处理段落: {self.state.paragraphs[i].title}" | ||
| 571 | + ) | ||
| 430 | logger.info("-" * 50) | 572 | logger.info("-" * 50) |
| 431 | 573 | ||
| 432 | # 初始搜索和总结 | 574 | # 初始搜索和总结 |
| @@ -446,16 +588,15 @@ class DeepSearchAgent: | @@ -446,16 +588,15 @@ class DeepSearchAgent: | ||
| 446 | paragraph = self.state.paragraphs[paragraph_index] | 588 | paragraph = self.state.paragraphs[paragraph_index] |
| 447 | 589 | ||
| 448 | # 准备搜索输入 | 590 | # 准备搜索输入 |
| 449 | - search_input = { | ||
| 450 | - "title": paragraph.title, | ||
| 451 | - "content": paragraph.content | ||
| 452 | - } | 591 | + search_input = {"title": paragraph.title, "content": paragraph.content} |
| 453 | 592 | ||
| 454 | # 生成搜索查询和工具选择 | 593 | # 生成搜索查询和工具选择 |
| 455 | logger.info(" - 生成搜索查询...") | 594 | logger.info(" - 生成搜索查询...") |
| 456 | search_output = self.first_search_node.run(search_input) | 595 | search_output = self.first_search_node.run(search_input) |
| 457 | search_query = search_output["search_query"] | 596 | search_query = search_output["search_query"] |
| 458 | - search_tool = search_output.get("search_tool", "search_topic_globally") # 默认工具 | 597 | + search_tool = search_output.get( |
| 598 | + "search_tool", "search_topic_globally" | ||
| 599 | + ) # 默认工具 | ||
| 459 | reasoning = search_output["reasoning"] | 600 | reasoning = search_output["reasoning"] |
| 460 | 601 | ||
| 461 | logger.info(f" - 搜索查询: {search_query}") | 602 | logger.info(f" - 搜索查询: {search_query}") |
| @@ -475,13 +616,17 @@ class DeepSearchAgent: | @@ -475,13 +616,17 @@ class DeepSearchAgent: | ||
| 475 | 616 | ||
| 476 | if start_date and end_date: | 617 | if start_date and end_date: |
| 477 | # 验证日期格式 | 618 | # 验证日期格式 |
| 478 | - if self._validate_date_format(start_date) and self._validate_date_format(end_date): | 619 | + if self._validate_date_format( |
| 620 | + start_date | ||
| 621 | + ) and self._validate_date_format(end_date): | ||
| 479 | search_kwargs["start_date"] = start_date | 622 | search_kwargs["start_date"] = start_date |
| 480 | search_kwargs["end_date"] = end_date | 623 | search_kwargs["end_date"] = end_date |
| 481 | logger.info(f" - 时间范围: {start_date} 到 {end_date}") | 624 | logger.info(f" - 时间范围: {start_date} 到 {end_date}") |
| 482 | else: | 625 | else: |
| 483 | logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索") | 626 | logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索") |
| 484 | - logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}") | 627 | + logger.info( |
| 628 | + f" 提供的日期: start_date={start_date}, end_date={end_date}" | ||
| 629 | + ) | ||
| 485 | search_tool = "search_topic_globally" | 630 | search_tool = "search_topic_globally" |
| 486 | elif search_tool == "search_topic_by_date": | 631 | elif search_tool == "search_topic_by_date": |
| 487 | logger.info(f" search_topic_by_date工具缺少时间参数,改用全局搜索") | 632 | logger.info(f" search_topic_by_date工具缺少时间参数,改用全局搜索") |
| @@ -494,7 +639,9 @@ class DeepSearchAgent: | @@ -494,7 +639,9 @@ class DeepSearchAgent: | ||
| 494 | search_kwargs["platform"] = platform | 639 | search_kwargs["platform"] = platform |
| 495 | logger.info(f" - 指定平台: {platform}") | 640 | logger.info(f" - 指定平台: {platform}") |
| 496 | else: | 641 | else: |
| 497 | - logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索") | 642 | + logger.warning( |
| 643 | + f" search_topic_on_platform工具缺少平台参数,改用全局搜索" | ||
| 644 | + ) | ||
| 498 | search_tool = "search_topic_globally" | 645 | search_tool = "search_topic_globally" |
| 499 | 646 | ||
| 500 | # 处理限制参数,使用配置文件中的默认值而不是agent提供的参数 | 647 | # 处理限制参数,使用配置文件中的默认值而不是agent提供的参数 |
| @@ -505,9 +652,13 @@ class DeepSearchAgent: | @@ -505,9 +652,13 @@ class DeepSearchAgent: | ||
| 505 | search_kwargs["limit"] = limit | 652 | search_kwargs["limit"] = limit |
| 506 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: | 653 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: |
| 507 | if search_tool == "search_topic_globally": | 654 | if search_tool == "search_topic_globally": |
| 508 | - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE | 655 | + limit_per_table = ( |
| 656 | + self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE | ||
| 657 | + ) | ||
| 509 | else: # search_topic_by_date | 658 | else: # search_topic_by_date |
| 510 | - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE | 659 | + limit_per_table = ( |
| 660 | + self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE | ||
| 661 | + ) | ||
| 511 | search_kwargs["limit_per_table"] = limit_per_table | 662 | search_kwargs["limit_per_table"] = limit_per_table |
| 512 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: | 663 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: |
| 513 | if search_tool == "get_comments_for_topic": | 664 | if search_tool == "get_comments_for_topic": |
| @@ -516,34 +667,46 @@ class DeepSearchAgent: | @@ -516,34 +667,46 @@ class DeepSearchAgent: | ||
| 516 | limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT | 667 | limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT |
| 517 | search_kwargs["limit"] = limit | 668 | search_kwargs["limit"] = limit |
| 518 | 669 | ||
| 519 | - search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) | 670 | + search_response = self.execute_search_tool( |
| 671 | + search_tool, search_query, **search_kwargs | ||
| 672 | + ) | ||
| 520 | 673 | ||
| 521 | # 转换为兼容格式 | 674 | # 转换为兼容格式 |
| 522 | search_results = [] | 675 | search_results = [] |
| 523 | if search_response and search_response.results: | 676 | if search_response and search_response.results: |
| 524 | # 使用配置文件控制传递给LLM的结果数量,0表示不限制 | 677 | # 使用配置文件控制传递给LLM的结果数量,0表示不限制 |
| 525 | if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: | 678 | if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: |
| 526 | - max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM) | 679 | + max_results = min( |
| 680 | + len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM | ||
| 681 | + ) | ||
| 527 | else: | 682 | else: |
| 528 | max_results = len(search_response.results) # 不限制,传递所有结果 | 683 | max_results = len(search_response.results) # 不限制,传递所有结果 |
| 529 | for result in search_response.results[:max_results]: | 684 | for result in search_response.results[:max_results]: |
| 530 | - search_results.append({ | ||
| 531 | - 'title': result.title_or_content, | ||
| 532 | - 'url': result.url or "", | ||
| 533 | - 'content': result.title_or_content, | ||
| 534 | - 'score': result.hotness_score, | ||
| 535 | - 'raw_content': result.title_or_content, | ||
| 536 | - 'published_date': result.publish_time.isoformat() if result.publish_time else None, | ||
| 537 | - 'platform': result.platform, | ||
| 538 | - 'content_type': result.content_type, | ||
| 539 | - 'author': result.author_nickname, | ||
| 540 | - 'engagement': result.engagement | ||
| 541 | - }) | 685 | + search_results.append( |
| 686 | + { | ||
| 687 | + "title": result.title_or_content, | ||
| 688 | + "url": result.url or "", | ||
| 689 | + "content": result.title_or_content, | ||
| 690 | + "score": result.hotness_score, | ||
| 691 | + "raw_content": result.title_or_content, | ||
| 692 | + "published_date": result.publish_time.isoformat() | ||
| 693 | + if result.publish_time | ||
| 694 | + else None, | ||
| 695 | + "platform": result.platform, | ||
| 696 | + "content_type": result.content_type, | ||
| 697 | + "author": result.author_nickname, | ||
| 698 | + "engagement": result.engagement, | ||
| 699 | + } | ||
| 700 | + ) | ||
| 542 | 701 | ||
| 543 | if search_results: | 702 | if search_results: |
| 544 | _message = f" - 找到 {len(search_results)} 个搜索结果" | 703 | _message = f" - 找到 {len(search_results)} 个搜索结果" |
| 545 | for j, result in enumerate(search_results, 1): | 704 | for j, result in enumerate(search_results, 1): |
| 546 | - date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" | 705 | + date_info = ( |
| 706 | + f" (发布于: {result.get('published_date', 'N/A')})" | ||
| 707 | + if result.get("published_date") | ||
| 708 | + else "" | ||
| 709 | + ) | ||
| 547 | _message += f"\n {j}. {result['title'][:50]}...{date_info}" | 710 | _message += f"\n {j}. {result['title'][:50]}...{date_info}" |
| 548 | logger.info(_message) | 711 | logger.info(_message) |
| 549 | else: | 712 | else: |
| @@ -560,7 +723,7 @@ class DeepSearchAgent: | @@ -560,7 +723,7 @@ class DeepSearchAgent: | ||
| 560 | "search_query": search_query, | 723 | "search_query": search_query, |
| 561 | "search_results": format_search_results_for_prompt( | 724 | "search_results": format_search_results_for_prompt( |
| 562 | search_results, self.config.MAX_CONTENT_LENGTH | 725 | search_results, self.config.MAX_CONTENT_LENGTH |
| 563 | - ) | 726 | + ), |
| 564 | } | 727 | } |
| 565 | 728 | ||
| 566 | # 更新状态 | 729 | # 更新状态 |
| @@ -581,13 +744,15 @@ class DeepSearchAgent: | @@ -581,13 +744,15 @@ class DeepSearchAgent: | ||
| 581 | reflection_input = { | 744 | reflection_input = { |
| 582 | "title": paragraph.title, | 745 | "title": paragraph.title, |
| 583 | "content": paragraph.content, | 746 | "content": paragraph.content, |
| 584 | - "paragraph_latest_state": paragraph.research.latest_summary | 747 | + "paragraph_latest_state": paragraph.research.latest_summary, |
| 585 | } | 748 | } |
| 586 | 749 | ||
| 587 | # 生成反思搜索查询 | 750 | # 生成反思搜索查询 |
| 588 | reflection_output = self.reflection_node.run(reflection_input) | 751 | reflection_output = self.reflection_node.run(reflection_input) |
| 589 | search_query = reflection_output["search_query"] | 752 | search_query = reflection_output["search_query"] |
| 590 | - search_tool = reflection_output.get("search_tool", "search_topic_globally") # 默认工具 | 753 | + search_tool = reflection_output.get( |
| 754 | + "search_tool", "search_topic_globally" | ||
| 755 | + ) # 默认工具 | ||
| 591 | reasoning = reflection_output["reasoning"] | 756 | reasoning = reflection_output["reasoning"] |
| 592 | 757 | ||
| 593 | logger.info(f" 反思查询: {search_query}") | 758 | logger.info(f" 反思查询: {search_query}") |
| @@ -605,16 +770,24 @@ class DeepSearchAgent: | @@ -605,16 +770,24 @@ class DeepSearchAgent: | ||
| 605 | 770 | ||
| 606 | if start_date and end_date: | 771 | if start_date and end_date: |
| 607 | # 验证日期格式 | 772 | # 验证日期格式 |
| 608 | - if self._validate_date_format(start_date) and self._validate_date_format(end_date): | 773 | + if self._validate_date_format( |
| 774 | + start_date | ||
| 775 | + ) and self._validate_date_format(end_date): | ||
| 609 | search_kwargs["start_date"] = start_date | 776 | search_kwargs["start_date"] = start_date |
| 610 | search_kwargs["end_date"] = end_date | 777 | search_kwargs["end_date"] = end_date |
| 611 | logger.info(f" 时间范围: {start_date} 到 {end_date}") | 778 | logger.info(f" 时间范围: {start_date} 到 {end_date}") |
| 612 | else: | 779 | else: |
| 613 | - logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索") | ||
| 614 | - logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}") | 780 | + logger.info( |
| 781 | + f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索" | ||
| 782 | + ) | ||
| 783 | + logger.info( | ||
| 784 | + f" 提供的日期: start_date={start_date}, end_date={end_date}" | ||
| 785 | + ) | ||
| 615 | search_tool = "search_topic_globally" | 786 | search_tool = "search_topic_globally" |
| 616 | elif search_tool == "search_topic_by_date": | 787 | elif search_tool == "search_topic_by_date": |
| 617 | - logger.warning(f" search_topic_by_date工具缺少时间参数,改用全局搜索") | 788 | + logger.warning( |
| 789 | + f" search_topic_by_date工具缺少时间参数,改用全局搜索" | ||
| 790 | + ) | ||
| 618 | search_tool = "search_topic_globally" | 791 | search_tool = "search_topic_globally" |
| 619 | 792 | ||
| 620 | # 处理需要平台参数的工具 | 793 | # 处理需要平台参数的工具 |
| @@ -624,7 +797,9 @@ class DeepSearchAgent: | @@ -624,7 +797,9 @@ class DeepSearchAgent: | ||
| 624 | search_kwargs["platform"] = platform | 797 | search_kwargs["platform"] = platform |
| 625 | logger.info(f" 指定平台: {platform}") | 798 | logger.info(f" 指定平台: {platform}") |
| 626 | else: | 799 | else: |
| 627 | - logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索") | 800 | + logger.warning( |
| 801 | + f" search_topic_on_platform工具缺少平台参数,改用全局搜索" | ||
| 802 | + ) | ||
| 628 | search_tool = "search_topic_globally" | 803 | search_tool = "search_topic_globally" |
| 629 | 804 | ||
| 630 | # 处理限制参数 | 805 | # 处理限制参数 |
| @@ -637,9 +812,13 @@ class DeepSearchAgent: | @@ -637,9 +812,13 @@ class DeepSearchAgent: | ||
| 637 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: | 812 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: |
| 638 | # 使用配置文件中的默认值,不允许agent控制limit_per_table参数 | 813 | # 使用配置文件中的默认值,不允许agent控制limit_per_table参数 |
| 639 | if search_tool == "search_topic_globally": | 814 | if search_tool == "search_topic_globally": |
| 640 | - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE | 815 | + limit_per_table = ( |
| 816 | + self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE | ||
| 817 | + ) | ||
| 641 | else: # search_topic_by_date | 818 | else: # search_topic_by_date |
| 642 | - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE | 819 | + limit_per_table = ( |
| 820 | + self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE | ||
| 821 | + ) | ||
| 643 | search_kwargs["limit_per_table"] = limit_per_table | 822 | search_kwargs["limit_per_table"] = limit_per_table |
| 644 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: | 823 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: |
| 645 | # 使用配置文件中的默认值,不允许agent控制limit参数 | 824 | # 使用配置文件中的默认值,不允许agent控制limit参数 |
| @@ -649,34 +828,47 @@ class DeepSearchAgent: | @@ -649,34 +828,47 @@ class DeepSearchAgent: | ||
| 649 | limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT | 828 | limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT |
| 650 | search_kwargs["limit"] = limit | 829 | search_kwargs["limit"] = limit |
| 651 | 830 | ||
| 652 | - search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) | 831 | + search_response = self.execute_search_tool( |
| 832 | + search_tool, search_query, **search_kwargs | ||
| 833 | + ) | ||
| 653 | 834 | ||
| 654 | # 转换为兼容格式 | 835 | # 转换为兼容格式 |
| 655 | search_results = [] | 836 | search_results = [] |
| 656 | if search_response and search_response.results: | 837 | if search_response and search_response.results: |
| 657 | # 使用配置文件控制传递给LLM的结果数量,0表示不限制 | 838 | # 使用配置文件控制传递给LLM的结果数量,0表示不限制 |
| 658 | if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: | 839 | if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: |
| 659 | - max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM) | 840 | + max_results = min( |
| 841 | + len(search_response.results), | ||
| 842 | + self.config.MAX_SEARCH_RESULTS_FOR_LLM, | ||
| 843 | + ) | ||
| 660 | else: | 844 | else: |
| 661 | max_results = len(search_response.results) # 不限制,传递所有结果 | 845 | max_results = len(search_response.results) # 不限制,传递所有结果 |
| 662 | for result in search_response.results[:max_results]: | 846 | for result in search_response.results[:max_results]: |
| 663 | - search_results.append({ | ||
| 664 | - 'title': result.title_or_content, | ||
| 665 | - 'url': result.url or "", | ||
| 666 | - 'content': result.title_or_content, | ||
| 667 | - 'score': result.hotness_score, | ||
| 668 | - 'raw_content': result.title_or_content, | ||
| 669 | - 'published_date': result.publish_time.isoformat() if result.publish_time else None, | ||
| 670 | - 'platform': result.platform, | ||
| 671 | - 'content_type': result.content_type, | ||
| 672 | - 'author': result.author_nickname, | ||
| 673 | - 'engagement': result.engagement | ||
| 674 | - }) | 847 | + search_results.append( |
| 848 | + { | ||
| 849 | + "title": result.title_or_content, | ||
| 850 | + "url": result.url or "", | ||
| 851 | + "content": result.title_or_content, | ||
| 852 | + "score": result.hotness_score, | ||
| 853 | + "raw_content": result.title_or_content, | ||
| 854 | + "published_date": result.publish_time.isoformat() | ||
| 855 | + if result.publish_time | ||
| 856 | + else None, | ||
| 857 | + "platform": result.platform, | ||
| 858 | + "content_type": result.content_type, | ||
| 859 | + "author": result.author_nickname, | ||
| 860 | + "engagement": result.engagement, | ||
| 861 | + } | ||
| 862 | + ) | ||
| 675 | 863 | ||
| 676 | if search_results: | 864 | if search_results: |
| 677 | _message = f" 找到 {len(search_results)} 个反思搜索结果" | 865 | _message = f" 找到 {len(search_results)} 个反思搜索结果" |
| 678 | for j, result in enumerate(search_results, 1): | 866 | for j, result in enumerate(search_results, 1): |
| 679 | - date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" | 867 | + date_info = ( |
| 868 | + f" (发布于: {result.get('published_date', 'N/A')})" | ||
| 869 | + if result.get("published_date") | ||
| 870 | + else "" | ||
| 871 | + ) | ||
| 680 | _message += f"\n {j}. {result['title'][:50]}...{date_info}" | 872 | _message += f"\n {j}. {result['title'][:50]}...{date_info}" |
| 681 | logger.info(_message) | 873 | logger.info(_message) |
| 682 | else: | 874 | else: |
| @@ -693,7 +885,7 @@ class DeepSearchAgent: | @@ -693,7 +885,7 @@ class DeepSearchAgent: | ||
| 693 | "search_results": format_search_results_for_prompt( | 885 | "search_results": format_search_results_for_prompt( |
| 694 | search_results, self.config.MAX_CONTENT_LENGTH | 886 | search_results, self.config.MAX_CONTENT_LENGTH |
| 695 | ), | 887 | ), |
| 696 | - "paragraph_latest_state": paragraph.research.latest_summary | 888 | + "paragraph_latest_state": paragraph.research.latest_summary, |
| 697 | } | 889 | } |
| 698 | 890 | ||
| 699 | # 更新状态 | 891 | # 更新状态 |
| @@ -710,10 +902,12 @@ class DeepSearchAgent: | @@ -710,10 +902,12 @@ class DeepSearchAgent: | ||
| 710 | # 准备报告数据 | 902 | # 准备报告数据 |
| 711 | report_data = [] | 903 | report_data = [] |
| 712 | for paragraph in self.state.paragraphs: | 904 | for paragraph in self.state.paragraphs: |
| 713 | - report_data.append({ | 905 | + report_data.append( |
| 906 | + { | ||
| 714 | "title": paragraph.title, | 907 | "title": paragraph.title, |
| 715 | - "paragraph_latest_state": paragraph.research.latest_summary | ||
| 716 | - }) | 908 | + "paragraph_latest_state": paragraph.research.latest_summary, |
| 909 | + } | ||
| 910 | + ) | ||
| 717 | 911 | ||
| 718 | # 格式化报告 | 912 | # 格式化报告 |
| 719 | try: | 913 | try: |
| @@ -735,14 +929,16 @@ class DeepSearchAgent: | @@ -735,14 +929,16 @@ class DeepSearchAgent: | ||
| 735 | """保存报告到文件""" | 929 | """保存报告到文件""" |
| 736 | # 生成文件名 | 930 | # 生成文件名 |
| 737 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | 931 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| 738 | - query_safe = "".join(c for c in self.state.query if c.isalnum() or c in (' ', '-', '_')).rstrip() | ||
| 739 | - query_safe = query_safe.replace(' ', '_')[:30] | 932 | + query_safe = "".join( |
| 933 | + c for c in self.state.query if c.isalnum() or c in (" ", "-", "_") | ||
| 934 | + ).rstrip() | ||
| 935 | + query_safe = query_safe.replace(" ", "_")[:30] | ||
| 740 | 936 | ||
| 741 | filename = f"deep_search_report_{query_safe}_{timestamp}.md" | 937 | filename = f"deep_search_report_{query_safe}_{timestamp}.md" |
| 742 | filepath = os.path.join(self.config.OUTPUT_DIR, filename) | 938 | filepath = os.path.join(self.config.OUTPUT_DIR, filename) |
| 743 | 939 | ||
| 744 | # 保存报告 | 940 | # 保存报告 |
| 745 | - with open(filepath, 'w', encoding='utf-8') as f: | 941 | + with open(filepath, "w", encoding="utf-8") as f: |
| 746 | f.write(report_content) | 942 | f.write(report_content) |
| 747 | 943 | ||
| 748 | logger.info(f"报告已保存到: {filepath}") | 944 | logger.info(f"报告已保存到: {filepath}") |
| @@ -61,6 +61,7 @@ weasyprint>=60.0 # PDF导出,支持Python 3.9-3.13 | @@ -61,6 +61,7 @@ weasyprint>=60.0 # PDF导出,支持Python 3.9-3.13 | ||
| 61 | # ===== 机器学习(可选,用于情感分析,不安装也没事写了容错程序) ===== | 61 | # ===== 机器学习(可选,用于情感分析,不安装也没事写了容错程序) ===== |
| 62 | torch>=2.0.0 # CPU版本 | 62 | torch>=2.0.0 # CPU版本 |
| 63 | transformers>=4.30.0 | 63 | transformers>=4.30.0 |
| 64 | +sentence-transformers>=2.2.2 | ||
| 64 | scikit-learn>=1.3.0 | 65 | scikit-learn>=1.3.0 |
| 65 | xgboost>=2.0.0 | 66 | xgboost>=2.0.0 |
| 66 | # NOTE:如果要安装GPU版本的torch,指令为pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126 | 67 | # NOTE:如果要安装GPU版本的torch,指令为pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126 |
-
Please register or login to post a comment