Committed by
GitHub
Merge pull request #350 from 666ghj/feature/insight_agent_cluster
feat(insight_agent): search results cluster
Showing
2 changed files
with
426 additions
and
229 deletions
| @@ -7,60 +7,75 @@ import json | @@ -7,60 +7,75 @@ import json | ||
| 7 | import os | 7 | import os |
| 8 | import re | 8 | import re |
| 9 | from datetime import datetime | 9 | from datetime import datetime |
| 10 | -from typing import Optional, Dict, Any, List, Union | 10 | +from typing import Any, Dict, List, Optional, Union |
| 11 | + | ||
| 12 | +import numpy as np | ||
| 11 | from loguru import logger | 13 | from loguru import logger |
| 14 | +from sentence_transformers import SentenceTransformer | ||
| 15 | +from sklearn.cluster import KMeans | ||
| 12 | 16 | ||
| 13 | from .llms import LLMClient | 17 | from .llms import LLMClient |
| 14 | from .nodes import ( | 18 | from .nodes import ( |
| 15 | - ReportStructureNode, | ||
| 16 | - FirstSearchNode, | ||
| 17 | - ReflectionNode, | 19 | + FirstSearchNode, |
| 18 | FirstSummaryNode, | 20 | FirstSummaryNode, |
| 21 | + ReflectionNode, | ||
| 19 | ReflectionSummaryNode, | 22 | ReflectionSummaryNode, |
| 20 | - ReportFormattingNode | 23 | + ReportFormattingNode, |
| 24 | + ReportStructureNode, | ||
| 21 | ) | 25 | ) |
| 22 | from .state import State | 26 | from .state import State |
| 23 | -from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer, multilingual_sentiment_analyzer | ||
| 24 | -from .utils.config import settings, Settings | 27 | +from .tools import ( |
| 28 | + DBResponse, | ||
| 29 | + MediaCrawlerDB, | ||
| 30 | + keyword_optimizer, | ||
| 31 | + multilingual_sentiment_analyzer, | ||
| 32 | +) | ||
| 25 | from .utils import format_search_results_for_prompt | 33 | from .utils import format_search_results_for_prompt |
| 34 | +from .utils.config import Settings, settings | ||
| 35 | + | ||
| 36 | +ENABLE_CLUSTERING: bool = True # 是否启用聚类采样 | ||
| 37 | +MAX_CLUSTERED_RESULTS: int = 50 # 聚类后最大返回结果数 | ||
| 38 | +RESULTS_PER_CLUSTER: int = 5 # 每个聚类返回的结果数 | ||
| 26 | 39 | ||
| 27 | 40 | ||
| 28 | class DeepSearchAgent: | 41 | class DeepSearchAgent: |
| 29 | """Deep Search Agent主类""" | 42 | """Deep Search Agent主类""" |
| 30 | - | 43 | + |
| 31 | def __init__(self, config: Optional[Settings] = None): | 44 | def __init__(self, config: Optional[Settings] = None): |
| 32 | """ | 45 | """ |
| 33 | 初始化Deep Search Agent | 46 | 初始化Deep Search Agent |
| 34 | - | 47 | + |
| 35 | Args: | 48 | Args: |
| 36 | config: 可选配置对象(不填则用全局settings) | 49 | config: 可选配置对象(不填则用全局settings) |
| 37 | """ | 50 | """ |
| 38 | self.config = config or settings | 51 | self.config = config or settings |
| 39 | - | 52 | + |
| 40 | # 初始化LLM客户端 | 53 | # 初始化LLM客户端 |
| 41 | self.llm_client = self._initialize_llm() | 54 | self.llm_client = self._initialize_llm() |
| 42 | - | ||
| 43 | - | 55 | + |
| 44 | # 初始化搜索工具集 | 56 | # 初始化搜索工具集 |
| 45 | self.search_agency = MediaCrawlerDB() | 57 | self.search_agency = MediaCrawlerDB() |
| 46 | - | 58 | + |
| 59 | + # 初始化聚类小模型(懒加载) | ||
| 60 | + self._clustering_model = None | ||
| 61 | + | ||
| 47 | # 初始化情感分析器 | 62 | # 初始化情感分析器 |
| 48 | self.sentiment_analyzer = multilingual_sentiment_analyzer | 63 | self.sentiment_analyzer = multilingual_sentiment_analyzer |
| 49 | - | 64 | + |
| 50 | # 初始化节点 | 65 | # 初始化节点 |
| 51 | self._initialize_nodes() | 66 | self._initialize_nodes() |
| 52 | - | 67 | + |
| 53 | # 状态 | 68 | # 状态 |
| 54 | self.state = State() | 69 | self.state = State() |
| 55 | - | 70 | + |
| 56 | # 确保输出目录存在 | 71 | # 确保输出目录存在 |
| 57 | os.makedirs(self.config.OUTPUT_DIR, exist_ok=True) | 72 | os.makedirs(self.config.OUTPUT_DIR, exist_ok=True) |
| 58 | - | 73 | + |
| 59 | logger.info(f"Insight Agent已初始化") | 74 | logger.info(f"Insight Agent已初始化") |
| 60 | logger.info(f"使用LLM: {self.llm_client.get_model_info()}") | 75 | logger.info(f"使用LLM: {self.llm_client.get_model_info()}") |
| 61 | logger.info(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)") | 76 | logger.info(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)") |
| 62 | logger.info(f"情感分析: WeiboMultilingualSentiment (支持22种语言的情感分析)") | 77 | logger.info(f"情感分析: WeiboMultilingualSentiment (支持22种语言的情感分析)") |
| 63 | - | 78 | + |
| 64 | def _initialize_llm(self) -> LLMClient: | 79 | def _initialize_llm(self) -> LLMClient: |
| 65 | """初始化LLM客户端""" | 80 | """初始化LLM客户端""" |
| 66 | return LLMClient( | 81 | return LLMClient( |
| @@ -68,7 +83,7 @@ class DeepSearchAgent: | @@ -68,7 +83,7 @@ class DeepSearchAgent: | ||
| 68 | model_name=self.config.INSIGHT_ENGINE_MODEL_NAME, | 83 | model_name=self.config.INSIGHT_ENGINE_MODEL_NAME, |
| 69 | base_url=self.config.INSIGHT_ENGINE_BASE_URL, | 84 | base_url=self.config.INSIGHT_ENGINE_BASE_URL, |
| 70 | ) | 85 | ) |
| 71 | - | 86 | + |
| 72 | def _initialize_nodes(self): | 87 | def _initialize_nodes(self): |
| 73 | """初始化处理节点""" | 88 | """初始化处理节点""" |
| 74 | self.first_search_node = FirstSearchNode(self.llm_client) | 89 | self.first_search_node = FirstSearchNode(self.llm_client) |
| @@ -76,36 +91,106 @@ class DeepSearchAgent: | @@ -76,36 +91,106 @@ class DeepSearchAgent: | ||
| 76 | self.first_summary_node = FirstSummaryNode(self.llm_client) | 91 | self.first_summary_node = FirstSummaryNode(self.llm_client) |
| 77 | self.reflection_summary_node = ReflectionSummaryNode(self.llm_client) | 92 | self.reflection_summary_node = ReflectionSummaryNode(self.llm_client) |
| 78 | self.report_formatting_node = ReportFormattingNode(self.llm_client) | 93 | self.report_formatting_node = ReportFormattingNode(self.llm_client) |
| 79 | - | 94 | + |
| 95 | + def _get_clustering_model(self): | ||
| 96 | + """懒加载聚类模型""" | ||
| 97 | + if self._clustering_model is None: | ||
| 98 | + logger.info(" 加载聚类模型 (paraphrase-multilingual-MiniLM-L12-v2)...") | ||
| 99 | + self._clustering_model = SentenceTransformer( | ||
| 100 | + "paraphrase-multilingual-MiniLM-L12-v2" | ||
| 101 | + ) | ||
| 102 | + return self._clustering_model | ||
| 103 | + | ||
| 80 | def _validate_date_format(self, date_str: str) -> bool: | 104 | def _validate_date_format(self, date_str: str) -> bool: |
| 81 | """ | 105 | """ |
| 82 | 验证日期格式是否为YYYY-MM-DD | 106 | 验证日期格式是否为YYYY-MM-DD |
| 83 | - | 107 | + |
| 84 | Args: | 108 | Args: |
| 85 | date_str: 日期字符串 | 109 | date_str: 日期字符串 |
| 86 | - | 110 | + |
| 87 | Returns: | 111 | Returns: |
| 88 | 是否为有效格式 | 112 | 是否为有效格式 |
| 89 | """ | 113 | """ |
| 90 | if not date_str: | 114 | if not date_str: |
| 91 | return False | 115 | return False |
| 92 | - | 116 | + |
| 93 | # 检查格式 | 117 | # 检查格式 |
| 94 | - pattern = r'^\d{4}-\d{2}-\d{2}$' | 118 | + pattern = r"^\d{4}-\d{2}-\d{2}$" |
| 95 | if not re.match(pattern, date_str): | 119 | if not re.match(pattern, date_str): |
| 96 | return False | 120 | return False |
| 97 | - | 121 | + |
| 98 | # 检查日期是否有效 | 122 | # 检查日期是否有效 |
| 99 | try: | 123 | try: |
| 100 | - datetime.strptime(date_str, '%Y-%m-%d') | 124 | + datetime.strptime(date_str, "%Y-%m-%d") |
| 101 | return True | 125 | return True |
| 102 | except ValueError: | 126 | except ValueError: |
| 103 | return False | 127 | return False |
| 104 | - | 128 | + |
| 129 | + def _cluster_and_sample_results( | ||
| 130 | + self, | ||
| 131 | + results: List, | ||
| 132 | + max_results: int = MAX_CLUSTERED_RESULTS, | ||
| 133 | + results_per_cluster: int = RESULTS_PER_CLUSTER, | ||
| 134 | + ) -> List: | ||
| 135 | + """ | ||
| 136 | + 对搜索结果进行聚类并采样 | ||
| 137 | + | ||
| 138 | + Args: | ||
| 139 | + results: 搜索结果列表 | ||
| 140 | + max_results: 最大返回结果数 | ||
| 141 | + results_per_cluster: 每个聚类返回的结果数 | ||
| 142 | + | ||
| 143 | + Returns: | ||
| 144 | + 采样后的结果列表 | ||
| 145 | + """ | ||
| 146 | + if len(results) <= max_results: | ||
| 147 | + return results | ||
| 148 | + | ||
| 149 | + try: | ||
| 150 | + # 提取文本 | ||
| 151 | + texts = [r.title_or_content[:500] for r in results] | ||
| 152 | + | ||
| 153 | + # 获取模型并编码 | ||
| 154 | + model = self._get_clustering_model() | ||
| 155 | + embeddings = model.encode(texts, show_progress_bar=False) | ||
| 156 | + | ||
| 157 | + # 计算聚类数 | ||
| 158 | + n_clusters = min(max(2, max_results // results_per_cluster), len(results)) | ||
| 159 | + | ||
| 160 | + # KMeans聚类 | ||
| 161 | + kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) | ||
| 162 | + labels = kmeans.fit_predict(embeddings) | ||
| 163 | + | ||
| 164 | + # 从每个聚类采样 | ||
| 165 | + sampled_results = [] | ||
| 166 | + for cluster_id in range(n_clusters): | ||
| 167 | + cluster_indices = np.flatnonzero(labels == cluster_id) | ||
| 168 | + cluster_results = [(results[i], i) for i in cluster_indices] | ||
| 169 | + cluster_results.sort( | ||
| 170 | + key=lambda x: x[0].hotness_score or 0, reverse=True | ||
| 171 | + ) | ||
| 172 | + | ||
| 173 | + for result, _ in cluster_results[:results_per_cluster]: | ||
| 174 | + sampled_results.append(result) | ||
| 175 | + if len(sampled_results) >= max_results: | ||
| 176 | + break | ||
| 177 | + | ||
| 178 | + if len(sampled_results) >= max_results: | ||
| 179 | + break | ||
| 180 | + | ||
| 181 | + logger.info( | ||
| 182 | + f" 聚类完成: {len(results)} 条 -> {n_clusters} 个主题 -> {len(sampled_results)} 条代表性结果" | ||
| 183 | + ) | ||
| 184 | + return sampled_results | ||
| 185 | + | ||
| 186 | + except Exception as e: | ||
| 187 | + logger.warning(f" 聚类失败,返回前{max_results}条: {str(e)}") | ||
| 188 | + return results[:max_results] | ||
| 189 | + | ||
| 105 | def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse: | 190 | def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse: |
| 106 | """ | 191 | """ |
| 107 | 执行指定的数据库查询工具(集成关键词优化中间件和情感分析) | 192 | 执行指定的数据库查询工具(集成关键词优化中间件和情感分析) |
| 108 | - | 193 | + |
| 109 | Args: | 194 | Args: |
| 110 | tool_name: 工具名称,可选值: | 195 | tool_name: 工具名称,可选值: |
| 111 | - "search_hot_content": 查找热点内容 | 196 | - "search_hot_content": 查找热点内容 |
| @@ -117,18 +202,20 @@ class DeepSearchAgent: | @@ -117,18 +202,20 @@ class DeepSearchAgent: | ||
| 117 | query: 搜索关键词/话题 | 202 | query: 搜索关键词/话题 |
| 118 | **kwargs: 额外参数(如start_date, end_date, platform, limit, enable_sentiment等) | 203 | **kwargs: 额外参数(如start_date, end_date, platform, limit, enable_sentiment等) |
| 119 | enable_sentiment: 是否自动对搜索结果进行情感分析(默认True) | 204 | enable_sentiment: 是否自动对搜索结果进行情感分析(默认True) |
| 120 | - | 205 | + |
| 121 | Returns: | 206 | Returns: |
| 122 | DBResponse对象(可能包含情感分析结果) | 207 | DBResponse对象(可能包含情感分析结果) |
| 123 | """ | 208 | """ |
| 124 | logger.info(f" → 执行数据库查询工具: {tool_name}") | 209 | logger.info(f" → 执行数据库查询工具: {tool_name}") |
| 125 | - | 210 | + |
| 126 | # 对于热点内容搜索,不需要关键词优化(因为不需要query参数) | 211 | # 对于热点内容搜索,不需要关键词优化(因为不需要query参数) |
| 127 | if tool_name == "search_hot_content": | 212 | if tool_name == "search_hot_content": |
| 128 | time_period = kwargs.get("time_period", "week") | 213 | time_period = kwargs.get("time_period", "week") |
| 129 | limit = kwargs.get("limit", 100) | 214 | limit = kwargs.get("limit", 100) |
| 130 | - response = self.search_agency.search_hot_content(time_period=time_period, limit=limit) | ||
| 131 | - | 215 | + response = self.search_agency.search_hot_content( |
| 216 | + time_period=time_period, limit=limit | ||
| 217 | + ) | ||
| 218 | + | ||
| 132 | # 检查是否需要进行情感分析 | 219 | # 检查是否需要进行情感分析 |
| 133 | enable_sentiment = kwargs.get("enable_sentiment", True) | 220 | enable_sentiment = kwargs.get("enable_sentiment", True) |
| 134 | if enable_sentiment and response.results and len(response.results) > 0: | 221 | if enable_sentiment and response.results and len(response.results) > 0: |
| @@ -138,74 +225,101 @@ class DeepSearchAgent: | @@ -138,74 +225,101 @@ class DeepSearchAgent: | ||
| 138 | # 将情感分析结果添加到响应的parameters中 | 225 | # 将情感分析结果添加到响应的parameters中 |
| 139 | response.parameters["sentiment_analysis"] = sentiment_analysis | 226 | response.parameters["sentiment_analysis"] = sentiment_analysis |
| 140 | logger.info(f" ✅ 情感分析完成") | 227 | logger.info(f" ✅ 情感分析完成") |
| 141 | - | 228 | + |
| 142 | return response | 229 | return response |
| 143 | - | 230 | + |
| 144 | # 独立情感分析工具 | 231 | # 独立情感分析工具 |
| 145 | if tool_name == "analyze_sentiment": | 232 | if tool_name == "analyze_sentiment": |
| 146 | texts = kwargs.get("texts", query) # 可以通过texts参数传递,或使用query | 233 | texts = kwargs.get("texts", query) # 可以通过texts参数传递,或使用query |
| 147 | sentiment_result = self.analyze_sentiment_only(texts) | 234 | sentiment_result = self.analyze_sentiment_only(texts) |
| 148 | - | 235 | + |
| 149 | # 构建DBResponse格式的响应 | 236 | # 构建DBResponse格式的响应 |
| 150 | return DBResponse( | 237 | return DBResponse( |
| 151 | tool_name="analyze_sentiment", | 238 | tool_name="analyze_sentiment", |
| 152 | parameters={ | 239 | parameters={ |
| 153 | "texts": texts if isinstance(texts, list) else [texts], | 240 | "texts": texts if isinstance(texts, list) else [texts], |
| 154 | - **kwargs | 241 | + **kwargs, |
| 155 | }, | 242 | }, |
| 156 | results=[], # 情感分析不返回搜索结果 | 243 | results=[], # 情感分析不返回搜索结果 |
| 157 | results_count=0, | 244 | results_count=0, |
| 158 | - metadata=sentiment_result | 245 | + metadata=sentiment_result, |
| 159 | ) | 246 | ) |
| 160 | - | 247 | + |
| 161 | # 对于需要搜索词的工具,使用关键词优化中间件 | 248 | # 对于需要搜索词的工具,使用关键词优化中间件 |
| 162 | optimized_response = keyword_optimizer.optimize_keywords( | 249 | optimized_response = keyword_optimizer.optimize_keywords( |
| 163 | - original_query=query, | ||
| 164 | - context=f"使用{tool_name}工具进行查询" | 250 | + original_query=query, context=f"使用{tool_name}工具进行查询" |
| 165 | ) | 251 | ) |
| 166 | - | 252 | + |
| 167 | logger.info(f" 🔍 原始查询: '{query}'") | 253 | logger.info(f" 🔍 原始查询: '{query}'") |
| 168 | logger.info(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}") | 254 | logger.info(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}") |
| 169 | - | 255 | + |
| 170 | # 使用优化后的关键词进行多次查询并整合结果 | 256 | # 使用优化后的关键词进行多次查询并整合结果 |
| 171 | all_results = [] | 257 | all_results = [] |
| 172 | total_count = 0 | 258 | total_count = 0 |
| 173 | - | 259 | + |
| 174 | for keyword in optimized_response.optimized_keywords: | 260 | for keyword in optimized_response.optimized_keywords: |
| 175 | logger.info(f" 查询关键词: '{keyword}'") | 261 | logger.info(f" 查询关键词: '{keyword}'") |
| 176 | - | 262 | + |
| 177 | try: | 263 | try: |
| 178 | if tool_name == "search_topic_globally": | 264 | if tool_name == "search_topic_globally": |
| 179 | # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 | 265 | # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 |
| 180 | - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE | ||
| 181 | - response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=limit_per_table) | 266 | + limit_per_table = ( |
| 267 | + self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE | ||
| 268 | + ) | ||
| 269 | + response = self.search_agency.search_topic_globally( | ||
| 270 | + topic=keyword, limit_per_table=limit_per_table | ||
| 271 | + ) | ||
| 182 | elif tool_name == "search_topic_by_date": | 272 | elif tool_name == "search_topic_by_date": |
| 183 | start_date = kwargs.get("start_date") | 273 | start_date = kwargs.get("start_date") |
| 184 | end_date = kwargs.get("end_date") | 274 | end_date = kwargs.get("end_date") |
| 185 | # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 | 275 | # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 |
| 186 | - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE | 276 | + limit_per_table = ( |
| 277 | + self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE | ||
| 278 | + ) | ||
| 187 | if not start_date or not end_date: | 279 | if not start_date or not end_date: |
| 188 | - raise ValueError("search_topic_by_date工具需要start_date和end_date参数") | ||
| 189 | - response = self.search_agency.search_topic_by_date(topic=keyword, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table) | 280 | + raise ValueError( |
| 281 | + "search_topic_by_date工具需要start_date和end_date参数" | ||
| 282 | + ) | ||
| 283 | + response = self.search_agency.search_topic_by_date( | ||
| 284 | + topic=keyword, | ||
| 285 | + start_date=start_date, | ||
| 286 | + end_date=end_date, | ||
| 287 | + limit_per_table=limit_per_table, | ||
| 288 | + ) | ||
| 190 | elif tool_name == "get_comments_for_topic": | 289 | elif tool_name == "get_comments_for_topic": |
| 191 | # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 | 290 | # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 |
| 192 | - limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT // len(optimized_response.optimized_keywords) | 291 | + limit = self.config.DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT // len( |
| 292 | + optimized_response.optimized_keywords | ||
| 293 | + ) | ||
| 193 | limit = max(limit, 50) | 294 | limit = max(limit, 50) |
| 194 | - response = self.search_agency.get_comments_for_topic(topic=keyword, limit=limit) | 295 | + response = self.search_agency.get_comments_for_topic( |
| 296 | + topic=keyword, limit=limit | ||
| 297 | + ) | ||
| 195 | elif tool_name == "search_topic_on_platform": | 298 | elif tool_name == "search_topic_on_platform": |
| 196 | platform = kwargs.get("platform") | 299 | platform = kwargs.get("platform") |
| 197 | start_date = kwargs.get("start_date") | 300 | start_date = kwargs.get("start_date") |
| 198 | end_date = kwargs.get("end_date") | 301 | end_date = kwargs.get("end_date") |
| 199 | # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 | 302 | # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 |
| 200 | - limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT // len(optimized_response.optimized_keywords) | 303 | + limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT // len( |
| 304 | + optimized_response.optimized_keywords | ||
| 305 | + ) | ||
| 201 | limit = max(limit, 30) | 306 | limit = max(limit, 30) |
| 202 | if not platform: | 307 | if not platform: |
| 203 | raise ValueError("search_topic_on_platform工具需要platform参数") | 308 | raise ValueError("search_topic_on_platform工具需要platform参数") |
| 204 | - response = self.search_agency.search_topic_on_platform(platform=platform, topic=keyword, start_date=start_date, end_date=end_date, limit=limit) | 309 | + response = self.search_agency.search_topic_on_platform( |
| 310 | + platform=platform, | ||
| 311 | + topic=keyword, | ||
| 312 | + start_date=start_date, | ||
| 313 | + end_date=end_date, | ||
| 314 | + limit=limit, | ||
| 315 | + ) | ||
| 205 | else: | 316 | else: |
| 206 | logger.info(f" 未知的搜索工具: {tool_name},使用默认全局搜索") | 317 | logger.info(f" 未知的搜索工具: {tool_name},使用默认全局搜索") |
| 207 | - response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE) | ||
| 208 | - | 318 | + response = self.search_agency.search_topic_globally( |
| 319 | + topic=keyword, | ||
| 320 | + limit_per_table=self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE, | ||
| 321 | + ) | ||
| 322 | + | ||
| 209 | # 收集结果 | 323 | # 收集结果 |
| 210 | if response.results: | 324 | if response.results: |
| 211 | logger.info(f" 找到 {len(response.results)} 条结果") | 325 | logger.info(f" 找到 {len(response.results)} 条结果") |
| @@ -213,15 +327,22 @@ class DeepSearchAgent: | @@ -213,15 +327,22 @@ class DeepSearchAgent: | ||
| 213 | total_count += len(response.results) | 327 | total_count += len(response.results) |
| 214 | else: | 328 | else: |
| 215 | logger.info(f" 未找到结果") | 329 | logger.info(f" 未找到结果") |
| 216 | - | 330 | + |
| 217 | except Exception as e: | 331 | except Exception as e: |
| 218 | logger.error(f" 查询'{keyword}'时出错: {str(e)}") | 332 | logger.error(f" 查询'{keyword}'时出错: {str(e)}") |
| 219 | continue | 333 | continue |
| 220 | - | 334 | + |
| 221 | # 去重和整合结果 | 335 | # 去重和整合结果 |
| 222 | unique_results = self._deduplicate_results(all_results) | 336 | unique_results = self._deduplicate_results(all_results) |
| 223 | logger.info(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)} 条") | 337 | logger.info(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)} 条") |
| 224 | - | 338 | + |
| 339 | + if ENABLE_CLUSTERING: | ||
| 340 | + unique_results = self._cluster_and_sample_results( | ||
| 341 | + unique_results, | ||
| 342 | + max_results=MAX_CLUSTERED_RESULTS, | ||
| 343 | + results_per_cluster=RESULTS_PER_CLUSTER, | ||
| 344 | + ) | ||
| 345 | + | ||
| 225 | # 构建整合后的响应 | 346 | # 构建整合后的响应 |
| 226 | integrated_response = DBResponse( | 347 | integrated_response = DBResponse( |
| 227 | tool_name=f"{tool_name}_optimized", | 348 | tool_name=f"{tool_name}_optimized", |
| @@ -229,12 +350,12 @@ class DeepSearchAgent: | @@ -229,12 +350,12 @@ class DeepSearchAgent: | ||
| 229 | "original_query": query, | 350 | "original_query": query, |
| 230 | "optimized_keywords": optimized_response.optimized_keywords, | 351 | "optimized_keywords": optimized_response.optimized_keywords, |
| 231 | "optimization_reasoning": optimized_response.reasoning, | 352 | "optimization_reasoning": optimized_response.reasoning, |
| 232 | - **kwargs | 353 | + **kwargs, |
| 233 | }, | 354 | }, |
| 234 | results=unique_results, | 355 | results=unique_results, |
| 235 | - results_count=len(unique_results) | 356 | + results_count=len(unique_results), |
| 236 | ) | 357 | ) |
| 237 | - | 358 | + |
| 238 | # 检查是否需要进行情感分析 | 359 | # 检查是否需要进行情感分析 |
| 239 | enable_sentiment = kwargs.get("enable_sentiment", True) | 360 | enable_sentiment = kwargs.get("enable_sentiment", True) |
| 240 | if enable_sentiment and unique_results and len(unique_results) > 0: | 361 | if enable_sentiment and unique_results and len(unique_results) > 0: |
| @@ -242,40 +363,45 @@ class DeepSearchAgent: | @@ -242,40 +363,45 @@ class DeepSearchAgent: | ||
| 242 | sentiment_analysis = self._perform_sentiment_analysis(unique_results) | 363 | sentiment_analysis = self._perform_sentiment_analysis(unique_results) |
| 243 | if sentiment_analysis: | 364 | if sentiment_analysis: |
| 244 | # 将情感分析结果添加到响应的parameters中 | 365 | # 将情感分析结果添加到响应的parameters中 |
| 245 | - integrated_response.parameters["sentiment_analysis"] = sentiment_analysis | 366 | + integrated_response.parameters["sentiment_analysis"] = ( |
| 367 | + sentiment_analysis | ||
| 368 | + ) | ||
| 246 | logger.info(f" ✅ 情感分析完成") | 369 | logger.info(f" ✅ 情感分析完成") |
| 247 | - | 370 | + |
| 248 | return integrated_response | 371 | return integrated_response |
| 249 | - | 372 | + |
| 250 | def _deduplicate_results(self, results: List) -> List: | 373 | def _deduplicate_results(self, results: List) -> List: |
| 251 | """ | 374 | """ |
| 252 | 去重搜索结果 | 375 | 去重搜索结果 |
| 253 | """ | 376 | """ |
| 254 | seen = set() | 377 | seen = set() |
| 255 | unique_results = [] | 378 | unique_results = [] |
| 256 | - | 379 | + |
| 257 | for result in results: | 380 | for result in results: |
| 258 | # 使用URL或内容作为去重标识 | 381 | # 使用URL或内容作为去重标识 |
| 259 | identifier = result.url if result.url else result.title_or_content[:100] | 382 | identifier = result.url if result.url else result.title_or_content[:100] |
| 260 | if identifier not in seen: | 383 | if identifier not in seen: |
| 261 | seen.add(identifier) | 384 | seen.add(identifier) |
| 262 | unique_results.append(result) | 385 | unique_results.append(result) |
| 263 | - | 386 | + |
| 264 | return unique_results | 387 | return unique_results |
| 265 | - | 388 | + |
| 266 | def _perform_sentiment_analysis(self, results: List) -> Optional[Dict[str, Any]]: | 389 | def _perform_sentiment_analysis(self, results: List) -> Optional[Dict[str, Any]]: |
| 267 | """ | 390 | """ |
| 268 | 对搜索结果执行情感分析 | 391 | 对搜索结果执行情感分析 |
| 269 | - | 392 | + |
| 270 | Args: | 393 | Args: |
| 271 | results: 搜索结果列表 | 394 | results: 搜索结果列表 |
| 272 | - | 395 | + |
| 273 | Returns: | 396 | Returns: |
| 274 | 情感分析结果字典,如果失败则返回None | 397 | 情感分析结果字典,如果失败则返回None |
| 275 | """ | 398 | """ |
| 276 | try: | 399 | try: |
| 277 | # 初始化情感分析器(如果尚未初始化且未被禁用) | 400 | # 初始化情感分析器(如果尚未初始化且未被禁用) |
| 278 | - if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled: | 401 | + if ( |
| 402 | + not self.sentiment_analyzer.is_initialized | ||
| 403 | + and not self.sentiment_analyzer.is_disabled | ||
| 404 | + ): | ||
| 279 | logger.info(" 初始化情感分析模型...") | 405 | logger.info(" 初始化情感分析模型...") |
| 280 | if not self.sentiment_analyzer.initialize(): | 406 | if not self.sentiment_analyzer.initialize(): |
| 281 | logger.info(" 情感分析模型初始化失败,将直接透传原始文本") | 407 | logger.info(" 情感分析模型初始化失败,将直接透传原始文本") |
| @@ -290,203 +416,222 @@ class DeepSearchAgent: | @@ -290,203 +416,222 @@ class DeepSearchAgent: | ||
| 290 | "platform": result.platform, | 416 | "platform": result.platform, |
| 291 | "author": result.author_nickname, | 417 | "author": result.author_nickname, |
| 292 | "url": result.url, | 418 | "url": result.url, |
| 293 | - "publish_time": str(result.publish_time) if result.publish_time else None | 419 | + "publish_time": str(result.publish_time) |
| 420 | + if result.publish_time | ||
| 421 | + else None, | ||
| 294 | } | 422 | } |
| 295 | results_dict.append(result_dict) | 423 | results_dict.append(result_dict) |
| 296 | - | 424 | + |
| 297 | # 执行情感分析 | 425 | # 执行情感分析 |
| 298 | sentiment_analysis = self.sentiment_analyzer.analyze_query_results( | 426 | sentiment_analysis = self.sentiment_analyzer.analyze_query_results( |
| 299 | - query_results=results_dict, | ||
| 300 | - text_field="content", | ||
| 301 | - min_confidence=0.5 | 427 | + query_results=results_dict, text_field="content", min_confidence=0.5 |
| 302 | ) | 428 | ) |
| 303 | - | 429 | + |
| 304 | return sentiment_analysis.get("sentiment_analysis") | 430 | return sentiment_analysis.get("sentiment_analysis") |
| 305 | - | 431 | + |
| 306 | except Exception as e: | 432 | except Exception as e: |
| 307 | logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}") | 433 | logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}") |
| 308 | return None | 434 | return None |
| 309 | - | 435 | + |
| 310 | def analyze_sentiment_only(self, texts: Union[str, List[str]]) -> Dict[str, Any]: | 436 | def analyze_sentiment_only(self, texts: Union[str, List[str]]) -> Dict[str, Any]: |
| 311 | """ | 437 | """ |
| 312 | 独立的情感分析工具 | 438 | 独立的情感分析工具 |
| 313 | - | 439 | + |
| 314 | Args: | 440 | Args: |
| 315 | texts: 单个文本或文本列表 | 441 | texts: 单个文本或文本列表 |
| 316 | - | 442 | + |
| 317 | Returns: | 443 | Returns: |
| 318 | 情感分析结果 | 444 | 情感分析结果 |
| 319 | """ | 445 | """ |
| 320 | logger.info(f" → 执行独立情感分析") | 446 | logger.info(f" → 执行独立情感分析") |
| 321 | - | 447 | + |
| 322 | try: | 448 | try: |
| 323 | # 初始化情感分析器(如果尚未初始化且未被禁用) | 449 | # 初始化情感分析器(如果尚未初始化且未被禁用) |
| 324 | - if not self.sentiment_analyzer.is_initialized and not self.sentiment_analyzer.is_disabled: | 450 | + if ( |
| 451 | + not self.sentiment_analyzer.is_initialized | ||
| 452 | + and not self.sentiment_analyzer.is_disabled | ||
| 453 | + ): | ||
| 325 | logger.info(" 初始化情感分析模型...") | 454 | logger.info(" 初始化情感分析模型...") |
| 326 | if not self.sentiment_analyzer.initialize(): | 455 | if not self.sentiment_analyzer.initialize(): |
| 327 | logger.info(" 情感分析模型初始化失败,将直接透传原始文本") | 456 | logger.info(" 情感分析模型初始化失败,将直接透传原始文本") |
| 328 | elif self.sentiment_analyzer.is_disabled: | 457 | elif self.sentiment_analyzer.is_disabled: |
| 329 | logger.warning(" 情感分析功能已禁用,直接透传原始文本") | 458 | logger.warning(" 情感分析功能已禁用,直接透传原始文本") |
| 330 | - | 459 | + |
| 331 | # 执行分析 | 460 | # 执行分析 |
| 332 | if isinstance(texts, str): | 461 | if isinstance(texts, str): |
| 333 | result = self.sentiment_analyzer.analyze_single_text(texts) | 462 | result = self.sentiment_analyzer.analyze_single_text(texts) |
| 334 | result_dict = result.__dict__ | 463 | result_dict = result.__dict__ |
| 335 | response = { | 464 | response = { |
| 336 | "success": result.success and result.analysis_performed, | 465 | "success": result.success and result.analysis_performed, |
| 337 | - "total_analyzed": 1 if result.analysis_performed and result.success else 0, | ||
| 338 | - "results": [result_dict] | 466 | + "total_analyzed": 1 |
| 467 | + if result.analysis_performed and result.success | ||
| 468 | + else 0, | ||
| 469 | + "results": [result_dict], | ||
| 339 | } | 470 | } |
| 340 | if not result.analysis_performed: | 471 | if not result.analysis_performed: |
| 341 | response["success"] = False | 472 | response["success"] = False |
| 342 | - response["warning"] = result.error_message or "情感分析功能不可用,已直接返回原始文本" | 473 | + response["warning"] = ( |
| 474 | + result.error_message or "情感分析功能不可用,已直接返回原始文本" | ||
| 475 | + ) | ||
| 343 | return response | 476 | return response |
| 344 | else: | 477 | else: |
| 345 | texts_list = list(texts) | 478 | texts_list = list(texts) |
| 346 | - batch_result = self.sentiment_analyzer.analyze_batch(texts_list, show_progress=True) | 479 | + batch_result = self.sentiment_analyzer.analyze_batch( |
| 480 | + texts_list, show_progress=True | ||
| 481 | + ) | ||
| 347 | response = { | 482 | response = { |
| 348 | - "success": batch_result.analysis_performed and batch_result.success_count > 0, | ||
| 349 | - "total_analyzed": batch_result.total_processed if batch_result.analysis_performed else 0, | 483 | + "success": batch_result.analysis_performed |
| 484 | + and batch_result.success_count > 0, | ||
| 485 | + "total_analyzed": batch_result.total_processed | ||
| 486 | + if batch_result.analysis_performed | ||
| 487 | + else 0, | ||
| 350 | "success_count": batch_result.success_count, | 488 | "success_count": batch_result.success_count, |
| 351 | "failed_count": batch_result.failed_count, | 489 | "failed_count": batch_result.failed_count, |
| 352 | - "average_confidence": batch_result.average_confidence if batch_result.analysis_performed else 0.0, | ||
| 353 | - "results": [result.__dict__ for result in batch_result.results] | 490 | + "average_confidence": batch_result.average_confidence |
| 491 | + if batch_result.analysis_performed | ||
| 492 | + else 0.0, | ||
| 493 | + "results": [result.__dict__ for result in batch_result.results], | ||
| 354 | } | 494 | } |
| 355 | if not batch_result.analysis_performed: | 495 | if not batch_result.analysis_performed: |
| 356 | warning = next( | 496 | warning = next( |
| 357 | - (r.error_message for r in batch_result.results if r.error_message), | ||
| 358 | - "情感分析功能不可用,已直接返回原始文本" | 497 | + ( |
| 498 | + r.error_message | ||
| 499 | + for r in batch_result.results | ||
| 500 | + if r.error_message | ||
| 501 | + ), | ||
| 502 | + "情感分析功能不可用,已直接返回原始文本", | ||
| 359 | ) | 503 | ) |
| 360 | response["success"] = False | 504 | response["success"] = False |
| 361 | response["warning"] = warning | 505 | response["warning"] = warning |
| 362 | return response | 506 | return response |
| 363 | - | 507 | + |
| 364 | except Exception as e: | 508 | except Exception as e: |
| 365 | logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}") | 509 | logger.exception(f" ❌ 情感分析过程中发生错误: {str(e)}") |
| 366 | - return { | ||
| 367 | - "success": False, | ||
| 368 | - "error": str(e), | ||
| 369 | - "results": [] | ||
| 370 | - } | ||
| 371 | - | 510 | + return {"success": False, "error": str(e), "results": []} |
| 511 | + | ||
| 372 | def research(self, query: str, save_report: bool = True) -> str: | 512 | def research(self, query: str, save_report: bool = True) -> str: |
| 373 | """ | 513 | """ |
| 374 | 执行深度研究 | 514 | 执行深度研究 |
| 375 | - | 515 | + |
| 376 | Args: | 516 | Args: |
| 377 | query: 研究查询 | 517 | query: 研究查询 |
| 378 | save_report: 是否保存报告到文件 | 518 | save_report: 是否保存报告到文件 |
| 379 | - | 519 | + |
| 380 | Returns: | 520 | Returns: |
| 381 | 最终报告内容 | 521 | 最终报告内容 |
| 382 | """ | 522 | """ |
| 383 | - logger.info(f"\n{'='*60}") | 523 | + logger.info(f"\n{'=' * 60}") |
| 384 | logger.info(f"开始深度研究: {query}") | 524 | logger.info(f"开始深度研究: {query}") |
| 385 | - logger.info(f"{'='*60}") | ||
| 386 | - | 525 | + logger.info(f"{'=' * 60}") |
| 526 | + | ||
| 387 | try: | 527 | try: |
| 388 | # Step 1: 生成报告结构 | 528 | # Step 1: 生成报告结构 |
| 389 | self._generate_report_structure(query) | 529 | self._generate_report_structure(query) |
| 390 | - | 530 | + |
| 391 | # Step 2: 处理每个段落 | 531 | # Step 2: 处理每个段落 |
| 392 | self._process_paragraphs() | 532 | self._process_paragraphs() |
| 393 | - | 533 | + |
| 394 | # Step 3: 生成最终报告 | 534 | # Step 3: 生成最终报告 |
| 395 | final_report = self._generate_final_report() | 535 | final_report = self._generate_final_report() |
| 396 | - | 536 | + |
| 397 | # Step 4: 保存报告 | 537 | # Step 4: 保存报告 |
| 398 | if save_report: | 538 | if save_report: |
| 399 | self._save_report(final_report) | 539 | self._save_report(final_report) |
| 400 | 540 | ||
| 401 | logger.info("深度研究完成!") | 541 | logger.info("深度研究完成!") |
| 402 | - | 542 | + |
| 403 | return final_report | 543 | return final_report |
| 404 | - | 544 | + |
| 405 | except Exception as e: | 545 | except Exception as e: |
| 406 | logger.exception(f"研究过程中发生错误: {str(e)}") | 546 | logger.exception(f"研究过程中发生错误: {str(e)}") |
| 407 | raise e | 547 | raise e |
| 408 | - | 548 | + |
| 409 | def _generate_report_structure(self, query: str): | 549 | def _generate_report_structure(self, query: str): |
| 410 | """生成报告结构""" | 550 | """生成报告结构""" |
| 411 | logger.info(f"\n[步骤 1] 生成报告结构...") | 551 | logger.info(f"\n[步骤 1] 生成报告结构...") |
| 412 | - | 552 | + |
| 413 | # 创建报告结构节点 | 553 | # 创建报告结构节点 |
| 414 | report_structure_node = ReportStructureNode(self.llm_client, query) | 554 | report_structure_node = ReportStructureNode(self.llm_client, query) |
| 415 | - | 555 | + |
| 416 | # 生成结构并更新状态 | 556 | # 生成结构并更新状态 |
| 417 | self.state = report_structure_node.mutate_state(state=self.state) | 557 | self.state = report_structure_node.mutate_state(state=self.state) |
| 418 | - | 558 | + |
| 419 | _message = f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:" | 559 | _message = f"报告结构已生成,共 {len(self.state.paragraphs)} 个段落:" |
| 420 | for i, paragraph in enumerate(self.state.paragraphs, 1): | 560 | for i, paragraph in enumerate(self.state.paragraphs, 1): |
| 421 | _message += f"\n {i}. {paragraph.title}" | 561 | _message += f"\n {i}. {paragraph.title}" |
| 422 | logger.info(_message) | 562 | logger.info(_message) |
| 423 | - | 563 | + |
| 424 | def _process_paragraphs(self): | 564 | def _process_paragraphs(self): |
| 425 | """处理所有段落""" | 565 | """处理所有段落""" |
| 426 | total_paragraphs = len(self.state.paragraphs) | 566 | total_paragraphs = len(self.state.paragraphs) |
| 427 | - | 567 | + |
| 428 | for i in range(total_paragraphs): | 568 | for i in range(total_paragraphs): |
| 429 | - logger.info(f"\n[步骤 2.{i+1}] 处理段落: {self.state.paragraphs[i].title}") | 569 | + logger.info( |
| 570 | + f"\n[步骤 2.{i + 1}] 处理段落: {self.state.paragraphs[i].title}" | ||
| 571 | + ) | ||
| 430 | logger.info("-" * 50) | 572 | logger.info("-" * 50) |
| 431 | - | 573 | + |
| 432 | # 初始搜索和总结 | 574 | # 初始搜索和总结 |
| 433 | self._initial_search_and_summary(i) | 575 | self._initial_search_and_summary(i) |
| 434 | - | 576 | + |
| 435 | # 反思循环 | 577 | # 反思循环 |
| 436 | self._reflection_loop(i) | 578 | self._reflection_loop(i) |
| 437 | - | 579 | + |
| 438 | # 标记段落完成 | 580 | # 标记段落完成 |
| 439 | self.state.paragraphs[i].research.mark_completed() | 581 | self.state.paragraphs[i].research.mark_completed() |
| 440 | - | 582 | + |
| 441 | progress = (i + 1) / total_paragraphs * 100 | 583 | progress = (i + 1) / total_paragraphs * 100 |
| 442 | logger.info(f"段落处理完成 ({progress:.1f}%)") | 584 | logger.info(f"段落处理完成 ({progress:.1f}%)") |
| 443 | - | 585 | + |
| 444 | def _initial_search_and_summary(self, paragraph_index: int): | 586 | def _initial_search_and_summary(self, paragraph_index: int): |
| 445 | """执行初始搜索和总结""" | 587 | """执行初始搜索和总结""" |
| 446 | paragraph = self.state.paragraphs[paragraph_index] | 588 | paragraph = self.state.paragraphs[paragraph_index] |
| 447 | - | 589 | + |
| 448 | # 准备搜索输入 | 590 | # 准备搜索输入 |
| 449 | - search_input = { | ||
| 450 | - "title": paragraph.title, | ||
| 451 | - "content": paragraph.content | ||
| 452 | - } | ||
| 453 | - | 591 | + search_input = {"title": paragraph.title, "content": paragraph.content} |
| 592 | + | ||
| 454 | # 生成搜索查询和工具选择 | 593 | # 生成搜索查询和工具选择 |
| 455 | logger.info(" - 生成搜索查询...") | 594 | logger.info(" - 生成搜索查询...") |
| 456 | search_output = self.first_search_node.run(search_input) | 595 | search_output = self.first_search_node.run(search_input) |
| 457 | search_query = search_output["search_query"] | 596 | search_query = search_output["search_query"] |
| 458 | - search_tool = search_output.get("search_tool", "search_topic_globally") # 默认工具 | 597 | + search_tool = search_output.get( |
| 598 | + "search_tool", "search_topic_globally" | ||
| 599 | + ) # 默认工具 | ||
| 459 | reasoning = search_output["reasoning"] | 600 | reasoning = search_output["reasoning"] |
| 460 | - | 601 | + |
| 461 | logger.info(f" - 搜索查询: {search_query}") | 602 | logger.info(f" - 搜索查询: {search_query}") |
| 462 | logger.info(f" - 选择的工具: {search_tool}") | 603 | logger.info(f" - 选择的工具: {search_tool}") |
| 463 | logger.info(f" - 推理: {reasoning}") | 604 | logger.info(f" - 推理: {reasoning}") |
| 464 | - | 605 | + |
| 465 | # 执行搜索 | 606 | # 执行搜索 |
| 466 | logger.info(" - 执行数据库查询...") | 607 | logger.info(" - 执行数据库查询...") |
| 467 | - | 608 | + |
| 468 | # 处理特殊参数 | 609 | # 处理特殊参数 |
| 469 | search_kwargs = {} | 610 | search_kwargs = {} |
| 470 | - | 611 | + |
| 471 | # 处理需要日期的工具 | 612 | # 处理需要日期的工具 |
| 472 | if search_tool in ["search_topic_by_date", "search_topic_on_platform"]: | 613 | if search_tool in ["search_topic_by_date", "search_topic_on_platform"]: |
| 473 | start_date = search_output.get("start_date") | 614 | start_date = search_output.get("start_date") |
| 474 | end_date = search_output.get("end_date") | 615 | end_date = search_output.get("end_date") |
| 475 | - | 616 | + |
| 476 | if start_date and end_date: | 617 | if start_date and end_date: |
| 477 | # 验证日期格式 | 618 | # 验证日期格式 |
| 478 | - if self._validate_date_format(start_date) and self._validate_date_format(end_date): | 619 | + if self._validate_date_format( |
| 620 | + start_date | ||
| 621 | + ) and self._validate_date_format(end_date): | ||
| 479 | search_kwargs["start_date"] = start_date | 622 | search_kwargs["start_date"] = start_date |
| 480 | search_kwargs["end_date"] = end_date | 623 | search_kwargs["end_date"] = end_date |
| 481 | logger.info(f" - 时间范围: {start_date} 到 {end_date}") | 624 | logger.info(f" - 时间范围: {start_date} 到 {end_date}") |
| 482 | else: | 625 | else: |
| 483 | logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索") | 626 | logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索") |
| 484 | - logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}") | 627 | + logger.info( |
| 628 | + f" 提供的日期: start_date={start_date}, end_date={end_date}" | ||
| 629 | + ) | ||
| 485 | search_tool = "search_topic_globally" | 630 | search_tool = "search_topic_globally" |
| 486 | elif search_tool == "search_topic_by_date": | 631 | elif search_tool == "search_topic_by_date": |
| 487 | logger.info(f" search_topic_by_date工具缺少时间参数,改用全局搜索") | 632 | logger.info(f" search_topic_by_date工具缺少时间参数,改用全局搜索") |
| 488 | search_tool = "search_topic_globally" | 633 | search_tool = "search_topic_globally" |
| 489 | - | 634 | + |
| 490 | # 处理需要平台参数的工具 | 635 | # 处理需要平台参数的工具 |
| 491 | if search_tool == "search_topic_on_platform": | 636 | if search_tool == "search_topic_on_platform": |
| 492 | platform = search_output.get("platform") | 637 | platform = search_output.get("platform") |
| @@ -494,9 +639,11 @@ class DeepSearchAgent: | @@ -494,9 +639,11 @@ class DeepSearchAgent: | ||
| 494 | search_kwargs["platform"] = platform | 639 | search_kwargs["platform"] = platform |
| 495 | logger.info(f" - 指定平台: {platform}") | 640 | logger.info(f" - 指定平台: {platform}") |
| 496 | else: | 641 | else: |
| 497 | - logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索") | 642 | + logger.warning( |
| 643 | + f" search_topic_on_platform工具缺少平台参数,改用全局搜索" | ||
| 644 | + ) | ||
| 498 | search_tool = "search_topic_globally" | 645 | search_tool = "search_topic_globally" |
| 499 | - | 646 | + |
| 500 | # 处理限制参数,使用配置文件中的默认值而不是agent提供的参数 | 647 | # 处理限制参数,使用配置文件中的默认值而不是agent提供的参数 |
| 501 | if search_tool == "search_hot_content": | 648 | if search_tool == "search_hot_content": |
| 502 | time_period = search_output.get("time_period", "week") | 649 | time_period = search_output.get("time_period", "week") |
| @@ -505,9 +652,13 @@ class DeepSearchAgent: | @@ -505,9 +652,13 @@ class DeepSearchAgent: | ||
| 505 | search_kwargs["limit"] = limit | 652 | search_kwargs["limit"] = limit |
| 506 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: | 653 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: |
| 507 | if search_tool == "search_topic_globally": | 654 | if search_tool == "search_topic_globally": |
| 508 | - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE | 655 | + limit_per_table = ( |
| 656 | + self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE | ||
| 657 | + ) | ||
| 509 | else: # search_topic_by_date | 658 | else: # search_topic_by_date |
| 510 | - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE | 659 | + limit_per_table = ( |
| 660 | + self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE | ||
| 661 | + ) | ||
| 511 | search_kwargs["limit_per_table"] = limit_per_table | 662 | search_kwargs["limit_per_table"] = limit_per_table |
| 512 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: | 663 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: |
| 513 | if search_tool == "get_comments_for_topic": | 664 | if search_tool == "get_comments_for_topic": |
| @@ -515,43 +666,55 @@ class DeepSearchAgent: | @@ -515,43 +666,55 @@ class DeepSearchAgent: | ||
| 515 | else: # search_topic_on_platform | 666 | else: # search_topic_on_platform |
| 516 | limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT | 667 | limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT |
| 517 | search_kwargs["limit"] = limit | 668 | search_kwargs["limit"] = limit |
| 518 | - | ||
| 519 | - search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) | ||
| 520 | - | 669 | + |
| 670 | + search_response = self.execute_search_tool( | ||
| 671 | + search_tool, search_query, **search_kwargs | ||
| 672 | + ) | ||
| 673 | + | ||
| 521 | # 转换为兼容格式 | 674 | # 转换为兼容格式 |
| 522 | search_results = [] | 675 | search_results = [] |
| 523 | if search_response and search_response.results: | 676 | if search_response and search_response.results: |
| 524 | # 使用配置文件控制传递给LLM的结果数量,0表示不限制 | 677 | # 使用配置文件控制传递给LLM的结果数量,0表示不限制 |
| 525 | if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: | 678 | if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: |
| 526 | - max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM) | 679 | + max_results = min( |
| 680 | + len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM | ||
| 681 | + ) | ||
| 527 | else: | 682 | else: |
| 528 | max_results = len(search_response.results) # 不限制,传递所有结果 | 683 | max_results = len(search_response.results) # 不限制,传递所有结果 |
| 529 | for result in search_response.results[:max_results]: | 684 | for result in search_response.results[:max_results]: |
| 530 | - search_results.append({ | ||
| 531 | - 'title': result.title_or_content, | ||
| 532 | - 'url': result.url or "", | ||
| 533 | - 'content': result.title_or_content, | ||
| 534 | - 'score': result.hotness_score, | ||
| 535 | - 'raw_content': result.title_or_content, | ||
| 536 | - 'published_date': result.publish_time.isoformat() if result.publish_time else None, | ||
| 537 | - 'platform': result.platform, | ||
| 538 | - 'content_type': result.content_type, | ||
| 539 | - 'author': result.author_nickname, | ||
| 540 | - 'engagement': result.engagement | ||
| 541 | - }) | ||
| 542 | - | 685 | + search_results.append( |
| 686 | + { | ||
| 687 | + "title": result.title_or_content, | ||
| 688 | + "url": result.url or "", | ||
| 689 | + "content": result.title_or_content, | ||
| 690 | + "score": result.hotness_score, | ||
| 691 | + "raw_content": result.title_or_content, | ||
| 692 | + "published_date": result.publish_time.isoformat() | ||
| 693 | + if result.publish_time | ||
| 694 | + else None, | ||
| 695 | + "platform": result.platform, | ||
| 696 | + "content_type": result.content_type, | ||
| 697 | + "author": result.author_nickname, | ||
| 698 | + "engagement": result.engagement, | ||
| 699 | + } | ||
| 700 | + ) | ||
| 701 | + | ||
| 543 | if search_results: | 702 | if search_results: |
| 544 | _message = f" - 找到 {len(search_results)} 个搜索结果" | 703 | _message = f" - 找到 {len(search_results)} 个搜索结果" |
| 545 | for j, result in enumerate(search_results, 1): | 704 | for j, result in enumerate(search_results, 1): |
| 546 | - date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" | 705 | + date_info = ( |
| 706 | + f" (发布于: {result.get('published_date', 'N/A')})" | ||
| 707 | + if result.get("published_date") | ||
| 708 | + else "" | ||
| 709 | + ) | ||
| 547 | _message += f"\n {j}. {result['title'][:50]}...{date_info}" | 710 | _message += f"\n {j}. {result['title'][:50]}...{date_info}" |
| 548 | logger.info(_message) | 711 | logger.info(_message) |
| 549 | else: | 712 | else: |
| 550 | logger.info(" - 未找到搜索结果") | 713 | logger.info(" - 未找到搜索结果") |
| 551 | - | 714 | + |
| 552 | # 更新状态中的搜索历史 | 715 | # 更新状态中的搜索历史 |
| 553 | paragraph.research.add_search_results(search_query, search_results) | 716 | paragraph.research.add_search_results(search_query, search_results) |
| 554 | - | 717 | + |
| 555 | # 生成初始总结 | 718 | # 生成初始总结 |
| 556 | logger.info(" - 生成初始总结...") | 719 | logger.info(" - 生成初始总结...") |
| 557 | summary_input = { | 720 | summary_input = { |
| @@ -560,63 +723,73 @@ class DeepSearchAgent: | @@ -560,63 +723,73 @@ class DeepSearchAgent: | ||
| 560 | "search_query": search_query, | 723 | "search_query": search_query, |
| 561 | "search_results": format_search_results_for_prompt( | 724 | "search_results": format_search_results_for_prompt( |
| 562 | search_results, self.config.MAX_CONTENT_LENGTH | 725 | search_results, self.config.MAX_CONTENT_LENGTH |
| 563 | - ) | 726 | + ), |
| 564 | } | 727 | } |
| 565 | - | 728 | + |
| 566 | # 更新状态 | 729 | # 更新状态 |
| 567 | self.state = self.first_summary_node.mutate_state( | 730 | self.state = self.first_summary_node.mutate_state( |
| 568 | summary_input, self.state, paragraph_index | 731 | summary_input, self.state, paragraph_index |
| 569 | ) | 732 | ) |
| 570 | - | 733 | + |
| 571 | logger.info(" - 初始总结完成") | 734 | logger.info(" - 初始总结完成") |
| 572 | - | 735 | + |
| 573 | def _reflection_loop(self, paragraph_index: int): | 736 | def _reflection_loop(self, paragraph_index: int): |
| 574 | """执行反思循环""" | 737 | """执行反思循环""" |
| 575 | paragraph = self.state.paragraphs[paragraph_index] | 738 | paragraph = self.state.paragraphs[paragraph_index] |
| 576 | - | 739 | + |
| 577 | for reflection_i in range(self.config.MAX_REFLECTIONS): | 740 | for reflection_i in range(self.config.MAX_REFLECTIONS): |
| 578 | logger.info(f" - 反思 {reflection_i + 1}/{self.config.MAX_REFLECTIONS}...") | 741 | logger.info(f" - 反思 {reflection_i + 1}/{self.config.MAX_REFLECTIONS}...") |
| 579 | - | 742 | + |
| 580 | # 准备反思输入 | 743 | # 准备反思输入 |
| 581 | reflection_input = { | 744 | reflection_input = { |
| 582 | "title": paragraph.title, | 745 | "title": paragraph.title, |
| 583 | "content": paragraph.content, | 746 | "content": paragraph.content, |
| 584 | - "paragraph_latest_state": paragraph.research.latest_summary | 747 | + "paragraph_latest_state": paragraph.research.latest_summary, |
| 585 | } | 748 | } |
| 586 | - | 749 | + |
| 587 | # 生成反思搜索查询 | 750 | # 生成反思搜索查询 |
| 588 | reflection_output = self.reflection_node.run(reflection_input) | 751 | reflection_output = self.reflection_node.run(reflection_input) |
| 589 | search_query = reflection_output["search_query"] | 752 | search_query = reflection_output["search_query"] |
| 590 | - search_tool = reflection_output.get("search_tool", "search_topic_globally") # 默认工具 | 753 | + search_tool = reflection_output.get( |
| 754 | + "search_tool", "search_topic_globally" | ||
| 755 | + ) # 默认工具 | ||
| 591 | reasoning = reflection_output["reasoning"] | 756 | reasoning = reflection_output["reasoning"] |
| 592 | - | 757 | + |
| 593 | logger.info(f" 反思查询: {search_query}") | 758 | logger.info(f" 反思查询: {search_query}") |
| 594 | logger.info(f" 选择的工具: {search_tool}") | 759 | logger.info(f" 选择的工具: {search_tool}") |
| 595 | logger.info(f" 反思推理: {reasoning}") | 760 | logger.info(f" 反思推理: {reasoning}") |
| 596 | - | 761 | + |
| 597 | # 执行反思搜索 | 762 | # 执行反思搜索 |
| 598 | # 处理特殊参数 | 763 | # 处理特殊参数 |
| 599 | search_kwargs = {} | 764 | search_kwargs = {} |
| 600 | - | 765 | + |
| 601 | # 处理需要日期的工具 | 766 | # 处理需要日期的工具 |
| 602 | if search_tool in ["search_topic_by_date", "search_topic_on_platform"]: | 767 | if search_tool in ["search_topic_by_date", "search_topic_on_platform"]: |
| 603 | start_date = reflection_output.get("start_date") | 768 | start_date = reflection_output.get("start_date") |
| 604 | end_date = reflection_output.get("end_date") | 769 | end_date = reflection_output.get("end_date") |
| 605 | - | 770 | + |
| 606 | if start_date and end_date: | 771 | if start_date and end_date: |
| 607 | # 验证日期格式 | 772 | # 验证日期格式 |
| 608 | - if self._validate_date_format(start_date) and self._validate_date_format(end_date): | 773 | + if self._validate_date_format( |
| 774 | + start_date | ||
| 775 | + ) and self._validate_date_format(end_date): | ||
| 609 | search_kwargs["start_date"] = start_date | 776 | search_kwargs["start_date"] = start_date |
| 610 | search_kwargs["end_date"] = end_date | 777 | search_kwargs["end_date"] = end_date |
| 611 | logger.info(f" 时间范围: {start_date} 到 {end_date}") | 778 | logger.info(f" 时间范围: {start_date} 到 {end_date}") |
| 612 | else: | 779 | else: |
| 613 | - logger.info(f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索") | ||
| 614 | - logger.info(f" 提供的日期: start_date={start_date}, end_date={end_date}") | 780 | + logger.info( |
| 781 | + f" 日期格式错误(应为YYYY-MM-DD),改用全局搜索" | ||
| 782 | + ) | ||
| 783 | + logger.info( | ||
| 784 | + f" 提供的日期: start_date={start_date}, end_date={end_date}" | ||
| 785 | + ) | ||
| 615 | search_tool = "search_topic_globally" | 786 | search_tool = "search_topic_globally" |
| 616 | elif search_tool == "search_topic_by_date": | 787 | elif search_tool == "search_topic_by_date": |
| 617 | - logger.warning(f" search_topic_by_date工具缺少时间参数,改用全局搜索") | 788 | + logger.warning( |
| 789 | + f" search_topic_by_date工具缺少时间参数,改用全局搜索" | ||
| 790 | + ) | ||
| 618 | search_tool = "search_topic_globally" | 791 | search_tool = "search_topic_globally" |
| 619 | - | 792 | + |
| 620 | # 处理需要平台参数的工具 | 793 | # 处理需要平台参数的工具 |
| 621 | if search_tool == "search_topic_on_platform": | 794 | if search_tool == "search_topic_on_platform": |
| 622 | platform = reflection_output.get("platform") | 795 | platform = reflection_output.get("platform") |
| @@ -624,9 +797,11 @@ class DeepSearchAgent: | @@ -624,9 +797,11 @@ class DeepSearchAgent: | ||
| 624 | search_kwargs["platform"] = platform | 797 | search_kwargs["platform"] = platform |
| 625 | logger.info(f" 指定平台: {platform}") | 798 | logger.info(f" 指定平台: {platform}") |
| 626 | else: | 799 | else: |
| 627 | - logger.warning(f" search_topic_on_platform工具缺少平台参数,改用全局搜索") | 800 | + logger.warning( |
| 801 | + f" search_topic_on_platform工具缺少平台参数,改用全局搜索" | ||
| 802 | + ) | ||
| 628 | search_tool = "search_topic_globally" | 803 | search_tool = "search_topic_globally" |
| 629 | - | 804 | + |
| 630 | # 处理限制参数 | 805 | # 处理限制参数 |
| 631 | if search_tool == "search_hot_content": | 806 | if search_tool == "search_hot_content": |
| 632 | time_period = reflection_output.get("time_period", "week") | 807 | time_period = reflection_output.get("time_period", "week") |
| @@ -637,9 +812,13 @@ class DeepSearchAgent: | @@ -637,9 +812,13 @@ class DeepSearchAgent: | ||
| 637 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: | 812 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: |
| 638 | # 使用配置文件中的默认值,不允许agent控制limit_per_table参数 | 813 | # 使用配置文件中的默认值,不允许agent控制limit_per_table参数 |
| 639 | if search_tool == "search_topic_globally": | 814 | if search_tool == "search_topic_globally": |
| 640 | - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE | 815 | + limit_per_table = ( |
| 816 | + self.config.DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE | ||
| 817 | + ) | ||
| 641 | else: # search_topic_by_date | 818 | else: # search_topic_by_date |
| 642 | - limit_per_table = self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE | 819 | + limit_per_table = ( |
| 820 | + self.config.DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE | ||
| 821 | + ) | ||
| 643 | search_kwargs["limit_per_table"] = limit_per_table | 822 | search_kwargs["limit_per_table"] = limit_per_table |
| 644 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: | 823 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: |
| 645 | # 使用配置文件中的默认值,不允许agent控制limit参数 | 824 | # 使用配置文件中的默认值,不允许agent控制limit参数 |
| @@ -648,43 +827,56 @@ class DeepSearchAgent: | @@ -648,43 +827,56 @@ class DeepSearchAgent: | ||
| 648 | else: # search_topic_on_platform | 827 | else: # search_topic_on_platform |
| 649 | limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT | 828 | limit = self.config.DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT |
| 650 | search_kwargs["limit"] = limit | 829 | search_kwargs["limit"] = limit |
| 651 | - | ||
| 652 | - search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) | ||
| 653 | - | 830 | + |
| 831 | + search_response = self.execute_search_tool( | ||
| 832 | + search_tool, search_query, **search_kwargs | ||
| 833 | + ) | ||
| 834 | + | ||
| 654 | # 转换为兼容格式 | 835 | # 转换为兼容格式 |
| 655 | search_results = [] | 836 | search_results = [] |
| 656 | if search_response and search_response.results: | 837 | if search_response and search_response.results: |
| 657 | # 使用配置文件控制传递给LLM的结果数量,0表示不限制 | 838 | # 使用配置文件控制传递给LLM的结果数量,0表示不限制 |
| 658 | if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: | 839 | if self.config.MAX_SEARCH_RESULTS_FOR_LLM > 0: |
| 659 | - max_results = min(len(search_response.results), self.config.MAX_SEARCH_RESULTS_FOR_LLM) | 840 | + max_results = min( |
| 841 | + len(search_response.results), | ||
| 842 | + self.config.MAX_SEARCH_RESULTS_FOR_LLM, | ||
| 843 | + ) | ||
| 660 | else: | 844 | else: |
| 661 | max_results = len(search_response.results) # 不限制,传递所有结果 | 845 | max_results = len(search_response.results) # 不限制,传递所有结果 |
| 662 | for result in search_response.results[:max_results]: | 846 | for result in search_response.results[:max_results]: |
| 663 | - search_results.append({ | ||
| 664 | - 'title': result.title_or_content, | ||
| 665 | - 'url': result.url or "", | ||
| 666 | - 'content': result.title_or_content, | ||
| 667 | - 'score': result.hotness_score, | ||
| 668 | - 'raw_content': result.title_or_content, | ||
| 669 | - 'published_date': result.publish_time.isoformat() if result.publish_time else None, | ||
| 670 | - 'platform': result.platform, | ||
| 671 | - 'content_type': result.content_type, | ||
| 672 | - 'author': result.author_nickname, | ||
| 673 | - 'engagement': result.engagement | ||
| 674 | - }) | ||
| 675 | - | 847 | + search_results.append( |
| 848 | + { | ||
| 849 | + "title": result.title_or_content, | ||
| 850 | + "url": result.url or "", | ||
| 851 | + "content": result.title_or_content, | ||
| 852 | + "score": result.hotness_score, | ||
| 853 | + "raw_content": result.title_or_content, | ||
| 854 | + "published_date": result.publish_time.isoformat() | ||
| 855 | + if result.publish_time | ||
| 856 | + else None, | ||
| 857 | + "platform": result.platform, | ||
| 858 | + "content_type": result.content_type, | ||
| 859 | + "author": result.author_nickname, | ||
| 860 | + "engagement": result.engagement, | ||
| 861 | + } | ||
| 862 | + ) | ||
| 863 | + | ||
| 676 | if search_results: | 864 | if search_results: |
| 677 | _message = f" 找到 {len(search_results)} 个反思搜索结果" | 865 | _message = f" 找到 {len(search_results)} 个反思搜索结果" |
| 678 | for j, result in enumerate(search_results, 1): | 866 | for j, result in enumerate(search_results, 1): |
| 679 | - date_info = f" (发布于: {result.get('published_date', 'N/A')})" if result.get('published_date') else "" | 867 | + date_info = ( |
| 868 | + f" (发布于: {result.get('published_date', 'N/A')})" | ||
| 869 | + if result.get("published_date") | ||
| 870 | + else "" | ||
| 871 | + ) | ||
| 680 | _message += f"\n {j}. {result['title'][:50]}...{date_info}" | 872 | _message += f"\n {j}. {result['title'][:50]}...{date_info}" |
| 681 | logger.info(_message) | 873 | logger.info(_message) |
| 682 | else: | 874 | else: |
| 683 | logger.info(" 未找到反思搜索结果") | 875 | logger.info(" 未找到反思搜索结果") |
| 684 | - | 876 | + |
| 685 | # 更新搜索历史 | 877 | # 更新搜索历史 |
| 686 | paragraph.research.add_search_results(search_query, search_results) | 878 | paragraph.research.add_search_results(search_query, search_results) |
| 687 | - | 879 | + |
| 688 | # 生成反思总结 | 880 | # 生成反思总结 |
| 689 | reflection_summary_input = { | 881 | reflection_summary_input = { |
| 690 | "title": paragraph.title, | 882 | "title": paragraph.title, |
| @@ -693,28 +885,30 @@ class DeepSearchAgent: | @@ -693,28 +885,30 @@ class DeepSearchAgent: | ||
| 693 | "search_results": format_search_results_for_prompt( | 885 | "search_results": format_search_results_for_prompt( |
| 694 | search_results, self.config.MAX_CONTENT_LENGTH | 886 | search_results, self.config.MAX_CONTENT_LENGTH |
| 695 | ), | 887 | ), |
| 696 | - "paragraph_latest_state": paragraph.research.latest_summary | 888 | + "paragraph_latest_state": paragraph.research.latest_summary, |
| 697 | } | 889 | } |
| 698 | - | 890 | + |
| 699 | # 更新状态 | 891 | # 更新状态 |
| 700 | self.state = self.reflection_summary_node.mutate_state( | 892 | self.state = self.reflection_summary_node.mutate_state( |
| 701 | reflection_summary_input, self.state, paragraph_index | 893 | reflection_summary_input, self.state, paragraph_index |
| 702 | ) | 894 | ) |
| 703 | - | 895 | + |
| 704 | logger.info(f" 反思 {reflection_i + 1} 完成") | 896 | logger.info(f" 反思 {reflection_i + 1} 完成") |
| 705 | - | 897 | + |
| 706 | def _generate_final_report(self) -> str: | 898 | def _generate_final_report(self) -> str: |
| 707 | """生成最终报告""" | 899 | """生成最终报告""" |
| 708 | logger.info(f"\n[步骤 3] 生成最终报告...") | 900 | logger.info(f"\n[步骤 3] 生成最终报告...") |
| 709 | - | 901 | + |
| 710 | # 准备报告数据 | 902 | # 准备报告数据 |
| 711 | report_data = [] | 903 | report_data = [] |
| 712 | for paragraph in self.state.paragraphs: | 904 | for paragraph in self.state.paragraphs: |
| 713 | - report_data.append({ | ||
| 714 | - "title": paragraph.title, | ||
| 715 | - "paragraph_latest_state": paragraph.research.latest_summary | ||
| 716 | - }) | ||
| 717 | - | 905 | + report_data.append( |
| 906 | + { | ||
| 907 | + "title": paragraph.title, | ||
| 908 | + "paragraph_latest_state": paragraph.research.latest_summary, | ||
| 909 | + } | ||
| 910 | + ) | ||
| 911 | + | ||
| 718 | # 格式化报告 | 912 | # 格式化报告 |
| 719 | try: | 913 | try: |
| 720 | final_report = self.report_formatting_node.run(report_data) | 914 | final_report = self.report_formatting_node.run(report_data) |
| @@ -723,46 +917,48 @@ class DeepSearchAgent: | @@ -723,46 +917,48 @@ class DeepSearchAgent: | ||
| 723 | final_report = self.report_formatting_node.format_report_manually( | 917 | final_report = self.report_formatting_node.format_report_manually( |
| 724 | report_data, self.state.report_title | 918 | report_data, self.state.report_title |
| 725 | ) | 919 | ) |
| 726 | - | 920 | + |
| 727 | # 更新状态 | 921 | # 更新状态 |
| 728 | self.state.final_report = final_report | 922 | self.state.final_report = final_report |
| 729 | self.state.mark_completed() | 923 | self.state.mark_completed() |
| 730 | - | 924 | + |
| 731 | logger.info("最终报告生成完成") | 925 | logger.info("最终报告生成完成") |
| 732 | return final_report | 926 | return final_report |
| 733 | - | 927 | + |
| 734 | def _save_report(self, report_content: str): | 928 | def _save_report(self, report_content: str): |
| 735 | """保存报告到文件""" | 929 | """保存报告到文件""" |
| 736 | # 生成文件名 | 930 | # 生成文件名 |
| 737 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | 931 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| 738 | - query_safe = "".join(c for c in self.state.query if c.isalnum() or c in (' ', '-', '_')).rstrip() | ||
| 739 | - query_safe = query_safe.replace(' ', '_')[:30] | ||
| 740 | - | 932 | + query_safe = "".join( |
| 933 | + c for c in self.state.query if c.isalnum() or c in (" ", "-", "_") | ||
| 934 | + ).rstrip() | ||
| 935 | + query_safe = query_safe.replace(" ", "_")[:30] | ||
| 936 | + | ||
| 741 | filename = f"deep_search_report_{query_safe}_{timestamp}.md" | 937 | filename = f"deep_search_report_{query_safe}_{timestamp}.md" |
| 742 | filepath = os.path.join(self.config.OUTPUT_DIR, filename) | 938 | filepath = os.path.join(self.config.OUTPUT_DIR, filename) |
| 743 | - | 939 | + |
| 744 | # 保存报告 | 940 | # 保存报告 |
| 745 | - with open(filepath, 'w', encoding='utf-8') as f: | 941 | + with open(filepath, "w", encoding="utf-8") as f: |
| 746 | f.write(report_content) | 942 | f.write(report_content) |
| 747 | - | 943 | + |
| 748 | logger.info(f"报告已保存到: {filepath}") | 944 | logger.info(f"报告已保存到: {filepath}") |
| 749 | - | 945 | + |
| 750 | # 保存状态(如果配置允许) | 946 | # 保存状态(如果配置允许) |
| 751 | if self.config.SAVE_INTERMEDIATE_STATES: | 947 | if self.config.SAVE_INTERMEDIATE_STATES: |
| 752 | state_filename = f"state_{query_safe}_{timestamp}.json" | 948 | state_filename = f"state_{query_safe}_{timestamp}.json" |
| 753 | state_filepath = os.path.join(self.config.OUTPUT_DIR, state_filename) | 949 | state_filepath = os.path.join(self.config.OUTPUT_DIR, state_filename) |
| 754 | self.state.save_to_file(state_filepath) | 950 | self.state.save_to_file(state_filepath) |
| 755 | logger.info(f"状态已保存到: {state_filepath}") | 951 | logger.info(f"状态已保存到: {state_filepath}") |
| 756 | - | 952 | + |
| 757 | def get_progress_summary(self) -> Dict[str, Any]: | 953 | def get_progress_summary(self) -> Dict[str, Any]: |
| 758 | """获取进度摘要""" | 954 | """获取进度摘要""" |
| 759 | return self.state.get_progress_summary() | 955 | return self.state.get_progress_summary() |
| 760 | - | 956 | + |
| 761 | def load_state(self, filepath: str): | 957 | def load_state(self, filepath: str): |
| 762 | """从文件加载状态""" | 958 | """从文件加载状态""" |
| 763 | self.state = State.load_from_file(filepath) | 959 | self.state = State.load_from_file(filepath) |
| 764 | logger.info(f"状态已从 {filepath} 加载") | 960 | logger.info(f"状态已从 {filepath} 加载") |
| 765 | - | 961 | + |
| 766 | def save_state(self, filepath: str): | 962 | def save_state(self, filepath: str): |
| 767 | """保存状态到文件""" | 963 | """保存状态到文件""" |
| 768 | self.state.save_to_file(filepath) | 964 | self.state.save_to_file(filepath) |
| @@ -772,12 +968,12 @@ class DeepSearchAgent: | @@ -772,12 +968,12 @@ class DeepSearchAgent: | ||
| 772 | def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent: | 968 | def create_agent(config_file: Optional[str] = None) -> DeepSearchAgent: |
| 773 | """ | 969 | """ |
| 774 | 创建Deep Search Agent实例的便捷函数 | 970 | 创建Deep Search Agent实例的便捷函数 |
| 775 | - | 971 | + |
| 776 | Args: | 972 | Args: |
| 777 | config_file: 配置文件路径 | 973 | config_file: 配置文件路径 |
| 778 | - | 974 | + |
| 779 | Returns: | 975 | Returns: |
| 780 | DeepSearchAgent实例 | 976 | DeepSearchAgent实例 |
| 781 | """ | 977 | """ |
| 782 | - config = Settings() # 以空配置初始化,而从从环境变量初始化 | 978 | + config = Settings() # 以空配置初始化,而从从环境变量初始化 |
| 783 | return DeepSearchAgent(config) | 979 | return DeepSearchAgent(config) |
| @@ -61,6 +61,7 @@ weasyprint>=60.0 # PDF导出,支持Python 3.9-3.13 | @@ -61,6 +61,7 @@ weasyprint>=60.0 # PDF导出,支持Python 3.9-3.13 | ||
| 61 | # ===== 机器学习(可选,用于情感分析,不安装也没事写了容错程序) ===== | 61 | # ===== 机器学习(可选,用于情感分析,不安装也没事写了容错程序) ===== |
| 62 | torch>=2.0.0 # CPU版本 | 62 | torch>=2.0.0 # CPU版本 |
| 63 | transformers>=4.30.0 | 63 | transformers>=4.30.0 |
| 64 | +sentence-transformers>=2.2.2 | ||
| 64 | scikit-learn>=1.3.0 | 65 | scikit-learn>=1.3.0 |
| 65 | xgboost>=2.0.0 | 66 | xgboost>=2.0.0 |
| 66 | # NOTE:如果要安装GPU版本的torch,指令为pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126 | 67 | # NOTE:如果要安装GPU版本的torch,指令为pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126 |
-
Please register or login to post a comment