Showing
16 changed files
with
1072 additions
and
89 deletions
| @@ -7,9 +7,9 @@ import json | @@ -7,9 +7,9 @@ import json | ||
| 7 | import os | 7 | import os |
| 8 | import re | 8 | import re |
| 9 | from datetime import datetime | 9 | from datetime import datetime |
| 10 | -from typing import Optional, Dict, Any, List | 10 | +from typing import Optional, Dict, Any, List, Union |
| 11 | 11 | ||
| 12 | -from .llms import DeepSeekLLM, OpenAILLM, BaseLLM | 12 | +from .llms import DeepSeekLLM, OpenAILLM, KimiLLM, BaseLLM |
| 13 | from .nodes import ( | 13 | from .nodes import ( |
| 14 | ReportStructureNode, | 14 | ReportStructureNode, |
| 15 | FirstSearchNode, | 15 | FirstSearchNode, |
| @@ -19,7 +19,7 @@ from .nodes import ( | @@ -19,7 +19,7 @@ from .nodes import ( | ||
| 19 | ReportFormattingNode | 19 | ReportFormattingNode |
| 20 | ) | 20 | ) |
| 21 | from .state import State | 21 | from .state import State |
| 22 | -from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer | 22 | +from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer, multilingual_sentiment_analyzer |
| 23 | from .utils import Config, load_config, format_search_results_for_prompt | 23 | from .utils import Config, load_config, format_search_results_for_prompt |
| 24 | 24 | ||
| 25 | 25 | ||
| @@ -50,6 +50,9 @@ class DeepSearchAgent: | @@ -50,6 +50,9 @@ class DeepSearchAgent: | ||
| 50 | # 初始化搜索工具集 | 50 | # 初始化搜索工具集 |
| 51 | self.search_agency = MediaCrawlerDB() | 51 | self.search_agency = MediaCrawlerDB() |
| 52 | 52 | ||
| 53 | + # 初始化情感分析器 | ||
| 54 | + self.sentiment_analyzer = multilingual_sentiment_analyzer | ||
| 55 | + | ||
| 53 | # 初始化节点 | 56 | # 初始化节点 |
| 54 | self._initialize_nodes() | 57 | self._initialize_nodes() |
| 55 | 58 | ||
| @@ -62,6 +65,7 @@ class DeepSearchAgent: | @@ -62,6 +65,7 @@ class DeepSearchAgent: | ||
| 62 | print(f"Deep Search Agent 已初始化") | 65 | print(f"Deep Search Agent 已初始化") |
| 63 | print(f"使用LLM: {self.llm_client.get_model_info()}") | 66 | print(f"使用LLM: {self.llm_client.get_model_info()}") |
| 64 | print(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)") | 67 | print(f"搜索工具集: MediaCrawlerDB (支持5种本地数据库查询工具)") |
| 68 | + print(f"情感分析: WeiboMultilingualSentiment (支持22种语言的情感分析)") | ||
| 65 | 69 | ||
| 66 | def _initialize_llm(self) -> BaseLLM: | 70 | def _initialize_llm(self) -> BaseLLM: |
| 67 | """初始化LLM客户端""" | 71 | """初始化LLM客户端""" |
| @@ -75,6 +79,11 @@ class DeepSearchAgent: | @@ -75,6 +79,11 @@ class DeepSearchAgent: | ||
| 75 | api_key=self.config.openai_api_key, | 79 | api_key=self.config.openai_api_key, |
| 76 | model_name=self.config.openai_model | 80 | model_name=self.config.openai_model |
| 77 | ) | 81 | ) |
| 82 | + elif self.config.default_llm_provider == "kimi": | ||
| 83 | + return KimiLLM( | ||
| 84 | + api_key=self.config.kimi_api_key, | ||
| 85 | + model_name=self.config.kimi_model | ||
| 86 | + ) | ||
| 78 | else: | 87 | else: |
| 79 | raise ValueError(f"不支持的LLM提供商: {self.config.default_llm_provider}") | 88 | raise ValueError(f"不支持的LLM提供商: {self.config.default_llm_provider}") |
| 80 | 89 | ||
| @@ -113,7 +122,7 @@ class DeepSearchAgent: | @@ -113,7 +122,7 @@ class DeepSearchAgent: | ||
| 113 | 122 | ||
| 114 | def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse: | 123 | def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse: |
| 115 | """ | 124 | """ |
| 116 | - 执行指定的数据库查询工具(集成关键词优化中间件) | 125 | + 执行指定的数据库查询工具(集成关键词优化中间件和情感分析) |
| 117 | 126 | ||
| 118 | Args: | 127 | Args: |
| 119 | tool_name: 工具名称,可选值: | 128 | tool_name: 工具名称,可选值: |
| @@ -122,11 +131,13 @@ class DeepSearchAgent: | @@ -122,11 +131,13 @@ class DeepSearchAgent: | ||
| 122 | - "search_topic_by_date": 按日期搜索话题 | 131 | - "search_topic_by_date": 按日期搜索话题 |
| 123 | - "get_comments_for_topic": 获取话题评论 | 132 | - "get_comments_for_topic": 获取话题评论 |
| 124 | - "search_topic_on_platform": 平台定向搜索 | 133 | - "search_topic_on_platform": 平台定向搜索 |
| 134 | + - "analyze_sentiment": 对查询结果进行情感分析 | ||
| 125 | query: 搜索关键词/话题 | 135 | query: 搜索关键词/话题 |
| 126 | - **kwargs: 额外参数(如start_date, end_date, platform, limit等) | 136 | + **kwargs: 额外参数(如start_date, end_date, platform, limit, enable_sentiment等) |
| 137 | + enable_sentiment: 是否自动对搜索结果进行情感分析(默认True) | ||
| 127 | 138 | ||
| 128 | Returns: | 139 | Returns: |
| 129 | - DBResponse对象 | 140 | + DBResponse对象(可能包含情感分析结果) |
| 130 | """ | 141 | """ |
| 131 | print(f" → 执行数据库查询工具: {tool_name}") | 142 | print(f" → 执行数据库查询工具: {tool_name}") |
| 132 | 143 | ||
| @@ -134,7 +145,36 @@ class DeepSearchAgent: | @@ -134,7 +145,36 @@ class DeepSearchAgent: | ||
| 134 | if tool_name == "search_hot_content": | 145 | if tool_name == "search_hot_content": |
| 135 | time_period = kwargs.get("time_period", "week") | 146 | time_period = kwargs.get("time_period", "week") |
| 136 | limit = kwargs.get("limit", 100) | 147 | limit = kwargs.get("limit", 100) |
| 137 | - return self.search_agency.search_hot_content(time_period=time_period, limit=limit) | 148 | + response = self.search_agency.search_hot_content(time_period=time_period, limit=limit) |
| 149 | + | ||
| 150 | + # 检查是否需要进行情感分析 | ||
| 151 | + enable_sentiment = kwargs.get("enable_sentiment", True) | ||
| 152 | + if enable_sentiment and response.results and len(response.results) > 0: | ||
| 153 | + print(f" 🎭 开始对热点内容进行情感分析...") | ||
| 154 | + sentiment_analysis = self._perform_sentiment_analysis(response.results) | ||
| 155 | + if sentiment_analysis: | ||
| 156 | + # 将情感分析结果添加到响应的parameters中 | ||
| 157 | + response.parameters["sentiment_analysis"] = sentiment_analysis | ||
| 158 | + print(f" ✅ 情感分析完成") | ||
| 159 | + | ||
| 160 | + return response | ||
| 161 | + | ||
| 162 | + # 独立情感分析工具 | ||
| 163 | + if tool_name == "analyze_sentiment": | ||
| 164 | + texts = kwargs.get("texts", query) # 可以通过texts参数传递,或使用query | ||
| 165 | + sentiment_result = self.analyze_sentiment_only(texts) | ||
| 166 | + | ||
| 167 | + # 构建DBResponse格式的响应 | ||
| 168 | + return DBResponse( | ||
| 169 | + tool_name="analyze_sentiment", | ||
| 170 | + parameters={ | ||
| 171 | + "texts": texts if isinstance(texts, list) else [texts], | ||
| 172 | + **kwargs | ||
| 173 | + }, | ||
| 174 | + results=[], # 情感分析不返回搜索结果 | ||
| 175 | + results_count=0, | ||
| 176 | + metadata=sentiment_result | ||
| 177 | + ) | ||
| 138 | 178 | ||
| 139 | # 对于需要搜索词的工具,使用关键词优化中间件 | 179 | # 对于需要搜索词的工具,使用关键词优化中间件 |
| 140 | optimized_response = keyword_optimizer.optimize_keywords( | 180 | optimized_response = keyword_optimizer.optimize_keywords( |
| @@ -154,31 +194,35 @@ class DeepSearchAgent: | @@ -154,31 +194,35 @@ class DeepSearchAgent: | ||
| 154 | 194 | ||
| 155 | try: | 195 | try: |
| 156 | if tool_name == "search_topic_globally": | 196 | if tool_name == "search_topic_globally": |
| 157 | - limit_per_table = kwargs.get("limit_per_table", 100) | 197 | + # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 |
| 198 | + limit_per_table = self.config.default_search_topic_globally_limit_per_table | ||
| 158 | response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=limit_per_table) | 199 | response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=limit_per_table) |
| 159 | elif tool_name == "search_topic_by_date": | 200 | elif tool_name == "search_topic_by_date": |
| 160 | start_date = kwargs.get("start_date") | 201 | start_date = kwargs.get("start_date") |
| 161 | end_date = kwargs.get("end_date") | 202 | end_date = kwargs.get("end_date") |
| 162 | - limit_per_table = kwargs.get("limit_per_table", 100) | 203 | + # 使用配置文件中的默认值,忽略agent提供的limit_per_table参数 |
| 204 | + limit_per_table = self.config.default_search_topic_by_date_limit_per_table | ||
| 163 | if not start_date or not end_date: | 205 | if not start_date or not end_date: |
| 164 | raise ValueError("search_topic_by_date工具需要start_date和end_date参数") | 206 | raise ValueError("search_topic_by_date工具需要start_date和end_date参数") |
| 165 | response = self.search_agency.search_topic_by_date(topic=keyword, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table) | 207 | response = self.search_agency.search_topic_by_date(topic=keyword, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table) |
| 166 | elif tool_name == "get_comments_for_topic": | 208 | elif tool_name == "get_comments_for_topic": |
| 167 | - limit = kwargs.get("limit", 500) // len(optimized_response.optimized_keywords) | 209 | + # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 |
| 210 | + limit = self.config.default_get_comments_for_topic_limit // len(optimized_response.optimized_keywords) | ||
| 168 | limit = max(limit, 50) | 211 | limit = max(limit, 50) |
| 169 | response = self.search_agency.get_comments_for_topic(topic=keyword, limit=limit) | 212 | response = self.search_agency.get_comments_for_topic(topic=keyword, limit=limit) |
| 170 | elif tool_name == "search_topic_on_platform": | 213 | elif tool_name == "search_topic_on_platform": |
| 171 | platform = kwargs.get("platform") | 214 | platform = kwargs.get("platform") |
| 172 | start_date = kwargs.get("start_date") | 215 | start_date = kwargs.get("start_date") |
| 173 | end_date = kwargs.get("end_date") | 216 | end_date = kwargs.get("end_date") |
| 174 | - limit = kwargs.get("limit", 200) // len(optimized_response.optimized_keywords) | 217 | + # 使用配置文件中的默认值,按关键词数量分配,但保证最小值 |
| 218 | + limit = self.config.default_search_topic_on_platform_limit // len(optimized_response.optimized_keywords) | ||
| 175 | limit = max(limit, 30) | 219 | limit = max(limit, 30) |
| 176 | if not platform: | 220 | if not platform: |
| 177 | raise ValueError("search_topic_on_platform工具需要platform参数") | 221 | raise ValueError("search_topic_on_platform工具需要platform参数") |
| 178 | response = self.search_agency.search_topic_on_platform(platform=platform, topic=keyword, start_date=start_date, end_date=end_date, limit=limit) | 222 | response = self.search_agency.search_topic_on_platform(platform=platform, topic=keyword, start_date=start_date, end_date=end_date, limit=limit) |
| 179 | else: | 223 | else: |
| 180 | print(f" 未知的搜索工具: {tool_name},使用默认全局搜索") | 224 | print(f" 未知的搜索工具: {tool_name},使用默认全局搜索") |
| 181 | - response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=100) | 225 | + response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=self.config.default_search_topic_globally_limit_per_table) |
| 182 | 226 | ||
| 183 | # 收集结果 | 227 | # 收集结果 |
| 184 | if response.results: | 228 | if response.results: |
| @@ -209,6 +253,16 @@ class DeepSearchAgent: | @@ -209,6 +253,16 @@ class DeepSearchAgent: | ||
| 209 | results_count=len(unique_results) | 253 | results_count=len(unique_results) |
| 210 | ) | 254 | ) |
| 211 | 255 | ||
| 256 | + # 检查是否需要进行情感分析 | ||
| 257 | + enable_sentiment = kwargs.get("enable_sentiment", True) | ||
| 258 | + if enable_sentiment and unique_results and len(unique_results) > 0: | ||
| 259 | + print(f" 🎭 开始对搜索结果进行情感分析...") | ||
| 260 | + sentiment_analysis = self._perform_sentiment_analysis(unique_results) | ||
| 261 | + if sentiment_analysis: | ||
| 262 | + # 将情感分析结果添加到响应的parameters中 | ||
| 263 | + integrated_response.parameters["sentiment_analysis"] = sentiment_analysis | ||
| 264 | + print(f" ✅ 情感分析完成") | ||
| 265 | + | ||
| 212 | return integrated_response | 266 | return integrated_response |
| 213 | 267 | ||
| 214 | def _deduplicate_results(self, results: List) -> List: | 268 | def _deduplicate_results(self, results: List) -> List: |
| @@ -227,6 +281,99 @@ class DeepSearchAgent: | @@ -227,6 +281,99 @@ class DeepSearchAgent: | ||
| 227 | 281 | ||
| 228 | return unique_results | 282 | return unique_results |
| 229 | 283 | ||
| 284 | + def _perform_sentiment_analysis(self, results: List) -> Optional[Dict[str, Any]]: | ||
| 285 | + """ | ||
| 286 | + 对搜索结果执行情感分析 | ||
| 287 | + | ||
| 288 | + Args: | ||
| 289 | + results: 搜索结果列表 | ||
| 290 | + | ||
| 291 | + Returns: | ||
| 292 | + 情感分析结果字典,如果失败则返回None | ||
| 293 | + """ | ||
| 294 | + try: | ||
| 295 | + # 初始化情感分析器(如果尚未初始化) | ||
| 296 | + if not self.sentiment_analyzer.is_initialized: | ||
| 297 | + print(" 初始化情感分析模型...") | ||
| 298 | + if not self.sentiment_analyzer.initialize(): | ||
| 299 | + print(" ❌ 情感分析模型初始化失败") | ||
| 300 | + return None | ||
| 301 | + | ||
| 302 | + # 将查询结果转换为字典格式 | ||
| 303 | + results_dict = [] | ||
| 304 | + for result in results: | ||
| 305 | + result_dict = { | ||
| 306 | + "content": result.title_or_content, | ||
| 307 | + "platform": result.platform, | ||
| 308 | + "author": result.author_nickname, | ||
| 309 | + "url": result.url, | ||
| 310 | + "publish_time": str(result.publish_time) if result.publish_time else None | ||
| 311 | + } | ||
| 312 | + results_dict.append(result_dict) | ||
| 313 | + | ||
| 314 | + # 执行情感分析 | ||
| 315 | + sentiment_analysis = self.sentiment_analyzer.analyze_query_results( | ||
| 316 | + query_results=results_dict, | ||
| 317 | + text_field="content", | ||
| 318 | + min_confidence=0.5 | ||
| 319 | + ) | ||
| 320 | + | ||
| 321 | + return sentiment_analysis.get("sentiment_analysis") | ||
| 322 | + | ||
| 323 | + except Exception as e: | ||
| 324 | + print(f" ❌ 情感分析过程中发生错误: {str(e)}") | ||
| 325 | + return None | ||
| 326 | + | ||
| 327 | + def analyze_sentiment_only(self, texts: Union[str, List[str]]) -> Dict[str, Any]: | ||
| 328 | + """ | ||
| 329 | + 独立的情感分析工具 | ||
| 330 | + | ||
| 331 | + Args: | ||
| 332 | + texts: 单个文本或文本列表 | ||
| 333 | + | ||
| 334 | + Returns: | ||
| 335 | + 情感分析结果 | ||
| 336 | + """ | ||
| 337 | + print(f" → 执行独立情感分析") | ||
| 338 | + | ||
| 339 | + try: | ||
| 340 | + # 初始化情感分析器(如果尚未初始化) | ||
| 341 | + if not self.sentiment_analyzer.is_initialized: | ||
| 342 | + print(" 初始化情感分析模型...") | ||
| 343 | + if not self.sentiment_analyzer.initialize(): | ||
| 344 | + return { | ||
| 345 | + "success": False, | ||
| 346 | + "error": "情感分析模型初始化失败", | ||
| 347 | + "results": [] | ||
| 348 | + } | ||
| 349 | + | ||
| 350 | + # 执行分析 | ||
| 351 | + if isinstance(texts, str): | ||
| 352 | + result = self.sentiment_analyzer.analyze_single_text(texts) | ||
| 353 | + return { | ||
| 354 | + "success": True, | ||
| 355 | + "total_analyzed": 1, | ||
| 356 | + "results": [result.__dict__] | ||
| 357 | + } | ||
| 358 | + else: | ||
| 359 | + batch_result = self.sentiment_analyzer.analyze_batch(texts, show_progress=True) | ||
| 360 | + return { | ||
| 361 | + "success": True, | ||
| 362 | + "total_analyzed": batch_result.total_processed, | ||
| 363 | + "success_count": batch_result.success_count, | ||
| 364 | + "failed_count": batch_result.failed_count, | ||
| 365 | + "average_confidence": batch_result.average_confidence, | ||
| 366 | + "results": [result.__dict__ for result in batch_result.results] | ||
| 367 | + } | ||
| 368 | + | ||
| 369 | + except Exception as e: | ||
| 370 | + print(f" ❌ 情感分析过程中发生错误: {str(e)}") | ||
| 371 | + return { | ||
| 372 | + "success": False, | ||
| 373 | + "error": str(e), | ||
| 374 | + "results": [] | ||
| 375 | + } | ||
| 376 | + | ||
| 230 | def research(self, query: str, save_report: bool = True) -> str: | 377 | def research(self, query: str, save_report: bool = True) -> str: |
| 231 | """ | 378 | """ |
| 232 | 执行深度研究 | 379 | 执行深度研究 |
| @@ -356,17 +503,23 @@ class DeepSearchAgent: | @@ -356,17 +503,23 @@ class DeepSearchAgent: | ||
| 356 | print(f" ⚠️ search_topic_on_platform工具缺少平台参数,改用全局搜索") | 503 | print(f" ⚠️ search_topic_on_platform工具缺少平台参数,改用全局搜索") |
| 357 | search_tool = "search_topic_globally" | 504 | search_tool = "search_topic_globally" |
| 358 | 505 | ||
| 359 | - # 处理限制参数 | 506 | + # 处理限制参数,使用配置文件中的默认值而不是agent提供的参数 |
| 360 | if search_tool == "search_hot_content": | 507 | if search_tool == "search_hot_content": |
| 361 | time_period = search_output.get("time_period", "week") | 508 | time_period = search_output.get("time_period", "week") |
| 362 | - limit = search_output.get("limit", 100) | 509 | + limit = self.config.default_search_hot_content_limit |
| 363 | search_kwargs["time_period"] = time_period | 510 | search_kwargs["time_period"] = time_period |
| 364 | search_kwargs["limit"] = limit | 511 | search_kwargs["limit"] = limit |
| 365 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: | 512 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: |
| 366 | - limit_per_table = search_output.get("limit_per_table", 100) | 513 | + if search_tool == "search_topic_globally": |
| 514 | + limit_per_table = self.config.default_search_topic_globally_limit_per_table | ||
| 515 | + else: # search_topic_by_date | ||
| 516 | + limit_per_table = self.config.default_search_topic_by_date_limit_per_table | ||
| 367 | search_kwargs["limit_per_table"] = limit_per_table | 517 | search_kwargs["limit_per_table"] = limit_per_table |
| 368 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: | 518 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: |
| 369 | - limit = search_output.get("limit", 200) | 519 | + if search_tool == "get_comments_for_topic": |
| 520 | + limit = self.config.default_get_comments_for_topic_limit | ||
| 521 | + else: # search_topic_on_platform | ||
| 522 | + limit = self.config.default_search_topic_on_platform_limit | ||
| 370 | search_kwargs["limit"] = limit | 523 | search_kwargs["limit"] = limit |
| 371 | 524 | ||
| 372 | search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) | 525 | search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) |
| @@ -374,8 +527,11 @@ class DeepSearchAgent: | @@ -374,8 +527,11 @@ class DeepSearchAgent: | ||
| 374 | # 转换为兼容格式 | 527 | # 转换为兼容格式 |
| 375 | search_results = [] | 528 | search_results = [] |
| 376 | if search_response and search_response.results: | 529 | if search_response and search_response.results: |
| 377 | - # 每种搜索工具都有其特定的结果数量,这里取前100个作为上限 | ||
| 378 | - max_results = min(len(search_response.results), 100) | 530 | + # 使用配置文件控制传递给LLM的结果数量,0表示不限制 |
| 531 | + if self.config.max_search_results_for_llm > 0: | ||
| 532 | + max_results = min(len(search_response.results), self.config.max_search_results_for_llm) | ||
| 533 | + else: | ||
| 534 | + max_results = len(search_response.results) # 不限制,传递所有结果 | ||
| 379 | for result in search_response.results[:max_results]: | 535 | for result in search_response.results[:max_results]: |
| 380 | search_results.append({ | 536 | search_results.append({ |
| 381 | 'title': result.title_or_content, | 537 | 'title': result.title_or_content, |
| @@ -479,14 +635,23 @@ class DeepSearchAgent: | @@ -479,14 +635,23 @@ class DeepSearchAgent: | ||
| 479 | # 处理限制参数 | 635 | # 处理限制参数 |
| 480 | if search_tool == "search_hot_content": | 636 | if search_tool == "search_hot_content": |
| 481 | time_period = reflection_output.get("time_period", "week") | 637 | time_period = reflection_output.get("time_period", "week") |
| 482 | - limit = reflection_output.get("limit", 10) | 638 | + # 使用配置文件中的默认值,不允许agent控制limit参数 |
| 639 | + limit = self.config.default_search_hot_content_limit | ||
| 483 | search_kwargs["time_period"] = time_period | 640 | search_kwargs["time_period"] = time_period |
| 484 | search_kwargs["limit"] = limit | 641 | search_kwargs["limit"] = limit |
| 485 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: | 642 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: |
| 486 | - limit_per_table = reflection_output.get("limit_per_table", 5) | 643 | + # 使用配置文件中的默认值,不允许agent控制limit_per_table参数 |
| 644 | + if search_tool == "search_topic_globally": | ||
| 645 | + limit_per_table = self.config.default_search_topic_globally_limit_per_table | ||
| 646 | + else: # search_topic_by_date | ||
| 647 | + limit_per_table = self.config.default_search_topic_by_date_limit_per_table | ||
| 487 | search_kwargs["limit_per_table"] = limit_per_table | 648 | search_kwargs["limit_per_table"] = limit_per_table |
| 488 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: | 649 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: |
| 489 | - limit = reflection_output.get("limit", 20) | 650 | + # 使用配置文件中的默认值,不允许agent控制limit参数 |
| 651 | + if search_tool == "get_comments_for_topic": | ||
| 652 | + limit = self.config.default_get_comments_for_topic_limit | ||
| 653 | + else: # search_topic_on_platform | ||
| 654 | + limit = self.config.default_search_topic_on_platform_limit | ||
| 490 | search_kwargs["limit"] = limit | 655 | search_kwargs["limit"] = limit |
| 491 | 656 | ||
| 492 | search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) | 657 | search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) |
| @@ -494,8 +659,11 @@ class DeepSearchAgent: | @@ -494,8 +659,11 @@ class DeepSearchAgent: | ||
| 494 | # 转换为兼容格式 | 659 | # 转换为兼容格式 |
| 495 | search_results = [] | 660 | search_results = [] |
| 496 | if search_response and search_response.results: | 661 | if search_response and search_response.results: |
| 497 | - # 每种搜索工具都有其特定的结果数量,这里取前100个作为上限 | ||
| 498 | - max_results = min(len(search_response.results), 100) | 662 | + # 使用配置文件控制传递给LLM的结果数量,0表示不限制 |
| 663 | + if self.config.max_search_results_for_llm > 0: | ||
| 664 | + max_results = min(len(search_response.results), self.config.max_search_results_for_llm) | ||
| 665 | + else: | ||
| 666 | + max_results = len(search_response.results) # 不限制,传递所有结果 | ||
| 499 | for result in search_response.results[:max_results]: | 667 | for result in search_response.results[:max_results]: |
| 500 | search_results.append({ | 668 | search_results.append({ |
| 501 | 'title': result.title_or_content, | 669 | 'title': result.title_or_content, |
| @@ -6,5 +6,6 @@ LLM调用模块 | @@ -6,5 +6,6 @@ LLM调用模块 | ||
| 6 | from .base import BaseLLM | 6 | from .base import BaseLLM |
| 7 | from .deepseek import DeepSeekLLM | 7 | from .deepseek import DeepSeekLLM |
| 8 | from .openai_llm import OpenAILLM | 8 | from .openai_llm import OpenAILLM |
| 9 | +from .kimi import KimiLLM | ||
| 9 | 10 | ||
| 10 | -__all__ = ["BaseLLM", "DeepSeekLLM", "OpenAILLM"] | 11 | +__all__ = ["BaseLLM", "DeepSeekLLM", "OpenAILLM", "KimiLLM"] |
InsightEngine/llms/kimi.py
0 → 100644
| 1 | +""" | ||
| 2 | +Kimi LLM实现 | ||
| 3 | +使用Moonshot AI的Kimi API进行文本生成 | ||
| 4 | +""" | ||
| 5 | + | ||
| 6 | +import os | ||
| 7 | +from typing import Optional, Dict, Any | ||
| 8 | +from openai import OpenAI | ||
| 9 | +# 假设 .base 模块和 BaseLLM 类已存在 | ||
| 10 | +from .base import BaseLLM | ||
| 11 | + | ||
| 12 | + | ||
| 13 | +class KimiLLM(BaseLLM): | ||
| 14 | + """Kimi LLM实现类""" | ||
| 15 | + | ||
| 16 | + def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None): | ||
| 17 | + """ | ||
| 18 | + 初始化Kimi客户端 | ||
| 19 | + | ||
| 20 | + Args: | ||
| 21 | + api_key: Kimi API密钥,如果不提供则从环境变量读取 | ||
| 22 | + model_name: 模型名称,默认使用kimi-k2-0711-preview | ||
| 23 | + """ | ||
| 24 | + if api_key is None: | ||
| 25 | + api_key = os.getenv("KIMI_API_KEY") | ||
| 26 | + if not api_key: | ||
| 27 | + raise ValueError("Kimi API Key未找到!请设置KIMI_API_KEY环境变量或在初始化时提供") | ||
| 28 | + | ||
| 29 | + super().__init__(api_key, model_name) | ||
| 30 | + | ||
| 31 | + # 初始化OpenAI客户端,使用Kimi的endpoint | ||
| 32 | + self.client = OpenAI( | ||
| 33 | + api_key=self.api_key, | ||
| 34 | + base_url="https://api.moonshot.cn/v1" | ||
| 35 | + ) | ||
| 36 | + | ||
| 37 | + self.default_model = model_name or self.get_default_model() | ||
| 38 | + | ||
| 39 | + def get_default_model(self) -> str: | ||
| 40 | + """获取默认模型名称""" | ||
| 41 | + return "kimi-k2-0711-preview" | ||
| 42 | + | ||
| 43 | + def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str: | ||
| 44 | + """ | ||
| 45 | + 调用Kimi API生成回复 | ||
| 46 | + | ||
| 47 | + Args: | ||
| 48 | + system_prompt: 系统提示词 | ||
| 49 | + user_prompt: 用户输入 | ||
| 50 | + **kwargs: 其他参数,如temperature、max_tokens等 | ||
| 51 | + | ||
| 52 | + Returns: | ||
| 53 | + Kimi生成的回复文本 | ||
| 54 | + """ | ||
| 55 | + try: | ||
| 56 | + # 构建消息 | ||
| 57 | + messages = [ | ||
| 58 | + {"role": "system", "content": system_prompt}, | ||
| 59 | + {"role": "user", "content": user_prompt} | ||
| 60 | + ] | ||
| 61 | + | ||
| 62 | + # 智能计算max_tokens - 根据输入长度自动调整输出长度 | ||
| 63 | + input_length = len(system_prompt) + len(user_prompt) | ||
| 64 | + if input_length > 100000: # 超长文本 | ||
| 65 | + default_max_tokens = 81920 | ||
| 66 | + elif input_length > 50000: # 超长文本 | ||
| 67 | + default_max_tokens = 40960 | ||
| 68 | + elif input_length > 20000: # 长文本 | ||
| 69 | + default_max_tokens = 16384 | ||
| 70 | + elif input_length > 5000: # 中等文本 | ||
| 71 | + default_max_tokens = 8192 | ||
| 72 | + else: # 短文本 | ||
| 73 | + default_max_tokens = 4096 | ||
| 74 | + | ||
| 75 | + # 设置默认参数,针对长文本处理优化 | ||
| 76 | + params = { | ||
| 77 | + "model": self.default_model, | ||
| 78 | + "messages": messages, | ||
| 79 | + "temperature": kwargs.get("temperature", 0.6), # Kimi建议使用0.6 | ||
| 80 | + "max_tokens": kwargs.get("max_tokens", default_max_tokens), # 智能调整token限制 | ||
| 81 | + "stream": False | ||
| 82 | + } | ||
| 83 | + | ||
| 84 | + # 添加其他可选参数 | ||
| 85 | + if "top_p" in kwargs: | ||
| 86 | + params["top_p"] = kwargs["top_p"] | ||
| 87 | + if "presence_penalty" in kwargs: | ||
| 88 | + params["presence_penalty"] = kwargs["presence_penalty"] | ||
| 89 | + if "frequency_penalty" in kwargs: | ||
| 90 | + params["frequency_penalty"] = kwargs["frequency_penalty"] | ||
| 91 | + if "stop" in kwargs: | ||
| 92 | + params["stop"] = kwargs["stop"] | ||
| 93 | + | ||
| 94 | + # 输出调试信息(仅在使用Kimi时) | ||
| 95 | + print(f"[Kimi] 输入长度: {input_length}, 使用max_tokens: {params['max_tokens']}") | ||
| 96 | + | ||
| 97 | + # 调用API | ||
| 98 | + response = self.client.chat.completions.create(**params) | ||
| 99 | + | ||
| 100 | + # 提取回复内容 | ||
| 101 | + if response.choices and response.choices[0].message: | ||
| 102 | + content = response.choices[0].message.content | ||
| 103 | + return self.validate_response(content) | ||
| 104 | + else: | ||
| 105 | + return "" | ||
| 106 | + | ||
| 107 | + except Exception as e: | ||
| 108 | + print(f"Kimi API调用错误: {str(e)}") | ||
| 109 | + raise e | ||
| 110 | + | ||
| 111 | + def get_model_info(self) -> Dict[str, Any]: | ||
| 112 | + """ | ||
| 113 | + 获取当前模型信息 | ||
| 114 | + | ||
| 115 | + Returns: | ||
| 116 | + 模型信息字典 | ||
| 117 | + """ | ||
| 118 | + return { | ||
| 119 | + "provider": "Kimi", | ||
| 120 | + "model": self.default_model, | ||
| 121 | + "api_base": "https://api.moonshot.cn/v1", | ||
| 122 | + "max_context_length": "长文本支持(200K+ tokens)" | ||
| 123 | + } | ||
| 124 | + | ||
| 125 | + # ==================== 代码修改部分 ==================== | ||
| 126 | + def invoke_long_context(self, system_prompt: str, user_prompt: str, **kwargs) -> str: | ||
| 127 | + """ | ||
| 128 | + 专门用于长文本处理的调用方法 (作为invoke的兼容接口)。 | ||
| 129 | + 此方法通过设置推荐的默认参数,然后调用通用的invoke方法来处理请求。 | ||
| 130 | + | ||
| 131 | + Args: | ||
| 132 | + system_prompt: 系统提示词 | ||
| 133 | + user_prompt: 用户输入 | ||
| 134 | + **kwargs: 其他参数 | ||
| 135 | + | ||
| 136 | + Returns: | ||
| 137 | + Kimi生成的回复文本 | ||
| 138 | + """ | ||
| 139 | + # 为长文本场景,设置一个慷慨的默认 max_tokens,仅当用户未指定时生效。 | ||
| 140 | + # 您原有的16384是一个非常合理的值。 | ||
| 141 | + kwargs.setdefault("max_tokens", 16384) | ||
| 142 | + | ||
| 143 | + # 直接调用核心的invoke方法,将所有参数(包括预设的默认值)传递给它。 | ||
| 144 | + return self.invoke(system_prompt, user_prompt, **kwargs) |
| @@ -39,8 +39,8 @@ output_schema_first_search = { | @@ -39,8 +39,8 @@ output_schema_first_search = { | ||
| 39 | "end_date": {"type": "string", "description": "结束日期,格式YYYY-MM-DD,search_topic_by_date和search_topic_on_platform工具可能需要"}, | 39 | "end_date": {"type": "string", "description": "结束日期,格式YYYY-MM-DD,search_topic_by_date和search_topic_on_platform工具可能需要"}, |
| 40 | "platform": {"type": "string", "description": "平台名称,search_topic_on_platform工具必需,可选值:bilibili, weibo, douyin, kuaishou, xhs, zhihu, tieba"}, | 40 | "platform": {"type": "string", "description": "平台名称,search_topic_on_platform工具必需,可选值:bilibili, weibo, douyin, kuaishou, xhs, zhihu, tieba"}, |
| 41 | "time_period": {"type": "string", "description": "时间周期,search_hot_content工具可选,可选值:24h, week, year"}, | 41 | "time_period": {"type": "string", "description": "时间周期,search_hot_content工具可选,可选值:24h, week, year"}, |
| 42 | - "limit": {"type": "integer", "description": "结果数量限制,各工具可选参数"}, | ||
| 43 | - "limit_per_table": {"type": "integer", "description": "每表结果数量限制,search_topic_globally和search_topic_by_date工具可选"} | 42 | + "enable_sentiment": {"type": "boolean", "description": "是否启用自动情感分析,默认为true,适用于除analyze_sentiment外的所有搜索工具"}, |
| 43 | + "texts": {"type": "array", "items": {"type": "string"}, "description": "文本列表,仅用于analyze_sentiment工具"} | ||
| 44 | }, | 44 | }, |
| 45 | "required": ["search_query", "search_tool", "reasoning"] | 45 | "required": ["search_query", "search_tool", "reasoning"] |
| 46 | } | 46 | } |
| @@ -88,8 +88,8 @@ output_schema_reflection = { | @@ -88,8 +88,8 @@ output_schema_reflection = { | ||
| 88 | "end_date": {"type": "string", "description": "结束日期,格式YYYY-MM-DD,search_topic_by_date和search_topic_on_platform工具可能需要"}, | 88 | "end_date": {"type": "string", "description": "结束日期,格式YYYY-MM-DD,search_topic_by_date和search_topic_on_platform工具可能需要"}, |
| 89 | "platform": {"type": "string", "description": "平台名称,search_topic_on_platform工具必需,可选值:bilibili, weibo, douyin, kuaishou, xhs, zhihu, tieba"}, | 89 | "platform": {"type": "string", "description": "平台名称,search_topic_on_platform工具必需,可选值:bilibili, weibo, douyin, kuaishou, xhs, zhihu, tieba"}, |
| 90 | "time_period": {"type": "string", "description": "时间周期,search_hot_content工具可选,可选值:24h, week, year"}, | 90 | "time_period": {"type": "string", "description": "时间周期,search_hot_content工具可选,可选值:24h, week, year"}, |
| 91 | - "limit": {"type": "integer", "description": "结果数量限制,各工具可选参数"}, | ||
| 92 | - "limit_per_table": {"type": "integer", "description": "每表结果数量限制,search_topic_globally和search_topic_by_date工具可选"} | 91 | + "enable_sentiment": {"type": "boolean", "description": "是否启用自动情感分析,默认为true,适用于除analyze_sentiment外的所有搜索工具"}, |
| 92 | + "texts": {"type": "array", "items": {"type": "string"}, "description": "文本列表,仅用于analyze_sentiment工具"} | ||
| 93 | }, | 93 | }, |
| 94 | "required": ["search_query", "search_tool", "reasoning"] | 94 | "required": ["search_query", "search_tool", "reasoning"] |
| 95 | } | 95 | } |
| @@ -155,34 +155,40 @@ SYSTEM_PROMPT_FIRST_SEARCH = f""" | @@ -155,34 +155,40 @@ SYSTEM_PROMPT_FIRST_SEARCH = f""" | ||
| 155 | {json.dumps(input_schema_first_search, indent=2, ensure_ascii=False)} | 155 | {json.dumps(input_schema_first_search, indent=2, ensure_ascii=False)} |
| 156 | </INPUT JSON SCHEMA> | 156 | </INPUT JSON SCHEMA> |
| 157 | 157 | ||
| 158 | -你可以使用以下5种专业的本地舆情数据库查询工具来挖掘真实的民意和公众观点: | 158 | +你可以使用以下6种专业的本地舆情数据库查询工具来挖掘真实的民意和公众观点: |
| 159 | 159 | ||
| 160 | 1. **search_hot_content** - 查找热点内容工具 | 160 | 1. **search_hot_content** - 查找热点内容工具 |
| 161 | - 适用于:挖掘当前最受关注的舆情事件和话题 | 161 | - 适用于:挖掘当前最受关注的舆情事件和话题 |
| 162 | - - 特点:基于真实的点赞、评论、分享数据发现热门话题 | ||
| 163 | - - 参数:time_period ('24h', 'week', 'year'),limit(数量限制) | 162 | + - 特点:基于真实的点赞、评论、分享数据发现热门话题,自动进行情感分析 |
| 163 | + - 参数:time_period ('24h', 'week', 'year'),limit(数量限制),enable_sentiment(是否启用情感分析,默认True) | ||
| 164 | 164 | ||
| 165 | 2. **search_topic_globally** - 全局话题搜索工具 | 165 | 2. **search_topic_globally** - 全局话题搜索工具 |
| 166 | - 适用于:全面了解公众对特定话题的讨论和观点 | 166 | - 适用于:全面了解公众对特定话题的讨论和观点 |
| 167 | - - 特点:覆盖B站、微博、抖音、快手、小红书、知乎、贴吧等主流平台的真实用户声音 | ||
| 168 | - - 参数:limit_per_table(每个表的结果数量限制) | 167 | + - 特点:覆盖B站、微博、抖音、快手、小红书、知乎、贴吧等主流平台的真实用户声音,自动进行情感分析 |
| 168 | + - 参数:limit_per_table(每个表的结果数量限制),enable_sentiment(是否启用情感分析,默认True) | ||
| 169 | 169 | ||
| 170 | 3. **search_topic_by_date** - 按日期搜索话题工具 | 170 | 3. **search_topic_by_date** - 按日期搜索话题工具 |
| 171 | - 适用于:追踪舆情事件的时间线发展和公众情绪变化 | 171 | - 适用于:追踪舆情事件的时间线发展和公众情绪变化 |
| 172 | - - 特点:精确的时间范围控制,适合分析舆情演变过程 | 172 | + - 特点:精确的时间范围控制,适合分析舆情演变过程,自动进行情感分析 |
| 173 | - 特殊要求:需要提供start_date和end_date参数,格式为'YYYY-MM-DD' | 173 | - 特殊要求:需要提供start_date和end_date参数,格式为'YYYY-MM-DD' |
| 174 | - - 参数:limit_per_table(每个表的结果数量限制) | 174 | + - 参数:limit_per_table(每个表的结果数量限制),enable_sentiment(是否启用情感分析,默认True) |
| 175 | 175 | ||
| 176 | 4. **get_comments_for_topic** - 获取话题评论工具 | 176 | 4. **get_comments_for_topic** - 获取话题评论工具 |
| 177 | - 适用于:深度挖掘网民的真实态度、情感和观点 | 177 | - 适用于:深度挖掘网民的真实态度、情感和观点 |
| 178 | - - 特点:直接获取用户评论,了解民意走向和情感倾向 | ||
| 179 | - - 参数:limit(评论总数量限制) | 178 | + - 特点:直接获取用户评论,了解民意走向和情感倾向,自动进行情感分析 |
| 179 | + - 参数:limit(评论总数量限制),enable_sentiment(是否启用情感分析,默认True) | ||
| 180 | 180 | ||
| 181 | 5. **search_topic_on_platform** - 平台定向搜索工具 | 181 | 5. **search_topic_on_platform** - 平台定向搜索工具 |
| 182 | - 适用于:分析特定社交平台用户群体的观点特征 | 182 | - 适用于:分析特定社交平台用户群体的观点特征 |
| 183 | - - 特点:针对不同平台用户群体的观点差异进行精准分析 | 183 | + - 特点:针对不同平台用户群体的观点差异进行精准分析,自动进行情感分析 |
| 184 | - 特殊要求:需要提供platform参数,可选start_date和end_date | 184 | - 特殊要求:需要提供platform参数,可选start_date和end_date |
| 185 | - - 参数:platform(必须),start_date, end_date(可选),limit(数量限制) | 185 | + - 参数:platform(必须),start_date, end_date(可选),limit(数量限制),enable_sentiment(是否启用情感分析,默认True) |
| 186 | + | ||
| 187 | +6. **analyze_sentiment** - 多语言情感分析工具 | ||
| 188 | + - 适用于:对文本内容进行专门的情感倾向分析 | ||
| 189 | + - 特点:支持中文、英文、西班牙文、阿拉伯文、日文、韩文等22种语言的情感分析,输出5级情感等级(非常负面、负面、中性、正面、非常正面) | ||
| 190 | + - 参数:texts(文本或文本列表),query也可用作单个文本输入 | ||
| 191 | + - 用途:当搜索结果的情感倾向不明确或需要专门的情感分析时使用 | ||
| 186 | 192 | ||
| 187 | **你的核心使命:挖掘真实的民意和人情味** | 193 | **你的核心使命:挖掘真实的民意和人情味** |
| 188 | 194 | ||
| @@ -195,11 +201,16 @@ SYSTEM_PROMPT_FIRST_SEARCH = f""" | @@ -195,11 +201,16 @@ SYSTEM_PROMPT_FIRST_SEARCH = f""" | ||
| 195 | - **贴近生活语言**:用简单、直接、口语化的词汇 | 201 | - **贴近生活语言**:用简单、直接、口语化的词汇 |
| 196 | - **包含情感词汇**:网民常用的褒贬词、情绪词 | 202 | - **包含情感词汇**:网民常用的褒贬词、情绪词 |
| 197 | - **考虑话题热词**:相关的网络流行语、缩写、昵称 | 203 | - **考虑话题热词**:相关的网络流行语、缩写、昵称 |
| 198 | -4. **参数优化配置**: | 204 | +4. **情感分析策略选择**: |
| 205 | + - **自动情感分析**:默认启用(enable_sentiment: true),适用于搜索工具,能自动分析搜索结果的情感倾向 | ||
| 206 | + - **专门情感分析**:当需要对特定文本进行详细情感分析时,使用analyze_sentiment工具 | ||
| 207 | + - **关闭情感分析**:在某些特殊情况下(如纯事实性内容),可设置enable_sentiment: false | ||
| 208 | +5. **参数优化配置**: | ||
| 199 | - search_topic_by_date: 必须提供start_date和end_date参数(格式:YYYY-MM-DD) | 209 | - search_topic_by_date: 必须提供start_date和end_date参数(格式:YYYY-MM-DD) |
| 200 | - search_topic_on_platform: 必须提供platform参数(bilibili, weibo, douyin, kuaishou, xhs, zhihu, tieba之一) | 210 | - search_topic_on_platform: 必须提供platform参数(bilibili, weibo, douyin, kuaishou, xhs, zhihu, tieba之一) |
| 201 | - - 其他工具:合理配置limit参数以获取足够的样本(建议:search_hot_content limit>=100,search_topic_globally limit_per_table>=50,search_topic_by_date limit_per_table>=100,get_comments_for_topic limit>=500,search_topic_on_platform limit>=200) | ||
| 202 | -5. **阐述选择理由**:说明为什么这样的查询能够获得最真实的民意反馈 | 211 | + - analyze_sentiment: 使用texts参数提供文本列表,或使用search_query作为单个文本 |
| 212 | + - 系统自动配置数据量参数,无需手动设置limit或limit_per_table参数 | ||
| 213 | +6. **阐述选择理由**:说明为什么这样的查询和情感分析策略能够获得最真实的民意反馈 | ||
| 203 | 214 | ||
| 204 | **搜索词设计核心原则**: | 215 | **搜索词设计核心原则**: |
| 205 | - **想象网友怎么说**:如果你是个普通网友,你会怎么讨论这个话题? | 216 | - **想象网友怎么说**:如果你是个普通网友,你会怎么讨论这个话题? |
| @@ -251,7 +262,12 @@ SYSTEM_PROMPT_FIRST_SUMMARY = f""" | @@ -251,7 +262,12 @@ SYSTEM_PROMPT_FIRST_SUMMARY = f""" | ||
| 251 | 2. **展现多元观点**:呈现不同平台、不同群体的观点差异和讨论重点 | 262 | 2. **展现多元观点**:呈现不同平台、不同群体的观点差异和讨论重点 |
| 252 | 3. **数据支撑分析**:用具体的点赞数、评论数、转发数等数据说明舆情热度 | 263 | 3. **数据支撑分析**:用具体的点赞数、评论数、转发数等数据说明舆情热度 |
| 253 | 4. **情感色彩描述**:准确描述公众的情感倾向(愤怒、支持、担忧、期待等) | 264 | 4. **情感色彩描述**:准确描述公众的情感倾向(愤怒、支持、担忧、期待等) |
| 254 | -5. **避免套话官话**:使用贴近民众的语言,避免过度官方化的表述 | 265 | +5. **智能运用情感分析**: |
| 266 | + - **整合情感数据**:如果搜索结果包含自动情感分析,要充分利用情感分布数据(如"正面情感占60%,负面情感占25%") | ||
| 267 | + - **情感趋势描述**:描述主要情感倾向和情感分布特征 | ||
| 268 | + - **高置信度引用**:优先引用高置信度的情感分析结果 | ||
| 269 | + - **情感细节分析**:结合具体的情感标签(非常正面、正面、中性、负面、非常负面)进行深度分析 | ||
| 270 | +6. **避免套话官话**:使用贴近民众的语言,避免过度官方化的表述 | ||
| 255 | 271 | ||
| 256 | 撰写风格: | 272 | 撰写风格: |
| 257 | - 语言生动,有感染力 | 273 | - 语言生动,有感染力 |
| @@ -277,13 +293,14 @@ SYSTEM_PROMPT_REFLECTION = f""" | @@ -277,13 +293,14 @@ SYSTEM_PROMPT_REFLECTION = f""" | ||
| 277 | {json.dumps(input_schema_reflection, indent=2, ensure_ascii=False)} | 293 | {json.dumps(input_schema_reflection, indent=2, ensure_ascii=False)} |
| 278 | </INPUT JSON SCHEMA> | 294 | </INPUT JSON SCHEMA> |
| 279 | 295 | ||
| 280 | -你可以使用以下5种专业的本地舆情数据库查询工具来深度挖掘民意: | 296 | +你可以使用以下6种专业的本地舆情数据库查询工具来深度挖掘民意: |
| 281 | 297 | ||
| 282 | -1. **search_hot_content** - 查找热点内容工具 | ||
| 283 | -2. **search_topic_globally** - 全局话题搜索工具 | ||
| 284 | -3. **search_topic_by_date** - 按日期搜索话题工具 | ||
| 285 | -4. **get_comments_for_topic** - 获取话题评论工具 | ||
| 286 | -5. **search_topic_on_platform** - 平台定向搜索工具 | 298 | +1. **search_hot_content** - 查找热点内容工具(自动情感分析) |
| 299 | +2. **search_topic_globally** - 全局话题搜索工具(自动情感分析) | ||
| 300 | +3. **search_topic_by_date** - 按日期搜索话题工具(自动情感分析) | ||
| 301 | +4. **get_comments_for_topic** - 获取话题评论工具(自动情感分析) | ||
| 302 | +5. **search_topic_on_platform** - 平台定向搜索工具(自动情感分析) | ||
| 303 | +6. **analyze_sentiment** - 多语言情感分析工具(专门的情感分析) | ||
| 287 | 304 | ||
| 288 | **反思的核心目标:让报告更有人情味和真实感** | 305 | **反思的核心目标:让报告更有人情味和真实感** |
| 289 | 306 | ||
| @@ -311,7 +328,7 @@ SYSTEM_PROMPT_REFLECTION = f""" | @@ -311,7 +328,7 @@ SYSTEM_PROMPT_REFLECTION = f""" | ||
| 311 | 4. **参数配置要求**: | 328 | 4. **参数配置要求**: |
| 312 | - search_topic_by_date: 必须提供start_date和end_date参数(格式:YYYY-MM-DD) | 329 | - search_topic_by_date: 必须提供start_date和end_date参数(格式:YYYY-MM-DD) |
| 313 | - search_topic_on_platform: 必须提供platform参数(bilibili, weibo, douyin, kuaishou, xhs, zhihu, tieba之一) | 330 | - search_topic_on_platform: 必须提供platform参数(bilibili, weibo, douyin, kuaishou, xhs, zhihu, tieba之一) |
| 314 | - - 其他工具:合理配置参数以获取多样化的民意样本(建议:search_hot_content limit>=100,search_topic_globally limit_per_table>=50,search_topic_by_date limit_per_table>=100,get_comments_for_topic limit>=500,search_topic_on_platform limit>=200) | 331 | + - 系统自动配置数据量参数,无需手动设置limit或limit_per_table参数 |
| 315 | 332 | ||
| 316 | 5. **阐述补充理由**:明确说明为什么需要这些额外的民意数据 | 333 | 5. **阐述补充理由**:明确说明为什么需要这些额外的民意数据 |
| 317 | 334 | ||
| @@ -357,9 +374,13 @@ SYSTEM_PROMPT_REFLECTION_SUMMARY = f""" | @@ -357,9 +374,13 @@ SYSTEM_PROMPT_REFLECTION_SUMMARY = f""" | ||
| 357 | 优化策略: | 374 | 优化策略: |
| 358 | 1. **融入新的民意数据**:将补充搜索到的真实用户声音整合到段落中 | 375 | 1. **融入新的民意数据**:将补充搜索到的真实用户声音整合到段落中 |
| 359 | 2. **丰富情感表达**:增加具体的情感描述和社会情绪分析 | 376 | 2. **丰富情感表达**:增加具体的情感描述和社会情绪分析 |
| 360 | -3. **补充遗漏观点**:添加之前缺失的不同群体、平台的观点 | ||
| 361 | -4. **强化数据支撑**:用具体数字和案例让分析更有说服力 | ||
| 362 | -5. **优化语言表达**:让文字更生动、更贴近民众,减少官方套话 | 377 | +3. **深化情感分析**: |
| 378 | + - **整合情感变化**:如果有新的情感分析数据,对比前后情感变化趋势 | ||
| 379 | + - **细化情感层次**:区分不同群体、平台的情感差异 | ||
| 380 | + - **量化情感描述**:用具体的情感分布数据支撑分析(如"新增数据显示负面情感比例上升至40%") | ||
| 381 | +4. **补充遗漏观点**:添加之前缺失的不同群体、平台的观点 | ||
| 382 | +5. **强化数据支撑**:用具体数字和案例让分析更有说服力 | ||
| 383 | +6. **优化语言表达**:让文字更生动、更贴近民众,减少官方套话 | ||
| 363 | 384 | ||
| 364 | 注意事项: | 385 | 注意事项: |
| 365 | - 保留段落的核心观点和重要信息 | 386 | - 保留段落的核心观点和重要信息 |
| @@ -14,6 +14,13 @@ from .keyword_optimizer import ( | @@ -14,6 +14,13 @@ from .keyword_optimizer import ( | ||
| 14 | KeywordOptimizationResponse, | 14 | KeywordOptimizationResponse, |
| 15 | keyword_optimizer | 15 | keyword_optimizer |
| 16 | ) | 16 | ) |
| 17 | +from .sentiment_analyzer import ( | ||
| 18 | + WeiboMultilingualSentimentAnalyzer, | ||
| 19 | + SentimentResult, | ||
| 20 | + BatchSentimentResult, | ||
| 21 | + multilingual_sentiment_analyzer, | ||
| 22 | + analyze_sentiment | ||
| 23 | +) | ||
| 17 | 24 | ||
| 18 | __all__ = [ | 25 | __all__ = [ |
| 19 | "MediaCrawlerDB", | 26 | "MediaCrawlerDB", |
| @@ -22,5 +29,10 @@ __all__ = [ | @@ -22,5 +29,10 @@ __all__ = [ | ||
| 22 | "print_response_summary", | 29 | "print_response_summary", |
| 23 | "KeywordOptimizer", | 30 | "KeywordOptimizer", |
| 24 | "KeywordOptimizationResponse", | 31 | "KeywordOptimizationResponse", |
| 25 | - "keyword_optimizer" | 32 | + "keyword_optimizer", |
| 33 | + "WeiboMultilingualSentimentAnalyzer", | ||
| 34 | + "SentimentResult", | ||
| 35 | + "BatchSentimentResult", | ||
| 36 | + "multilingual_sentiment_analyzer", | ||
| 37 | + "analyze_sentiment" | ||
| 26 | ] | 38 | ] |
| @@ -228,7 +228,7 @@ class KeywordOptimizer: | @@ -228,7 +228,7 @@ class KeywordOptimizer: | ||
| 228 | 228 | ||
| 229 | # 清理和验证关键词 | 229 | # 清理和验证关键词 |
| 230 | cleaned_keywords = [] | 230 | cleaned_keywords = [] |
| 231 | - for keyword in keywords[:20]: # 最多5个 | 231 | + for keyword in keywords[:20]: # 最多20个 |
| 232 | keyword = keyword.strip().strip('"\'""''') | 232 | keyword = keyword.strip().strip('"\'""''') |
| 233 | if keyword and len(keyword) <= 20: # 合理长度 | 233 | if keyword and len(keyword) <= 20: # 合理长度 |
| 234 | cleaned_keywords.append(keyword) | 234 | cleaned_keywords.append(keyword) |
InsightEngine/tools/sentiment_analyzer.py
0 → 100644
| 1 | +""" | ||
| 2 | +多语言情感分析工具 | ||
| 3 | +基于WeiboMultilingualSentiment模型为InsightEngine提供情感分析功能 | ||
| 4 | +""" | ||
| 5 | + | ||
| 6 | +import torch | ||
| 7 | +from transformers import AutoTokenizer, AutoModelForSequenceClassification | ||
| 8 | +import os | ||
| 9 | +import sys | ||
| 10 | +from typing import List, Dict, Any, Optional, Union | ||
| 11 | +from dataclasses import dataclass | ||
| 12 | +import re | ||
| 13 | + | ||
| 14 | +# 添加项目根目录到路径,以便导入WeiboMultilingualSentiment | ||
| 15 | +project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||
| 16 | +weibo_sentiment_path = os.path.join(project_root, "SentimentAnalysisModel", "WeiboMultilingualSentiment") | ||
| 17 | +sys.path.append(weibo_sentiment_path) | ||
| 18 | + | ||
| 19 | + | ||
| 20 | +@dataclass | ||
| 21 | +class SentimentResult: | ||
| 22 | + """情感分析结果数据类""" | ||
| 23 | + text: str | ||
| 24 | + sentiment_label: str | ||
| 25 | + confidence: float | ||
| 26 | + probability_distribution: Dict[str, float] | ||
| 27 | + success: bool = True | ||
| 28 | + error_message: Optional[str] = None | ||
| 29 | + | ||
| 30 | + | ||
| 31 | +@dataclass | ||
| 32 | +class BatchSentimentResult: | ||
| 33 | + """批量情感分析结果数据类""" | ||
| 34 | + results: List[SentimentResult] | ||
| 35 | + total_processed: int | ||
| 36 | + success_count: int | ||
| 37 | + failed_count: int | ||
| 38 | + average_confidence: float | ||
| 39 | + | ||
| 40 | + | ||
| 41 | +class WeiboMultilingualSentimentAnalyzer: | ||
| 42 | + """ | ||
| 43 | + 多语言情感分析器 | ||
| 44 | + 封装WeiboMultilingualSentiment模型,为AI Agent提供情感分析功能 | ||
| 45 | + """ | ||
| 46 | + | ||
| 47 | + def __init__(self): | ||
| 48 | + """初始化情感分析器""" | ||
| 49 | + self.model = None | ||
| 50 | + self.tokenizer = None | ||
| 51 | + self.device = None | ||
| 52 | + self.is_initialized = False | ||
| 53 | + | ||
| 54 | + # 情感标签映射(5级分类) | ||
| 55 | + self.sentiment_map = { | ||
| 56 | + 0: "非常负面", | ||
| 57 | + 1: "负面", | ||
| 58 | + 2: "中性", | ||
| 59 | + 3: "正面", | ||
| 60 | + 4: "非常正面" | ||
| 61 | + } | ||
| 62 | + | ||
| 63 | + print("WeiboMultilingualSentimentAnalyzer 已创建,调用 initialize() 来加载模型") | ||
| 64 | + | ||
| 65 | + def initialize(self) -> bool: | ||
| 66 | + """ | ||
| 67 | + 初始化模型和分词器 | ||
| 68 | + | ||
| 69 | + Returns: | ||
| 70 | + 是否初始化成功 | ||
| 71 | + """ | ||
| 72 | + if self.is_initialized: | ||
| 73 | + print("模型已经初始化,无需重复加载") | ||
| 74 | + return True | ||
| 75 | + | ||
| 76 | + try: | ||
| 77 | + print("正在加载多语言情感分析模型...") | ||
| 78 | + | ||
| 79 | + # 使用多语言情感分析模型 | ||
| 80 | + model_name = "tabularisai/multilingual-sentiment-analysis" | ||
| 81 | + local_model_path = os.path.join(weibo_sentiment_path, "model") | ||
| 82 | + | ||
| 83 | + # 检查本地是否已有模型 | ||
| 84 | + if os.path.exists(local_model_path): | ||
| 85 | + print("从本地加载模型...") | ||
| 86 | + self.tokenizer = AutoTokenizer.from_pretrained(local_model_path) | ||
| 87 | + self.model = AutoModelForSequenceClassification.from_pretrained(local_model_path) | ||
| 88 | + else: | ||
| 89 | + print("首次使用,正在下载模型到本地...") | ||
| 90 | + # 下载并保存到本地 | ||
| 91 | + self.tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
| 92 | + self.model = AutoModelForSequenceClassification.from_pretrained(model_name) | ||
| 93 | + | ||
| 94 | + # 保存到本地 | ||
| 95 | + os.makedirs(local_model_path, exist_ok=True) | ||
| 96 | + self.tokenizer.save_pretrained(local_model_path) | ||
| 97 | + self.model.save_pretrained(local_model_path) | ||
| 98 | + print(f"模型已保存到: {local_model_path}") | ||
| 99 | + | ||
| 100 | + # 设置设备 | ||
| 101 | + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
| 102 | + self.model.to(self.device) | ||
| 103 | + self.model.eval() | ||
| 104 | + self.is_initialized = True | ||
| 105 | + | ||
| 106 | + print(f"模型加载成功! 使用设备: {self.device}") | ||
| 107 | + print("支持语言: 中文、英文、西班牙文、阿拉伯文、日文、韩文等22种语言") | ||
| 108 | + print("情感等级: 非常负面、负面、中性、正面、非常正面") | ||
| 109 | + | ||
| 110 | + return True | ||
| 111 | + | ||
| 112 | + except Exception as e: | ||
| 113 | + print(f"模型加载失败: {e}") | ||
| 114 | + print("请检查网络连接或模型文件") | ||
| 115 | + self.is_initialized = False | ||
| 116 | + return False | ||
| 117 | + | ||
| 118 | + def _preprocess_text(self, text: str) -> str: | ||
| 119 | + """ | ||
| 120 | + 文本预处理 | ||
| 121 | + | ||
| 122 | + Args: | ||
| 123 | + text: 输入文本 | ||
| 124 | + | ||
| 125 | + Returns: | ||
| 126 | + 处理后的文本 | ||
| 127 | + """ | ||
| 128 | + # 基本文本清理 | ||
| 129 | + if not text or not text.strip(): | ||
| 130 | + return "" | ||
| 131 | + | ||
| 132 | + # 去除多余空格 | ||
| 133 | + text = re.sub(r'\s+', ' ', text.strip()) | ||
| 134 | + | ||
| 135 | + return text | ||
| 136 | + | ||
| 137 | + def analyze_single_text(self, text: str) -> SentimentResult: | ||
| 138 | + """ | ||
| 139 | + 对单个文本进行情感分析 | ||
| 140 | + | ||
| 141 | + Args: | ||
| 142 | + text: 要分析的文本 | ||
| 143 | + | ||
| 144 | + Returns: | ||
| 145 | + SentimentResult对象 | ||
| 146 | + """ | ||
| 147 | + if not self.is_initialized: | ||
| 148 | + return SentimentResult( | ||
| 149 | + text=text, | ||
| 150 | + sentiment_label="未初始化", | ||
| 151 | + confidence=0.0, | ||
| 152 | + probability_distribution={}, | ||
| 153 | + success=False, | ||
| 154 | + error_message="模型未初始化,请先调用 initialize() 方法" | ||
| 155 | + ) | ||
| 156 | + | ||
| 157 | + try: | ||
| 158 | + # 预处理文本 | ||
| 159 | + processed_text = self._preprocess_text(text) | ||
| 160 | + | ||
| 161 | + if not processed_text: | ||
| 162 | + return SentimentResult( | ||
| 163 | + text=text, | ||
| 164 | + sentiment_label="输入错误", | ||
| 165 | + confidence=0.0, | ||
| 166 | + probability_distribution={}, | ||
| 167 | + success=False, | ||
| 168 | + error_message="输入文本为空或无效" | ||
| 169 | + ) | ||
| 170 | + | ||
| 171 | + # 分词编码 | ||
| 172 | + inputs = self.tokenizer( | ||
| 173 | + processed_text, | ||
| 174 | + max_length=512, | ||
| 175 | + padding=True, | ||
| 176 | + truncation=True, | ||
| 177 | + return_tensors='pt' | ||
| 178 | + ) | ||
| 179 | + | ||
| 180 | + # 转移到设备 | ||
| 181 | + inputs = {k: v.to(self.device) for k, v in inputs.items()} | ||
| 182 | + | ||
| 183 | + # 预测 | ||
| 184 | + with torch.no_grad(): | ||
| 185 | + outputs = self.model(**inputs) | ||
| 186 | + logits = outputs.logits | ||
| 187 | + probabilities = torch.softmax(logits, dim=1) | ||
| 188 | + prediction = torch.argmax(probabilities, dim=1).item() | ||
| 189 | + | ||
| 190 | + # 构建结果 | ||
| 191 | + confidence = probabilities[0][prediction].item() | ||
| 192 | + label = self.sentiment_map[prediction] | ||
| 193 | + | ||
| 194 | + # 构建概率分布字典 | ||
| 195 | + prob_dist = {} | ||
| 196 | + for i, (label_name, prob) in enumerate(zip(self.sentiment_map.values(), probabilities[0])): | ||
| 197 | + prob_dist[label_name] = prob.item() | ||
| 198 | + | ||
| 199 | + return SentimentResult( | ||
| 200 | + text=text, | ||
| 201 | + sentiment_label=label, | ||
| 202 | + confidence=confidence, | ||
| 203 | + probability_distribution=prob_dist, | ||
| 204 | + success=True | ||
| 205 | + ) | ||
| 206 | + | ||
| 207 | + except Exception as e: | ||
| 208 | + return SentimentResult( | ||
| 209 | + text=text, | ||
| 210 | + sentiment_label="分析失败", | ||
| 211 | + confidence=0.0, | ||
| 212 | + probability_distribution={}, | ||
| 213 | + success=False, | ||
| 214 | + error_message=f"预测时发生错误: {str(e)}" | ||
| 215 | + ) | ||
| 216 | + | ||
| 217 | + def analyze_batch(self, texts: List[str], show_progress: bool = True) -> BatchSentimentResult: | ||
| 218 | + """ | ||
| 219 | + 批量情感分析 | ||
| 220 | + | ||
| 221 | + Args: | ||
| 222 | + texts: 文本列表 | ||
| 223 | + show_progress: 是否显示进度 | ||
| 224 | + | ||
| 225 | + Returns: | ||
| 226 | + BatchSentimentResult对象 | ||
| 227 | + """ | ||
| 228 | + if not texts: | ||
| 229 | + return BatchSentimentResult( | ||
| 230 | + results=[], | ||
| 231 | + total_processed=0, | ||
| 232 | + success_count=0, | ||
| 233 | + failed_count=0, | ||
| 234 | + average_confidence=0.0 | ||
| 235 | + ) | ||
| 236 | + | ||
| 237 | + results = [] | ||
| 238 | + success_count = 0 | ||
| 239 | + total_confidence = 0.0 | ||
| 240 | + | ||
| 241 | + for i, text in enumerate(texts): | ||
| 242 | + if show_progress and len(texts) > 1: | ||
| 243 | + print(f"处理进度: {i+1}/{len(texts)}") | ||
| 244 | + | ||
| 245 | + result = self.analyze_single_text(text) | ||
| 246 | + results.append(result) | ||
| 247 | + | ||
| 248 | + if result.success: | ||
| 249 | + success_count += 1 | ||
| 250 | + total_confidence += result.confidence | ||
| 251 | + | ||
| 252 | + average_confidence = total_confidence / success_count if success_count > 0 else 0.0 | ||
| 253 | + failed_count = len(texts) - success_count | ||
| 254 | + | ||
| 255 | + return BatchSentimentResult( | ||
| 256 | + results=results, | ||
| 257 | + total_processed=len(texts), | ||
| 258 | + success_count=success_count, | ||
| 259 | + failed_count=failed_count, | ||
| 260 | + average_confidence=average_confidence | ||
| 261 | + ) | ||
| 262 | + | ||
| 263 | + def analyze_query_results(self, query_results: List[Dict[str, Any]], | ||
| 264 | + text_field: str = "content", | ||
| 265 | + min_confidence: float = 0.5) -> Dict[str, Any]: | ||
| 266 | + """ | ||
| 267 | + 对查询结果进行情感分析 | ||
| 268 | + 专门用于分析从MediaCrawlerDB返回的查询结果 | ||
| 269 | + | ||
| 270 | + Args: | ||
| 271 | + query_results: 查询结果列表,每个元素包含文本内容 | ||
| 272 | + text_field: 文本内容字段名,默认为"content" | ||
| 273 | + min_confidence: 最小置信度阈值 | ||
| 274 | + | ||
| 275 | + Returns: | ||
| 276 | + 包含情感分析结果的字典 | ||
| 277 | + """ | ||
| 278 | + if not query_results: | ||
| 279 | + return { | ||
| 280 | + "sentiment_analysis": { | ||
| 281 | + "total_analyzed": 0, | ||
| 282 | + "sentiment_distribution": {}, | ||
| 283 | + "high_confidence_results": [], | ||
| 284 | + "summary": "没有内容需要分析" | ||
| 285 | + } | ||
| 286 | + } | ||
| 287 | + | ||
| 288 | + # 提取文本内容 | ||
| 289 | + texts_to_analyze = [] | ||
| 290 | + original_data = [] | ||
| 291 | + | ||
| 292 | + for item in query_results: | ||
| 293 | + # 尝试多个可能的文本字段 | ||
| 294 | + text_content = "" | ||
| 295 | + for field in [text_field, "title_or_content", "content", "title", "text"]: | ||
| 296 | + if field in item and item[field]: | ||
| 297 | + text_content = str(item[field]) | ||
| 298 | + break | ||
| 299 | + | ||
| 300 | + if text_content.strip(): | ||
| 301 | + texts_to_analyze.append(text_content) | ||
| 302 | + original_data.append(item) | ||
| 303 | + | ||
| 304 | + if not texts_to_analyze: | ||
| 305 | + return { | ||
| 306 | + "sentiment_analysis": { | ||
| 307 | + "total_analyzed": 0, | ||
| 308 | + "sentiment_distribution": {}, | ||
| 309 | + "high_confidence_results": [], | ||
| 310 | + "summary": "查询结果中没有找到可分析的文本内容" | ||
| 311 | + } | ||
| 312 | + } | ||
| 313 | + | ||
| 314 | + # 执行批量情感分析 | ||
| 315 | + print(f"正在对{len(texts_to_analyze)}条内容进行情感分析...") | ||
| 316 | + batch_result = self.analyze_batch(texts_to_analyze, show_progress=True) | ||
| 317 | + | ||
| 318 | + # 统计情感分布 | ||
| 319 | + sentiment_distribution = {} | ||
| 320 | + high_confidence_results = [] | ||
| 321 | + | ||
| 322 | + for result, original_item in zip(batch_result.results, original_data): | ||
| 323 | + if result.success: | ||
| 324 | + # 统计情感分布 | ||
| 325 | + sentiment = result.sentiment_label | ||
| 326 | + if sentiment not in sentiment_distribution: | ||
| 327 | + sentiment_distribution[sentiment] = 0 | ||
| 328 | + sentiment_distribution[sentiment] += 1 | ||
| 329 | + | ||
| 330 | + # 收集高置信度结果 | ||
| 331 | + if result.confidence >= min_confidence: | ||
| 332 | + high_confidence_results.append({ | ||
| 333 | + "original_data": original_item, | ||
| 334 | + "sentiment": result.sentiment_label, | ||
| 335 | + "confidence": result.confidence, | ||
| 336 | + "text_preview": result.text[:100] + "..." if len(result.text) > 100 else result.text | ||
| 337 | + }) | ||
| 338 | + | ||
| 339 | + # 生成情感分析摘要 | ||
| 340 | + total_analyzed = batch_result.success_count | ||
| 341 | + if total_analyzed > 0: | ||
| 342 | + dominant_sentiment = max(sentiment_distribution.items(), key=lambda x: x[1]) | ||
| 343 | + sentiment_summary = f"共分析{total_analyzed}条内容,主要情感倾向为'{dominant_sentiment[0]}'({dominant_sentiment[1]}条,占{dominant_sentiment[1]/total_analyzed*100:.1f}%)" | ||
| 344 | + else: | ||
| 345 | + sentiment_summary = "情感分析失败" | ||
| 346 | + | ||
| 347 | + return { | ||
| 348 | + "sentiment_analysis": { | ||
| 349 | + "total_analyzed": total_analyzed, | ||
| 350 | + "success_rate": f"{batch_result.success_count}/{batch_result.total_processed}", | ||
| 351 | + "average_confidence": round(batch_result.average_confidence, 4), | ||
| 352 | + "sentiment_distribution": sentiment_distribution, | ||
| 353 | + "high_confidence_results": high_confidence_results, # 返回所有高置信度结果,不做限制 | ||
| 354 | + "summary": sentiment_summary | ||
| 355 | + } | ||
| 356 | + } | ||
| 357 | + | ||
| 358 | + def get_model_info(self) -> Dict[str, Any]: | ||
| 359 | + """ | ||
| 360 | + 获取模型信息 | ||
| 361 | + | ||
| 362 | + Returns: | ||
| 363 | + 模型信息字典 | ||
| 364 | + """ | ||
| 365 | + return { | ||
| 366 | + "model_name": "tabularisai/multilingual-sentiment-analysis", | ||
| 367 | + "supported_languages": [ | ||
| 368 | + "中文", "英文", "西班牙文", "阿拉伯文", "日文", "韩文", | ||
| 369 | + "德文", "法文", "意大利文", "葡萄牙文", "俄文", "荷兰文", | ||
| 370 | + "波兰文", "土耳其文", "丹麦文", "希腊文", "芬兰文", | ||
| 371 | + "瑞典文", "挪威文", "匈牙利文", "捷克文", "保加利亚文" | ||
| 372 | + ], | ||
| 373 | + "sentiment_levels": list(self.sentiment_map.values()), | ||
| 374 | + "is_initialized": self.is_initialized, | ||
| 375 | + "device": str(self.device) if self.device else "未设置" | ||
| 376 | + } | ||
| 377 | + | ||
| 378 | + | ||
| 379 | +# 创建全局实例(延迟初始化) | ||
| 380 | +multilingual_sentiment_analyzer = WeiboMultilingualSentimentAnalyzer() | ||
| 381 | + | ||
| 382 | + | ||
| 383 | +def analyze_sentiment(text_or_texts: Union[str, List[str]], | ||
| 384 | + initialize_if_needed: bool = True) -> Union[SentimentResult, BatchSentimentResult]: | ||
| 385 | + """ | ||
| 386 | + 便捷的情感分析函数 | ||
| 387 | + | ||
| 388 | + Args: | ||
| 389 | + text_or_texts: 单个文本或文本列表 | ||
| 390 | + initialize_if_needed: 如果模型未初始化,是否自动初始化 | ||
| 391 | + | ||
| 392 | + Returns: | ||
| 393 | + SentimentResult或BatchSentimentResult | ||
| 394 | + """ | ||
| 395 | + if initialize_if_needed and not multilingual_sentiment_analyzer.is_initialized: | ||
| 396 | + if not multilingual_sentiment_analyzer.initialize(): | ||
| 397 | + # 如果初始化失败,返回失败结果 | ||
| 398 | + if isinstance(text_or_texts, str): | ||
| 399 | + return SentimentResult( | ||
| 400 | + text=text_or_texts, | ||
| 401 | + sentiment_label="初始化失败", | ||
| 402 | + confidence=0.0, | ||
| 403 | + probability_distribution={}, | ||
| 404 | + success=False, | ||
| 405 | + error_message="模型初始化失败" | ||
| 406 | + ) | ||
| 407 | + else: | ||
| 408 | + return BatchSentimentResult( | ||
| 409 | + results=[], | ||
| 410 | + total_processed=0, | ||
| 411 | + success_count=0, | ||
| 412 | + failed_count=len(text_or_texts), | ||
| 413 | + average_confidence=0.0 | ||
| 414 | + ) | ||
| 415 | + | ||
| 416 | + if isinstance(text_or_texts, str): | ||
| 417 | + return multilingual_sentiment_analyzer.analyze_single_text(text_or_texts) | ||
| 418 | + else: | ||
| 419 | + return multilingual_sentiment_analyzer.analyze_batch(text_or_texts) | ||
| 420 | + | ||
| 421 | + | ||
| 422 | +if __name__ == "__main__": | ||
| 423 | + # 测试代码 | ||
| 424 | + analyzer = WeiboMultilingualSentimentAnalyzer() | ||
| 425 | + | ||
| 426 | + if analyzer.initialize(): | ||
| 427 | + # 测试单个文本 | ||
| 428 | + result = analyzer.analyze_single_text("今天天气真好,心情特别棒!") | ||
| 429 | + print(f"单个文本分析: {result.sentiment_label} (置信度: {result.confidence:.4f})") | ||
| 430 | + | ||
| 431 | + # 测试批量文本 | ||
| 432 | + test_texts = [ | ||
| 433 | + "这家餐厅的菜味道非常棒!", | ||
| 434 | + "服务态度太差了,很失望", | ||
| 435 | + "I absolutely love this product!", | ||
| 436 | + "The customer service was disappointing." | ||
| 437 | + ] | ||
| 438 | + | ||
| 439 | + batch_result = analyzer.analyze_batch(test_texts) | ||
| 440 | + print(f"\n批量分析: 成功 {batch_result.success_count}/{batch_result.total_processed}") | ||
| 441 | + | ||
| 442 | + for result in batch_result.results: | ||
| 443 | + print(f"'{result.text[:30]}...' -> {result.sentiment_label} ({result.confidence:.4f})") | ||
| 444 | + else: | ||
| 445 | + print("模型初始化失败,无法进行测试") |
| @@ -14,6 +14,7 @@ class Config: | @@ -14,6 +14,7 @@ class Config: | ||
| 14 | # API密钥 | 14 | # API密钥 |
| 15 | deepseek_api_key: Optional[str] = None | 15 | deepseek_api_key: Optional[str] = None |
| 16 | openai_api_key: Optional[str] = None | 16 | openai_api_key: Optional[str] = None |
| 17 | + kimi_api_key: Optional[str] = None | ||
| 17 | 18 | ||
| 18 | # 数据库配置 | 19 | # 数据库配置 |
| 19 | db_host: Optional[str] = None | 20 | db_host: Optional[str] = None |
| @@ -24,13 +25,14 @@ class Config: | @@ -24,13 +25,14 @@ class Config: | ||
| 24 | db_charset: str = "utf8mb4" | 25 | db_charset: str = "utf8mb4" |
| 25 | 26 | ||
| 26 | # 模型配置 | 27 | # 模型配置 |
| 27 | - default_llm_provider: str = "deepseek" # deepseek 或 openai | 28 | + default_llm_provider: str = "deepseek" # deepseek、openai 或 kimi |
| 28 | deepseek_model: str = "deepseek-chat" | 29 | deepseek_model: str = "deepseek-chat" |
| 29 | openai_model: str = "gpt-4o-mini" | 30 | openai_model: str = "gpt-4o-mini" |
| 31 | + kimi_model: str = "kimi-k2-0711-preview" | ||
| 30 | 32 | ||
| 31 | # 搜索配置 | 33 | # 搜索配置 |
| 32 | search_timeout: int = 240 | 34 | search_timeout: int = 240 |
| 33 | - max_content_length: int = 100000 | 35 | + max_content_length: int = 500000 # 提高5倍以充分利用Kimi的长文本能力 |
| 34 | 36 | ||
| 35 | # 数据库查询限制 | 37 | # 数据库查询限制 |
| 36 | default_search_hot_content_limit: int = 100 | 38 | default_search_hot_content_limit: int = 100 |
| @@ -43,6 +45,10 @@ class Config: | @@ -43,6 +45,10 @@ class Config: | ||
| 43 | max_reflections: int = 3 | 45 | max_reflections: int = 3 |
| 44 | max_paragraphs: int = 6 | 46 | max_paragraphs: int = 6 |
| 45 | 47 | ||
| 48 | + # 结果处理限制 | ||
| 49 | + max_search_results_for_llm: int = 0 # 0表示不限制,传递所有搜索结果给LLM | ||
| 50 | + max_high_confidence_sentiment_results: int = 0 # 0表示不限制,返回所有高置信度情感分析结果 | ||
| 51 | + | ||
| 46 | # 输出配置 | 52 | # 输出配置 |
| 47 | output_dir: str = "reports" | 53 | output_dir: str = "reports" |
| 48 | save_intermediate_states: bool = True | 54 | save_intermediate_states: bool = True |
| @@ -102,6 +108,10 @@ class Config: | @@ -102,6 +108,10 @@ class Config: | ||
| 102 | 108 | ||
| 103 | max_reflections=getattr(config_module, "MAX_REFLECTIONS", 2), | 109 | max_reflections=getattr(config_module, "MAX_REFLECTIONS", 2), |
| 104 | max_paragraphs=getattr(config_module, "MAX_PARAGRAPHS", 5), | 110 | max_paragraphs=getattr(config_module, "MAX_PARAGRAPHS", 5), |
| 111 | + | ||
| 112 | + max_search_results_for_llm=getattr(config_module, "MAX_SEARCH_RESULTS_FOR_LLM", 0), | ||
| 113 | + max_high_confidence_sentiment_results=getattr(config_module, "MAX_HIGH_CONFIDENCE_SENTIMENT_RESULTS", 0), | ||
| 114 | + | ||
| 105 | output_dir=getattr(config_module, "OUTPUT_DIR", "reports"), | 115 | output_dir=getattr(config_module, "OUTPUT_DIR", "reports"), |
| 106 | save_intermediate_states=getattr(config_module, "SAVE_INTERMEDIATE_STATES", True) | 116 | save_intermediate_states=getattr(config_module, "SAVE_INTERMEDIATE_STATES", True) |
| 107 | ) | 117 | ) |
| @@ -120,6 +130,7 @@ class Config: | @@ -120,6 +130,7 @@ class Config: | ||
| 120 | return cls( | 130 | return cls( |
| 121 | deepseek_api_key=config_dict.get("DEEPSEEK_API_KEY"), | 131 | deepseek_api_key=config_dict.get("DEEPSEEK_API_KEY"), |
| 122 | openai_api_key=config_dict.get("OPENAI_API_KEY"), | 132 | openai_api_key=config_dict.get("OPENAI_API_KEY"), |
| 133 | + kimi_api_key=config_dict.get("KIMI_API_KEY"), | ||
| 123 | 134 | ||
| 124 | db_host=config_dict.get("DB_HOST"), | 135 | db_host=config_dict.get("DB_HOST"), |
| 125 | db_user=config_dict.get("DB_USER"), | 136 | db_user=config_dict.get("DB_USER"), |
| @@ -131,9 +142,10 @@ class Config: | @@ -131,9 +142,10 @@ class Config: | ||
| 131 | default_llm_provider=config_dict.get("DEFAULT_LLM_PROVIDER", "deepseek"), | 142 | default_llm_provider=config_dict.get("DEFAULT_LLM_PROVIDER", "deepseek"), |
| 132 | deepseek_model=config_dict.get("DEEPSEEK_MODEL", "deepseek-chat"), | 143 | deepseek_model=config_dict.get("DEEPSEEK_MODEL", "deepseek-chat"), |
| 133 | openai_model=config_dict.get("OPENAI_MODEL", "gpt-4o-mini"), | 144 | openai_model=config_dict.get("OPENAI_MODEL", "gpt-4o-mini"), |
| 145 | + kimi_model=config_dict.get("KIMI_MODEL", "kimi-k2-0711-preview"), | ||
| 134 | 146 | ||
| 135 | search_timeout=int(config_dict.get("SEARCH_TIMEOUT", "240")), | 147 | search_timeout=int(config_dict.get("SEARCH_TIMEOUT", "240")), |
| 136 | - max_content_length=int(config_dict.get("SEARCH_CONTENT_MAX_LENGTH", "200000")), | 148 | + max_content_length=int(config_dict.get("SEARCH_CONTENT_MAX_LENGTH", "500000")), |
| 137 | 149 | ||
| 138 | default_search_hot_content_limit=int(config_dict.get("DEFAULT_SEARCH_HOT_CONTENT_LIMIT", "100")), | 150 | default_search_hot_content_limit=int(config_dict.get("DEFAULT_SEARCH_HOT_CONTENT_LIMIT", "100")), |
| 139 | default_search_topic_globally_limit_per_table=int(config_dict.get("DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE", "50")), | 151 | default_search_topic_globally_limit_per_table=int(config_dict.get("DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE", "50")), |
| @@ -143,6 +155,10 @@ class Config: | @@ -143,6 +155,10 @@ class Config: | ||
| 143 | 155 | ||
| 144 | max_reflections=int(config_dict.get("MAX_REFLECTIONS", "2")), | 156 | max_reflections=int(config_dict.get("MAX_REFLECTIONS", "2")), |
| 145 | max_paragraphs=int(config_dict.get("MAX_PARAGRAPHS", "5")), | 157 | max_paragraphs=int(config_dict.get("MAX_PARAGRAPHS", "5")), |
| 158 | + | ||
| 159 | + max_search_results_for_llm=int(config_dict.get("MAX_SEARCH_RESULTS_FOR_LLM", "0")), | ||
| 160 | + max_high_confidence_sentiment_results=int(config_dict.get("MAX_HIGH_CONFIDENCE_SENTIMENT_RESULTS", "0")), | ||
| 161 | + | ||
| 146 | output_dir=config_dict.get("OUTPUT_DIR", "reports"), | 162 | output_dir=config_dict.get("OUTPUT_DIR", "reports"), |
| 147 | save_intermediate_states=config_dict.get("SAVE_INTERMEDIATE_STATES", "true").lower() == "true" | 163 | save_intermediate_states=config_dict.get("SAVE_INTERMEDIATE_STATES", "true").lower() == "true" |
| 148 | ) | 164 | ) |
| @@ -9,7 +9,7 @@ import re | @@ -9,7 +9,7 @@ import re | ||
| 9 | from datetime import datetime | 9 | from datetime import datetime |
| 10 | from typing import Optional, Dict, Any, List | 10 | from typing import Optional, Dict, Any, List |
| 11 | 11 | ||
| 12 | -from .llms import DeepSeekLLM, OpenAILLM, BaseLLM | 12 | +from .llms import DeepSeekLLM, OpenAILLM, GeminiLLM, BaseLLM |
| 13 | from .nodes import ( | 13 | from .nodes import ( |
| 14 | ReportStructureNode, | 14 | ReportStructureNode, |
| 15 | FirstSearchNode, | 15 | FirstSearchNode, |
| @@ -67,6 +67,11 @@ class DeepSearchAgent: | @@ -67,6 +67,11 @@ class DeepSearchAgent: | ||
| 67 | api_key=self.config.openai_api_key, | 67 | api_key=self.config.openai_api_key, |
| 68 | model_name=self.config.openai_model | 68 | model_name=self.config.openai_model |
| 69 | ) | 69 | ) |
| 70 | + elif self.config.default_llm_provider == "gemini": | ||
| 71 | + return GeminiLLM( | ||
| 72 | + api_key=self.config.gemini_api_key, | ||
| 73 | + model_name=self.config.gemini_model | ||
| 74 | + ) | ||
| 70 | else: | 75 | else: |
| 71 | raise ValueError(f"不支持的LLM提供商: {self.config.default_llm_provider}") | 76 | raise ValueError(f"不支持的LLM提供商: {self.config.default_llm_provider}") |
| 72 | 77 |
| @@ -6,5 +6,6 @@ LLM调用模块 | @@ -6,5 +6,6 @@ LLM调用模块 | ||
| 6 | from .base import BaseLLM | 6 | from .base import BaseLLM |
| 7 | from .deepseek import DeepSeekLLM | 7 | from .deepseek import DeepSeekLLM |
| 8 | from .openai_llm import OpenAILLM | 8 | from .openai_llm import OpenAILLM |
| 9 | +from .gemini_llm import GeminiLLM | ||
| 9 | 10 | ||
| 10 | -__all__ = ["BaseLLM", "DeepSeekLLM", "OpenAILLM"] | 11 | +__all__ = ["BaseLLM", "DeepSeekLLM", "OpenAILLM", "GeminiLLM"] |
MediaEngine/llms/gemini_llm.py
0 → 100644
| 1 | +""" | ||
| 2 | +Gemini LLM实现 | ||
| 3 | +使用Gemini 2.5-pro中转API进行文本生成 | ||
| 4 | +""" | ||
| 5 | + | ||
| 6 | +import os | ||
| 7 | +from typing import Optional, Dict, Any | ||
| 8 | +from openai import OpenAI | ||
| 9 | +from .base import BaseLLM | ||
| 10 | + | ||
| 11 | + | ||
| 12 | +class GeminiLLM(BaseLLM): | ||
| 13 | + """Gemini LLM实现类""" | ||
| 14 | + | ||
| 15 | + def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None): | ||
| 16 | + """ | ||
| 17 | + 初始化Gemini客户端 | ||
| 18 | + | ||
| 19 | + Args: | ||
| 20 | + api_key: Gemini API密钥,如果不提供则从环境变量读取 | ||
| 21 | + model_name: 模型名称,默认使用gemini-2.5-pro | ||
| 22 | + """ | ||
| 23 | + if api_key is None: | ||
| 24 | + api_key = os.getenv("GEMINI_API_KEY") | ||
| 25 | + if not api_key: | ||
| 26 | + raise ValueError("Gemini API Key未找到!请设置GEMINI_API_KEY环境变量或在初始化时提供") | ||
| 27 | + | ||
| 28 | + super().__init__(api_key, model_name) | ||
| 29 | + | ||
| 30 | + # 初始化OpenAI客户端,使用Gemini的中转endpoint | ||
| 31 | + self.client = OpenAI( | ||
| 32 | + api_key=self.api_key, | ||
| 33 | + base_url="https://www.chataiapi.com/v1" | ||
| 34 | + ) | ||
| 35 | + | ||
| 36 | + self.default_model = model_name or self.get_default_model() | ||
| 37 | + | ||
| 38 | + def get_default_model(self) -> str: | ||
| 39 | + """获取默认模型名称""" | ||
| 40 | + return "gemini-2.5-pro" | ||
| 41 | + | ||
| 42 | + def invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> str: | ||
| 43 | + """ | ||
| 44 | + 调用Gemini API生成回复 | ||
| 45 | + | ||
| 46 | + Args: | ||
| 47 | + system_prompt: 系统提示词 | ||
| 48 | + user_prompt: 用户输入 | ||
| 49 | + **kwargs: 其他参数,如temperature、max_tokens等 | ||
| 50 | + | ||
| 51 | + Returns: | ||
| 52 | + Gemini生成的回复文本 | ||
| 53 | + """ | ||
| 54 | + try: | ||
| 55 | + # 构建消息 | ||
| 56 | + messages = [ | ||
| 57 | + {"role": "system", "content": system_prompt}, | ||
| 58 | + {"role": "user", "content": user_prompt} | ||
| 59 | + ] | ||
| 60 | + | ||
| 61 | + # 设置默认参数 | ||
| 62 | + params = { | ||
| 63 | + "model": self.default_model, | ||
| 64 | + "messages": messages, | ||
| 65 | + "temperature": kwargs.get("temperature", 0.7), | ||
| 66 | + "max_tokens": kwargs.get("max_tokens", 4000), | ||
| 67 | + "stream": False | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | + # 调用API | ||
| 71 | + response = self.client.chat.completions.create(**params) | ||
| 72 | + | ||
| 73 | + # 提取回复内容 | ||
| 74 | + if response.choices and response.choices[0].message: | ||
| 75 | + content = response.choices[0].message.content | ||
| 76 | + return self.validate_response(content) | ||
| 77 | + else: | ||
| 78 | + return "" | ||
| 79 | + | ||
| 80 | + except Exception as e: | ||
| 81 | + print(f"Gemini API调用错误: {str(e)}") | ||
| 82 | + raise e | ||
| 83 | + | ||
| 84 | + def get_model_info(self) -> Dict[str, Any]: | ||
| 85 | + """ | ||
| 86 | + 获取当前模型信息 | ||
| 87 | + | ||
| 88 | + Returns: | ||
| 89 | + 模型信息字典 | ||
| 90 | + """ | ||
| 91 | + return { | ||
| 92 | + "provider": "Gemini", | ||
| 93 | + "model": self.default_model, | ||
| 94 | + "api_base": "https://www.chataiapi.com/v1" | ||
| 95 | + } |
| @@ -14,12 +14,14 @@ class Config: | @@ -14,12 +14,14 @@ class Config: | ||
| 14 | # API密钥 | 14 | # API密钥 |
| 15 | deepseek_api_key: Optional[str] = None | 15 | deepseek_api_key: Optional[str] = None |
| 16 | openai_api_key: Optional[str] = None | 16 | openai_api_key: Optional[str] = None |
| 17 | + gemini_api_key: Optional[str] = None | ||
| 17 | bocha_api_key: Optional[str] = None | 18 | bocha_api_key: Optional[str] = None |
| 18 | 19 | ||
| 19 | # 模型配置 | 20 | # 模型配置 |
| 20 | - default_llm_provider: str = "deepseek" # deepseek 或 openai | 21 | + default_llm_provider: str = "deepseek" # deepseek、openai 或 gemini |
| 21 | deepseek_model: str = "deepseek-chat" | 22 | deepseek_model: str = "deepseek-chat" |
| 22 | openai_model: str = "gpt-4o-mini" | 23 | openai_model: str = "gpt-4o-mini" |
| 24 | + gemini_model: str = "gemini-2.5-pro" | ||
| 23 | 25 | ||
| 24 | # 搜索配置 | 26 | # 搜索配置 |
| 25 | search_timeout: int = 240 | 27 | search_timeout: int = 240 |
| @@ -44,6 +46,10 @@ class Config: | @@ -44,6 +46,10 @@ class Config: | ||
| 44 | print("错误: OpenAI API Key未设置") | 46 | print("错误: OpenAI API Key未设置") |
| 45 | return False | 47 | return False |
| 46 | 48 | ||
| 49 | + if self.default_llm_provider == "gemini" and not self.gemini_api_key: | ||
| 50 | + print("错误: Gemini API Key未设置") | ||
| 51 | + return False | ||
| 52 | + | ||
| 47 | if not self.bocha_api_key: | 53 | if not self.bocha_api_key: |
| 48 | print("错误: Bocha API Key未设置") | 54 | print("错误: Bocha API Key未设置") |
| 49 | return False | 55 | return False |
| @@ -65,11 +71,12 @@ class Config: | @@ -65,11 +71,12 @@ class Config: | ||
| 65 | return cls( | 71 | return cls( |
| 66 | deepseek_api_key=getattr(config_module, "DEEPSEEK_API_KEY", None), | 72 | deepseek_api_key=getattr(config_module, "DEEPSEEK_API_KEY", None), |
| 67 | openai_api_key=getattr(config_module, "OPENAI_API_KEY", None), | 73 | openai_api_key=getattr(config_module, "OPENAI_API_KEY", None), |
| 74 | + gemini_api_key=getattr(config_module, "GEMINI_API_KEY", None), | ||
| 68 | bocha_api_key=getattr(config_module, "BOCHA_API_KEY", None), | 75 | bocha_api_key=getattr(config_module, "BOCHA_API_KEY", None), |
| 69 | default_llm_provider=getattr(config_module, "DEFAULT_LLM_PROVIDER", "deepseek"), | 76 | default_llm_provider=getattr(config_module, "DEFAULT_LLM_PROVIDER", "deepseek"), |
| 70 | deepseek_model=getattr(config_module, "DEEPSEEK_MODEL", "deepseek-chat"), | 77 | deepseek_model=getattr(config_module, "DEEPSEEK_MODEL", "deepseek-chat"), |
| 71 | openai_model=getattr(config_module, "OPENAI_MODEL", "gpt-4o-mini"), | 78 | openai_model=getattr(config_module, "OPENAI_MODEL", "gpt-4o-mini"), |
| 72 | - | 79 | + gemini_model=getattr(config_module, "GEMINI_MODEL", "gemini-2.5-pro"), |
| 73 | search_timeout=getattr(config_module, "SEARCH_TIMEOUT", 240), | 80 | search_timeout=getattr(config_module, "SEARCH_TIMEOUT", 240), |
| 74 | max_content_length=getattr(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000), | 81 | max_content_length=getattr(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000), |
| 75 | max_reflections=getattr(config_module, "MAX_REFLECTIONS", 2), | 82 | max_reflections=getattr(config_module, "MAX_REFLECTIONS", 2), |
| @@ -92,11 +99,12 @@ class Config: | @@ -92,11 +99,12 @@ class Config: | ||
| 92 | return cls( | 99 | return cls( |
| 93 | deepseek_api_key=config_dict.get("DEEPSEEK_API_KEY"), | 100 | deepseek_api_key=config_dict.get("DEEPSEEK_API_KEY"), |
| 94 | openai_api_key=config_dict.get("OPENAI_API_KEY"), | 101 | openai_api_key=config_dict.get("OPENAI_API_KEY"), |
| 102 | + gemini_api_key=config_dict.get("GEMINI_API_KEY"), | ||
| 95 | bocha_api_key=config_dict.get("BOCHA_API_KEY"), | 103 | bocha_api_key=config_dict.get("BOCHA_API_KEY"), |
| 96 | default_llm_provider=config_dict.get("DEFAULT_LLM_PROVIDER", "deepseek"), | 104 | default_llm_provider=config_dict.get("DEFAULT_LLM_PROVIDER", "deepseek"), |
| 97 | deepseek_model=config_dict.get("DEEPSEEK_MODEL", "deepseek-chat"), | 105 | deepseek_model=config_dict.get("DEEPSEEK_MODEL", "deepseek-chat"), |
| 98 | openai_model=config_dict.get("OPENAI_MODEL", "gpt-4o-mini"), | 106 | openai_model=config_dict.get("OPENAI_MODEL", "gpt-4o-mini"), |
| 99 | - | 107 | + gemini_model=config_dict.get("GEMINI_MODEL", "gemini-2.5-pro"), |
| 100 | search_timeout=int(config_dict.get("SEARCH_TIMEOUT", "240")), | 108 | search_timeout=int(config_dict.get("SEARCH_TIMEOUT", "240")), |
| 101 | max_content_length=int(config_dict.get("SEARCH_CONTENT_MAX_LENGTH", "20000")), | 109 | max_content_length=int(config_dict.get("SEARCH_CONTENT_MAX_LENGTH", "20000")), |
| 102 | max_reflections=int(config_dict.get("MAX_REFLECTIONS", "2")), | 110 | max_reflections=int(config_dict.get("MAX_REFLECTIONS", "2")), |
| @@ -161,5 +161,30 @@ def show_multilingual_demo(tokenizer, model, device, sentiment_map): | @@ -161,5 +161,30 @@ def show_multilingual_demo(tokenizer, model, device, sentiment_map): | ||
| 161 | 161 | ||
| 162 | print("\n=== 示例结束 ===") | 162 | print("\n=== 示例结束 ===") |
| 163 | 163 | ||
| 164 | + ''' | ||
| 165 | + 正在加载多语言情感分析模型... | ||
| 166 | +从本地加载模型... | ||
| 167 | +模型加载成功! 使用设备: cuda | ||
| 168 | + | ||
| 169 | +============= 多语言情感分析 ============= | ||
| 170 | +支持语言: 中文、英文、西班牙文、阿拉伯文、日文、韩文等22种语言 | ||
| 171 | +情感等级: 非常负面、负面、中性、正面、非常正面 | ||
| 172 | +输入文本进行分析 (输入 'q' 退出): | ||
| 173 | +输入 'demo' 查看多语言示例 | ||
| 174 | + | ||
| 175 | +请输入文本: 我喜欢你 | ||
| 176 | +C:\Users\67093\.conda\envs\pytorch_python11\Lib\site-packages\transformers\models\distilbert\modeling_distilbert.py:401: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at C:\cb\pytorch_1000000000000\work\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:263.) | ||
| 177 | + attn_output = torch.nn.functional.scaled_dot_product_attention( | ||
| 178 | +预测结果: 正面 (置信度: 0.5204) | ||
| 179 | +详细概率分布: | ||
| 180 | + 非常负面: 0.0329 | ||
| 181 | + 负面: 0.0263 | ||
| 182 | + 中性: 0.1987 | ||
| 183 | + 正面: 0.5204 | ||
| 184 | + 非常正面: 0.2216 | ||
| 185 | + | ||
| 186 | +请输入文本: | ||
| 187 | + ''' | ||
| 188 | + | ||
| 164 | if __name__ == "__main__": | 189 | if __name__ == "__main__": |
| 165 | main() | 190 | main() |
| @@ -10,10 +10,10 @@ from datetime import datetime | @@ -10,10 +10,10 @@ from datetime import datetime | ||
| 10 | import json | 10 | import json |
| 11 | 11 | ||
| 12 | # 添加src目录到Python路径 | 12 | # 添加src目录到Python路径 |
| 13 | -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '.')) | 13 | +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) |
| 14 | 14 | ||
| 15 | from InsightEngine import DeepSearchAgent, Config | 15 | from InsightEngine import DeepSearchAgent, Config |
| 16 | -from config import DEEPSEEK_API_KEY, DB_HOST, DB_USER, DB_PASSWORD, DB_NAME, DB_PORT, DB_CHARSET | 16 | +from config import DEEPSEEK_API_KEY, KIMI_API_KEY, DB_HOST, DB_USER, DB_PASSWORD, DB_NAME, DB_PORT, DB_CHARSET |
| 17 | 17 | ||
| 18 | 18 | ||
| 19 | def main(): | 19 | def main(): |
| @@ -31,20 +31,38 @@ def main(): | @@ -31,20 +31,38 @@ def main(): | ||
| 31 | with st.sidebar: | 31 | with st.sidebar: |
| 32 | st.header("配置") | 32 | st.header("配置") |
| 33 | 33 | ||
| 34 | + # 模型选择 | ||
| 35 | + llm_provider = st.selectbox("LLM提供商", ["deepseek", "openai", "kimi"]) | ||
| 36 | + | ||
| 34 | # 高级配置 | 37 | # 高级配置 |
| 35 | st.subheader("高级配置") | 38 | st.subheader("高级配置") |
| 36 | max_reflections = st.slider("反思次数", 1, 5, 2) | 39 | max_reflections = st.slider("反思次数", 1, 5, 2) |
| 37 | - max_content_length = st.number_input("最大内容长度", 10000, 500000, 200000) # 提高10倍:1000-50000-20000 → 10000-500000-200000 | ||
| 38 | 40 | ||
| 39 | - # 模型选择 | ||
| 40 | - llm_provider = st.selectbox("LLM提供商", ["deepseek", "openai"]) | 41 | + # 根据选择的模型动态调整默认值 |
| 42 | + if llm_provider == "kimi": | ||
| 43 | + default_content_length = 500000 # Kimi支持长文本,使用更大的默认值 | ||
| 44 | + max_limit = 1000000 # 提高上限 | ||
| 45 | + st.info("💡 Kimi模型支持超长文本处理,建议使用更大的内容长度以充分利用其能力") | ||
| 46 | + else: | ||
| 47 | + default_content_length = 200000 | ||
| 48 | + max_limit = 500000 | ||
| 49 | + | ||
| 50 | + max_content_length = st.number_input("最大内容长度", 10000, max_limit, default_content_length) | ||
| 51 | + | ||
| 52 | + # 初始化所有可能的变量 | ||
| 53 | + openai_key = "" | ||
| 54 | + kimi_key = "" | ||
| 41 | 55 | ||
| 42 | if llm_provider == "deepseek": | 56 | if llm_provider == "deepseek": |
| 43 | model_name = st.selectbox("DeepSeek模型", ["deepseek-chat"]) | 57 | model_name = st.selectbox("DeepSeek模型", ["deepseek-chat"]) |
| 44 | - else: | 58 | + elif llm_provider == "openai": |
| 45 | model_name = st.selectbox("OpenAI模型", ["gpt-4o-mini", "gpt-4o"]) | 59 | model_name = st.selectbox("OpenAI模型", ["gpt-4o-mini", "gpt-4o"]) |
| 46 | openai_key = st.text_input("OpenAI API Key", type="password", | 60 | openai_key = st.text_input("OpenAI API Key", type="password", |
| 47 | value="") | 61 | value="") |
| 62 | + else: # kimi | ||
| 63 | + model_name = st.selectbox("Kimi模型", ["kimi-k2-0711-preview"]) | ||
| 64 | + kimi_key = st.text_input("Kimi API Key", type="password", | ||
| 65 | + value="") | ||
| 48 | 66 | ||
| 49 | # 主界面 | 67 | # 主界面 |
| 50 | col1, col2 = st.columns([2, 1]) | 68 | col1, col2 = st.columns([2, 1]) |
| @@ -96,8 +114,13 @@ def main(): | @@ -96,8 +114,13 @@ def main(): | ||
| 96 | st.error("请提供OpenAI API Key") | 114 | st.error("请提供OpenAI API Key") |
| 97 | return | 115 | return |
| 98 | 116 | ||
| 117 | + if llm_provider == "kimi" and not kimi_key and not KIMI_API_KEY: | ||
| 118 | + st.error("请提供Kimi API Key或在配置文件中设置KIMI_API_KEY") | ||
| 119 | + return | ||
| 120 | + | ||
| 99 | # 自动使用配置文件中的API密钥和数据库配置 | 121 | # 自动使用配置文件中的API密钥和数据库配置 |
| 100 | deepseek_key = DEEPSEEK_API_KEY | 122 | deepseek_key = DEEPSEEK_API_KEY |
| 123 | + kimi_key_final = kimi_key if kimi_key else KIMI_API_KEY | ||
| 101 | db_host = DB_HOST | 124 | db_host = DB_HOST |
| 102 | db_user = DB_USER | 125 | db_user = DB_USER |
| 103 | db_password = DB_PASSWORD | 126 | db_password = DB_PASSWORD |
| @@ -109,6 +132,7 @@ def main(): | @@ -109,6 +132,7 @@ def main(): | ||
| 109 | config = Config( | 132 | config = Config( |
| 110 | deepseek_api_key=deepseek_key if llm_provider == "deepseek" else None, | 133 | deepseek_api_key=deepseek_key if llm_provider == "deepseek" else None, |
| 111 | openai_api_key=openai_key if llm_provider == "openai" else None, | 134 | openai_api_key=openai_key if llm_provider == "openai" else None, |
| 135 | + kimi_api_key=kimi_key_final if llm_provider == "kimi" else None, | ||
| 112 | db_host=db_host, | 136 | db_host=db_host, |
| 113 | db_user=db_user, | 137 | db_user=db_user, |
| 114 | db_password=db_password, | 138 | db_password=db_password, |
| @@ -118,6 +142,7 @@ def main(): | @@ -118,6 +142,7 @@ def main(): | ||
| 118 | default_llm_provider=llm_provider, | 142 | default_llm_provider=llm_provider, |
| 119 | deepseek_model=model_name if llm_provider == "deepseek" else "deepseek-chat", | 143 | deepseek_model=model_name if llm_provider == "deepseek" else "deepseek-chat", |
| 120 | openai_model=model_name if llm_provider == "openai" else "gpt-4o-mini", | 144 | openai_model=model_name if llm_provider == "openai" else "gpt-4o-mini", |
| 145 | + kimi_model=model_name if llm_provider == "kimi" else "kimi-k2-0711-preview", | ||
| 121 | max_reflections=max_reflections, | 146 | max_reflections=max_reflections, |
| 122 | max_content_length=max_content_length, | 147 | max_content_length=max_content_length, |
| 123 | output_dir="insight_engine_streamlit_reports" | 148 | output_dir="insight_engine_streamlit_reports" |
| @@ -13,7 +13,7 @@ import json | @@ -13,7 +13,7 @@ import json | ||
| 13 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) | 13 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) |
| 14 | 14 | ||
| 15 | from MediaEngine import DeepSearchAgent, Config | 15 | from MediaEngine import DeepSearchAgent, Config |
| 16 | -from config import DEEPSEEK_API_KEY, BOCHA_Web_Search_API_KEY | 16 | +from config import DEEPSEEK_API_KEY, BOCHA_Web_Search_API_KEY, GEMINI_API_KEY |
| 17 | 17 | ||
| 18 | 18 | ||
| 19 | def main(): | 19 | def main(): |
| @@ -37,14 +37,16 @@ def main(): | @@ -37,14 +37,16 @@ def main(): | ||
| 37 | max_content_length = st.number_input("最大内容长度", 1000, 50000, 20000) | 37 | max_content_length = st.number_input("最大内容长度", 1000, 50000, 20000) |
| 38 | 38 | ||
| 39 | # 模型选择 | 39 | # 模型选择 |
| 40 | - llm_provider = st.selectbox("LLM提供商", ["deepseek", "openai"]) | 40 | + llm_provider = st.selectbox("LLM提供商", ["deepseek", "openai", "gemini"]) |
| 41 | 41 | ||
| 42 | + openai_key = "" # 初始化变量 | ||
| 42 | if llm_provider == "deepseek": | 43 | if llm_provider == "deepseek": |
| 43 | model_name = st.selectbox("DeepSeek模型", ["deepseek-chat"]) | 44 | model_name = st.selectbox("DeepSeek模型", ["deepseek-chat"]) |
| 44 | - else: | 45 | + elif llm_provider == "openai": |
| 45 | model_name = st.selectbox("OpenAI模型", ["gpt-4o-mini", "gpt-4o"]) | 46 | model_name = st.selectbox("OpenAI模型", ["gpt-4o-mini", "gpt-4o"]) |
| 46 | - openai_key = st.text_input("OpenAI API Key", type="password", | ||
| 47 | - value="") | 47 | + openai_key = st.text_input("OpenAI API Key", type="password", value="") |
| 48 | + else: # gemini | ||
| 49 | + model_name = st.selectbox("Gemini模型", ["gemini-2.5-pro"]) | ||
| 48 | 50 | ||
| 49 | # 主界面 | 51 | # 主界面 |
| 50 | col1, col2 = st.columns([2, 1]) | 52 | col1, col2 = st.columns([2, 1]) |
| @@ -98,16 +100,19 @@ def main(): | @@ -98,16 +100,19 @@ def main(): | ||
| 98 | 100 | ||
| 99 | # 自动使用配置文件中的API密钥 | 101 | # 自动使用配置文件中的API密钥 |
| 100 | deepseek_key = DEEPSEEK_API_KEY | 102 | deepseek_key = DEEPSEEK_API_KEY |
| 103 | + gemini_key = GEMINI_API_KEY # 使用config.py中的Gemini API密钥 | ||
| 101 | bocha_key = BOCHA_Web_Search_API_KEY | 104 | bocha_key = BOCHA_Web_Search_API_KEY |
| 102 | 105 | ||
| 103 | # 创建配置 | 106 | # 创建配置 |
| 104 | config = Config( | 107 | config = Config( |
| 105 | deepseek_api_key=deepseek_key if llm_provider == "deepseek" else None, | 108 | deepseek_api_key=deepseek_key if llm_provider == "deepseek" else None, |
| 106 | openai_api_key=openai_key if llm_provider == "openai" else None, | 109 | openai_api_key=openai_key if llm_provider == "openai" else None, |
| 110 | + gemini_api_key=gemini_key if llm_provider == "gemini" else None, | ||
| 107 | bocha_api_key=bocha_key, | 111 | bocha_api_key=bocha_key, |
| 108 | default_llm_provider=llm_provider, | 112 | default_llm_provider=llm_provider, |
| 109 | deepseek_model=model_name if llm_provider == "deepseek" else "deepseek-chat", | 113 | deepseek_model=model_name if llm_provider == "deepseek" else "deepseek-chat", |
| 110 | openai_model=model_name if llm_provider == "openai" else "gpt-4o-mini", | 114 | openai_model=model_name if llm_provider == "openai" else "gpt-4o-mini", |
| 115 | + gemini_model=model_name if llm_provider == "gemini" else "gemini-2.5-pro", | ||
| 111 | max_reflections=max_reflections, | 116 | max_reflections=max_reflections, |
| 112 | max_content_length=max_content_length, | 117 | max_content_length=max_content_length, |
| 113 | output_dir="media_engine_streamlit_reports" | 118 | output_dir="media_engine_streamlit_reports" |
| 1 | # -*- coding: utf-8 -*- | 1 | # -*- coding: utf-8 -*- |
| 2 | """ | 2 | """ |
| 3 | -智能舆情分析平台配置文件 | ||
| 4 | -存储数据库连接信息和API密钥 | 3 | +Intelligence Public Opinion Analysis Platform Configuration File |
| 4 | +Stores database connection information and API keys | ||
| 5 | """ | 5 | """ |
| 6 | 6 | ||
| 7 | -# MySQL数据库配置 | ||
| 8 | -DB_HOST = "rm-2zeib6b13f6tt9kncoo.mysql.rds.aliyuncs.com" | 7 | +# MySQL Database Configuration |
| 8 | +DB_HOST = "your_database_host" # e.g., "localhost" or "127.0.0.1" | ||
| 9 | DB_PORT = 3306 | 9 | DB_PORT = 3306 |
| 10 | -DB_USER = "root" | ||
| 11 | -DB_PASSWORD = "mneDccc7sHHANtFk" | ||
| 12 | -DB_NAME = "media_crawler" | 10 | +DB_USER = "your_database_user" |
| 11 | +DB_PASSWORD = "your_database_password" | ||
| 12 | +DB_NAME = "your_database_name" | ||
| 13 | DB_CHARSET = "utf8mb4" | 13 | DB_CHARSET = "utf8mb4" |
| 14 | 14 | ||
| 15 | -# agent1 DeepSeek API密钥 | ||
| 16 | -DEEPSEEK_API_KEY = "sk-4bbc57fadd234666a3840f1a7edc1f2e" | 15 | +# DeepSeek API Key |
| 16 | +# 申请地址https://www.deepseek.com/ | ||
| 17 | +DEEPSEEK_API_KEY = "your_deepseek_api_key" | ||
| 17 | 18 | ||
| 18 | -# agent2 DeepSeek API密钥 | ||
| 19 | -DEEPSEEK_API_KEY_2 = "sk-b26405d2e02f475c960d21c2acce61e7" | 19 | +# Tavily Search API Key |
| 20 | +# 申请地址https://www.tavily.com/ | ||
| 21 | +TAVILY_API_KEY = "your_tavily_api_key" | ||
| 20 | 22 | ||
| 21 | -# Tavily搜索API密钥 | ||
| 22 | -TAVILY_API_KEY = "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" | 23 | +# Kimi API Key |
| 24 | +# 申请地址https://www.kimi.com/ | ||
| 25 | +KIMI_API_KEY = "your_kimi_api_key" | ||
| 23 | 26 | ||
| 24 | -# 博查Web Search API密钥 | ||
| 25 | -BOCHA_Web_Search_API_KEY = "sk-496b37a2a1ee4915b438dd822b03de8d" | ||
| 27 | +# Gemini API Key (via OpenAI format proxy) | ||
| 28 | +# 申请地址hapi.chataiapi.com/ | ||
| 29 | +GEMINI_API_KEY = "your_gemini_api_key" | ||
| 30 | + | ||
| 31 | +# Bocha Search API Key | ||
| 32 | +# 申请地址https://open.bochaai.com/ | ||
| 33 | +BOCHA_Web_Search_API_KEY = "your_bocha_web_search_api_key" | ||
| 34 | + | ||
| 35 | +# Guiji Flow API Key | ||
| 36 | +# 申请地址https://siliconflow.cn/ | ||
| 37 | +GUIJI_QWEN3_API_KEY = "your_guiji_qwen3_api_key" |
-
Please register or login to post a comment