Add options for selecting large models and prioritize using DeepSeek for analysis.
Showing
1 changed file
with
151 additions
and
107 deletions
| 1 | import openai | 1 | import openai |
| 2 | import anthropic | 2 | import anthropic |
| 3 | import json | 3 | import json |
| 4 | -from typing import List, Dict | 4 | +from typing import List, Dict, Tuple, Any |
| 5 | import os | 5 | import os |
| 6 | +import asyncio | ||
| 7 | +import math | ||
| 6 | from datetime import datetime | 8 | from datetime import datetime |
| 7 | from utils.logger import app_logger as logging | 9 | from utils.logger import app_logger as logging |
| 8 | 10 | ||
| 9 | class AIAnalyzer: | 11 | class AIAnalyzer: |
| 10 | def __init__(self): | 12 | def __init__(self): |
| 11 | - # 从环境变量获取API密钥 | 13 | + # 尝试从环境变量中获取API密钥,如果没有则主动询问配置 |
| 12 | self.openai_key = os.getenv('OPENAI_API_KEY') | 14 | self.openai_key = os.getenv('OPENAI_API_KEY') |
| 15 | + if not self.openai_key: | ||
| 16 | + print("未检测到 OPENAI_API_KEY。") | ||
| 17 | + # 提示时允许按回车跳过输入 | ||
| 18 | + self.openai_key = input("请输入 OPENAI_API_KEY (按回车键跳过输入): ").strip() | ||
| 19 | + | ||
| 13 | self.claude_key = os.getenv('ANTHROPIC_API_KEY') | 20 | self.claude_key = os.getenv('ANTHROPIC_API_KEY') |
| 21 | + if not self.claude_key: | ||
| 22 | + print("未检测到 ANTHROPIC_API_KEY。") | ||
| 23 | + self.claude_key = input("请输入 ANTHROPIC_API_KEY (按回车键跳过输入): ").strip() | ||
| 24 | + | ||
| 14 | self.deepseek_key = os.getenv('DEEPSEEK_API_KEY') | 25 | self.deepseek_key = os.getenv('DEEPSEEK_API_KEY') |
| 26 | + if not self.deepseek_key: | ||
| 27 | + print("未检测到 DEEPSEEK_API_KEY。") | ||
| 28 | + self.deepseek_key = input("请输入 DEEPSEEK_API_KEY (按回车键跳过输入): ").strip() | ||
| 15 | 29 | ||
| 16 | - if not any([self.openai_key, self.claude_key, self.deepseek_key]): | ||
| 17 | - raise ValueError("请至少设置一个API密钥 (OPENAI_API_KEY, ANTHROPIC_API_KEY 或 DEEPSEEK_API_KEY)") | 30 | + # 如果不希望通过交互输入,也可以直接在此处配置(注释掉下面几行即可) |
| 31 | + # self.openai_key = "你的OpenAI_API_KEY" | ||
| 32 | + # self.claude_key = "你的ANTHROPIC_API_KEY" | ||
| 33 | + # self.deepseek_key = "你的DEEPSEEK_API_KEY" | ||
| 18 | 34 | ||
| 35 | + # 配置各API客户端 | ||
| 19 | if self.openai_key: | 36 | if self.openai_key: |
| 20 | openai.api_key = self.openai_key | 37 | openai.api_key = self.openai_key |
| 21 | if self.claude_key: | 38 | if self.claude_key: |
| 22 | self.claude_client = anthropic.Anthropic(api_key=self.claude_key) | 39 | self.claude_client = anthropic.Anthropic(api_key=self.claude_key) |
| 23 | if self.deepseek_key: | 40 | if self.deepseek_key: |
| 24 | - # 配置DeepSeek API | ||
| 25 | self.deepseek_client = openai.OpenAI( | 41 | self.deepseek_client = openai.OpenAI( |
| 26 | api_key=self.deepseek_key, | 42 | api_key=self.deepseek_key, |
| 27 | base_url="https://api.deepseek.com/v1" | 43 | base_url="https://api.deepseek.com/v1" |
| 28 | ) | 44 | ) |
| 29 | 45 | ||
| 30 | - # 支持的模型列表 | ||
| 31 | - self.supported_models = { | ||
| 32 | - # OpenAI 模型 | 46 | + # 支持的模型列表(增加了最新的 ChatGPT 和 Claude 模型) |
| 47 | + self.supported_models: Dict[str, Dict[str, Any]] = { | ||
| 48 | + # OpenAI 最新模型(ChatGPT系列) | ||
| 49 | + 'gpt-4o-latest': { | ||
| 50 | + 'provider': 'openai', | ||
| 51 | + 'max_tokens': 128000, # 支持大窗口 | ||
| 52 | + 'cost_per_1k': 0.01 # 参考价格(美元) | ||
| 53 | + }, | ||
| 54 | + 'gpt-4o-mini': { | ||
| 55 | + 'provider': 'openai', | ||
| 56 | + 'max_tokens': 4000, # 轻量版,适合快速任务 | ||
| 57 | + 'cost_per_1k': 0.00015 # 成本大幅降低 | ||
| 58 | + }, | ||
| 59 | + # 旧版OpenAI模型 | ||
| 33 | 'gpt-3.5-turbo': {'provider': 'openai', 'max_tokens': 2000, 'cost_per_1k': 0.0015}, | 60 | 'gpt-3.5-turbo': {'provider': 'openai', 'max_tokens': 2000, 'cost_per_1k': 0.0015}, |
| 34 | 'gpt-3.5-turbo-16k': {'provider': 'openai', 'max_tokens': 16000, 'cost_per_1k': 0.003}, | 61 | 'gpt-3.5-turbo-16k': {'provider': 'openai', 'max_tokens': 16000, 'cost_per_1k': 0.003}, |
| 35 | 'gpt-4': {'provider': 'openai', 'max_tokens': 8000, 'cost_per_1k': 0.03}, | 62 | 'gpt-4': {'provider': 'openai', 'max_tokens': 8000, 'cost_per_1k': 0.03}, |
| 36 | 'gpt-4-32k': {'provider': 'openai', 'max_tokens': 32000, 'cost_per_1k': 0.06}, | 63 | 'gpt-4-32k': {'provider': 'openai', 'max_tokens': 32000, 'cost_per_1k': 0.06}, |
| 37 | 'gpt-4-turbo-preview': {'provider': 'openai', 'max_tokens': 128000, 'cost_per_1k': 0.01}, | 64 | 'gpt-4-turbo-preview': {'provider': 'openai', 'max_tokens': 128000, 'cost_per_1k': 0.01}, |
| 38 | 65 | ||
| 39 | - # Claude 模型 | ||
| 40 | - 'claude-3-opus-20240229': {'provider': 'anthropic', 'max_tokens': 4000, 'cost_per_1k': 0.015}, | ||
| 41 | - 'claude-3-sonnet-20240229': {'provider': 'anthropic', 'max_tokens': 3000, 'cost_per_1k': 0.003}, | ||
| 42 | - 'claude-3-haiku-20240307': {'provider': 'anthropic', 'max_tokens': 2000, 'cost_per_1k': 0.0025}, | 66 | + # Anthropic 最新模型(Claude系列) |
| 67 | + 'claude-3.5-sonnet-new': { | ||
| 68 | + 'provider': 'anthropic', | ||
| 69 | + 'max_tokens': 200000, # 新版Claude 3.5 Sonnet | ||
| 70 | + 'cost_per_1k': 0.015 | ||
| 71 | + }, | ||
| 72 | + 'claude-3.5-haiku': { | ||
| 73 | + 'provider': 'anthropic', | ||
| 74 | + 'max_tokens': 200000, # 最新Claude 3.5 Haiku | ||
| 75 | + 'cost_per_1k': 0.0025 | ||
| 76 | + }, | ||
| 77 | + # 旧版Claude模型 | ||
| 43 | 'claude-2.1': {'provider': 'anthropic', 'max_tokens': 100000, 'cost_per_1k': 0.008}, | 78 | 'claude-2.1': {'provider': 'anthropic', 'max_tokens': 100000, 'cost_per_1k': 0.008}, |
| 44 | 'claude-2.0': {'provider': 'anthropic', 'max_tokens': 100000, 'cost_per_1k': 0.008}, | 79 | 'claude-2.0': {'provider': 'anthropic', 'max_tokens': 100000, 'cost_per_1k': 0.008}, |
| 45 | 'claude-instant-1.2': {'provider': 'anthropic', 'max_tokens': 100000, 'cost_per_1k': 0.0015}, | 80 | 'claude-instant-1.2': {'provider': 'anthropic', 'max_tokens': 100000, 'cost_per_1k': 0.0015}, |
| 46 | 81 | ||
| 47 | # DeepSeek 模型 | 82 | # DeepSeek 模型 |
| 48 | - 'deepseek-chat': {'provider': 'deepseek', 'max_tokens': 4000, 'cost_per_1k': 0.002}, # DeepSeek-V3 | ||
| 49 | - 'deepseek-reasoner': {'provider': 'deepseek', 'max_tokens': 4000, 'cost_per_1k': 0.003} # DeepSeek-R1 | 83 | + 'deepseek-chat': {'provider': 'deepseek', 'max_tokens': 4000, 'cost_per_1k': 0.002}, |
| 84 | + 'deepseek-reasoner': {'provider': 'deepseek', 'max_tokens': 4000, 'cost_per_1k': 0.003} | ||
| 50 | } | 85 | } |
| 51 | 86 | ||
| 52 | # 不同深度的分析提示词 | 87 | # 不同深度的分析提示词 |
| 53 | - self.prompt_templates = { | 88 | + self.prompt_templates: Dict[str, str] = { |
| 54 | 'basic': """你是一个专业的舆情分析助手。请对每条消息进行基础的情感分析。 | 89 | 'basic': """你是一个专业的舆情分析助手。请对每条消息进行基础的情感分析。 |
| 55 | 请按以下JSON格式返回: | 90 | 请按以下JSON格式返回: |
| 56 | { | 91 | { |
| @@ -105,9 +140,19 @@ class AIAnalyzer: | @@ -105,9 +140,19 @@ class AIAnalyzer: | ||
| 105 | 140 | ||
| 106 | async def analyze_messages(self, messages: List[Dict], batch_size: int = 50, | 141 | async def analyze_messages(self, messages: List[Dict], batch_size: int = 50, |
| 107 | model_type: str = "gpt-3.5-turbo", | 142 | model_type: str = "gpt-3.5-turbo", |
| 108 | - analysis_depth: str = "standard") -> List[Dict]: | ||
| 109 | - """分析一批消息并返回分析结果""" | 143 | + analysis_depth: str = "standard", |
| 144 | + prefer_deepseek: bool = True) -> List[Dict]: | ||
| 145 | + """ | ||
| 146 | + 分析一批消息并返回分析结果。 | ||
| 147 | + 如果 DeepSeek API 可用且 prefer_deepseek 为 True,则优先使用 DeepSeek 模型。 | ||
| 148 | + """ | ||
| 110 | try: | 149 | try: |
| 150 | + # 优先使用 DeepSeek 模型以降低成本 | ||
| 151 | + if prefer_deepseek and self.deepseek_key: | ||
| 152 | + if model_type not in ['deepseek-chat', 'deepseek-reasoner']: | ||
| 153 | + logging.info("检测到 DeepSeek API, 优先使用 'deepseek-chat' 模型以降低成本。") | ||
| 154 | + model_type = 'deepseek-chat' | ||
| 155 | + | ||
| 111 | if model_type not in self.supported_models: | 156 | if model_type not in self.supported_models: |
| 112 | raise ValueError(f"不支持的模型类型: {model_type}") | 157 | raise ValueError(f"不支持的模型类型: {model_type}") |
| 113 | 158 | ||
| @@ -116,91 +161,85 @@ class AIAnalyzer: | @@ -116,91 +161,85 @@ class AIAnalyzer: | ||
| 116 | max_tokens = model_info['max_tokens'] | 161 | max_tokens = model_info['max_tokens'] |
| 117 | 162 | ||
| 118 | # 根据模型类型调整批处理大小 | 163 | # 根据模型类型调整批处理大小 |
| 119 | - adjusted_batch_size = min(batch_size, self._get_optimal_batch_size(model_type)) | 164 | + optimal_batch_size = self._get_optimal_batch_size(model_type) |
| 165 | + adjusted_batch_size = min(batch_size, optimal_batch_size) | ||
| 120 | if adjusted_batch_size != batch_size: | 166 | if adjusted_batch_size != batch_size: |
| 121 | logging.info(f"已将批处理大小从 {batch_size} 调整为 {adjusted_batch_size}") | 167 | logging.info(f"已将批处理大小从 {batch_size} 调整为 {adjusted_batch_size}") |
| 122 | 168 | ||
| 123 | - all_results = [] | ||
| 124 | - total_cost = 0 | ||
| 125 | - | ||
| 126 | - # 分批处理消息 | 169 | + tasks = [] |
| 170 | + total_cost = 0.0 | ||
| 171 | + # 分批处理消息并异步调用分析任务 | ||
| 127 | for i in range(0, len(messages), adjusted_batch_size): | 172 | for i in range(0, len(messages), adjusted_batch_size): |
| 128 | batch = messages[i:i + adjusted_batch_size] | 173 | batch = messages[i:i + adjusted_batch_size] |
| 129 | - formatted_messages = [] | ||
| 130 | - for msg in batch: | ||
| 131 | - formatted_messages.append(f"消息ID: {msg['id']}\n内容: {msg['content']}") | 174 | + system_prompt = self.prompt_templates.get(analysis_depth, self.prompt_templates['standard']) |
| 175 | + tasks.append(self._process_batch(batch, system_prompt, model_type, max_tokens, provider)) | ||
| 176 | + | ||
| 177 | + # 并发执行所有批次任务 | ||
| 178 | + results = await asyncio.gather(*tasks) | ||
| 179 | + | ||
| 180 | + all_results = [] | ||
| 181 | + for batch_result, batch_cost in results: | ||
| 182 | + all_results.extend(batch_result) | ||
| 183 | + total_cost += batch_cost | ||
| 184 | + | ||
| 185 | + logging.info(f"分析完成, 总成本: ${total_cost:.4f}") | ||
| 186 | + return all_results | ||
| 187 | + except Exception as e: | ||
| 188 | + logging.error(f"AI分析过程出错: {e}", exc_info=True) | ||
| 189 | + return [] | ||
| 132 | 190 | ||
| 191 | + async def _process_batch(self, batch: List[Dict], system_prompt: str, | ||
| 192 | + model_type: str, max_tokens: int, provider: str) -> Tuple[List[Dict], float]: | ||
| 193 | + """ | ||
| 194 | + 处理单个批次的消息,返回 (分析结果, 本批次成本) | ||
| 195 | + """ | ||
| 196 | + try: | ||
| 197 | + formatted_messages = [ | ||
| 198 | + f"消息ID: {msg.get('id')}\n内容: {msg.get('content')}" for msg in batch | ||
| 199 | + ] | ||
| 133 | messages_text = "\n---\n".join(formatted_messages) | 200 | messages_text = "\n---\n".join(formatted_messages) |
| 134 | - system_prompt = self.prompt_templates.get(analysis_depth, self.prompt_templates['standard']) | ||
| 135 | 201 | ||
| 136 | if provider == 'openai': | 202 | if provider == 'openai': |
| 137 | - result = await self._analyze_with_openai( | ||
| 138 | - messages_text, | ||
| 139 | - system_prompt, | ||
| 140 | - model_type, | ||
| 141 | - max_tokens | ||
| 142 | - ) | 203 | + result = await self._analyze_with_openai(messages_text, system_prompt, model_type, max_tokens) |
| 143 | elif provider == 'anthropic': | 204 | elif provider == 'anthropic': |
| 144 | - result = await self._analyze_with_claude( | ||
| 145 | - messages_text, | ||
| 146 | - system_prompt, | ||
| 147 | - model_type, | ||
| 148 | - max_tokens | ||
| 149 | - ) | 205 | + result = await self._analyze_with_claude(messages_text, system_prompt, model_type, max_tokens) |
| 150 | elif provider == 'deepseek': | 206 | elif provider == 'deepseek': |
| 151 | - result = await self._analyze_with_deepseek( | ||
| 152 | - messages_text, | ||
| 153 | - system_prompt, | ||
| 154 | - model_type, | ||
| 155 | - max_tokens | ||
| 156 | - ) | 207 | + result = await self._analyze_with_deepseek(messages_text, system_prompt, model_type, max_tokens) |
| 208 | + else: | ||
| 209 | + logging.error(f"未知的API供应商: {provider}") | ||
| 210 | + return ([], 0.0) | ||
| 157 | 211 | ||
| 158 | - if result: | ||
| 159 | - all_results.extend(result) | ||
| 160 | - # 计算本批次成本 | ||
| 161 | batch_cost = self._calculate_cost(len(messages_text), model_type) | 212 | batch_cost = self._calculate_cost(len(messages_text), model_type) |
| 162 | - total_cost += batch_cost | ||
| 163 | - logging.info(f"批次处理完成,成本: ${batch_cost:.4f}") | ||
| 164 | - | ||
| 165 | - logging.info(f"分析完成,总成本: ${total_cost:.4f}") | ||
| 166 | - return all_results | ||
| 167 | - | 213 | + logging.info(f"批次处理完成, 成本: ${batch_cost:.4f}") |
| 214 | + return (result, batch_cost) | ||
| 168 | except Exception as e: | 215 | except Exception as e: |
| 169 | - logging.error(f"AI分析过程出错: {e}") | ||
| 170 | - return [] | 216 | + logging.error(f"处理批次时出错: {e}", exc_info=True) |
| 217 | + return ([], 0.0) | ||
| 171 | 218 | ||
| 172 | def _get_optimal_batch_size(self, model_type: str) -> int: | 219 | def _get_optimal_batch_size(self, model_type: str) -> int: |
| 173 | """根据模型类型获取最优批处理大小""" | 220 | """根据模型类型获取最优批处理大小""" |
| 174 | model_info = self.supported_models[model_type] | 221 | model_info = self.supported_models[model_type] |
| 175 | max_tokens = model_info['max_tokens'] | 222 | max_tokens = model_info['max_tokens'] |
| 176 | 223 | ||
| 177 | - # 估算每条消息的平均token数(假设为200) | 224 | + # 估算每条消息的平均 token 数(假设为 200) |
| 178 | avg_tokens_per_message = 200 | 225 | avg_tokens_per_message = 200 |
| 179 | - | ||
| 180 | - # 预留20%的token用于系统提示词和响应 | 226 | + # 预留 20% 的 token 用于系统提示词和响应 |
| 181 | available_tokens = int(max_tokens * 0.8) | 227 | available_tokens = int(max_tokens * 0.8) |
| 182 | - | ||
| 183 | - # 计算最优批处理大小 | ||
| 184 | optimal_batch_size = max(1, min(100, available_tokens // avg_tokens_per_message)) | 228 | optimal_batch_size = max(1, min(100, available_tokens // avg_tokens_per_message)) |
| 185 | - | ||
| 186 | return optimal_batch_size | 229 | return optimal_batch_size |
| 187 | 230 | ||
| 188 | def _calculate_cost(self, input_length: int, model_type: str) -> float: | 231 | def _calculate_cost(self, input_length: int, model_type: str) -> float: |
| 189 | - """计算API调用成本""" | 232 | + """计算 API 调用成本""" |
| 190 | model_info = self.supported_models[model_type] | 233 | model_info = self.supported_models[model_type] |
| 191 | cost_per_1k = model_info['cost_per_1k'] | 234 | cost_per_1k = model_info['cost_per_1k'] |
| 192 | - | ||
| 193 | - # 估算token数(假设每4个字符约等于1个token) | ||
| 194 | - estimated_tokens = input_length // 4 | ||
| 195 | - | ||
| 196 | - # 计算成本(美元) | 235 | + # 估算 token 数(假设每 4 个字符约等于 1 个 token) |
| 236 | + estimated_tokens = math.ceil(input_length / 4) | ||
| 197 | cost = (estimated_tokens / 1000) * cost_per_1k | 237 | cost = (estimated_tokens / 1000) * cost_per_1k |
| 198 | - | ||
| 199 | return cost | 238 | return cost |
| 200 | 239 | ||
| 201 | async def _analyze_with_openai(self, messages_text: str, system_prompt: str, | 240 | async def _analyze_with_openai(self, messages_text: str, system_prompt: str, |
| 202 | model: str, max_tokens: int) -> List[Dict]: | 241 | model: str, max_tokens: int) -> List[Dict]: |
| 203 | - """使用OpenAI API进行分析""" | 242 | + """使用 OpenAI API 进行分析""" |
| 204 | try: | 243 | try: |
| 205 | response = await openai.ChatCompletion.acreate( | 244 | response = await openai.ChatCompletion.acreate( |
| 206 | model=model, | 245 | model=model, |
| @@ -210,52 +249,44 @@ class AIAnalyzer: | @@ -210,52 +249,44 @@ class AIAnalyzer: | ||
| 210 | ], | 249 | ], |
| 211 | temperature=0.3, | 250 | temperature=0.3, |
| 212 | max_tokens=max_tokens, | 251 | max_tokens=max_tokens, |
| 213 | - n=1, | ||
| 214 | - response_format={"type": "json_object"} # 强制JSON响应格式 | 252 | + n=1 |
| 215 | ) | 253 | ) |
| 216 | - | ||
| 217 | - result = json.loads(response.choices[0].message.content) | 254 | + content = response.choices[0].message.content |
| 255 | + result = json.loads(content) | ||
| 218 | if isinstance(result, dict) and 'analysis_results' in result: | 256 | if isinstance(result, dict) and 'analysis_results' in result: |
| 219 | return result['analysis_results'] | 257 | return result['analysis_results'] |
| 220 | else: | 258 | else: |
| 221 | - logging.error(f"OpenAI API返回格式不正确: {response.choices[0].message.content}") | 259 | + logging.error(f"OpenAI API返回格式不正确: {content}") |
| 222 | return [] | 260 | return [] |
| 223 | - | ||
| 224 | except Exception as e: | 261 | except Exception as e: |
| 225 | - logging.error(f"OpenAI API调用失败: {e}") | 262 | + logging.error(f"OpenAI API调用失败: {e}", exc_info=True) |
| 226 | return [] | 263 | return [] |
| 227 | 264 | ||
| 228 | async def _analyze_with_claude(self, messages_text: str, system_prompt: str, | 265 | async def _analyze_with_claude(self, messages_text: str, system_prompt: str, |
| 229 | model: str, max_tokens: int) -> List[Dict]: | 266 | model: str, max_tokens: int) -> List[Dict]: |
| 230 | - """使用Claude API进行分析""" | 267 | + """使用 Claude API 进行分析""" |
| 231 | try: | 268 | try: |
| 232 | response = await self.claude_client.messages.create( | 269 | response = await self.claude_client.messages.create( |
| 233 | model=model, | 270 | model=model, |
| 234 | max_tokens=max_tokens, | 271 | max_tokens=max_tokens, |
| 235 | temperature=0.3, | 272 | temperature=0.3, |
| 236 | system=system_prompt, | 273 | system=system_prompt, |
| 237 | - messages=[ | ||
| 238 | - { | ||
| 239 | - "role": "user", | ||
| 240 | - "content": f"请分析以下消息:\n{messages_text}" | ||
| 241 | - } | ||
| 242 | - ] | 274 | + messages=[{"role": "user", "content": f"请分析以下消息:\n{messages_text}"}] |
| 243 | ) | 275 | ) |
| 244 | - | ||
| 245 | - result = json.loads(response.content[0].text) | 276 | + content = response.content[0].text |
| 277 | + result = json.loads(content) | ||
| 246 | if isinstance(result, dict) and 'analysis_results' in result: | 278 | if isinstance(result, dict) and 'analysis_results' in result: |
| 247 | return result['analysis_results'] | 279 | return result['analysis_results'] |
| 248 | else: | 280 | else: |
| 249 | - logging.error(f"Claude API返回格式不正确: {response.content[0].text}") | 281 | + logging.error(f"Claude API返回格式不正确: {content}") |
| 250 | return [] | 282 | return [] |
| 251 | - | ||
| 252 | except Exception as e: | 283 | except Exception as e: |
| 253 | - logging.error(f"Claude API调用失败: {e}") | 284 | + logging.error(f"Claude API调用失败: {e}", exc_info=True) |
| 254 | return [] | 285 | return [] |
| 255 | 286 | ||
| 256 | async def _analyze_with_deepseek(self, messages_text: str, system_prompt: str, | 287 | async def _analyze_with_deepseek(self, messages_text: str, system_prompt: str, |
| 257 | model: str, max_tokens: int) -> List[Dict]: | 288 | model: str, max_tokens: int) -> List[Dict]: |
| 258 | - """使用DeepSeek API进行分析""" | 289 | + """使用 DeepSeek API 进行分析""" |
| 259 | try: | 290 | try: |
| 260 | response = await self.deepseek_client.chat.completions.create( | 291 | response = await self.deepseek_client.chat.completions.create( |
| 261 | model=model, | 292 | model=model, |
| @@ -264,44 +295,57 @@ class AIAnalyzer: | @@ -264,44 +295,57 @@ class AIAnalyzer: | ||
| 264 | {"role": "user", "content": f"请分析以下消息:\n{messages_text}"} | 295 | {"role": "user", "content": f"请分析以下消息:\n{messages_text}"} |
| 265 | ], | 296 | ], |
| 266 | temperature=0.3, | 297 | temperature=0.3, |
| 267 | - max_tokens=max_tokens, | ||
| 268 | - response_format={"type": "json_object"} # 强制JSON响应格式 | 298 | + max_tokens=max_tokens |
| 269 | ) | 299 | ) |
| 270 | - | ||
| 271 | - result = json.loads(response.choices[0].message.content) | 300 | + content = response.choices[0].message.content |
| 301 | + result = json.loads(content) | ||
| 272 | if isinstance(result, dict) and 'analysis_results' in result: | 302 | if isinstance(result, dict) and 'analysis_results' in result: |
| 273 | return result['analysis_results'] | 303 | return result['analysis_results'] |
| 274 | else: | 304 | else: |
| 275 | - logging.error(f"DeepSeek API返回格式不正确: {response.choices[0].message.content}") | 305 | + logging.error(f"DeepSeek API返回格式不正确: {content}") |
| 276 | return [] | 306 | return [] |
| 277 | - | ||
| 278 | except Exception as e: | 307 | except Exception as e: |
| 279 | - logging.error(f"DeepSeek API调用失败: {e}") | 308 | + logging.error(f"DeepSeek API调用失败: {e}", exc_info=True) |
| 280 | return [] | 309 | return [] |
| 281 | 310 | ||
| 282 | def format_analysis_for_display(self, analysis: Dict) -> Dict: | 311 | def format_analysis_for_display(self, analysis: Dict) -> Dict: |
| 283 | """将分析结果格式化为前端显示格式""" | 312 | """将分析结果格式化为前端显示格式""" |
| 284 | base_result = { | 313 | base_result = { |
| 285 | - 'id': analysis['message_id'], | ||
| 286 | - 'sentiment': analysis['sentiment'], | ||
| 287 | - 'sentiment_score': f"{float(analysis['sentiment_score']):.2%}", | ||
| 288 | - 'keywords': ', '.join(analysis['keywords']), | ||
| 289 | - 'key_points': analysis['key_points'], | ||
| 290 | - 'influence': analysis['influence_analysis'], | ||
| 291 | - 'risk_level': analysis['risk_level'], | 314 | + 'id': analysis.get('message_id', ''), |
| 315 | + 'sentiment': analysis.get('sentiment', ''), | ||
| 316 | + 'sentiment_score': f"{float(analysis.get('sentiment_score', 0)):.2%}", | ||
| 317 | + 'keywords': ', '.join(analysis.get('keywords', [])), | ||
| 318 | + 'key_points': analysis.get('key_points', ''), | ||
| 319 | + 'influence': analysis.get('influence_analysis', ''), | ||
| 320 | + 'risk_level': analysis.get('risk_level', ''), | ||
| 292 | 'analysis_time': datetime.fromtimestamp( | 321 | 'analysis_time': datetime.fromtimestamp( |
| 293 | - float(analysis['timestamp']) | 322 | + float(analysis.get('timestamp', 0)) |
| 294 | ).strftime('%Y-%m-%d %H:%M:%S') | 323 | ).strftime('%Y-%m-%d %H:%M:%S') |
| 295 | } | 324 | } |
| 296 | 325 | ||
| 297 | # 如果是深度分析,添加额外信息 | 326 | # 如果是深度分析,添加额外信息 |
| 298 | if 'risk_factors' in analysis: | 327 | if 'risk_factors' in analysis: |
| 299 | base_result.update({ | 328 | base_result.update({ |
| 300 | - 'risk_factors': analysis['risk_factors'], | ||
| 301 | - 'suggestions': analysis['suggestions'] | 329 | + 'risk_factors': analysis.get('risk_factors', []), |
| 330 | + 'suggestions': analysis.get('suggestions', []) | ||
| 302 | }) | 331 | }) |
| 303 | 332 | ||
| 304 | return base_result | 333 | return base_result |
| 305 | 334 | ||
| 306 | -# 创建全局AI分析器实例 | 335 | +# 创建全局 AI 分析器实例 |
| 307 | ai_analyzer = AIAnalyzer() | 336 | ai_analyzer = AIAnalyzer() |
| 337 | + | ||
| 338 | +# 若需要直接配置或测试,可在此处编写测试代码 | ||
| 339 | +if __name__ == "__main__": | ||
| 340 | + # 示例:直接配置并调用分析器(可替换为实际测试代码) | ||
| 341 | + sample_messages = [ | ||
| 342 | + {"id": "1", "content": "今天天气真好,我很开心。"}, | ||
| 343 | + {"id": "2", "content": "经济形势不容乐观,风险较大。"} | ||
| 344 | + ] | ||
| 345 | + | ||
| 346 | + async def test(): | ||
| 347 | + results = await ai_analyzer.analyze_messages(sample_messages, model_type="gpt-4o-latest", analysis_depth="standard") | ||
| 348 | + for res in results: | ||
| 349 | + print(ai_analyzer.format_analysis_for_display(res)) | ||
| 350 | + | ||
| 351 | + asyncio.run(test()) |
-
Please register or login to post a comment