戒酒的李白

Add options for selecting large models and prioritize using DeepSeek for analysis.

1 import openai 1 import openai
2 import anthropic 2 import anthropic
3 import json 3 import json
4 -from typing import List, Dict 4 +from typing import List, Dict, Tuple, Any
5 import os 5 import os
  6 +import asyncio
  7 +import math
6 from datetime import datetime 8 from datetime import datetime
7 from utils.logger import app_logger as logging 9 from utils.logger import app_logger as logging
8 10
9 class AIAnalyzer: 11 class AIAnalyzer:
10 def __init__(self): 12 def __init__(self):
11 - # 从环境变量获取API密钥 13 + # 尝试从环境变量中获取API密钥,如果没有则主动询问配置
12 self.openai_key = os.getenv('OPENAI_API_KEY') 14 self.openai_key = os.getenv('OPENAI_API_KEY')
  15 + if not self.openai_key:
  16 + print("未检测到 OPENAI_API_KEY。")
  17 + # 提示时允许按回车跳过输入
  18 + self.openai_key = input("请输入 OPENAI_API_KEY (按回车键跳过输入): ").strip()
  19 +
13 self.claude_key = os.getenv('ANTHROPIC_API_KEY') 20 self.claude_key = os.getenv('ANTHROPIC_API_KEY')
  21 + if not self.claude_key:
  22 + print("未检测到 ANTHROPIC_API_KEY。")
  23 + self.claude_key = input("请输入 ANTHROPIC_API_KEY (按回车键跳过输入): ").strip()
  24 +
14 self.deepseek_key = os.getenv('DEEPSEEK_API_KEY') 25 self.deepseek_key = os.getenv('DEEPSEEK_API_KEY')
  26 + if not self.deepseek_key:
  27 + print("未检测到 DEEPSEEK_API_KEY。")
  28 + self.deepseek_key = input("请输入 DEEPSEEK_API_KEY (按回车键跳过输入): ").strip()
15 29
16 - if not any([self.openai_key, self.claude_key, self.deepseek_key]):  
17 - raise ValueError("请至少设置一个API密钥 (OPENAI_API_KEY, ANTHROPIC_API_KEY 或 DEEPSEEK_API_KEY)") 30 + # 如果不希望通过交互输入,也可以直接在此处配置(注释掉下面几行即可)
  31 + # self.openai_key = "你的OpenAI_API_KEY"
  32 + # self.claude_key = "你的ANTHROPIC_API_KEY"
  33 + # self.deepseek_key = "你的DEEPSEEK_API_KEY"
18 34
  35 + # 配置各API客户端
19 if self.openai_key: 36 if self.openai_key:
20 openai.api_key = self.openai_key 37 openai.api_key = self.openai_key
21 if self.claude_key: 38 if self.claude_key:
22 self.claude_client = anthropic.Anthropic(api_key=self.claude_key) 39 self.claude_client = anthropic.Anthropic(api_key=self.claude_key)
23 if self.deepseek_key: 40 if self.deepseek_key:
24 - # 配置DeepSeek API  
25 self.deepseek_client = openai.OpenAI( 41 self.deepseek_client = openai.OpenAI(
26 api_key=self.deepseek_key, 42 api_key=self.deepseek_key,
27 base_url="https://api.deepseek.com/v1" 43 base_url="https://api.deepseek.com/v1"
28 ) 44 )
29 45
30 - # 支持的模型列表  
31 - self.supported_models = {  
32 - # OpenAI 模型 46 + # 支持的模型列表(增加了最新的 ChatGPT 和 Claude 模型)
  47 + self.supported_models: Dict[str, Dict[str, Any]] = {
  48 + # OpenAI 最新模型(ChatGPT系列)
  49 + 'gpt-4o-latest': {
  50 + 'provider': 'openai',
  51 + 'max_tokens': 128000, # 支持大窗口
  52 + 'cost_per_1k': 0.01 # 参考价格(美元)
  53 + },
  54 + 'gpt-4o-mini': {
  55 + 'provider': 'openai',
  56 + 'max_tokens': 4000, # 轻量版,适合快速任务
  57 + 'cost_per_1k': 0.00015 # 成本大幅降低
  58 + },
  59 + # 旧版OpenAI模型
33 'gpt-3.5-turbo': {'provider': 'openai', 'max_tokens': 2000, 'cost_per_1k': 0.0015}, 60 'gpt-3.5-turbo': {'provider': 'openai', 'max_tokens': 2000, 'cost_per_1k': 0.0015},
34 'gpt-3.5-turbo-16k': {'provider': 'openai', 'max_tokens': 16000, 'cost_per_1k': 0.003}, 61 'gpt-3.5-turbo-16k': {'provider': 'openai', 'max_tokens': 16000, 'cost_per_1k': 0.003},
35 'gpt-4': {'provider': 'openai', 'max_tokens': 8000, 'cost_per_1k': 0.03}, 62 'gpt-4': {'provider': 'openai', 'max_tokens': 8000, 'cost_per_1k': 0.03},
36 'gpt-4-32k': {'provider': 'openai', 'max_tokens': 32000, 'cost_per_1k': 0.06}, 63 'gpt-4-32k': {'provider': 'openai', 'max_tokens': 32000, 'cost_per_1k': 0.06},
37 'gpt-4-turbo-preview': {'provider': 'openai', 'max_tokens': 128000, 'cost_per_1k': 0.01}, 64 'gpt-4-turbo-preview': {'provider': 'openai', 'max_tokens': 128000, 'cost_per_1k': 0.01},
38 65
39 - # Claude 模型  
40 - 'claude-3-opus-20240229': {'provider': 'anthropic', 'max_tokens': 4000, 'cost_per_1k': 0.015},  
41 - 'claude-3-sonnet-20240229': {'provider': 'anthropic', 'max_tokens': 3000, 'cost_per_1k': 0.003},  
42 - 'claude-3-haiku-20240307': {'provider': 'anthropic', 'max_tokens': 2000, 'cost_per_1k': 0.0025}, 66 + # Anthropic 最新模型(Claude系列)
  67 + 'claude-3.5-sonnet-new': {
  68 + 'provider': 'anthropic',
  69 + 'max_tokens': 200000, # 新版Claude 3.5 Sonnet
  70 + 'cost_per_1k': 0.015
  71 + },
  72 + 'claude-3.5-haiku': {
  73 + 'provider': 'anthropic',
  74 + 'max_tokens': 200000, # 最新Claude 3.5 Haiku
  75 + 'cost_per_1k': 0.0025
  76 + },
  77 + # 旧版Claude模型
43 'claude-2.1': {'provider': 'anthropic', 'max_tokens': 100000, 'cost_per_1k': 0.008}, 78 'claude-2.1': {'provider': 'anthropic', 'max_tokens': 100000, 'cost_per_1k': 0.008},
44 'claude-2.0': {'provider': 'anthropic', 'max_tokens': 100000, 'cost_per_1k': 0.008}, 79 'claude-2.0': {'provider': 'anthropic', 'max_tokens': 100000, 'cost_per_1k': 0.008},
45 'claude-instant-1.2': {'provider': 'anthropic', 'max_tokens': 100000, 'cost_per_1k': 0.0015}, 80 'claude-instant-1.2': {'provider': 'anthropic', 'max_tokens': 100000, 'cost_per_1k': 0.0015},
46 81
47 # DeepSeek 模型 82 # DeepSeek 模型
48 - 'deepseek-chat': {'provider': 'deepseek', 'max_tokens': 4000, 'cost_per_1k': 0.002}, # DeepSeek-V3  
49 - 'deepseek-reasoner': {'provider': 'deepseek', 'max_tokens': 4000, 'cost_per_1k': 0.003} # DeepSeek-R1 83 + 'deepseek-chat': {'provider': 'deepseek', 'max_tokens': 4000, 'cost_per_1k': 0.002},
  84 + 'deepseek-reasoner': {'provider': 'deepseek', 'max_tokens': 4000, 'cost_per_1k': 0.003}
50 } 85 }
51 86
52 # 不同深度的分析提示词 87 # 不同深度的分析提示词
53 - self.prompt_templates = { 88 + self.prompt_templates: Dict[str, str] = {
54 'basic': """你是一个专业的舆情分析助手。请对每条消息进行基础的情感分析。 89 'basic': """你是一个专业的舆情分析助手。请对每条消息进行基础的情感分析。
55 请按以下JSON格式返回: 90 请按以下JSON格式返回:
56 { 91 {
@@ -105,9 +140,19 @@ class AIAnalyzer: @@ -105,9 +140,19 @@ class AIAnalyzer:
105 140
106 async def analyze_messages(self, messages: List[Dict], batch_size: int = 50, 141 async def analyze_messages(self, messages: List[Dict], batch_size: int = 50,
107 model_type: str = "gpt-3.5-turbo", 142 model_type: str = "gpt-3.5-turbo",
108 - analysis_depth: str = "standard") -> List[Dict]:  
109 - """分析一批消息并返回分析结果""" 143 + analysis_depth: str = "standard",
  144 + prefer_deepseek: bool = True) -> List[Dict]:
  145 + """
  146 + 分析一批消息并返回分析结果。
  147 + 如果 DeepSeek API 可用且 prefer_deepseek 为 True,则优先使用 DeepSeek 模型。
  148 + """
110 try: 149 try:
  150 + # 优先使用 DeepSeek 模型以降低成本
  151 + if prefer_deepseek and self.deepseek_key:
  152 + if model_type not in ['deepseek-chat', 'deepseek-reasoner']:
  153 + logging.info("检测到 DeepSeek API, 优先使用 'deepseek-chat' 模型以降低成本。")
  154 + model_type = 'deepseek-chat'
  155 +
111 if model_type not in self.supported_models: 156 if model_type not in self.supported_models:
112 raise ValueError(f"不支持的模型类型: {model_type}") 157 raise ValueError(f"不支持的模型类型: {model_type}")
113 158
@@ -116,91 +161,85 @@ class AIAnalyzer: @@ -116,91 +161,85 @@ class AIAnalyzer:
116 max_tokens = model_info['max_tokens'] 161 max_tokens = model_info['max_tokens']
117 162
118 # 根据模型类型调整批处理大小 163 # 根据模型类型调整批处理大小
119 - adjusted_batch_size = min(batch_size, self._get_optimal_batch_size(model_type)) 164 + optimal_batch_size = self._get_optimal_batch_size(model_type)
  165 + adjusted_batch_size = min(batch_size, optimal_batch_size)
120 if adjusted_batch_size != batch_size: 166 if adjusted_batch_size != batch_size:
121 logging.info(f"已将批处理大小从 {batch_size} 调整为 {adjusted_batch_size}") 167 logging.info(f"已将批处理大小从 {batch_size} 调整为 {adjusted_batch_size}")
122 168
123 - all_results = []  
124 - total_cost = 0  
125 -  
126 - # 分批处理消息 169 + tasks = []
  170 + total_cost = 0.0
  171 + # 分批处理消息并异步调用分析任务
127 for i in range(0, len(messages), adjusted_batch_size): 172 for i in range(0, len(messages), adjusted_batch_size):
128 batch = messages[i:i + adjusted_batch_size] 173 batch = messages[i:i + adjusted_batch_size]
129 - formatted_messages = []  
130 - for msg in batch:  
131 - formatted_messages.append(f"消息ID: {msg['id']}\n内容: {msg['content']}") 174 + system_prompt = self.prompt_templates.get(analysis_depth, self.prompt_templates['standard'])
  175 + tasks.append(self._process_batch(batch, system_prompt, model_type, max_tokens, provider))
  176 +
  177 + # 并发执行所有批次任务
  178 + results = await asyncio.gather(*tasks)
  179 +
  180 + all_results = []
  181 + for batch_result, batch_cost in results:
  182 + all_results.extend(batch_result)
  183 + total_cost += batch_cost
  184 +
  185 + logging.info(f"分析完成, 总成本: ${total_cost:.4f}")
  186 + return all_results
  187 + except Exception as e:
  188 + logging.error(f"AI分析过程出错: {e}", exc_info=True)
  189 + return []
132 190
  191 + async def _process_batch(self, batch: List[Dict], system_prompt: str,
  192 + model_type: str, max_tokens: int, provider: str) -> Tuple[List[Dict], float]:
  193 + """
  194 + 处理单个批次的消息,返回 (分析结果, 本批次成本)
  195 + """
  196 + try:
  197 + formatted_messages = [
  198 + f"消息ID: {msg.get('id')}\n内容: {msg.get('content')}" for msg in batch
  199 + ]
133 messages_text = "\n---\n".join(formatted_messages) 200 messages_text = "\n---\n".join(formatted_messages)
134 - system_prompt = self.prompt_templates.get(analysis_depth, self.prompt_templates['standard'])  
135 201
136 if provider == 'openai': 202 if provider == 'openai':
137 - result = await self._analyze_with_openai(  
138 - messages_text,  
139 - system_prompt,  
140 - model_type,  
141 - max_tokens  
142 - ) 203 + result = await self._analyze_with_openai(messages_text, system_prompt, model_type, max_tokens)
143 elif provider == 'anthropic': 204 elif provider == 'anthropic':
144 - result = await self._analyze_with_claude(  
145 - messages_text,  
146 - system_prompt,  
147 - model_type,  
148 - max_tokens  
149 - ) 205 + result = await self._analyze_with_claude(messages_text, system_prompt, model_type, max_tokens)
150 elif provider == 'deepseek': 206 elif provider == 'deepseek':
151 - result = await self._analyze_with_deepseek(  
152 - messages_text,  
153 - system_prompt,  
154 - model_type,  
155 - max_tokens  
156 - ) 207 + result = await self._analyze_with_deepseek(messages_text, system_prompt, model_type, max_tokens)
  208 + else:
  209 + logging.error(f"未知的API供应商: {provider}")
  210 + return ([], 0.0)
157 211
158 - if result:  
159 - all_results.extend(result)  
160 - # 计算本批次成本  
161 batch_cost = self._calculate_cost(len(messages_text), model_type) 212 batch_cost = self._calculate_cost(len(messages_text), model_type)
162 - total_cost += batch_cost  
163 - logging.info(f"批次处理完成,成本: ${batch_cost:.4f}")  
164 -  
165 - logging.info(f"分析完成,总成本: ${total_cost:.4f}")  
166 - return all_results  
167 - 213 + logging.info(f"批次处理完成, 成本: ${batch_cost:.4f}")
  214 + return (result, batch_cost)
168 except Exception as e: 215 except Exception as e:
169 - logging.error(f"AI分析过程出错: {e}")  
170 - return [] 216 + logging.error(f"处理批次时出错: {e}", exc_info=True)
  217 + return ([], 0.0)
171 218
172 def _get_optimal_batch_size(self, model_type: str) -> int: 219 def _get_optimal_batch_size(self, model_type: str) -> int:
173 """根据模型类型获取最优批处理大小""" 220 """根据模型类型获取最优批处理大小"""
174 model_info = self.supported_models[model_type] 221 model_info = self.supported_models[model_type]
175 max_tokens = model_info['max_tokens'] 222 max_tokens = model_info['max_tokens']
176 223
177 - # 估算每条消息的平均token数(假设为200) 224 + # 估算每条消息的平均 token 数(假设为 200)
178 avg_tokens_per_message = 200 225 avg_tokens_per_message = 200
179 -  
180 - # 预留20%的token用于系统提示词和响应 226 + # 预留 20% 的 token 用于系统提示词和响应
181 available_tokens = int(max_tokens * 0.8) 227 available_tokens = int(max_tokens * 0.8)
182 -  
183 - # 计算最优批处理大小  
184 optimal_batch_size = max(1, min(100, available_tokens // avg_tokens_per_message)) 228 optimal_batch_size = max(1, min(100, available_tokens // avg_tokens_per_message))
185 -  
186 return optimal_batch_size 229 return optimal_batch_size
187 230
188 def _calculate_cost(self, input_length: int, model_type: str) -> float: 231 def _calculate_cost(self, input_length: int, model_type: str) -> float:
189 - """计算API调用成本""" 232 + """计算 API 调用成本"""
190 model_info = self.supported_models[model_type] 233 model_info = self.supported_models[model_type]
191 cost_per_1k = model_info['cost_per_1k'] 234 cost_per_1k = model_info['cost_per_1k']
192 -  
193 - # 估算token数(假设每4个字符约等于1个token)  
194 - estimated_tokens = input_length // 4  
195 -  
196 - # 计算成本(美元) 235 + # 估算 token 数(假设每 4 个字符约等于 1 个 token)
  236 + estimated_tokens = math.ceil(input_length / 4)
197 cost = (estimated_tokens / 1000) * cost_per_1k 237 cost = (estimated_tokens / 1000) * cost_per_1k
198 -  
199 return cost 238 return cost
200 239
201 async def _analyze_with_openai(self, messages_text: str, system_prompt: str, 240 async def _analyze_with_openai(self, messages_text: str, system_prompt: str,
202 model: str, max_tokens: int) -> List[Dict]: 241 model: str, max_tokens: int) -> List[Dict]:
203 - """使用OpenAI API进行分析""" 242 + """使用 OpenAI API 进行分析"""
204 try: 243 try:
205 response = await openai.ChatCompletion.acreate( 244 response = await openai.ChatCompletion.acreate(
206 model=model, 245 model=model,
@@ -210,52 +249,44 @@ class AIAnalyzer: @@ -210,52 +249,44 @@ class AIAnalyzer:
210 ], 249 ],
211 temperature=0.3, 250 temperature=0.3,
212 max_tokens=max_tokens, 251 max_tokens=max_tokens,
213 - n=1,  
214 - response_format={"type": "json_object"} # 强制JSON响应格式 252 + n=1
215 ) 253 )
216 -  
217 - result = json.loads(response.choices[0].message.content) 254 + content = response.choices[0].message.content
  255 + result = json.loads(content)
218 if isinstance(result, dict) and 'analysis_results' in result: 256 if isinstance(result, dict) and 'analysis_results' in result:
219 return result['analysis_results'] 257 return result['analysis_results']
220 else: 258 else:
221 - logging.error(f"OpenAI API返回格式不正确: {response.choices[0].message.content}") 259 + logging.error(f"OpenAI API返回格式不正确: {content}")
222 return [] 260 return []
223 -  
224 except Exception as e: 261 except Exception as e:
225 - logging.error(f"OpenAI API调用失败: {e}") 262 + logging.error(f"OpenAI API调用失败: {e}", exc_info=True)
226 return [] 263 return []
227 264
228 async def _analyze_with_claude(self, messages_text: str, system_prompt: str, 265 async def _analyze_with_claude(self, messages_text: str, system_prompt: str,
229 model: str, max_tokens: int) -> List[Dict]: 266 model: str, max_tokens: int) -> List[Dict]:
230 - """使用Claude API进行分析""" 267 + """使用 Claude API 进行分析"""
231 try: 268 try:
232 response = await self.claude_client.messages.create( 269 response = await self.claude_client.messages.create(
233 model=model, 270 model=model,
234 max_tokens=max_tokens, 271 max_tokens=max_tokens,
235 temperature=0.3, 272 temperature=0.3,
236 system=system_prompt, 273 system=system_prompt,
237 - messages=[  
238 - {  
239 - "role": "user",  
240 - "content": f"请分析以下消息:\n{messages_text}"  
241 - }  
242 - ] 274 + messages=[{"role": "user", "content": f"请分析以下消息:\n{messages_text}"}]
243 ) 275 )
244 -  
245 - result = json.loads(response.content[0].text) 276 + content = response.content[0].text
  277 + result = json.loads(content)
246 if isinstance(result, dict) and 'analysis_results' in result: 278 if isinstance(result, dict) and 'analysis_results' in result:
247 return result['analysis_results'] 279 return result['analysis_results']
248 else: 280 else:
249 - logging.error(f"Claude API返回格式不正确: {response.content[0].text}") 281 + logging.error(f"Claude API返回格式不正确: {content}")
250 return [] 282 return []
251 -  
252 except Exception as e: 283 except Exception as e:
253 - logging.error(f"Claude API调用失败: {e}") 284 + logging.error(f"Claude API调用失败: {e}", exc_info=True)
254 return [] 285 return []
255 286
256 async def _analyze_with_deepseek(self, messages_text: str, system_prompt: str, 287 async def _analyze_with_deepseek(self, messages_text: str, system_prompt: str,
257 model: str, max_tokens: int) -> List[Dict]: 288 model: str, max_tokens: int) -> List[Dict]:
258 - """使用DeepSeek API进行分析""" 289 + """使用 DeepSeek API 进行分析"""
259 try: 290 try:
260 response = await self.deepseek_client.chat.completions.create( 291 response = await self.deepseek_client.chat.completions.create(
261 model=model, 292 model=model,
@@ -264,44 +295,57 @@ class AIAnalyzer: @@ -264,44 +295,57 @@ class AIAnalyzer:
264 {"role": "user", "content": f"请分析以下消息:\n{messages_text}"} 295 {"role": "user", "content": f"请分析以下消息:\n{messages_text}"}
265 ], 296 ],
266 temperature=0.3, 297 temperature=0.3,
267 - max_tokens=max_tokens,  
268 - response_format={"type": "json_object"} # 强制JSON响应格式 298 + max_tokens=max_tokens
269 ) 299 )
270 -  
271 - result = json.loads(response.choices[0].message.content) 300 + content = response.choices[0].message.content
  301 + result = json.loads(content)
272 if isinstance(result, dict) and 'analysis_results' in result: 302 if isinstance(result, dict) and 'analysis_results' in result:
273 return result['analysis_results'] 303 return result['analysis_results']
274 else: 304 else:
275 - logging.error(f"DeepSeek API返回格式不正确: {response.choices[0].message.content}") 305 + logging.error(f"DeepSeek API返回格式不正确: {content}")
276 return [] 306 return []
277 -  
278 except Exception as e: 307 except Exception as e:
279 - logging.error(f"DeepSeek API调用失败: {e}") 308 + logging.error(f"DeepSeek API调用失败: {e}", exc_info=True)
280 return [] 309 return []
281 310
282 def format_analysis_for_display(self, analysis: Dict) -> Dict: 311 def format_analysis_for_display(self, analysis: Dict) -> Dict:
283 """将分析结果格式化为前端显示格式""" 312 """将分析结果格式化为前端显示格式"""
284 base_result = { 313 base_result = {
285 - 'id': analysis['message_id'],  
286 - 'sentiment': analysis['sentiment'],  
287 - 'sentiment_score': f"{float(analysis['sentiment_score']):.2%}",  
288 - 'keywords': ', '.join(analysis['keywords']),  
289 - 'key_points': analysis['key_points'],  
290 - 'influence': analysis['influence_analysis'],  
291 - 'risk_level': analysis['risk_level'], 314 + 'id': analysis.get('message_id', ''),
  315 + 'sentiment': analysis.get('sentiment', ''),
  316 + 'sentiment_score': f"{float(analysis.get('sentiment_score', 0)):.2%}",
  317 + 'keywords': ', '.join(analysis.get('keywords', [])),
  318 + 'key_points': analysis.get('key_points', ''),
  319 + 'influence': analysis.get('influence_analysis', ''),
  320 + 'risk_level': analysis.get('risk_level', ''),
292 'analysis_time': datetime.fromtimestamp( 321 'analysis_time': datetime.fromtimestamp(
293 - float(analysis['timestamp']) 322 + float(analysis.get('timestamp', 0))
294 ).strftime('%Y-%m-%d %H:%M:%S') 323 ).strftime('%Y-%m-%d %H:%M:%S')
295 } 324 }
296 325
297 # 如果是深度分析,添加额外信息 326 # 如果是深度分析,添加额外信息
298 if 'risk_factors' in analysis: 327 if 'risk_factors' in analysis:
299 base_result.update({ 328 base_result.update({
300 - 'risk_factors': analysis['risk_factors'],  
301 - 'suggestions': analysis['suggestions'] 329 + 'risk_factors': analysis.get('risk_factors', []),
  330 + 'suggestions': analysis.get('suggestions', [])
302 }) 331 })
303 332
304 return base_result 333 return base_result
305 334
306 -# 创建全局AI分析器实例 335 +# 创建全局 AI 分析器实例
307 ai_analyzer = AIAnalyzer() 336 ai_analyzer = AIAnalyzer()
  337 +
  338 +# 若需要直接配置或测试,可在此处编写测试代码
  339 +if __name__ == "__main__":
  340 + # 示例:直接配置并调用分析器(可替换为实际测试代码)
  341 + sample_messages = [
  342 + {"id": "1", "content": "今天天气真好,我很开心。"},
  343 + {"id": "2", "content": "经济形势不容乐观,风险较大。"}
  344 + ]
  345 +
  346 + async def test():
  347 + results = await ai_analyzer.analyze_messages(sample_messages, model_type="gpt-4o-latest", analysis_depth="standard")
  348 + for res in results:
  349 + print(ai_analyzer.format_analysis_for_display(res))
  350 +
  351 + asyncio.run(test())