Showing
7 changed files
with
437 additions
and
54 deletions
| @@ -19,7 +19,7 @@ from .nodes import ( | @@ -19,7 +19,7 @@ from .nodes import ( | ||
| 19 | ReportFormattingNode | 19 | ReportFormattingNode |
| 20 | ) | 20 | ) |
| 21 | from .state import State | 21 | from .state import State |
| 22 | -from .tools import MediaCrawlerDB, DBResponse | 22 | +from .tools import MediaCrawlerDB, DBResponse, keyword_optimizer |
| 23 | from .utils import Config, load_config, format_search_results_for_prompt | 23 | from .utils import Config, load_config, format_search_results_for_prompt |
| 24 | 24 | ||
| 25 | 25 | ||
| @@ -113,7 +113,7 @@ class DeepSearchAgent: | @@ -113,7 +113,7 @@ class DeepSearchAgent: | ||
| 113 | 113 | ||
| 114 | def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse: | 114 | def execute_search_tool(self, tool_name: str, query: str, **kwargs) -> DBResponse: |
| 115 | """ | 115 | """ |
| 116 | - 执行指定的数据库查询工具 | 116 | + 执行指定的数据库查询工具(集成关键词优化中间件) |
| 117 | 117 | ||
| 118 | Args: | 118 | Args: |
| 119 | tool_name: 工具名称,可选值: | 119 | tool_name: 工具名称,可选值: |
| @@ -130,34 +130,102 @@ class DeepSearchAgent: | @@ -130,34 +130,102 @@ class DeepSearchAgent: | ||
| 130 | """ | 130 | """ |
| 131 | print(f" → 执行数据库查询工具: {tool_name}") | 131 | print(f" → 执行数据库查询工具: {tool_name}") |
| 132 | 132 | ||
| 133 | + # 对于热点内容搜索,不需要关键词优化(因为不需要query参数) | ||
| 133 | if tool_name == "search_hot_content": | 134 | if tool_name == "search_hot_content": |
| 134 | time_period = kwargs.get("time_period", "week") | 135 | time_period = kwargs.get("time_period", "week") |
| 135 | - limit = kwargs.get("limit", 10) | 136 | + limit = kwargs.get("limit", 100) |
| 136 | return self.search_agency.search_hot_content(time_period=time_period, limit=limit) | 137 | return self.search_agency.search_hot_content(time_period=time_period, limit=limit) |
| 137 | - elif tool_name == "search_topic_globally": | ||
| 138 | - limit_per_table = kwargs.get("limit_per_table", 5) | ||
| 139 | - return self.search_agency.search_topic_globally(topic=query, limit_per_table=limit_per_table) | ||
| 140 | - elif tool_name == "search_topic_by_date": | ||
| 141 | - start_date = kwargs.get("start_date") | ||
| 142 | - end_date = kwargs.get("end_date") | ||
| 143 | - limit_per_table = kwargs.get("limit_per_table", 10) | ||
| 144 | - if not start_date or not end_date: | ||
| 145 | - raise ValueError("search_topic_by_date工具需要start_date和end_date参数") | ||
| 146 | - return self.search_agency.search_topic_by_date(topic=query, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table) | ||
| 147 | - elif tool_name == "get_comments_for_topic": | ||
| 148 | - limit = kwargs.get("limit", 50) | ||
| 149 | - return self.search_agency.get_comments_for_topic(topic=query, limit=limit) | ||
| 150 | - elif tool_name == "search_topic_on_platform": | ||
| 151 | - platform = kwargs.get("platform") | ||
| 152 | - start_date = kwargs.get("start_date") | ||
| 153 | - end_date = kwargs.get("end_date") | ||
| 154 | - limit = kwargs.get("limit", 20) | ||
| 155 | - if not platform: | ||
| 156 | - raise ValueError("search_topic_on_platform工具需要platform参数") | ||
| 157 | - return self.search_agency.search_topic_on_platform(platform=platform, topic=query, start_date=start_date, end_date=end_date, limit=limit) | ||
| 158 | - else: | ||
| 159 | - print(f" ⚠️ 未知的搜索工具: {tool_name},使用默认全局搜索") | ||
| 160 | - return self.search_agency.search_topic_globally(topic=query) | 138 | + |
| 139 | + # 对于需要搜索词的工具,使用关键词优化中间件 | ||
| 140 | + optimized_response = keyword_optimizer.optimize_keywords( | ||
| 141 | + original_query=query, | ||
| 142 | + context=f"使用{tool_name}工具进行查询" | ||
| 143 | + ) | ||
| 144 | + | ||
| 145 | + print(f" 🔍 原始查询: '{query}'") | ||
| 146 | + print(f" ✨ 优化后关键词: {optimized_response.optimized_keywords}") | ||
| 147 | + | ||
| 148 | + # 使用优化后的关键词进行多次查询并整合结果 | ||
| 149 | + all_results = [] | ||
| 150 | + total_count = 0 | ||
| 151 | + | ||
| 152 | + for keyword in optimized_response.optimized_keywords: | ||
| 153 | + print(f" 查询关键词: '{keyword}'") | ||
| 154 | + | ||
| 155 | + try: | ||
| 156 | + if tool_name == "search_topic_globally": | ||
| 157 | + limit_per_table = kwargs.get("limit_per_table", 100) | ||
| 158 | + response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=limit_per_table) | ||
| 159 | + elif tool_name == "search_topic_by_date": | ||
| 160 | + start_date = kwargs.get("start_date") | ||
| 161 | + end_date = kwargs.get("end_date") | ||
| 162 | + limit_per_table = kwargs.get("limit_per_table", 100) | ||
| 163 | + if not start_date or not end_date: | ||
| 164 | + raise ValueError("search_topic_by_date工具需要start_date和end_date参数") | ||
| 165 | + response = self.search_agency.search_topic_by_date(topic=keyword, start_date=start_date, end_date=end_date, limit_per_table=limit_per_table) | ||
| 166 | + elif tool_name == "get_comments_for_topic": | ||
| 167 | + limit = kwargs.get("limit", 500) // len(optimized_response.optimized_keywords) | ||
| 168 | + limit = max(limit, 50) | ||
| 169 | + response = self.search_agency.get_comments_for_topic(topic=keyword, limit=limit) | ||
| 170 | + elif tool_name == "search_topic_on_platform": | ||
| 171 | + platform = kwargs.get("platform") | ||
| 172 | + start_date = kwargs.get("start_date") | ||
| 173 | + end_date = kwargs.get("end_date") | ||
| 174 | + limit = kwargs.get("limit", 200) // len(optimized_response.optimized_keywords) | ||
| 175 | + limit = max(limit, 30) | ||
| 176 | + if not platform: | ||
| 177 | + raise ValueError("search_topic_on_platform工具需要platform参数") | ||
| 178 | + response = self.search_agency.search_topic_on_platform(platform=platform, topic=keyword, start_date=start_date, end_date=end_date, limit=limit) | ||
| 179 | + else: | ||
| 180 | + print(f" 未知的搜索工具: {tool_name},使用默认全局搜索") | ||
| 181 | + response = self.search_agency.search_topic_globally(topic=keyword, limit_per_table=100) | ||
| 182 | + | ||
| 183 | + # 收集结果 | ||
| 184 | + if response.results: | ||
| 185 | + print(f" 找到 {len(response.results)} 条结果") | ||
| 186 | + all_results.extend(response.results) | ||
| 187 | + total_count += len(response.results) | ||
| 188 | + else: | ||
| 189 | + print(f" 未找到结果") | ||
| 190 | + | ||
| 191 | + except Exception as e: | ||
| 192 | + print(f" 查询'{keyword}'时出错: {str(e)}") | ||
| 193 | + continue | ||
| 194 | + | ||
| 195 | + # 去重和整合结果 | ||
| 196 | + unique_results = self._deduplicate_results(all_results) | ||
| 197 | + print(f" 总计找到 {total_count} 条结果,去重后 {len(unique_results)} 条") | ||
| 198 | + | ||
| 199 | + # 构建整合后的响应 | ||
| 200 | + integrated_response = DBResponse( | ||
| 201 | + tool_name=f"{tool_name}_optimized", | ||
| 202 | + parameters={ | ||
| 203 | + "original_query": query, | ||
| 204 | + "optimized_keywords": optimized_response.optimized_keywords, | ||
| 205 | + "optimization_reasoning": optimized_response.reasoning, | ||
| 206 | + **kwargs | ||
| 207 | + }, | ||
| 208 | + results=unique_results, | ||
| 209 | + results_count=len(unique_results) | ||
| 210 | + ) | ||
| 211 | + | ||
| 212 | + return integrated_response | ||
| 213 | + | ||
| 214 | + def _deduplicate_results(self, results: List) -> List: | ||
| 215 | + """ | ||
| 216 | + 去重搜索结果 | ||
| 217 | + """ | ||
| 218 | + seen = set() | ||
| 219 | + unique_results = [] | ||
| 220 | + | ||
| 221 | + for result in results: | ||
| 222 | + # 使用URL或内容作为去重标识 | ||
| 223 | + identifier = result.url if result.url else result.title_or_content[:100] | ||
| 224 | + if identifier not in seen: | ||
| 225 | + seen.add(identifier) | ||
| 226 | + unique_results.append(result) | ||
| 227 | + | ||
| 228 | + return unique_results | ||
| 161 | 229 | ||
| 162 | def research(self, query: str, save_report: bool = True) -> str: | 230 | def research(self, query: str, save_report: bool = True) -> str: |
| 163 | """ | 231 | """ |
| @@ -291,14 +359,14 @@ class DeepSearchAgent: | @@ -291,14 +359,14 @@ class DeepSearchAgent: | ||
| 291 | # 处理限制参数 | 359 | # 处理限制参数 |
| 292 | if search_tool == "search_hot_content": | 360 | if search_tool == "search_hot_content": |
| 293 | time_period = search_output.get("time_period", "week") | 361 | time_period = search_output.get("time_period", "week") |
| 294 | - limit = search_output.get("limit", 10) | 362 | + limit = search_output.get("limit", 100) |
| 295 | search_kwargs["time_period"] = time_period | 363 | search_kwargs["time_period"] = time_period |
| 296 | search_kwargs["limit"] = limit | 364 | search_kwargs["limit"] = limit |
| 297 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: | 365 | elif search_tool in ["search_topic_globally", "search_topic_by_date"]: |
| 298 | - limit_per_table = search_output.get("limit_per_table", 5) | 366 | + limit_per_table = search_output.get("limit_per_table", 100) |
| 299 | search_kwargs["limit_per_table"] = limit_per_table | 367 | search_kwargs["limit_per_table"] = limit_per_table |
| 300 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: | 368 | elif search_tool in ["get_comments_for_topic", "search_topic_on_platform"]: |
| 301 | - limit = search_output.get("limit", 20) | 369 | + limit = search_output.get("limit", 200) |
| 302 | search_kwargs["limit"] = limit | 370 | search_kwargs["limit"] = limit |
| 303 | 371 | ||
| 304 | search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) | 372 | search_response = self.execute_search_tool(search_tool, search_query, **search_kwargs) |
| @@ -306,8 +374,8 @@ class DeepSearchAgent: | @@ -306,8 +374,8 @@ class DeepSearchAgent: | ||
| 306 | # 转换为兼容格式 | 374 | # 转换为兼容格式 |
| 307 | search_results = [] | 375 | search_results = [] |
| 308 | if search_response and search_response.results: | 376 | if search_response and search_response.results: |
| 309 | - # 每种搜索工具都有其特定的结果数量,这里取前10个作为上限 | ||
| 310 | - max_results = min(len(search_response.results), 10) | 377 | + # 每种搜索工具都有其特定的结果数量,这里取前100个作为上限 |
| 378 | + max_results = min(len(search_response.results), 100) | ||
| 311 | for result in search_response.results[:max_results]: | 379 | for result in search_response.results[:max_results]: |
| 312 | search_results.append({ | 380 | search_results.append({ |
| 313 | 'title': result.title_or_content, | 381 | 'title': result.title_or_content, |
| @@ -426,8 +494,8 @@ class DeepSearchAgent: | @@ -426,8 +494,8 @@ class DeepSearchAgent: | ||
| 426 | # 转换为兼容格式 | 494 | # 转换为兼容格式 |
| 427 | search_results = [] | 495 | search_results = [] |
| 428 | if search_response and search_response.results: | 496 | if search_response and search_response.results: |
| 429 | - # 每种搜索工具都有其特定的结果数量,这里取前10个作为上限 | ||
| 430 | - max_results = min(len(search_response.results), 10) | 497 | + # 每种搜索工具都有其特定的结果数量,这里取前100个作为上限 |
| 498 | + max_results = min(len(search_response.results), 100) | ||
| 431 | for result in search_response.results[:max_results]: | 499 | for result in search_response.results[:max_results]: |
| 432 | search_results.append({ | 500 | search_results.append({ |
| 433 | 'title': result.title_or_content, | 501 | 'title': result.title_or_content, |
| @@ -198,7 +198,7 @@ SYSTEM_PROMPT_FIRST_SEARCH = f""" | @@ -198,7 +198,7 @@ SYSTEM_PROMPT_FIRST_SEARCH = f""" | ||
| 198 | 4. **参数优化配置**: | 198 | 4. **参数优化配置**: |
| 199 | - search_topic_by_date: 必须提供start_date和end_date参数(格式:YYYY-MM-DD) | 199 | - search_topic_by_date: 必须提供start_date和end_date参数(格式:YYYY-MM-DD) |
| 200 | - search_topic_on_platform: 必须提供platform参数(bilibili, weibo, douyin, kuaishou, xhs, zhihu, tieba之一) | 200 | - search_topic_on_platform: 必须提供platform参数(bilibili, weibo, douyin, kuaishou, xhs, zhihu, tieba之一) |
| 201 | - - 其他工具:合理配置limit参数以获取足够的样本 | 201 | + - 其他工具:合理配置limit参数以获取足够的样本(建议:search_hot_content limit>=100,search_topic_globally limit_per_table>=50,search_topic_by_date limit_per_table>=100,get_comments_for_topic limit>=500,search_topic_on_platform limit>=200) |
| 202 | 5. **阐述选择理由**:说明为什么这样的查询能够获得最真实的民意反馈 | 202 | 5. **阐述选择理由**:说明为什么这样的查询能够获得最真实的民意反馈 |
| 203 | 203 | ||
| 204 | **搜索词设计核心原则**: | 204 | **搜索词设计核心原则**: |
| @@ -311,7 +311,7 @@ SYSTEM_PROMPT_REFLECTION = f""" | @@ -311,7 +311,7 @@ SYSTEM_PROMPT_REFLECTION = f""" | ||
| 311 | 4. **参数配置要求**: | 311 | 4. **参数配置要求**: |
| 312 | - search_topic_by_date: 必须提供start_date和end_date参数(格式:YYYY-MM-DD) | 312 | - search_topic_by_date: 必须提供start_date和end_date参数(格式:YYYY-MM-DD) |
| 313 | - search_topic_on_platform: 必须提供platform参数(bilibili, weibo, douyin, kuaishou, xhs, zhihu, tieba之一) | 313 | - search_topic_on_platform: 必须提供platform参数(bilibili, weibo, douyin, kuaishou, xhs, zhihu, tieba之一) |
| 314 | - - 其他工具:合理配置参数以获取多样化的民意样本 | 314 | + - 其他工具:合理配置参数以获取多样化的民意样本(建议:search_hot_content limit>=100,search_topic_globally limit_per_table>=50,search_topic_by_date limit_per_table>=100,get_comments_for_topic limit>=500,search_topic_on_platform limit>=200) |
| 315 | 315 | ||
| 316 | 5. **阐述补充理由**:明确说明为什么需要这些额外的民意数据 | 316 | 5. **阐述补充理由**:明确说明为什么需要这些额外的民意数据 |
| 317 | 317 |
| @@ -9,10 +9,18 @@ from .search import ( | @@ -9,10 +9,18 @@ from .search import ( | ||
| 9 | DBResponse, | 9 | DBResponse, |
| 10 | print_response_summary | 10 | print_response_summary |
| 11 | ) | 11 | ) |
| 12 | +from .keyword_optimizer import ( | ||
| 13 | + KeywordOptimizer, | ||
| 14 | + KeywordOptimizationResponse, | ||
| 15 | + keyword_optimizer | ||
| 16 | +) | ||
| 12 | 17 | ||
| 13 | __all__ = [ | 18 | __all__ = [ |
| 14 | "MediaCrawlerDB", | 19 | "MediaCrawlerDB", |
| 15 | "QueryResult", | 20 | "QueryResult", |
| 16 | "DBResponse", | 21 | "DBResponse", |
| 17 | - "print_response_summary" | 22 | + "print_response_summary", |
| 23 | + "KeywordOptimizer", | ||
| 24 | + "KeywordOptimizationResponse", | ||
| 25 | + "keyword_optimizer" | ||
| 18 | ] | 26 | ] |
InsightEngine/tools/keyword_optimizer.py
0 → 100644
| 1 | +""" | ||
| 2 | +关键词优化中间件 | ||
| 3 | +使用Qwen AI将Agent生成的搜索词优化为更适合舆情数据库查询的关键词 | ||
| 4 | +""" | ||
| 5 | + | ||
| 6 | +import requests | ||
| 7 | +import json | ||
| 8 | +import sys | ||
| 9 | +import os | ||
| 10 | +from typing import List, Dict, Any | ||
| 11 | +from dataclasses import dataclass | ||
| 12 | + | ||
| 13 | +# 添加项目根目录到Python路径以导入config | ||
| 14 | +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) | ||
| 15 | +from config import GUIJI_QWEN3_API_KEY | ||
| 16 | + | ||
| 17 | +@dataclass | ||
| 18 | +class KeywordOptimizationResponse: | ||
| 19 | + """关键词优化响应""" | ||
| 20 | + original_query: str | ||
| 21 | + optimized_keywords: List[str] | ||
| 22 | + reasoning: str | ||
| 23 | + success: bool | ||
| 24 | + error_message: str = "" | ||
| 25 | + | ||
| 26 | +class KeywordOptimizer: | ||
| 27 | + """ | ||
| 28 | + 关键词优化器 | ||
| 29 | + 使用硅基流动的Qwen3模型将Agent生成的搜索词优化为更贴近真实舆情的关键词 | ||
| 30 | + """ | ||
| 31 | + | ||
| 32 | + def __init__(self, api_key: str = None): | ||
| 33 | + """ | ||
| 34 | + 初始化关键词优化器 | ||
| 35 | + | ||
| 36 | + Args: | ||
| 37 | + api_key: 硅基流动API密钥,如果不提供则从配置文件读取 | ||
| 38 | + """ | ||
| 39 | + self.api_key = api_key or GUIJI_QWEN3_API_KEY | ||
| 40 | + self.base_url = "https://api.siliconflow.cn/v1/chat/completions" | ||
| 41 | + self.model = "Qwen/Qwen3-30B-A3B-Instruct-2507" | ||
| 42 | + | ||
| 43 | + if not self.api_key: | ||
| 44 | + raise ValueError("未找到硅基流动API密钥,请在config.py中设置GUIJI_QWEN3_API_KEY") | ||
| 45 | + | ||
| 46 | + def optimize_keywords(self, original_query: str, context: str = "") -> KeywordOptimizationResponse: | ||
| 47 | + """ | ||
| 48 | + 优化搜索关键词 | ||
| 49 | + | ||
| 50 | + Args: | ||
| 51 | + original_query: Agent生成的原始搜索查询 | ||
| 52 | + context: 额外的上下文信息(如段落标题、内容描述等) | ||
| 53 | + | ||
| 54 | + Returns: | ||
| 55 | + KeywordOptimizationResponse: 优化后的关键词列表 | ||
| 56 | + """ | ||
| 57 | + print(f"🔍 关键词优化中间件: 处理查询 '{original_query}'") | ||
| 58 | + | ||
| 59 | + try: | ||
| 60 | + # 构建优化prompt | ||
| 61 | + system_prompt = self._build_system_prompt() | ||
| 62 | + user_prompt = self._build_user_prompt(original_query, context) | ||
| 63 | + | ||
| 64 | + # 调用Qwen API | ||
| 65 | + response = self._call_qwen_api(system_prompt, user_prompt) | ||
| 66 | + | ||
| 67 | + if response["success"]: | ||
| 68 | + # 解析响应 | ||
| 69 | + content = response["content"] | ||
| 70 | + try: | ||
| 71 | + # 尝试解析JSON格式的响应 | ||
| 72 | + if content.strip().startswith('{'): | ||
| 73 | + parsed = json.loads(content) | ||
| 74 | + keywords = parsed.get("keywords", []) | ||
| 75 | + reasoning = parsed.get("reasoning", "") | ||
| 76 | + else: | ||
| 77 | + # 如果不是JSON格式,尝试从文本中提取关键词 | ||
| 78 | + keywords = self._extract_keywords_from_text(content) | ||
| 79 | + reasoning = content | ||
| 80 | + | ||
| 81 | + # 验证关键词质量 | ||
| 82 | + validated_keywords = self._validate_keywords(keywords) | ||
| 83 | + | ||
| 84 | + print(f"✅ 优化成功: {len(validated_keywords)}个关键词") | ||
| 85 | + for i, keyword in enumerate(validated_keywords, 1): | ||
| 86 | + print(f" {i}. '{keyword}'") | ||
| 87 | + | ||
| 88 | + return KeywordOptimizationResponse( | ||
| 89 | + original_query=original_query, | ||
| 90 | + optimized_keywords=validated_keywords, | ||
| 91 | + reasoning=reasoning, | ||
| 92 | + success=True | ||
| 93 | + ) | ||
| 94 | + | ||
| 95 | + except Exception as e: | ||
| 96 | + print(f"⚠️ 解析响应失败,使用备用方案: {str(e)}") | ||
| 97 | + # 备用方案:从原始查询中提取关键词 | ||
| 98 | + fallback_keywords = self._fallback_keyword_extraction(original_query) | ||
| 99 | + return KeywordOptimizationResponse( | ||
| 100 | + original_query=original_query, | ||
| 101 | + optimized_keywords=fallback_keywords, | ||
| 102 | + reasoning="API响应解析失败,使用备用关键词提取", | ||
| 103 | + success=True | ||
| 104 | + ) | ||
| 105 | + else: | ||
| 106 | + print(f"❌ API调用失败: {response['error']}") | ||
| 107 | + # 使用备用方案 | ||
| 108 | + fallback_keywords = self._fallback_keyword_extraction(original_query) | ||
| 109 | + return KeywordOptimizationResponse( | ||
| 110 | + original_query=original_query, | ||
| 111 | + optimized_keywords=fallback_keywords, | ||
| 112 | + reasoning="API调用失败,使用备用关键词提取", | ||
| 113 | + success=True, | ||
| 114 | + error_message=response['error'] | ||
| 115 | + ) | ||
| 116 | + | ||
| 117 | + except Exception as e: | ||
| 118 | + print(f"❌ 关键词优化失败: {str(e)}") | ||
| 119 | + # 最终备用方案 | ||
| 120 | + fallback_keywords = self._fallback_keyword_extraction(original_query) | ||
| 121 | + return KeywordOptimizationResponse( | ||
| 122 | + original_query=original_query, | ||
| 123 | + optimized_keywords=fallback_keywords, | ||
| 124 | + reasoning="系统错误,使用备用关键词提取", | ||
| 125 | + success=False, | ||
| 126 | + error_message=str(e) | ||
| 127 | + ) | ||
| 128 | + | ||
| 129 | + def _build_system_prompt(self) -> str: | ||
| 130 | + """构建系统prompt""" | ||
| 131 | + return """你是一位专业的舆情数据挖掘专家。你的任务是将用户提供的搜索查询优化为更适合在社交媒体舆情数据库中查找的关键词。 | ||
| 132 | + | ||
| 133 | +**核心原则**: | ||
| 134 | +1. **贴近网民语言**:使用普通网友在社交媒体上会使用的词汇 | ||
| 135 | +2. **避免专业术语**:不使用"舆情"、"传播"、"倾向"、"展望"等官方词汇 | ||
| 136 | +3. **简洁具体**:每个关键词要非常简洁明了,便于数据库匹配 | ||
| 137 | +4. **情感丰富**:包含网民常用的情感表达词汇 | ||
| 138 | +5. **数量控制**:最少提供10个关键词,最多提供20个关键词 | ||
| 139 | +6. **避免重复**:不要脱离初始查询的主题 | ||
| 140 | + | ||
| 141 | +**输出格式**: | ||
| 142 | +请以JSON格式返回结果: | ||
| 143 | +{ | ||
| 144 | + "keywords": ["关键词1", "关键词2", "关键词3"], | ||
| 145 | + "reasoning": "选择这些关键词的理由" | ||
| 146 | +} | ||
| 147 | + | ||
| 148 | +**示例**: | ||
| 149 | +输入:"武汉大学舆情管理 未来展望 发展趋势" | ||
| 150 | +输出: | ||
| 151 | +{ | ||
| 152 | + "keywords": ["武大", "武汉大学", "学校管理", "大学", "教育"], | ||
| 153 | + "reasoning": "选择'武大'和'武汉大学'作为核心词汇,这是网民最常使用的称呼;'学校管理'比'舆情管理'更贴近日常表达;避免使用'未来展望'、'发展趋势'等网民很少使用的专业术语" | ||
| 154 | +}""" | ||
| 155 | + | ||
| 156 | + def _build_user_prompt(self, original_query: str, context: str) -> str: | ||
| 157 | + """构建用户prompt""" | ||
| 158 | + prompt = f"请将以下搜索查询优化为适合舆情数据库查询的关键词:\n\n原始查询:{original_query}" | ||
| 159 | + | ||
| 160 | + if context: | ||
| 161 | + prompt += f"\n\n上下文信息:{context}" | ||
| 162 | + | ||
| 163 | + prompt += "\n\n请记住:要使用网民在社交媒体上真实使用的词汇,避免官方术语和专业词汇。" | ||
| 164 | + | ||
| 165 | + return prompt | ||
| 166 | + | ||
| 167 | + def _call_qwen_api(self, system_prompt: str, user_prompt: str) -> Dict[str, Any]: | ||
| 168 | + """调用Qwen API""" | ||
| 169 | + headers = { | ||
| 170 | + "Authorization": f"Bearer {self.api_key}", | ||
| 171 | + "Content-Type": "application/json" | ||
| 172 | + } | ||
| 173 | + | ||
| 174 | + data = { | ||
| 175 | + "model": self.model, | ||
| 176 | + "messages": [ | ||
| 177 | + {"role": "system", "content": system_prompt}, | ||
| 178 | + {"role": "user", "content": user_prompt} | ||
| 179 | + ], | ||
| 180 | + "max_tokens": 10000, | ||
| 181 | + "temperature": 0.7 | ||
| 182 | + } | ||
| 183 | + | ||
| 184 | + try: | ||
| 185 | + response = requests.post(self.base_url, headers=headers, json=data, timeout=30) | ||
| 186 | + response.raise_for_status() | ||
| 187 | + | ||
| 188 | + result = response.json() | ||
| 189 | + | ||
| 190 | + if "choices" in result and len(result["choices"]) > 0: | ||
| 191 | + content = result["choices"][0]["message"]["content"] | ||
| 192 | + return {"success": True, "content": content} | ||
| 193 | + else: | ||
| 194 | + return {"success": False, "error": "API返回格式异常"} | ||
| 195 | + | ||
| 196 | + except requests.exceptions.RequestException as e: | ||
| 197 | + return {"success": False, "error": f"网络请求错误: {str(e)}"} | ||
| 198 | + except Exception as e: | ||
| 199 | + return {"success": False, "error": f"API调用异常: {str(e)}"} | ||
| 200 | + | ||
| 201 | + def _extract_keywords_from_text(self, text: str) -> List[str]: | ||
| 202 | + """从文本中提取关键词(当JSON解析失败时使用)""" | ||
| 203 | + # 简单的关键词提取逻辑 | ||
| 204 | + lines = text.split('\n') | ||
| 205 | + keywords = [] | ||
| 206 | + | ||
| 207 | + for line in lines: | ||
| 208 | + line = line.strip() | ||
| 209 | + # 查找可能的关键词 | ||
| 210 | + if ':' in line or ':' in line: | ||
| 211 | + parts = line.split(':') if ':' in line else line.split(':') | ||
| 212 | + if len(parts) > 1: | ||
| 213 | + potential_keywords = parts[1].strip() | ||
| 214 | + # 尝试分割关键词 | ||
| 215 | + if '、' in potential_keywords: | ||
| 216 | + keywords.extend([k.strip() for k in potential_keywords.split('、')]) | ||
| 217 | + elif ',' in potential_keywords: | ||
| 218 | + keywords.extend([k.strip() for k in potential_keywords.split(',')]) | ||
| 219 | + else: | ||
| 220 | + keywords.append(potential_keywords) | ||
| 221 | + | ||
| 222 | + # 如果没有找到,尝试其他方法 | ||
| 223 | + if not keywords: | ||
| 224 | + # 查找引号中的内容 | ||
| 225 | + import re | ||
| 226 | + quoted_content = re.findall(r'["""\'](.*?)["""\']', text) | ||
| 227 | + keywords.extend(quoted_content) | ||
| 228 | + | ||
| 229 | + # 清理和验证关键词 | ||
| 230 | + cleaned_keywords = [] | ||
| 231 | + for keyword in keywords[:20]: # 最多5个 | ||
| 232 | + keyword = keyword.strip().strip('"\'""''') | ||
| 233 | + if keyword and len(keyword) <= 20: # 合理长度 | ||
| 234 | + cleaned_keywords.append(keyword) | ||
| 235 | + | ||
| 236 | + return cleaned_keywords[:20] | ||
| 237 | + | ||
| 238 | + def _validate_keywords(self, keywords: List[str]) -> List[str]: | ||
| 239 | + """验证和清理关键词""" | ||
| 240 | + validated = [] | ||
| 241 | + | ||
| 242 | + # 不良关键词(过于专业或官方) | ||
| 243 | + bad_keywords = { | ||
| 244 | + '态度分析', '公众反应', '情绪倾向', | ||
| 245 | + '未来展望', '发展趋势', '战略规划', '政策导向', '管理机制' | ||
| 246 | + } | ||
| 247 | + | ||
| 248 | + for keyword in keywords: | ||
| 249 | + if isinstance(keyword, str): | ||
| 250 | + keyword = keyword.strip().strip('"\'""''') | ||
| 251 | + | ||
| 252 | + # 基本验证 | ||
| 253 | + if (keyword and | ||
| 254 | + len(keyword) <= 20 and | ||
| 255 | + len(keyword) >= 1 and | ||
| 256 | + not any(bad_word in keyword for bad_word in bad_keywords)): | ||
| 257 | + validated.append(keyword) | ||
| 258 | + | ||
| 259 | + return validated[:20] # 最多返回20个关键词 | ||
| 260 | + | ||
| 261 | + def _fallback_keyword_extraction(self, original_query: str) -> List[str]: | ||
| 262 | + """备用关键词提取方案""" | ||
| 263 | + # 简单的关键词提取逻辑 | ||
| 264 | + # 移除常见的无用词汇 | ||
| 265 | + stop_words = {'、'} | ||
| 266 | + | ||
| 267 | + # 分割查询 | ||
| 268 | + import re | ||
| 269 | + # 按空格、标点分割 | ||
| 270 | + tokens = re.split(r'[\s,。!?;:、]+', original_query) | ||
| 271 | + | ||
| 272 | + keywords = [] | ||
| 273 | + for token in tokens: | ||
| 274 | + token = token.strip() | ||
| 275 | + if token and token not in stop_words and len(token) >= 2: | ||
| 276 | + keywords.append(token) | ||
| 277 | + | ||
| 278 | + # 如果没有有效关键词,使用原始查询的第一个词 | ||
| 279 | + if not keywords: | ||
| 280 | + first_word = original_query.split()[0] if original_query.split() else original_query | ||
| 281 | + keywords = [first_word] if first_word else ["热门"] | ||
| 282 | + | ||
| 283 | + return keywords[:20] | ||
| 284 | + | ||
| 285 | +# 全局实例 | ||
| 286 | +keyword_optimizer = KeywordOptimizer() |
| @@ -2,7 +2,7 @@ | @@ -2,7 +2,7 @@ | ||
| 2 | 专为 AI Agent 设计的本地舆情数据库查询工具集 (MediaCrawlerDB) | 2 | 专为 AI Agent 设计的本地舆情数据库查询工具集 (MediaCrawlerDB) |
| 3 | 3 | ||
| 4 | 版本: 3.0 | 4 | 版本: 3.0 |
| 5 | -最后更新: 2025-08-22 | 5 | +最后更新: 2025-08-23 |
| 6 | 6 | ||
| 7 | 此脚本将复杂的本地MySQL数据库查询功能封装成一系列目标明确、参数清晰的独立工具, | 7 | 此脚本将复杂的本地MySQL数据库查询功能封装成一系列目标明确、参数清晰的独立工具, |
| 8 | 专为AI Agent调用而设计。Agent只需根据任务意图(如搜索热点、全局搜索话题、 | 8 | 专为AI Agent调用而设计。Agent只需根据任务意图(如搜索热点、全局搜索话题、 |
| @@ -44,7 +44,7 @@ class QueryResult: | @@ -44,7 +44,7 @@ class QueryResult: | ||
| 44 | publish_time: Optional[datetime] = None | 44 | publish_time: Optional[datetime] = None |
| 45 | engagement: Dict[str, int] = field(default_factory=dict) | 45 | engagement: Dict[str, int] = field(default_factory=dict) |
| 46 | source_keyword: Optional[str] = None | 46 | source_keyword: Optional[str] = None |
| 47 | - hotness_score: float = 0.0 # 新增:综合热度分 | 47 | + hotness_score: float = 0.0 |
| 48 | source_table: str = "" | 48 | source_table: str = "" |
| 49 | 49 | ||
| 50 | @dataclass | 50 | @dataclass |
| @@ -136,14 +136,14 @@ class MediaCrawlerDB: | @@ -136,14 +136,14 @@ class MediaCrawlerDB: | ||
| 136 | def search_hot_content( | 136 | def search_hot_content( |
| 137 | self, | 137 | self, |
| 138 | time_period: Literal['24h', 'week', 'year'] = 'week', | 138 | time_period: Literal['24h', 'week', 'year'] = 'week', |
| 139 | - limit: int = 10 | 139 | + limit: int = 50 |
| 140 | ) -> DBResponse: | 140 | ) -> DBResponse: |
| 141 | """ | 141 | """ |
| 142 | - 【工具】查找热点内容: (已简化) 获取最近一段时间内综合热度最高的内容。 | 142 | + 【工具】查找热点内容: 获取最近一段时间内综合热度最高的内容。 |
| 143 | 143 | ||
| 144 | Args: | 144 | Args: |
| 145 | time_period (Literal['24h', 'week', 'year']): 时间范围,默认为 'week'。 | 145 | time_period (Literal['24h', 'week', 'year']): 时间范围,默认为 'week'。 |
| 146 | - limit (int): 返回结果的最大数量,默认为 10。 | 146 | + limit (int): 返回结果的最大数量,默认为 50。 |
| 147 | 147 | ||
| 148 | Returns: | 148 | Returns: |
| 149 | DBResponse: 包含按综合热度排序后的内容列表。 | 149 | DBResponse: 包含按综合热度排序后的内容列表。 |
| @@ -190,13 +190,13 @@ class MediaCrawlerDB: | @@ -190,13 +190,13 @@ class MediaCrawlerDB: | ||
| 190 | formatted_results = [QueryResult(platform=r['p'], content_type=r['t'], title_or_content=r['title'], author_nickname=r.get('author'), url=r['url'], publish_time=self._to_datetime(r['ts']), engagement=self._extract_engagement(r), hotness_score=r.get('hotness_score', 0.0), source_keyword=r.get('source_keyword'), source_table=r['tbl']) for r in raw_results] | 190 | formatted_results = [QueryResult(platform=r['p'], content_type=r['t'], title_or_content=r['title'], author_nickname=r.get('author'), url=r['url'], publish_time=self._to_datetime(r['ts']), engagement=self._extract_engagement(r), hotness_score=r.get('hotness_score', 0.0), source_keyword=r.get('source_keyword'), source_table=r['tbl']) for r in raw_results] |
| 191 | return DBResponse("search_hot_content", params_for_log, results=formatted_results, results_count=len(formatted_results)) | 191 | return DBResponse("search_hot_content", params_for_log, results=formatted_results, results_count=len(formatted_results)) |
| 192 | 192 | ||
| 193 | - def search_topic_globally(self, topic: str, limit_per_table: int = 5) -> DBResponse: | 193 | + def search_topic_globally(self, topic: str, limit_per_table: int = 100) -> DBResponse: |
| 194 | """ | 194 | """ |
| 195 | 【工具】全局话题搜索: 在数据库中(内容、评论、标签、来源关键字)全面搜索指定话题。 | 195 | 【工具】全局话题搜索: 在数据库中(内容、评论、标签、来源关键字)全面搜索指定话题。 |
| 196 | 196 | ||
| 197 | Args: | 197 | Args: |
| 198 | topic (str): 要搜索的话题关键词。 | 198 | topic (str): 要搜索的话题关键词。 |
| 199 | - limit_per_table (int): 从每个相关表中返回的最大记录数,默认为 5。 | 199 | + limit_per_table (int): 从每个相关表中返回的最大记录数,默认为 100。 |
| 200 | 200 | ||
| 201 | Returns: | 201 | Returns: |
| 202 | DBResponse: 包含所有匹配结果的聚合列表。 | 202 | DBResponse: 包含所有匹配结果的聚合列表。 |
| @@ -227,7 +227,7 @@ class MediaCrawlerDB: | @@ -227,7 +227,7 @@ class MediaCrawlerDB: | ||
| 227 | )) | 227 | )) |
| 228 | return DBResponse("search_topic_globally", params_for_log, results=all_results, results_count=len(all_results)) | 228 | return DBResponse("search_topic_globally", params_for_log, results=all_results, results_count=len(all_results)) |
| 229 | 229 | ||
| 230 | - def search_topic_by_date(self, topic: str, start_date: str, end_date: str, limit_per_table: int = 10) -> DBResponse: | 230 | + def search_topic_by_date(self, topic: str, start_date: str, end_date: str, limit_per_table: int = 100) -> DBResponse: |
| 231 | """ | 231 | """ |
| 232 | 【工具】按日期搜索话题: 在明确的历史时间段内,搜索与特定话题相关的内容。 | 232 | 【工具】按日期搜索话题: 在明确的历史时间段内,搜索与特定话题相关的内容。 |
| 233 | 233 | ||
| @@ -235,7 +235,7 @@ class MediaCrawlerDB: | @@ -235,7 +235,7 @@ class MediaCrawlerDB: | ||
| 235 | topic (str): 要搜索的话题关键词。 | 235 | topic (str): 要搜索的话题关键词。 |
| 236 | start_date (str): 开始日期,格式 'YYYY-MM-DD'。 | 236 | start_date (str): 开始日期,格式 'YYYY-MM-DD'。 |
| 237 | end_date (str): 结束日期,格式 'YYYY-MM-DD'。 | 237 | end_date (str): 结束日期,格式 'YYYY-MM-DD'。 |
| 238 | - limit_per_table (int): 从每个相关表中返回的最大记录数,默认为 10。 | 238 | + limit_per_table (int): 从每个相关表中返回的最大记录数,默认为 100。 |
| 239 | 239 | ||
| 240 | Returns: | 240 | Returns: |
| 241 | DBResponse: 包含在指定日期范围内找到的结果的聚合列表。 | 241 | DBResponse: 包含在指定日期范围内找到的结果的聚合列表。 |
| @@ -282,13 +282,13 @@ class MediaCrawlerDB: | @@ -282,13 +282,13 @@ class MediaCrawlerDB: | ||
| 282 | )) | 282 | )) |
| 283 | return DBResponse("search_topic_by_date", params_for_log, results=all_results, results_count=len(all_results)) | 283 | return DBResponse("search_topic_by_date", params_for_log, results=all_results, results_count=len(all_results)) |
| 284 | 284 | ||
| 285 | - def get_comments_for_topic(self, topic: str, limit: int = 50) -> DBResponse: | 285 | + def get_comments_for_topic(self, topic: str, limit: int = 500) -> DBResponse: |
| 286 | """ | 286 | """ |
| 287 | 【工具】获取话题评论: 专门搜索并返回所有平台中与特定话题相关的公众评论数据。 | 287 | 【工具】获取话题评论: 专门搜索并返回所有平台中与特定话题相关的公众评论数据。 |
| 288 | 288 | ||
| 289 | Args: | 289 | Args: |
| 290 | topic (str): 要搜索的话题关键词。 | 290 | topic (str): 要搜索的话题关键词。 |
| 291 | - limit (int): 返回评论的总数量上限,默认为 50。 | 291 | + limit (int): 返回评论的总数量上限,默认为 500。 |
| 292 | 292 | ||
| 293 | Returns: | 293 | Returns: |
| 294 | DBResponse: 包含匹配的评论列表。 | 294 | DBResponse: 包含匹配的评论列表。 |
| @@ -30,11 +30,18 @@ class Config: | @@ -30,11 +30,18 @@ class Config: | ||
| 30 | 30 | ||
| 31 | # 搜索配置 | 31 | # 搜索配置 |
| 32 | search_timeout: int = 240 | 32 | search_timeout: int = 240 |
| 33 | - max_content_length: int = 20000 | 33 | + max_content_length: int = 100000 |
| 34 | + | ||
| 35 | + # 数据库查询限制 | ||
| 36 | + default_search_hot_content_limit: int = 100 | ||
| 37 | + default_search_topic_globally_limit_per_table: int = 50 | ||
| 38 | + default_search_topic_by_date_limit_per_table: int = 100 | ||
| 39 | + default_get_comments_for_topic_limit: int = 500 | ||
| 40 | + default_search_topic_on_platform_limit: int = 200 | ||
| 34 | 41 | ||
| 35 | # Agent配置 | 42 | # Agent配置 |
| 36 | - max_reflections: int = 2 | ||
| 37 | - max_paragraphs: int = 5 | 43 | + max_reflections: int = 3 |
| 44 | + max_paragraphs: int = 6 | ||
| 38 | 45 | ||
| 39 | # 输出配置 | 46 | # 输出配置 |
| 40 | output_dir: str = "reports" | 47 | output_dir: str = "reports" |
| @@ -85,7 +92,14 @@ class Config: | @@ -85,7 +92,14 @@ class Config: | ||
| 85 | openai_model=getattr(config_module, "OPENAI_MODEL", "gpt-4o-mini"), | 92 | openai_model=getattr(config_module, "OPENAI_MODEL", "gpt-4o-mini"), |
| 86 | 93 | ||
| 87 | search_timeout=getattr(config_module, "SEARCH_TIMEOUT", 240), | 94 | search_timeout=getattr(config_module, "SEARCH_TIMEOUT", 240), |
| 88 | - max_content_length=getattr(config_module, "SEARCH_CONTENT_MAX_LENGTH", 20000), | 95 | + max_content_length=getattr(config_module, "SEARCH_CONTENT_MAX_LENGTH", 200000), |
| 96 | + | ||
| 97 | + default_search_hot_content_limit=getattr(config_module, "DEFAULT_SEARCH_HOT_CONTENT_LIMIT", 100), | ||
| 98 | + default_search_topic_globally_limit_per_table=getattr(config_module, "DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE", 50), | ||
| 99 | + default_search_topic_by_date_limit_per_table=getattr(config_module, "DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE", 100), | ||
| 100 | + default_get_comments_for_topic_limit=getattr(config_module, "DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT", 500), | ||
| 101 | + default_search_topic_on_platform_limit=getattr(config_module, "DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT", 200), | ||
| 102 | + | ||
| 89 | max_reflections=getattr(config_module, "MAX_REFLECTIONS", 2), | 103 | max_reflections=getattr(config_module, "MAX_REFLECTIONS", 2), |
| 90 | max_paragraphs=getattr(config_module, "MAX_PARAGRAPHS", 5), | 104 | max_paragraphs=getattr(config_module, "MAX_PARAGRAPHS", 5), |
| 91 | output_dir=getattr(config_module, "OUTPUT_DIR", "reports"), | 105 | output_dir=getattr(config_module, "OUTPUT_DIR", "reports"), |
| @@ -119,7 +133,14 @@ class Config: | @@ -119,7 +133,14 @@ class Config: | ||
| 119 | openai_model=config_dict.get("OPENAI_MODEL", "gpt-4o-mini"), | 133 | openai_model=config_dict.get("OPENAI_MODEL", "gpt-4o-mini"), |
| 120 | 134 | ||
| 121 | search_timeout=int(config_dict.get("SEARCH_TIMEOUT", "240")), | 135 | search_timeout=int(config_dict.get("SEARCH_TIMEOUT", "240")), |
| 122 | - max_content_length=int(config_dict.get("SEARCH_CONTENT_MAX_LENGTH", "20000")), | 136 | + max_content_length=int(config_dict.get("SEARCH_CONTENT_MAX_LENGTH", "200000")), |
| 137 | + | ||
| 138 | + default_search_hot_content_limit=int(config_dict.get("DEFAULT_SEARCH_HOT_CONTENT_LIMIT", "100")), | ||
| 139 | + default_search_topic_globally_limit_per_table=int(config_dict.get("DEFAULT_SEARCH_TOPIC_GLOBALLY_LIMIT_PER_TABLE", "50")), | ||
| 140 | + default_search_topic_by_date_limit_per_table=int(config_dict.get("DEFAULT_SEARCH_TOPIC_BY_DATE_LIMIT_PER_TABLE", "100")), | ||
| 141 | + default_get_comments_for_topic_limit=int(config_dict.get("DEFAULT_GET_COMMENTS_FOR_TOPIC_LIMIT", "500")), | ||
| 142 | + default_search_topic_on_platform_limit=int(config_dict.get("DEFAULT_SEARCH_TOPIC_ON_PLATFORM_LIMIT", "200")), | ||
| 143 | + | ||
| 123 | max_reflections=int(config_dict.get("MAX_REFLECTIONS", "2")), | 144 | max_reflections=int(config_dict.get("MAX_REFLECTIONS", "2")), |
| 124 | max_paragraphs=int(config_dict.get("MAX_PARAGRAPHS", "5")), | 145 | max_paragraphs=int(config_dict.get("MAX_PARAGRAPHS", "5")), |
| 125 | output_dir=config_dict.get("OUTPUT_DIR", "reports"), | 146 | output_dir=config_dict.get("OUTPUT_DIR", "reports"), |
| @@ -34,7 +34,7 @@ def main(): | @@ -34,7 +34,7 @@ def main(): | ||
| 34 | # 高级配置 | 34 | # 高级配置 |
| 35 | st.subheader("高级配置") | 35 | st.subheader("高级配置") |
| 36 | max_reflections = st.slider("反思次数", 1, 5, 2) | 36 | max_reflections = st.slider("反思次数", 1, 5, 2) |
| 37 | - max_content_length = st.number_input("最大内容长度", 1000, 50000, 20000) | 37 | + max_content_length = st.number_input("最大内容长度", 10000, 500000, 200000) # 提高10倍:1000-50000-20000 → 10000-500000-200000 |
| 38 | 38 | ||
| 39 | # 模型选择 | 39 | # 模型选择 |
| 40 | llm_provider = st.selectbox("LLM提供商", ["deepseek", "openai"]) | 40 | llm_provider = st.selectbox("LLM提供商", ["deepseek", "openai"]) |
-
Please register or login to post a comment