Showing
1 changed file
with
422 additions
and
229 deletions
| @@ -7,60 +7,75 @@ import json | @@ -7,60 +7,75 @@ import json | ||
| 7 | import os | 7 | import os |
| 8 | import re | 8 | import re |
| 9 | from datetime import datetime | 9 | from datetime import datetime |
| 10 | -from typing import Optional, Dict, Any, List, Union | 10 | +from typing import Any, Dict, List, Optional, Union |
| 11 | + | ||
| 12 | +import numpy as np | ||
| 11 | from loguru import logger | 13 | from loguru import logger |
| 14 | +from sentence_transformers import SentenceTransformer | ||
| 15 | +from sklearn.cluster import KMeans | ||
| 12 | 16 | ||
| 13 | from .llms import LLMClient | 17 | from .llms import LLMClient |
| 14 | from .nodes import ( | 18 | from .nodes import ( |
| 15 | - ReportStructureNode, | ||
| 16 | - FirstSearchNode, | ||
| 17 | - ReflectionNode, | 19 | + FirstSearchNode, |
| 18 | FirstSummaryNode, | 20 | FirstSummaryNode, |
| 21 | + ReflectionNode, | ||
| 19 | ReflectionSummaryNode, | 22 | ReflectionSummaryNode, |
| 20 | - ReportFormattingNode | 23 | + ReportFormattingNode, |
| 24 | + ReportStructureNode, | ||
| 21 | ) | 25 | ) |
| 22 | from .state import State | 26 | from .state import State |
| 23 | -from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer, multilingual_sentiment_analyzer | ||
| 24 | -from .utils.config import settings, Settings | 27 | +from .tools import ( |
| 28 | + DBResponse, | ||
| 29 | + MediaCrawlerDB, | ||
| 30 | + keyword_optimizer, | ||
| 31 | + multilingual_sentiment_analyzer, | ||
| 32 | +) | ||
| 25 | from .utils import format_search_results_for_prompt | 33 | from .utils import format_search_results_for_prompt |
| 34 | +from .utils.config import Settings, settings | ||
| 35 | + | ||
| 36 | +ENABLE_CLUSTERING: bool = True # 是否启用聚类采样 | ||
| 37 | +MAX_CLUSTERED_RESULTS: int = 50 # 聚类后最大返回结果数 | ||
| 38 | +RESULTS_PER_CLUSTER: int = 5 # 每个聚类返回的结果数 | ||
| 26 | 39 | ||
| 27 | 40 | ||
| 28 | class DeepSearchAgent: | 41 | class DeepSearchAgent: |
| 29 | """Deep Search Agent主类""" | 42 | """Deep Search Agent主类""" |
| 30 | - | 43 | + |
| 31 | def __init__(self, config: Optional[Settings] = None): | 44 | def __init__(self, config: Optional[Settings] = None): |
| 32 | """ | 45 | """ |
| 33 | 初始化Deep Search Agent | 46 | 初始化Deep Search Agent |
| 34 | - | 47 | + |
| 35 | Args: | 48 | Args: |
| 36 | config: 可选配置对象(不填则用全局settings) | 49 | config: 可选配置对象(不填则用全局settings) |
| 37 | """ | 50 | """ |
| 38 | self.config = config or settings | 51 | self.config = config or settings |
| 39 | - | 52 | + |
| 40 | # 初始化LLM客户端 | 53 | # 初始化LLM客户端 |
| 41 | self.llm_client = self._initialize_llm() | 54 | self.llm_client = self._initialize_llm() |
| 42 | - | ||
| 43 | - | 55 | + |
| 44 | # 初始化搜索工具集 | 56 | # 初始化搜索工具集 |
| 45 | self.search_agency = MediaCrawlerDB() | 57 | self.search_agency = MediaCrawlerDB() |
| 46 | - | 58 | + |
| 59 | + # 初始化聚类小模型(懒加载) | ||
| 60 | + self._clustering_model = None | ||
| 61 | + | ||
| 47 | # 初始化情感分析器 | 62 | # 初始化情感分析器 |
| 48 | self.sentiment_analyzer = multilingual_sentiment_analyzer | 63 | self.sentiment_analyzer = multilingual_sentiment_analyzer |
| 49 | - | 64 | + |
| 50 | # 初始化节点 | 65 | # 初始化节点 |
| 51 | self._initialize_nodes() | 66 | self._initialize_nodes() |
| 52 | - | 67 | + |
| 53 | # 状态 | 68 | # 状态 |
| 54 | self.state = State() | 69 | self.state = State() |
| 55 | - | 70 | + |
| 56 | # 确保输出目录存在 | 71 | # 确保输出目录存在 |
| 57 | os.makedirs(self.config.OUTPUT_DIR, exist_ok=True) | 72 | os.makedirs(self.config.OUTPUT_DIR, exist_ok=True) |
| 58 | - | 73 | + |
| 59 | logger.info(f"Insight Agent已初始化") | 74 | logger.info(f"Insight Agent已初始化") |
| 60 | logger.info(f"使用LLM: {self.llm_client.get_model_info()}") | 75 | logger.info(f"使用LLM: {self.llm_client.get_model_info()}") |
| 61 | logger.info(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)") | 76 | logger.info(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)") |
| 62 | logger.info(f"情感分析: WeiboMultilingualSentiment (支持22种语言的情感分析)") | 77 | logger.info(f"情感分析: WeiboMultilingualSentiment (支持22种语言的情感分析)") |
| 63 | - | 78 | + |
| 64 | def _initialize_llm(self) -> LLMClient: | 79 | def _initialize_llm(self) -> LLMClient: |
| 65 | """初始化LLM客户端""" | 80 | """初始化LLM客户端""" |
| 66 | return LLMClient( | 81 | return LLMClient( |
| @@ -68,7 +83,7 @@ class DeepSearchAgent: | @@ -68,7 +83,7 @@ class DeepSearchAgent: | ||
| 68 | model_name=self.config.INSIGHT_ENGINE_MODEL_NAME, | 83 | model_name=self.config.INSIGHT_ENGINE_MODEL_NAME, |
| 69 | base_url=self.config.INSIGHT_ENGINE_BASE_URL, | 84 | base_url=self.config.INSIGHT_ENGINE_BASE_URL, |
| 70 | ) | 85 | ) |
| 71 | - | 86 | + |
| 72 | def _initialize_nodes(self): | 87 | def _initialize_nodes(self): |
| 73 | """初始化处理节点""" | 88 | """初始化处理节点""" |
| 74 | self.first_search_node = FirstSearchNode(self.llm_client) | 89 | self.first_search_node = FirstSearchNode(self.llm_client) |
| @@ -76,36 +91,103 @@ class DeepSearchAgent: | @@ -76,36 +91,103 @@ class DeepSearchAgent: | ||
| 76 | self.first_summary_node = FirstSummaryNode(self.llm_client) | 91 | self.first_summary_node = FirstSummaryNode(self.llm_client) |
| 77 | self.reflection_summary_node = ReflectionSummaryNode(self.llm_client) | 92 | self.reflection_summary_node = ReflectionSummaryNode(self.llm_client) |
| 78 | self.report_formatting_node = ReportFormattingNode(self.llm_client) | 93 | self.report_formatting_node = ReportFormattingNode(self.llm_client) |
| 79 | - | 94 | + |
| 95 | + def _get_clustering_model(self): | ||
| 96 | + """懒加载聚类模型""" | ||
| 97 | + if self._clustering_model is None: | ||
| 98 | + logger.info(" 加载聚类模型 (paraphrase-multilingual-MiniLM-L12-v2)...") | ||
| 99 | + self._clustering_model = SentenceTransformer( | ||
| 100 | + "paraphrase-multilingual-MiniLM-L12-v2" | ||
| 101 | + ) | ||
| 102 | + return self._clustering_model | ||
| 103 | + | ||
| 80 | def _validate_date_format(self, date_str: str) -> bool: | 104 | def _validate_date_format(self, date_str: str) -> bool: |
| 81 | """ | 105 | """ |
| 82 | 验证日期格式是否为YYYY-MM-DD | 106 | 验证日期格式是否为YYYY-MM-DD |
| 83 | - | 107 | + |
| 84 | Args: | 108 | Args: |
| 85 | date_str: 日期字符串 | 109 | date_str: 日期字符串 |
| 86 | - | 110 | + |
| 87 | Returns: | 111 | Returns: |
| 88 | 是否为有效格式 | 112 | 是否为有效格式 |
| 89 | """ | 113 | """ |
| 90 | if not date_str: | 114 | if not date_str: |
| 91 | return False | 115 | return False |
| 92 | - | 116 | + |
| 93 | # 检查格式 | 117 | # 检查格式 |
| 94 | - pattern = r'^\d{4}-\d{2}-\d{2}$' | 118 | + pattern = r"^\d{4}-\d{2}-\d{2}$" |
| 95 | if not re.match(pattern, date_str): | 119 | if not re.match(pattern, date_str): |
| 96 | return False | 120 | return False |
| 97 | - | 121 | + |
| 98 | # 检查日期是否有效 | 122 | # 检查日期是否有效 |
| 99 | try: | 123 | try: |
| 100 | - datetime.strptime(date_str, '%Y-%m-%d') | 124 | + datetime.strptime(date_str, "%Y-%m-%d") |
| 101 | return True | 125 | return True |
| 102 | except ValueError: | 126 | except ValueError: |
| 103 | return False | 127 | return False |
| 104 | - | 128 | + |
| 129 | + def _cluster_and_sample_results( | ||
| 130 | + self, results: List, max_results: int = 50, results_per_cluster: int = 5 | ||
| 131 | + ) -> List: | ||
| 132 | + """ | ||
| 133 | + 对搜索结果进行聚类并采样 | ||
| 134 | + | ||
| 135 | + Args: | ||
| 136 | + results: 搜索结果列表 | ||
| 137 | + max_results: 最大返回结果数 | ||
| 138 | + results_per_cluster: 每个聚类返回的结果数 | ||
| 139 | + | ||
| 140 | + Returns: | ||
| 141 | + 采样后的结果列表 | ||
| 142 | + """ | ||
| 143 | + if len(results) <= max_results: | ||
| 144 | + return results | ||
| 145 | + | ||
| 146 | + try: | ||
| 147 | + # 提取文本 | ||
| 148 | + texts = [r.title_or_content[:500] for r in results] | ||
| 149 | + | ||
| 150 | + # 获取模型并编码 | ||
| 151 | + model = self._get_clustering_model() | ||
| 152 | + embeddings = model.encode(texts, show_progress_bar=False) | ||
| 153 | + | ||
| 154 | + # 计算聚类数 | ||
| 155 | + n_clusters = min(max(2, max_results // results_per_cluster), len(results)) | ||
| 156 | + | ||
| 157 | + # KMeans聚类 | ||
| 158 | + kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) | ||
| 159 | + labels = kmeans.fit_predict(embeddings) | ||
| 160 | + | ||
| 161 | + # 从每个聚类采样 | ||
| 162 | + sampled_results = [] | ||
| 163 | + for cluster_id in range(n_clusters): | ||
| 164 | + cluster_indices = np.where(labels == cluster_id)[0] | ||
| 165 | + cluster_results = [(results[i], i) for i in cluster_indices] | ||
| 166 | + cluster_results.sort( | ||
| 167 | + key=lambda x: x[0].hotness_score or 0, reverse=True | ||
| 168 | + ) | ||
| 169 | + | ||
| 170 | + for result, _ in cluster_results[:results_per_cluster]: | ||
| 171 | + sampled_results.append(result) | ||
| 172 | + if len(sampled_results) >= max_results: | ||
| 173 | + break | ||
| 174 | + | ||
| 175 | + if len(sampled_results) >= max_results: | ||
| 176 | + break | ||
| 177 | + | ||
| 178 | + logger.info( | ||
| 179 | + f" 聚类完成: {len(results)} 条 -> {n_clusters} 个主题 -> {len(sampled_results)} 条代表性结果" | ||
| 180 | + ) | ||
| 181 | + return sampled_results | ||
| 182 | + | ||
| 183 | + except Exception as e: | ||
| 184 | + logger.warning(f" 聚类失败,返回前{max_results}条: {str(e)}") | ||
| 185 | + return results[:max_results] | ||
| 186 | + | ||
| 105 | def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse: | 187 | def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse: |
| 106 | """ | 188 | """ |
| 107 | 执行指定的数据库查询工具(集成关键词优化中间件和情感分析) | 189 | 执行指定的数据库查询工具(集成关键词优化中间件和情感分析) |
| 108 | - | 190 | + |
| 109 | Args: | 191 | Args: |
| 110 | tool_name: 工具名称,可选值: | 192 | tool_name: 工具名称,可选值: |
| 111 | - "search_hot_content": 查找热点内容 | 193 | - "search_hot_content": 查找热点内容 |
| @@ -117,18 +199,20 @@ class DeepSearchAgent: | @@ -117,18 +199,20 @@ class DeepSearchAgent: | ||
| 117 | query: 搜索关键词/话题 | 199 | query: 搜索关键词/话题 |
| 118 | **kwargs: 额外参数(如start_date, end_date, platform, limit, enable_sentiment等) | 200 | **kwargs: 额外参数(如start_date, end_date, platform, limit, enable_sentiment等) |
| 119 | enable_sentiment: 是否自动对搜索结果进行情感分析(默认True) | 201 | enable_sentiment: 是否自动对搜索结果进行情感分析(默认True) |
| 120 | - | 202 | + |
| 121 | Returns: | 203 | Returns: |
| 122 | DBResponse对象(可能包含情感分析结果) | 204 | DBResponse对象(可能包含情感分析结果) |
| 123 | """ | 205 | """ |
| 124 | logger.info(f" → 执行数据库查询工具: {tool_name}") | 206 | logger.info(f" → 执行数据库查询工具: {tool_name}") |
| 125 | - | 207 | + |
| 126 | # 对于热点内容搜索,不需要关键词优化(因为不需要query参数) | 208 | # 对于热点内容搜索,不需要关键词优化(因为不需要query参数) |
| 127 | if tool_name == "search_hot_content": | 209 | if tool_name == "search_hot_content": |
| 128 | time_period = kwargs.get("time_period", "week") | 210 | time_period = kwargs.get("time_period", "week") |
| 129 | limit = kwargs.get("limit", 100) | 211 | limit = kwargs.get("limit", 100) |
| 130 | - response = self.search_agency.search_hot_content(time_period=time_period, limit=limit) | ||
| 131 | - | 212 | + response = self.search_agency.search_hot_content( |
| 213 | + time_period=time_period, limit=limit | ||
| 214 | + ) | ||
| 215 | + | ||
| 132 | # 检查是否需要进行情感分析 | 216 | # 检查是否需要进行情感分析 |
| 133 | enable_sentiment = kwargs.get("enable_sentiment", True) | 217 | enable_sentiment = kwargs.get("enable_sentiment", True) |
| 134 | if enable_sentiment and response.results and len(response.results) > 0: | 218 | if enable_sentiment and response.results and len(response.results) > 0: |
| @@ -138,74 +222,101 @@ class DeepSearchAgent: | @@ -138,74 +222,101 @@ class DeepSearchAgent: | ||
| 138 | # 将情感分析结果添加到响应的parameters中 | 222 | # 将情感分析结果添加到响应的parameters中 |
| 139 | response.parameters["sentiment_analysis"] = sentiment_analysis | 223 | response.parameters["sentiment_analysis"] = sentiment_analysis |
| 140 | logger.info(f" ✅ 情感分析完成") | 224 | logger.info(f" ✅ 情感分析完成") |
| 141 | - | 225 | + |
| 142 | return response | 226 | return response |
| 143 | - | 227 | + |
| 144 | # 独立情感分析工具 | 228 | # 独立情感分析工具 |
| 145 | if tool_name == "analyze_sentiment": | 229 | if tool_name == "analyze_sentiment": |
| 146 | texts = kwargs.get("texts", query) # 可以通过texts参数传递,或使用query | 230 | texts = kwargs.get("texts", query) # 可以通过texts参数传递,或使用query |
| 147 | sentiment_result = self.analyze_sentiment_only(texts) | 231 | sentiment_result = self.analyze_sentiment_only(texts) |
| 148 | - | 232 | + |
| 149 | # 构建DBResponse格式的响应 | 233 | # 构建DBResponse格式的响应 |
| 150 | return DBResponse( | 234 | return DBResponse( |
| 151 | tool_name="analyze_sentiment", | 235 | tool_name="analyze_sentiment", |
| 152 | parameters={ | 236 | parameters={ |
| 153 | "texts": texts if isinstance(texts, list) else [texts], | 237 | "texts": texts if isinstance(texts, list) else [texts], |
| 154 | - **kwargs | 238 | + **kwargs, |
| 155 | }, | 239 | }, |
| 156 | results=[], # 情感分析不返回搜索结果 | 240 | results=[], # 情感分析不返回搜索结果 |
| 157 | results_count=0, | 241 | results_count=0, |
| 158 | - metadata=sentiment_result | 242 | + metadata=sentiment_result, |
| 159 | ) | 243 | ) |
| 160 | - | 244 | + |
| 161 | # 对于需要搜索词的工具,使用关键词优化中间件 | 245 | # 对于需要搜索词的工具,使用关键词优化中间件 |
| 162 | optimized_response = keyword_optimizer.optimize_keywords( | 246 | optimized_response = keyword_optimizer.optimize_keywords( |
| 163 | - original_query=query, | ||
| 164 | - context=f"使用{tool_name}工具进行查询" | 247 | + original_query=query, context=f"使用{tool_name}工具进行查询" |
| 165 | ) | 248 | ) |
| 166 | - | 249 | + |
| 167 | logger.info(f" 🔍 原始查询: '{query}'") | 250 | logger.info(f" 🔍 原始查询: '{query}'") |
| 168 | logger.info(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}") | 251 | logger.info(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}") |
| 169 | - | 252 | + |
| 170 | # 使用优化后的关键词进行多次查询并整合结果 | 253 | # 使用优化后的关键词进行多次查询并整合结果 |
| 171 | all_results = [] | 254 | all_results = [] |
| 172 | total_count = 0 | 255 | total_count = 0 |
| 173 | - | 256 | + |
| 174 | for keyword in optimized_response.optimized_keywords: | 257 | for keyword in optimized_response.optimized_keywords: |
| 175 | logger.info(f" 查询关键词: '{keyword}'") | 258 | logger.info(f" 查询关键词: '{keyword}'") |
| 176 | - | 259 | + |
| 177 | try: | 260 | try: |
| 178 | if tool_name == "search_topic_globally": | 261 | if tool_name == "search_topic_globally": |
| 179 | # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 | 262 | # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 |
| 180 | - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE | ||
| 181 | - response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=limit_per_table) | 263 | + limit_per_table = ( |
| 264 | + self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE | ||
| 265 | + ) | ||
| 266 | + response = self.search_agency.search_topic_globally( | ||
| 267 | + topic=keyword, limit_per_table=limit_per_table | ||
| 268 | + ) | ||
| 182 | elif tool_name == "search_topic_by_date": | 269 | elif tool_name == "search_topic_by_date": |
| 183 | start_date = kwargs.get("start_date") | 270 | start_date = kwargs.get("start_date") |
| 184 | end_date = kwargs.get("end_date") | 271 | end_date = kwargs.get("end_date") |
| 185 | # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 | 272 | # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 |
| 186 | - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE | 273 | + limit_per_table = ( |
| 274 | + self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE | ||
| 275 | + ) | ||
| 187 | if not start_date or not end_date: | 276 | if not start_date or not end_date: |
| 188 | - raise ValueError("search_topic_by_date工具需要start_date和end_date参数") | ||
| 189 | - response = self.search_agency.search_topic_by_date(topic=keyword, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table) | 277 | + raise ValueError( |
| 278 | + "search_topic_by_date工具需要start_date和end_date参数" | ||
| 279 | + ) | ||
| 280 | + response = self.search_agency.search_topic_by_date( | ||
| 281 | + topic=keyword, | ||
| 282 | + start_date=start_date, | ||
| 283 | + end_date=end_date, | ||
| 284 | + limit_per_table=limit_per_table, | ||
| 285 | + ) | ||
| 190 | elif tool_name == "get_comments_for_topic": | 286 | elif tool_name == "get_comments_for_topic": |
| 191 | # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 | 287 | # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 |
| 192 | - limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT // len(optimized_response.optimized_keywords) | 288 | + limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT // len( |
| 289 | + optimized_response.optimized_keywords | ||
| 290 | + ) | ||
| 193 | limit = max(limit, 50) | 291 | limit = max(limit, 50) |
| 194 | - response = self.search_agency.get_comments_for_topic(topic=keyword, limit=limit) | 292 | + response = self.search_agency.get_comments_for_topic( |
| 293 | + topic=keyword, limit=limit | ||
| 294 | + ) | ||
| 195 | elif tool_name == "search_topic_on_platform": | 295 | elif tool_name == "search_topic_on_platform": |
| 196 | platform = kwargs.get("platform") | 296 | platform = kwargs.get("platform") |
| 197 | start_date = kwargs.get("start_date") | 297 | start_date = kwargs.get("start_date") |
| 198 | end_date = kwargs.get("end_date") | 298 | end_date = kwargs.get("end_date") |
| 199 | # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 | 299 | # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 |
| 200 | - limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT // len(optimized_response.optimized_keywords) | 300 | + limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT // len( |
| 301 | + optimized_response.optimized_keywords | ||
| 302 | + ) | ||
| 201 | limit = max(limit, 30) | 303 | limit = max(limit, 30) |
| 202 | if not platform: | 304 | if not platform: |
| 203 | raise ValueError("search_topic_on_platform工具需要platform参数") | 305 | raise ValueError("search_topic_on_platform工具需要platform参数") |
| 204 | - response = self.search_agency.search_topic_on_platform(platform=platform, topic=keyword, start_date=start_date, end_date=end_date, limit=limit) | 306 | + response = self.search_agency.search_topic_on_platform( |
| 307 | + platform=platform, | ||
| 308 | + topic=keyword, | ||
| 309 | + start_date=start_date, | ||
| 310 | + end_date=end_date, | ||
| 311 | + limit=limit, | ||
| 312 | + ) | ||
| 205 | else: | 313 | else: |
| 206 | logger.info(f" 未知的搜索工具: {tool_name},使用默认全局搜索") | 314 | logger.info(f" 未知的搜索工具: {tool_name},使用默认全局搜索") |
| 207 | - response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE) | ||
| 208 | - | 315 | + response = self.search_agency.search_topic_globally( |
| 316 | + topic=keyword, | ||
| 317 | + limit_per_table=self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE, | ||
| 318 | + ) | ||
| 319 | + | ||
| 209 | # 收集结果 | 320 | # 收集结果 |
| 210 | if response.results: | 321 | if response.results: |
| 211 | logger.info(f" 找到 {len(response.results)} 条结果") | 322 | logger.info(f" 找到 {len(response.results)} 条结果") |
| @@ -213,15 +324,22 @@ class DeepSearchAgent: | @@ -213,15 +324,22 @@ class DeepSearchAgent: | ||
| 213 | total_count += len(response.results) | 324 | total_count += len(response.results) |
| 214 | else: | 325 | else: |
| 215 | logger.info(f" 未找到结果") | 326 | logger.info(f" 未找到结果") |
| 216 | - | 327 | + |
| 217 | except Exception as e: | 328 | except Exception as e: |
| 218 | logger.error(f" 查询'{keyword}'时出错: {str(e)}") | 329 | logger.error(f" 查询'{keyword}'时出错: {str(e)}") |
| 219 | continue | 330 | continue |
| 220 | - | 331 | + |
| 221 | # 去重和整合结果 | 332 | # 去重和整合结果 |
| 222 | unique_results = self._deduplicate_results(all_results) | 333 | unique_results = self._deduplicate_results(all_results) |
| 223 | logger.info(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)} 条") | 334 | logger.info(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)} 条") |
| 224 | - | 335 | + |
| 336 | + if ENABLE_CLUSTERING: | ||
| 337 | + unique_results = self._cluster_and_sample_results( | ||
| 338 | + unique_results, | ||
| 339 | + max_results=MAX_CLUSTERED_RESULTS, | ||
| 340 | + results_per_cluster=RESULTS_PER_CLUSTER, | ||
| 341 | + ) | ||
| 342 | + | ||
| 225 | # 构建整合后的响应 | 343 | # 构建整合后的响应 |
| 226 | integrated_response = DBResponse( | 344 | integrated_response = DBResponse( |
| 227 | tool_name=f"{tool_name}_optimized", | 345 | tool_name=f"{tool_name}_optimized", |
| @@ -229,12 +347,12 @@ class DeepSearchAgent: | @@ -229,12 +347,12 @@ class DeepSearchAgent: | ||
| 229 | "original_query": query, | 347 | "original_query": query, |
| 230 | "optimized_keywords": optimized_response.optimized_keywords, | 348 | "optimized_keywords": optimized_response.optimized_keywords, |
| 231 | "optimization_reasoning": optimized_response.reasoning, | 349 | "optimization_reasoning": optimized_response.reasoning, |
| 232 | - **kwargs | 350 | + **kwargs, |
| 233 | }, | 351 | }, |
| 234 | results=unique_results, | 352 | results=unique_results, |
| 235 | - results_count=len(unique_results) | 353 | + results_count=len(unique_results), |
| 236 | ) | 354 | ) |
| 237 | - | 355 | + |
| 238 | # 检查是否需要进行情感分析 | 356 | # 检查是否需要进行情感分析 |
| 239 | enable_sentiment = kwargs.get("enable_sentiment", True) | 357 | enable_sentiment = kwargs.get("enable_sentiment", True) |
| 240 | if enable_sentiment and unique_results and len(unique_results) > 0: | 358 | if enable_sentiment and unique_results and len(unique_results) > 0: |
| @@ -242,40 +360,45 @@ class DeepSearchAgent: | @@ -242,40 +360,45 @@ class DeepSearchAgent: | ||
| 242 | sentiment_analysis = self._perform_sentiment_analysis(unique_results) | 360 | sentiment_analysis = self._perform_sentiment_analysis(unique_results) |
| 243 | if sentiment_analysis: | 361 | if sentiment_analysis: |
| 244 | # 将情感分析结果添加到响应的parameters中 | 362 | # 将情感分析结果添加到响应的parameters中 |
| 245 | - integrated_response.parameters["sentiment_analysis"] = sentiment_analysis | 363 | + integrated_response.parameters["sentiment_analysis"] = ( |
| 364 | + sentiment_analysis | ||
| 365 | + ) | ||
| 246 | logger.info(f" ✅ 情感分析完成") | 366 | logger.info(f" ✅ 情感分析完成") |
| 247 | - | 367 | + |
| 248 | return integrated_response | 368 | return integrated_response |
| 249 | - | 369 | + |
| 250 | def _deduplicate_results(self, results: List) -> List: | 370 | def _deduplicate_results(self, results: List) -> List: |
| 251 | """ | 371 | """ |
| 252 | 去重搜索结果 | 372 | 去重搜索结果 |
| 253 | """ | 373 | """ |
| 254 | seen = set() | 374 | seen = set() |
| 255 | unique_results = [] | 375 | unique_results = [] |
| 256 | - | 376 | + |
| 257 | for result in results: | 377 | for result in results: |
| 258 | # 使用URL或内容作为去重标识 | 378 | # 使用URL或内容作为去重标识 |
| 259 | identifier = result.url if result.url else result.title_or_content[:100] | 379 | identifier = result.url if result.url else result.title_or_content[:100] |
| 260 | if identifier not in seen: | 380 | if identifier not in seen: |
| 261 | seen.add(identifier) | 381 | seen.add(identifier) |
| 262 | unique_results.append(result) | 382 | unique_results.append(result) |
| 263 | - | 383 | + |
| 264 | return unique_results | 384 | return unique_results |
| 265 | - | 385 | + |
| 266 | def _perform_sentiment_analysis(self, results: List) -> Optional[Dict[str, Any]]: | 386 | def _perform_sentiment_analysis(self, results: List) -> Optional[Dict[str, Any]]: |
| 267 | """ | 387 | """ |
| 268 | 对搜索结果执行情感分析 | 388 | 对搜索结果执行情感分析 |
| 269 | - | 389 | + |
| 270 | Args: | 390 | Args: |
| 271 | results: 搜索结果列表 | 391 | results: 搜索结果列表 |
| 272 | - | 392 | + |
| 273 | Returns: | 393 | Returns: |
| 274 | 情感分析结果字典,如果失败则返回None | 394 | 情感分析结果字典,如果失败则返回None |
| 275 | """ | 395 | """ |
| 276 | try: | 396 | try: |
| 277 | # 初始化情感分析器(如果尚未初始化且未被禁用) | 397 | # 初始化情感分析器(如果尚未初始化且未被禁用) |
| 278 | - if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled: | 398 | + if ( |
| 399 | + not self.sentiment_analyzer.is_initialized | ||
| 400 | + and not self.sentiment_analyzer.is_disabled | ||
| 401 | + ): | ||
| 279 | logger.info(" 初始化情感分析模型...") | 402 | logger.info(" 初始化情感分析模型...") |
| 280 | if not self.sentiment_analyzer.initialize(): | 403 | if not self.sentiment_analyzer.initialize(): |
| 281 | logger.info(" 情感分析模型初始化失败,将直接透传原始文本") | 404 | logger.info(" 情感分析模型初始化失败,将直接透传原始文本") |
| @@ -290,203 +413,222 @@ class DeepSearchAgent: | @@ -290,203 +413,222 @@ class DeepSearchAgent: | ||
| 290 | "platform": result.platform, | 413 | "platform": result.platform, |
| 291 | "author": result.author_nickname, | 414 | "author": result.author_nickname, |
| 292 | "url": result.url, | 415 | "url": result.url, |
| 293 | - "publish_time": str(result.publish_time) if result.publish_time else None | 416 | + "publish_time": str(result.publish_time) |
| 417 | + if result.publish_time | ||
| 418 | + else None, | ||
| 294 | } | 419 | } |
| 295 | results_dict.append(result_dict) | 420 | results_dict.append(result_dict) |
| 296 | - | 421 | + |
| 297 | # 执行情感分析 | 422 | # 执行情感分析 |
| 298 | sentiment_analysis = self.sentiment_analyzer.analyze_query_results( | 423 | sentiment_analysis = self.sentiment_analyzer.analyze_query_results( |
| 299 | - query_results=results_dict, | ||
| 300 | - text_field="content", | ||
| 301 | - min_confidence=0.5 | 424 | + query_results=results_dict, text_field="content", min_confidence=0.5 |
| 302 | ) | 425 | ) |
| 303 | - | 426 | + |
| 304 | return sentiment_analysis.get("sentiment_analysis") | 427 | return sentiment_analysis.get("sentiment_analysis") |
| 305 | - | 428 | + |
| 306 | except Exception as e: | 429 | except Exception as e: |
| 307 | logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}") | 430 | logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}") |
| 308 | return None | 431 | return None |
| 309 | - | 432 | + |
| 310 | def analyze_sentiment_only(self, texts: Union[str, List[str]]) -> Dict[str, Any]: | 433 | def analyze_sentiment_only(self, texts: Union[str, List[str]]) -> Dict[str, Any]: |
| 311 | """ | 434 | """ |
| 312 | 独立的情感分析工具 | 435 | 独立的情感分析工具 |
| 313 | - | 436 | + |
| 314 | Args: | 437 | Args: |
| 315 | texts: 单个文本或文本列表 | 438 | texts: 单个文本或文本列表 |
| 316 | - | 439 | + |
| 317 | Returns: | 440 | Returns: |
| 318 | 情感分析结果 | 441 | 情感分析结果 |
| 319 | """ | 442 | """ |
| 320 | logger.info(f" → 执行独立情感分析") | 443 | logger.info(f" → 执行独立情感分析") |
| 321 | - | 444 | + |
| 322 | try: | 445 | try: |
| 323 | # 初始化情感分析器(如果尚未初始化且未被禁用) | 446 | # 初始化情感分析器(如果尚未初始化且未被禁用) |
| 324 | - if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled: | 447 | + if ( |
| 448 | + not self.sentiment_analyzer.is_initialized | ||
| 449 | + and not self.sentiment_analyzer.is_disabled | ||
| 450 | + ): | ||
| 325 | logger.info(" 初始化情感分析模型...") | 451 | logger.info(" 初始化情感分析模型...") |
| 326 | if not self.sentiment_analyzer.initialize(): | 452 | if not self.sentiment_analyzer.initialize(): |
| 327 | logger.info(" 情感分析模型初始化失败,将直接透传原始文本") | 453 | logger.info(" 情感分析模型初始化失败,将直接透传原始文本") |
| 328 | elif self.sentiment_analyzer.is_disabled: | 454 | elif self.sentiment_analyzer.is_disabled: |
| 329 | logger.warning(" 情感分析功能已禁用,直接透传原始文本") | 455 | logger.warning(" 情感分析功能已禁用,直接透传原始文本") |
| 330 | - | 456 | + |
| 331 | # 执行分析 | 457 | # 执行分析 |
| 332 | if isinstance(texts, str): | 458 | if isinstance(texts, str): |
| 333 | result = self.sentiment_analyzer.analyze_single_text(texts) | 459 | result = self.sentiment_analyzer.analyze_single_text(texts) |
| 334 | result_dict = result.__dict__ | 460 | result_dict = result.__dict__ |
| 335 | response = { | 461 | response = { |
| 336 | "success": result.success and result.analysis_performed, | 462 | "success": result.success and result.analysis_performed, |
| 337 | - "total_analyzed": 1 if result.analysis_performed and result.success else 0, | ||
| 338 | - "results": [result_dict] | 463 | + "total_analyzed": 1 |
| 464 | + if result.analysis_performed and result.success | ||
| 465 | + else 0, | ||
| 466 | + "results": [result_dict], | ||
| 339 | } | 467 | } |
| 340 | if not result.analysis_performed: | 468 | if not result.analysis_performed: |
| 341 | response["success"] = False | 469 | response["success"] = False |
| 342 | - response["warning"] = result.error_message or "情感分析功能不可用,已直接返回原始文本" | 470 | + response["warning"] = ( |
| 471 | + result.error_message or "情感分析功能不可用,已直接返回原始文本" | ||
| 472 | + ) | ||
| 343 | return response | 473 | return response |
| 344 | else: | 474 | else: |
| 345 | texts_list = list(texts) | 475 | texts_list = list(texts) |
| 346 | - batch_result = self.sentiment_analyzer.analyze_batch(texts_list, show_progress=True) | 476 | + batch_result = self.sentiment_analyzer.analyze_batch( |
| 477 | + texts_list, show_progress=True | ||
| 478 | + ) | ||
| 347 | response = { | 479 | response = { |
| 348 | - "success": batch_result.analysis_performed and batch_result.success_count > 0, | ||
| 349 | - "total_analyzed": batch_result.total_processed if batch_result.analysis_performed else 0, | 480 | + "success": batch_result.analysis_performed |
| 481 | + and batch_result.success_count > 0, | ||
| 482 | + "total_analyzed": batch_result.total_processed | ||
| 483 | + if batch_result.analysis_performed | ||
| 484 | + else 0, | ||
| 350 | "success_count": batch_result.success_count, | 485 | "success_count": batch_result.success_count, |
| 351 | "failed_count": batch_result.failed_count, | 486 | "failed_count": batch_result.failed_count, |
| 352 | - "average_confidence": batch_result.average_confidence if batch_result.analysis_performed else 0.0, | ||
| 353 | - "results": [result.__dict__ for result in batch_result.results] | 487 | + "average_confidence": batch_result.average_confidence |
| 488 | + if batch_result.analysis_performed | ||
| 489 | + else 0.0, | ||
| 490 | + "results": [result.__dict__ for result in batch_result.results], | ||
| 354 | } | 491 | } |
| 355 | if not batch_result.analysis_performed: | 492 | if not batch_result.analysis_performed: |
| 356 | warning = next( | 493 | warning = next( |
| 357 | - (r.error_message for r in batch_result.results if r.error_message), | ||
| 358 | - "情感分析功能不可用,已直接返回原始文本" | 494 | + ( |
| 495 | + r.error_message | ||
| 496 | + for r in batch_result.results | ||
| 497 | + if r.error_message | ||
| 498 | + ), | ||
| 499 | + "情感分析功能不可用,已直接返回原始文本", | ||
| 359 | ) | 500 | ) |
| 360 | response["success"] = False | 501 | response["success"] = False |
| 361 | response["warning"] = warning | 502 | response["warning"] = warning |
| 362 | return response | 503 | return response |
| 363 | - | 504 | + |
| 364 | except Exception as e: | 505 | except Exception as e: |
| 365 | logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}") | 506 | logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}") |
| 366 | - return { | ||
| 367 | - "success": False, | ||
| 368 | - "error": str(e), | ||
| 369 | - "results": [] | ||
| 370 | - } | ||
| 371 | - | 507 | + return {"success": False, "error": str(e), "results": []} |
| 508 | + | ||
| 372 | def research(self, query: str, save_report: bool = True) -> str: | 509 | def research(self, query: str, save_report: bool = True) -> str: |
| 373 | """ | 510 | """ |
| 374 | 执行深度研究 | 511 | 执行深度研究 |
| 375 | - | 512 | + |
| 376 | Args: | 513 | Args: |
| 377 | query: 研究查询 | 514 | query: 研究查询 |
| 378 | save_report: 是否保存报告到文件 | 515 | save_report: 是否保存报告到文件 |
| 379 | - | 516 | + |
| 380 | Returns: | 517 | Returns: |
| 381 | 最终报告内容 | 518 | 最终报告内容 |
| 382 | """ | 519 | """ |
| 383 | - logger.info(f"\n{'='*60}") | 520 | + logger.info(f"\n{'=' * 60}") |
| 384 | logger.info(f"开始深度研究: {query}") | 521 | logger.info(f"开始深度研究: {query}") |
| 385 | - logger.info(f"{'='*60}") | ||
| 386 | - | 522 | + logger.info(f"{'=' * 60}") |
| 523 | + | ||
| 387 | try: | 524 | try: |
| 388 | # Step 1: 生成报告结构 | 525 | # Step 1: 生成报告结构 |
| 389 | self._generate_report_structure(query) | 526 | self._generate_report_structure(query) |
| 390 | - | 527 | + |
| 391 | # Step 2: 处理每个段落 | 528 | # Step 2: 处理每个段落 |
| 392 | self._process_paragraphs() | 529 | self._process_paragraphs() |
| 393 | - | 530 | + |
| 394 | # Step 3: 生成最终报告 | 531 | # Step 3: 生成最终报告 |
| 395 | final_report = self._generate_final_report() | 532 | final_report = self._generate_final_report() |
| 396 | - | 533 | + |
| 397 | # Step 4: 保存报告 | 534 | # Step 4: 保存报告 |
| 398 | if save_report: | 535 | if save_report: |
| 399 | self._save_report(final_report) | 536 | self._save_report(final_report) |
| 400 | 537 | ||
| 401 | logger.info("深度研究完成!") | 538 | logger.info("深度研究完成!") |
| 402 | - | 539 | + |
| 403 | return final_report | 540 | return final_report |
| 404 | - | 541 | + |
| 405 | except Exception as e: | 542 | except Exception as e: |
| 406 | logger.exception(f"研究过程中发生错误: {str(e)}") | 543 | logger.exception(f"研究过程中发生错误: {str(e)}") |
| 407 | raise e | 544 | raise e |
| 408 | - | 545 | + |
| 409 | def _generate_report_structure(self, query: str): | 546 | def _generate_report_structure(self, query: str): |
| 410 | """生成报告结构""" | 547 | """生成报告结构""" |
| 411 | logger.info(f"\n[步骤 1] 生成报告结构...") | 548 | logger.info(f"\n[步骤 1] 生成报告结构...") |
| 412 | - | 549 | + |
| 413 | # 创建报告结构节点 | 550 | # 创建报告结构节点 |
| 414 | report_structure_node = ReportStructureNode(self.llm_client, query) | 551 | report_structure_node = ReportStructureNode(self.llm_client, query) |
| 415 | - | 552 | + |
| 416 | # 生成结构并更新状态 | 553 | # 生成结构并更新状态 |
| 417 | self.state = report_structure_node.mutate_state(state=self.state) | 554 | self.state = report_structure_node.mutate_state(state=self.state) |
| 418 | - | 555 | + |
| 419 | _message = f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:" | 556 | _message = f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:" |
| 420 | for i, paragraph in enumerate(self.state.paragraphs, 1): | 557 | for i, paragraph in enumerate(self.state.paragraphs, 1): |
| 421 | _message += f"\n {i}. {paragraph.title}" | 558 | _message += f"\n {i}. {paragraph.title}" |
| 422 | logger.info(_message) | 559 | logger.info(_message) |
| 423 | - | 560 | + |
| 424 | def _process_paragraphs(self): | 561 | def _process_paragraphs(self): |
| 425 | """处理所有段落""" | 562 | """处理所有段落""" |
| 426 | total_paragraphs = len(self.state.paragraphs) | 563 | total_paragraphs = len(self.state.paragraphs) |
| 427 | - | 564 | + |
| 428 | for i in range(total_paragraphs): | 565 | for i in range(total_paragraphs): |
| 429 | - logger.info(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}") | 566 | + logger.info( |
| 567 | + f"\n[步骤 2.{i + 1}] 处理段落: {self.state.paragraphs[i].title}" | ||
| 568 | + ) | ||
| 430 | logger.info("-" * 50) | 569 | logger.info("-" * 50) |
| 431 | - | 570 | + |
| 432 | # 初始搜索和总结 | 571 | # 初始搜索和总结 |
| 433 | self._initial_search_and_summary(i) | 572 | self._initial_search_and_summary(i) |
| 434 | - | 573 | + |
| 435 | # 反思循环 | 574 | # 反思循环 |
| 436 | self._reflection_loop(i) | 575 | self._reflection_loop(i) |
| 437 | - | 576 | + |
| 438 | # 标记段落完成 | 577 | # 标记段落完成 |
| 439 | self.state.paragraphs[i].research.mark_completed() | 578 | self.state.paragraphs[i].research.mark_completed() |
| 440 | - | 579 | + |
| 441 | progress = (i + 1) / total_paragraphs * 100 | 580 | progress = (i + 1) / total_paragraphs * 100 |
| 442 | logger.info(f"段落处理完成 ({progress:.1f}%)") | 581 | logger.info(f"段落处理完成 ({progress:.1f}%)") |
| 443 | - | 582 | + |
| 444 | def _initial_search_and_summary(self, paragraph_index: int): | 583 | def _initial_search_and_summary(self, paragraph_index: int): |
| 445 | """执行初始搜索和总结""" | 584 | """执行初始搜索和总结""" |
| 446 | paragraph = self.state.paragraphs[paragraph_index] | 585 | paragraph = self.state.paragraphs[paragraph_index] |
| 447 | - | 586 | + |
| 448 | # 准备搜索输入 | 587 | # 准备搜索输入 |
| 449 | - search_input = { | ||
| 450 | - "title": paragraph.title, | ||
| 451 | - "content": paragraph.content | ||
| 452 | - } | ||
| 453 | - | 588 | + search_input = {"title": paragraph.title, "content": paragraph.content} |
| 589 | + | ||
| 454 | # 生成搜索查询和工具选择 | 590 | # 生成搜索查询和工具选择 |
| 455 | logger.info(" - 生成搜索查询...") | 591 | logger.info(" - 生成搜索查询...") |
| 456 | search_output = self.first_search_node.run(search_input) | 592 | search_output = self.first_search_node.run(search_input) |
| 457 | search_query = search_output["search_query"] | 593 | search_query = search_output["search_query"] |
| 458 | - search_tool = search_output.get("search_tool", "search_topic_globally") # 默认工具 | 594 | + search_tool = search_output.get( |
| 595 | + "search_tool", "search_topic_globally" | ||
| 596 | + ) # 默认工具 | ||
| 459 | reasoning = search_output["reasoning"] | 597 | reasoning = search_output["reasoning"] |
| 460 | - | 598 | + |
| 461 | logger.info(f" - 搜索查询: {search_query}") | 599 | logger.info(f" - 搜索查询: {search_query}") |
| 462 | logger.info(f" - 选择的工具: {search_tool}") | 600 | logger.info(f" - 选择的工具: {search_tool}") |
| 463 | logger.info(f" - 推理: {reasoning}") | 601 | logger.info(f" - 推理: {reasoning}") |
| 464 | - | 602 | + |
| 465 | # 执行搜索 | 603 | # 执行搜索 |
| 466 | logger.info(" - 执行数据库查询...") | 604 | logger.info(" - 执行数据库查询...") |
| 467 | - | 605 | + |
| 468 | # 处理特殊参数 | 606 | # 处理特殊参数 |
| 469 | search_kwargs = {} | 607 | search_kwargs = {} |
| 470 | - | 608 | + |
| 471 | # 处理需要日期的工具 | 609 | # 处理需要日期的工具 |
| 472 | if search_tool in ["search_topic_by_date", "search_topic_on_platform"]: | 610 | if search_tool in ["search_topic_by_date", "search_topic_on_platform"]: |
| 473 | start_date = search_output.get("start_date") | 611 | start_date = search_output.get("start_date") |
| 474 | end_date = search_output.get("end_date") | 612 | end_date = search_output.get("end_date") |
| 475 | - | 613 | + |
| 476 | if start_date and end_date: | 614 | if start_date and end_date: |
| 477 | # 验证日期格式 | 615 | # 验证日期格式 |
| 478 | - if self._validate_date_format(start_date) and self._validate_date_format(end_date): | 616 | + if self._validate_date_format( |
| 617 | + start_date | ||
| 618 | + ) and self._validate_date_format(end_date): | ||
| 479 | search_kwargs["start_date"] = start_date | 619 | search_kwargs["start_date"] = start_date |
| 480 | search_kwargs["end_date"] = end_date | 620 | search_kwargs["end_date"] = end_date |
| 481 | logger.info(f" - 时间范围: {start_date} 到 {end_date}") | 621 | logger.info(f" - 时间范围: {start_date} 到 {end_date}") |
| 482 | else: | 622 | else: |
| 483 | logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索") | 623 | logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索") |
| 484 | - logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}") | 624 | + logger.info( |
| 625 | + f" 提供的日期: start_date={start_date}, end_date={end_date}" | ||
| 626 | + ) | ||
| 485 | search_tool = "search_topic_globally" | 627 | search_tool = "search_topic_globally" |
| 486 | elif search_tool == "search_topic_by_date": | 628 | elif search_tool == "search_topic_by_date": |
| 487 | logger.info(f" search_topic_by_date工具缺少时间参数,改用全局搜索") | 629 | logger.info(f" search_topic_by_date工具缺少时间参数,改用全局搜索") |
| 488 | search_tool = "search_topic_globally" | 630 | search_tool = "search_topic_globally" |
| 489 | - | 631 | + |
| 490 | # 处理需要平台参数的工具 | 632 | # 处理需要平台参数的工具 |
| 491 | if search_tool == "search_topic_on_platform": | 633 | if search_tool == "search_topic_on_platform": |
| 492 | platform = search_output.get("platform") | 634 | platform = search_output.get("platform") |
| @@ -494,9 +636,11 @@ class DeepSearchAgent: | @@ -494,9 +636,11 @@ class DeepSearchAgent: | ||
| 494 | search_kwargs["platform"] = platform | 636 | search_kwargs["platform"] = platform |
| 495 | logger.info(f" - 指定平台: {platform}") | 637 | logger.info(f" - 指定平台: {platform}") |
| 496 | else: | 638 | else: |
| 497 | - logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索") | 639 | + logger.warning( |
| 640 | + f" search_topic_on_platform工具缺少平台参数,改用全局搜索" | ||
| 641 | + ) | ||
| 498 | search_tool = "search_topic_globally" | 642 | search_tool = "search_topic_globally" |
| 499 | - | 643 | + |
| 500 | # 处理限制参数,使用配置文件中的默认值而不是agent提供的参数 | 644 | # 处理限制参数,使用配置文件中的默认值而不是agent提供的参数 |
| 501 | if search_tool == "search_hot_content": | 645 | if search_tool == "search_hot_content": |
| 502 | time_period = search_output.get("time_period", "week") | 646 | time_period = search_output.get("time_period", "week") |
| @@ -505,9 +649,13 @@ class DeepSearchAgent: | @@ -505,9 +649,13 @@ class DeepSearchAgent: | ||
| 505 | search_kwargs["limit"] = limit | 649 | search_kwargs["limit"] = limit |
| 506 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: | 650 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: |
| 507 | if search_tool == "search_topic_globally": | 651 | if search_tool == "search_topic_globally": |
| 508 | - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE | 652 | + limit_per_table = ( |
| 653 | + self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE | ||
| 654 | + ) | ||
| 509 | else: # search_topic_by_date | 655 | else: # search_topic_by_date |
| 510 | - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE | 656 | + limit_per_table = ( |
| 657 | + self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE | ||
| 658 | + ) | ||
| 511 | search_kwargs["limit_per_table"] = limit_per_table | 659 | search_kwargs["limit_per_table"] = limit_per_table |
| 512 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: | 660 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: |
| 513 | if search_tool == "get_comments_for_topic": | 661 | if search_tool == "get_comments_for_topic": |
| @@ -515,43 +663,55 @@ class DeepSearchAgent: | @@ -515,43 +663,55 @@ class DeepSearchAgent: | ||
| 515 | else: # search_topic_on_platform | 663 | else: # search_topic_on_platform |
| 516 | limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT | 664 | limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT |
| 517 | search_kwargs["limit"] = limit | 665 | search_kwargs["limit"] = limit |
| 518 | - | ||
| 519 | - search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) | ||
| 520 | - | 666 | + |
| 667 | + search_response = self.execute_search_tool( | ||
| 668 | + search_tool, search_query, **search_kwargs | ||
| 669 | + ) | ||
| 670 | + | ||
| 521 | # 转换为兼容格式 | 671 | # 转换为兼容格式 |
| 522 | search_results = [] | 672 | search_results = [] |
| 523 | if search_response and search_response.results: | 673 | if search_response and search_response.results: |
| 524 | # 使用配置文件控制传递给LLM的结果数量,0表示不限制 | 674 | # 使用配置文件控制传递给LLM的结果数量,0表示不限制 |
| 525 | if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: | 675 | if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: |
| 526 | - max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM) | 676 | + max_results = min( |
| 677 | + len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM | ||
| 678 | + ) | ||
| 527 | else: | 679 | else: |
| 528 | max_results = len(search_response.results) # 不限制,传递所有结果 | 680 | max_results = len(search_response.results) # 不限制,传递所有结果 |
| 529 | for result in search_response.results[:max_results]: | 681 | for result in search_response.results[:max_results]: |
| 530 | - search_results.append({ | ||
| 531 | - 'title': result.title_or_content, | ||
| 532 | - 'url': result.url or "", | ||
| 533 | - 'content': result.title_or_content, | ||
| 534 | - 'score': result.hotness_score, | ||
| 535 | - 'raw_content': result.title_or_content, | ||
| 536 | - 'published_date': result.publish_time.isoformat() if result.publish_time else None, | ||
| 537 | - 'platform': result.platform, | ||
| 538 | - 'content_type': result.content_type, | ||
| 539 | - 'author': result.author_nickname, | ||
| 540 | - 'engagement': result.engagement | ||
| 541 | - }) | ||
| 542 | - | 682 | + search_results.append( |
| 683 | + { | ||
| 684 | + "title": result.title_or_content, | ||
| 685 | + "url": result.url or "", | ||
| 686 | + "content": result.title_or_content, | ||
| 687 | + "score": result.hotness_score, | ||
| 688 | + "raw_content": result.title_or_content, | ||
| 689 | + "published_date": result.publish_time.isoformat() | ||
| 690 | + if result.publish_time | ||
| 691 | + else None, | ||
| 692 | + "platform": result.platform, | ||
| 693 | + "content_type": result.content_type, | ||
| 694 | + "author": result.author_nickname, | ||
| 695 | + "engagement": result.engagement, | ||
| 696 | + } | ||
| 697 | + ) | ||
| 698 | + | ||
| 543 | if search_results: | 699 | if search_results: |
| 544 | _message = f" - 找到 {len(search_results)} 个搜索结果" | 700 | _message = f" - 找到 {len(search_results)} 个搜索结果" |
| 545 | for j, result in enumerate(search_results, 1): | 701 | for j, result in enumerate(search_results, 1): |
| 546 | - date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" | 702 | + date_info = ( |
| 703 | + f" (发布于: {result.get('published_date', 'N/A')})" | ||
| 704 | + if result.get("published_date") | ||
| 705 | + else "" | ||
| 706 | + ) | ||
| 547 | _message += f"\n {j}. {result['title'][:50]}...{date_info}" | 707 | _message += f"\n {j}. {result['title'][:50]}...{date_info}" |
| 548 | logger.info(_message) | 708 | logger.info(_message) |
| 549 | else: | 709 | else: |
| 550 | logger.info(" - 未找到搜索结果") | 710 | logger.info(" - 未找到搜索结果") |
| 551 | - | 711 | + |
| 552 | # 更新状态中的搜索历史 | 712 | # 更新状态中的搜索历史 |
| 553 | paragraph.research.add_search_results(search_query, search_results) | 713 | paragraph.research.add_search_results(search_query, search_results) |
| 554 | - | 714 | + |
| 555 | # 生成初始总结 | 715 | # 生成初始总结 |
| 556 | logger.info(" - 生成初始总结...") | 716 | logger.info(" - 生成初始总结...") |
| 557 | summary_input = { | 717 | summary_input = { |
| @@ -560,63 +720,73 @@ class DeepSearchAgent: | @@ -560,63 +720,73 @@ class DeepSearchAgent: | ||
| 560 | "search_query": search_query, | 720 | "search_query": search_query, |
| 561 | "search_results": format_search_results_for_prompt( | 721 | "search_results": format_search_results_for_prompt( |
| 562 | search_results, self.config.MAX_CONTENT_LENGTH | 722 | search_results, self.config.MAX_CONTENT_LENGTH |
| 563 | - ) | 723 | + ), |
| 564 | } | 724 | } |
| 565 | - | 725 | + |
| 566 | # 更新状态 | 726 | # 更新状态 |
| 567 | self.state = self.first_summary_node.mutate_state( | 727 | self.state = self.first_summary_node.mutate_state( |
| 568 | summary_input, self.state, paragraph_index | 728 | summary_input, self.state, paragraph_index |
| 569 | ) | 729 | ) |
| 570 | - | 730 | + |
| 571 | logger.info(" - 初始总结完成") | 731 | logger.info(" - 初始总结完成") |
| 572 | - | 732 | + |
| 573 | def _reflection_loop(self, paragraph_index: int): | 733 | def _reflection_loop(self, paragraph_index: int): |
| 574 | """执行反思循环""" | 734 | """执行反思循环""" |
| 575 | paragraph = self.state.paragraphs[paragraph_index] | 735 | paragraph = self.state.paragraphs[paragraph_index] |
| 576 | - | 736 | + |
| 577 | for reflection_i in range(self.config.MAX_REFLECTIONS): | 737 | for reflection_i in range(self.config.MAX_REFLECTIONS): |
| 578 | logger.info(f" - 反思 {reflection_i + 1}/{self.config.MAX_REFLECTIONS}...") | 738 | logger.info(f" - 反思 {reflection_i + 1}/{self.config.MAX_REFLECTIONS}...") |
| 579 | - | 739 | + |
| 580 | # 准备反思输入 | 740 | # 准备反思输入 |
| 581 | reflection_input = { | 741 | reflection_input = { |
| 582 | "title": paragraph.title, | 742 | "title": paragraph.title, |
| 583 | "content": paragraph.content, | 743 | "content": paragraph.content, |
| 584 | - "paragraph_latest_state": paragraph.research.latest_summary | 744 | + "paragraph_latest_state": paragraph.research.latest_summary, |
| 585 | } | 745 | } |
| 586 | - | 746 | + |
| 587 | # 生成反思搜索查询 | 747 | # 生成反思搜索查询 |
| 588 | reflection_output = self.reflection_node.run(reflection_input) | 748 | reflection_output = self.reflection_node.run(reflection_input) |
| 589 | search_query = reflection_output["search_query"] | 749 | search_query = reflection_output["search_query"] |
| 590 | - search_tool = reflection_output.get("search_tool", "search_topic_globally") # 默认工具 | 750 | + search_tool = reflection_output.get( |
| 751 | + "search_tool", "search_topic_globally" | ||
| 752 | + ) # 默认工具 | ||
| 591 | reasoning = reflection_output["reasoning"] | 753 | reasoning = reflection_output["reasoning"] |
| 592 | - | 754 | + |
| 593 | logger.info(f" 反思查询: {search_query}") | 755 | logger.info(f" 反思查询: {search_query}") |
| 594 | logger.info(f" 选择的工具: {search_tool}") | 756 | logger.info(f" 选择的工具: {search_tool}") |
| 595 | logger.info(f" 反思推理: {reasoning}") | 757 | logger.info(f" 反思推理: {reasoning}") |
| 596 | - | 758 | + |
| 597 | # 执行反思搜索 | 759 | # 执行反思搜索 |
| 598 | # 处理特殊参数 | 760 | # 处理特殊参数 |
| 599 | search_kwargs = {} | 761 | search_kwargs = {} |
| 600 | - | 762 | + |
| 601 | # 处理需要日期的工具 | 763 | # 处理需要日期的工具 |
| 602 | if search_tool in ["search_topic_by_date", "search_topic_on_platform"]: | 764 | if search_tool in ["search_topic_by_date", "search_topic_on_platform"]: |
| 603 | start_date = reflection_output.get("start_date") | 765 | start_date = reflection_output.get("start_date") |
| 604 | end_date = reflection_output.get("end_date") | 766 | end_date = reflection_output.get("end_date") |
| 605 | - | 767 | + |
| 606 | if start_date and end_date: | 768 | if start_date and end_date: |
| 607 | # 验证日期格式 | 769 | # 验证日期格式 |
| 608 | - if self._validate_date_format(start_date) and self._validate_date_format(end_date): | 770 | + if self._validate_date_format( |
| 771 | + start_date | ||
| 772 | + ) and self._validate_date_format(end_date): | ||
| 609 | search_kwargs["start_date"] = start_date | 773 | search_kwargs["start_date"] = start_date |
| 610 | search_kwargs["end_date"] = end_date | 774 | search_kwargs["end_date"] = end_date |
| 611 | logger.info(f" 时间范围: {start_date} 到 {end_date}") | 775 | logger.info(f" 时间范围: {start_date} 到 {end_date}") |
| 612 | else: | 776 | else: |
| 613 | - logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索") | ||
| 614 | - logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}") | 777 | + logger.info( |
| 778 | + f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索" | ||
| 779 | + ) | ||
| 780 | + logger.info( | ||
| 781 | + f" 提供的日期: start_date={start_date}, end_date={end_date}" | ||
| 782 | + ) | ||
| 615 | search_tool = "search_topic_globally" | 783 | search_tool = "search_topic_globally" |
| 616 | elif search_tool == "search_topic_by_date": | 784 | elif search_tool == "search_topic_by_date": |
| 617 | - logger.warning(f" search_topic_by_date工具缺少时间参数,改用全局搜索") | 785 | + logger.warning( |
| 786 | + f" search_topic_by_date工具缺少时间参数,改用全局搜索" | ||
| 787 | + ) | ||
| 618 | search_tool = "search_topic_globally" | 788 | search_tool = "search_topic_globally" |
| 619 | - | 789 | + |
| 620 | # 处理需要平台参数的工具 | 790 | # 处理需要平台参数的工具 |
| 621 | if search_tool == "search_topic_on_platform": | 791 | if search_tool == "search_topic_on_platform": |
| 622 | platform = reflection_output.get("platform") | 792 | platform = reflection_output.get("platform") |
| @@ -624,9 +794,11 @@ class DeepSearchAgent: | @@ -624,9 +794,11 @@ class DeepSearchAgent: | ||
| 624 | search_kwargs["platform"] = platform | 794 | search_kwargs["platform"] = platform |
| 625 | logger.info(f" 指定平台: {platform}") | 795 | logger.info(f" 指定平台: {platform}") |
| 626 | else: | 796 | else: |
| 627 | - logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索") | 797 | + logger.warning( |
| 798 | + f" search_topic_on_platform工具缺少平台参数,改用全局搜索" | ||
| 799 | + ) | ||
| 628 | search_tool = "search_topic_globally" | 800 | search_tool = "search_topic_globally" |
| 629 | - | 801 | + |
| 630 | # 处理限制参数 | 802 | # 处理限制参数 |
| 631 | if search_tool == "search_hot_content": | 803 | if search_tool == "search_hot_content": |
| 632 | time_period = reflection_output.get("time_period", "week") | 804 | time_period = reflection_output.get("time_period", "week") |
| @@ -637,9 +809,13 @@ class DeepSearchAgent: | @@ -637,9 +809,13 @@ class DeepSearchAgent: | ||
| 637 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: | 809 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: |
| 638 | # 使用配置文件中的默认值,不允许agent控制limit_per_table参数 | 810 | # 使用配置文件中的默认值,不允许agent控制limit_per_table参数 |
| 639 | if search_tool == "search_topic_globally": | 811 | if search_tool == "search_topic_globally": |
| 640 | - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE | 812 | + limit_per_table = ( |
| 813 | + self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE | ||
| 814 | + ) | ||
| 641 | else: # search_topic_by_date | 815 | else: # search_topic_by_date |
| 642 | - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE | 816 | + limit_per_table = ( |
| 817 | + self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE | ||
| 818 | + ) | ||
| 643 | search_kwargs["limit_per_table"] = limit_per_table | 819 | search_kwargs["limit_per_table"] = limit_per_table |
| 644 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: | 820 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: |
| 645 | # 使用配置文件中的默认值,不允许agent控制limit参数 | 821 | # 使用配置文件中的默认值,不允许agent控制limit参数 |
| @@ -648,43 +824,56 @@ class DeepSearchAgent: | @@ -648,43 +824,56 @@ class DeepSearchAgent: | ||
| 648 | else: # search_topic_on_platform | 824 | else: # search_topic_on_platform |
| 649 | limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT | 825 | limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT |
| 650 | search_kwargs["limit"] = limit | 826 | search_kwargs["limit"] = limit |
| 651 | - | ||
| 652 | - search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) | ||
| 653 | - | 827 | + |
| 828 | + search_response = self.execute_search_tool( | ||
| 829 | + search_tool, search_query, **search_kwargs | ||
| 830 | + ) | ||
| 831 | + | ||
| 654 | # 转换为兼容格式 | 832 | # 转换为兼容格式 |
| 655 | search_results = [] | 833 | search_results = [] |
| 656 | if search_response and search_response.results: | 834 | if search_response and search_response.results: |
| 657 | # 使用配置文件控制传递给LLM的结果数量,0表示不限制 | 835 | # 使用配置文件控制传递给LLM的结果数量,0表示不限制 |
| 658 | if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: | 836 | if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: |
| 659 | - max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM) | 837 | + max_results = min( |
| 838 | + len(search_response.results), | ||
| 839 | + self.config.MAX_SEARCH_RESULTS_FOR_LLM, | ||
| 840 | + ) | ||
| 660 | else: | 841 | else: |
| 661 | max_results = len(search_response.results) # 不限制,传递所有结果 | 842 | max_results = len(search_response.results) # 不限制,传递所有结果 |
| 662 | for result in search_response.results[:max_results]: | 843 | for result in search_response.results[:max_results]: |
| 663 | - search_results.append({ | ||
| 664 | - 'title': result.title_or_content, | ||
| 665 | - 'url': result.url or "", | ||
| 666 | - 'content': result.title_or_content, | ||
| 667 | - 'score': result.hotness_score, | ||
| 668 | - 'raw_content': result.title_or_content, | ||
| 669 | - 'published_date': result.publish_time.isoformat() if result.publish_time else None, | ||
| 670 | - 'platform': result.platform, | ||
| 671 | - 'content_type': result.content_type, | ||
| 672 | - 'author': result.author_nickname, | ||
| 673 | - 'engagement': result.engagement | ||
| 674 | - }) | ||
| 675 | - | 844 | + search_results.append( |
| 845 | + { | ||
| 846 | + "title": result.title_or_content, | ||
| 847 | + "url": result.url or "", | ||
| 848 | + "content": result.title_or_content, | ||
| 849 | + "score": result.hotness_score, | ||
| 850 | + "raw_content": result.title_or_content, | ||
| 851 | + "published_date": result.publish_time.isoformat() | ||
| 852 | + if result.publish_time | ||
| 853 | + else None, | ||
| 854 | + "platform": result.platform, | ||
| 855 | + "content_type": result.content_type, | ||
| 856 | + "author": result.author_nickname, | ||
| 857 | + "engagement": result.engagement, | ||
| 858 | + } | ||
| 859 | + ) | ||
| 860 | + | ||
| 676 | if search_results: | 861 | if search_results: |
| 677 | _message = f" 找到 {len(search_results)} 个反思搜索结果" | 862 | _message = f" 找到 {len(search_results)} 个反思搜索结果" |
| 678 | for j, result in enumerate(search_results, 1): | 863 | for j, result in enumerate(search_results, 1): |
| 679 | - date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" | 864 | + date_info = ( |
| 865 | + f" (发布于: {result.get('published_date', 'N/A')})" | ||
| 866 | + if result.get("published_date") | ||
| 867 | + else "" | ||
| 868 | + ) | ||
| 680 | _message += f"\n {j}. {result['title'][:50]}...{date_info}" | 869 | _message += f"\n {j}. {result['title'][:50]}...{date_info}" |
| 681 | logger.info(_message) | 870 | logger.info(_message) |
| 682 | else: | 871 | else: |
| 683 | logger.info(" 未找到反思搜索结果") | 872 | logger.info(" 未找到反思搜索结果") |
| 684 | - | 873 | + |
| 685 | # 更新搜索历史 | 874 | # 更新搜索历史 |
| 686 | paragraph.research.add_search_results(search_query, search_results) | 875 | paragraph.research.add_search_results(search_query, search_results) |
| 687 | - | 876 | + |
| 688 | # 生成反思总结 | 877 | # 生成反思总结 |
| 689 | reflection_summary_input = { | 878 | reflection_summary_input = { |
| 690 | "title": paragraph.title, | 879 | "title": paragraph.title, |
| @@ -693,28 +882,30 @@ class DeepSearchAgent: | @@ -693,28 +882,30 @@ class DeepSearchAgent: | ||
| 693 | "search_results": format_search_results_for_prompt( | 882 | "search_results": format_search_results_for_prompt( |
| 694 | search_results, self.config.MAX_CONTENT_LENGTH | 883 | search_results, self.config.MAX_CONTENT_LENGTH |
| 695 | ), | 884 | ), |
| 696 | - "paragraph_latest_state": paragraph.research.latest_summary | 885 | + "paragraph_latest_state": paragraph.research.latest_summary, |
| 697 | } | 886 | } |
| 698 | - | 887 | + |
| 699 | # 更新状态 | 888 | # 更新状态 |
| 700 | self.state = self.reflection_summary_node.mutate_state( | 889 | self.state = self.reflection_summary_node.mutate_state( |
| 701 | reflection_summary_input, self.state, paragraph_index | 890 | reflection_summary_input, self.state, paragraph_index |
| 702 | ) | 891 | ) |
| 703 | - | 892 | + |
| 704 | logger.info(f" 反思 {reflection_i + 1} 完成") | 893 | logger.info(f" 反思 {reflection_i + 1} 完成") |
| 705 | - | 894 | + |
| 706 | def _generate_final_report(self) -> str: | 895 | def _generate_final_report(self) -> str: |
| 707 | """生成最终报告""" | 896 | """生成最终报告""" |
| 708 | logger.info(f"\n[步骤 3] 生成最终报告...") | 897 | logger.info(f"\n[步骤 3] 生成最终报告...") |
| 709 | - | 898 | + |
| 710 | # 准备报告数据 | 899 | # 准备报告数据 |
| 711 | report_data = [] | 900 | report_data = [] |
| 712 | for paragraph in self.state.paragraphs: | 901 | for paragraph in self.state.paragraphs: |
| 713 | - report_data.append({ | ||
| 714 | - "title": paragraph.title, | ||
| 715 | - "paragraph_latest_state": paragraph.research.latest_summary | ||
| 716 | - }) | ||
| 717 | - | 902 | + report_data.append( |
| 903 | + { | ||
| 904 | + "title": paragraph.title, | ||
| 905 | + "paragraph_latest_state": paragraph.research.latest_summary, | ||
| 906 | + } | ||
| 907 | + ) | ||
| 908 | + | ||
| 718 | # 格式化报告 | 909 | # 格式化报告 |
| 719 | try: | 910 | try: |
| 720 | final_report = self.report_formatting_node.run(report_data) | 911 | final_report = self.report_formatting_node.run(report_data) |
| @@ -723,46 +914,48 @@ class DeepSearchAgent: | @@ -723,46 +914,48 @@ class DeepSearchAgent: | ||
| 723 | final_report = self.report_formatting_node.format_report_manually( | 914 | final_report = self.report_formatting_node.format_report_manually( |
| 724 | report_data, self.state.report_title | 915 | report_data, self.state.report_title |
| 725 | ) | 916 | ) |
| 726 | - | 917 | + |
| 727 | # 更新状态 | 918 | # 更新状态 |
| 728 | self.state.final_report = final_report | 919 | self.state.final_report = final_report |
| 729 | self.state.mark_completed() | 920 | self.state.mark_completed() |
| 730 | - | 921 | + |
| 731 | logger.info("最终报告生成完成") | 922 | logger.info("最终报告生成完成") |
| 732 | return final_report | 923 | return final_report |
| 733 | - | 924 | + |
| 734 | def _save_report(self, report_content: str): | 925 | def _save_report(self, report_content: str): |
| 735 | """保存报告到文件""" | 926 | """保存报告到文件""" |
| 736 | # 生成文件名 | 927 | # 生成文件名 |
| 737 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | 928 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| 738 | - query_safe = "".join(c for c in self.state.query if c.isalnum() or c in (' ', '-', '_')).rstrip() | ||
| 739 | - query_safe = query_safe.replace(' ', '_')[:30] | ||
| 740 | - | 929 | + query_safe = "".join( |
| 930 | + c for c in self.state.query if c.isalnum() or c in (" ", "-", "_") | ||
| 931 | + ).rstrip() | ||
| 932 | + query_safe = query_safe.replace(" ", "_")[:30] | ||
| 933 | + | ||
| 741 | filename = f"deep_search_report_{query_safe}_{timestamp}.md" | 934 | filename = f"deep_search_report_{query_safe}_{timestamp}.md" |
| 742 | filepath = os.path.join(self.config.OUTPUT_DIR, filename) | 935 | filepath = os.path.join(self.config.OUTPUT_DIR, filename) |
| 743 | - | 936 | + |
| 744 | # 保存报告 | 937 | # 保存报告 |
| 745 | - with open(filepath, 'w', encoding='utf-8') as f: | 938 | + with open(filepath, "w", encoding="utf-8") as f: |
| 746 | f.write(report_content) | 939 | f.write(report_content) |
| 747 | - | 940 | + |
| 748 | logger.info(f"报告已保存到: {filepath}") | 941 | logger.info(f"报告已保存到: {filepath}") |
| 749 | - | 942 | + |
| 750 | # 保存状态(如果配置允许) | 943 | # 保存状态(如果配置允许) |
| 751 | if self.config.SAVE_INTERMEDIATE_STATES: | 944 | if self.config.SAVE_INTERMEDIATE_STATES: |
| 752 | state_filename = f"state_{query_safe}_{timestamp}.json" | 945 | state_filename = f"state_{query_safe}_{timestamp}.json" |
| 753 | state_filepath = os.path.join(self.config.OUTPUT_DIR, state_filename) | 946 | state_filepath = os.path.join(self.config.OUTPUT_DIR, state_filename) |
| 754 | self.state.save_to_file(state_filepath) | 947 | self.state.save_to_file(state_filepath) |
| 755 | logger.info(f"状态已保存到: {state_filepath}") | 948 | logger.info(f"状态已保存到: {state_filepath}") |
| 756 | - | 949 | + |
| 757 | def get_progress_summary(self) -> Dict[str, Any]: | 950 | def get_progress_summary(self) -> Dict[str, Any]: |
| 758 | """获取进度摘要""" | 951 | """获取进度摘要""" |
| 759 | return self.state.get_progress_summary() | 952 | return self.state.get_progress_summary() |
| 760 | - | 953 | + |
| 761 | def load_state(self, filepath: str): | 954 | def load_state(self, filepath: str): |
| 762 | """从文件加载状态""" | 955 | """从文件加载状态""" |
| 763 | self.state = State.load_from_file(filepath) | 956 | self.state = State.load_from_file(filepath) |
| 764 | logger.info(f"状态已从 {filepath} 加载") | 957 | logger.info(f"状态已从 {filepath} 加载") |
| 765 | - | 958 | + |
| 766 | def save_state(self, filepath: str): | 959 | def save_state(self, filepath: str): |
| 767 | """保存状态到文件""" | 960 | """保存状态到文件""" |
| 768 | self.state.save_to_file(filepath) | 961 | self.state.save_to_file(filepath) |
| @@ -772,12 +965,12 @@ class DeepSearchAgent: | @@ -772,12 +965,12 @@ class DeepSearchAgent: | ||
| 772 | def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent: | 965 | def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent: |
| 773 | """ | 966 | """ |
| 774 | 创建Deep Search Agent实例的便捷函数 | 967 | 创建Deep Search Agent实例的便捷函数 |
| 775 | - | 968 | + |
| 776 | Args: | 969 | Args: |
| 777 | config_file: 配置文件路径 | 970 | config_file: 配置文件路径 |
| 778 | - | 971 | + |
| 779 | Returns: | 972 | Returns: |
| 780 | DeepSearchAgent实例 | 973 | DeepSearchAgent实例 |
| 781 | """ | 974 | """ |
| 782 | - config = Settings() # 以空配置初始化,而从从环境变量初始化 | 975 | + config = Settings() # 以空配置初始化,而从从环境变量初始化 |
| 783 | return DeepSearchAgent(config) | 976 | return DeepSearchAgent(config) |
-
Please register or login to post a comment