Showing
1 changed file
with
178 additions
and
114 deletions
| @@ -11,6 +11,7 @@ import re | @@ -11,6 +11,7 @@ import re | ||
| 11 | 11 | ||
| 12 | try: | 12 | try: |
| 13 | import torch | 13 | import torch |
| 14 | + | ||
| 14 | TORCH_AVAILABLE = True | 15 | TORCH_AVAILABLE = True |
| 15 | except ImportError: | 16 | except ImportError: |
| 16 | torch = None # type: ignore | 17 | torch = None # type: ignore |
| @@ -18,6 +19,7 @@ except ImportError: | @@ -18,6 +19,7 @@ except ImportError: | ||
| 18 | 19 | ||
| 19 | try: | 20 | try: |
| 20 | from transformers import AutoTokenizer, AutoModelForSequenceClassification | 21 | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| 22 | + | ||
| 21 | TRANSFORMERS_AVAILABLE = True | 23 | TRANSFORMERS_AVAILABLE = True |
| 22 | except ImportError: | 24 | except ImportError: |
| 23 | AutoTokenizer = None # type: ignore | 25 | AutoTokenizer = None # type: ignore |
| @@ -28,6 +30,7 @@ except ImportError: | @@ -28,6 +30,7 @@ except ImportError: | ||
| 28 | # INFO:若想跳过情感分析,可手动切换此开关为False | 30 | # INFO:若想跳过情感分析,可手动切换此开关为False |
| 29 | SENTIMENT_ANALYSIS_ENABLED = True | 31 | SENTIMENT_ANALYSIS_ENABLED = True |
| 30 | 32 | ||
| 33 | + | ||
| 31 | def _describe_missing_dependencies() -> str: | 34 | def _describe_missing_dependencies() -> str: |
| 32 | missing = [] | 35 | missing = [] |
| 33 | if not TORCH_AVAILABLE: | 36 | if not TORCH_AVAILABLE: |
| @@ -36,14 +39,21 @@ def _describe_missing_dependencies() -> str: | @@ -36,14 +39,21 @@ def _describe_missing_dependencies() -> str: | ||
| 36 | missing.append("Transformers") | 39 | missing.append("Transformers") |
| 37 | return " / ".join(missing) | 40 | return " / ".join(missing) |
| 38 | 41 | ||
| 42 | + | ||
| 39 | # 添加项目根目录到路径,以便导入WeiboMultilingualSentiment | 43 | # 添加项目根目录到路径,以便导入WeiboMultilingualSentiment |
| 40 | -project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||
| 41 | -weibo_sentiment_path = os.path.join(project_root, "SentimentAnalysisModel", "WeiboMultilingualSentiment") | 44 | +project_root = os.path.dirname( |
| 45 | + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||
| 46 | +) | ||
| 47 | +weibo_sentiment_path = os.path.join( | ||
| 48 | + project_root, "SentimentAnalysisModel", "WeiboMultilingualSentiment" | ||
| 49 | +) | ||
| 42 | sys.path.append(weibo_sentiment_path) | 50 | sys.path.append(weibo_sentiment_path) |
| 43 | 51 | ||
| 52 | + | ||
| 44 | @dataclass | 53 | @dataclass |
| 45 | class SentimentResult: | 54 | class SentimentResult: |
| 46 | """情感分析结果数据类""" | 55 | """情感分析结果数据类""" |
| 56 | + | ||
| 47 | text: str | 57 | text: str |
| 48 | sentiment_label: str | 58 | sentiment_label: str |
| 49 | confidence: float | 59 | confidence: float |
| @@ -53,9 +63,10 @@ class SentimentResult: | @@ -53,9 +63,10 @@ class SentimentResult: | ||
| 53 | analysis_performed: bool = True | 63 | analysis_performed: bool = True |
| 54 | 64 | ||
| 55 | 65 | ||
| 56 | -@dataclass | 66 | +@dataclass |
| 57 | class BatchSentimentResult: | 67 | class BatchSentimentResult: |
| 58 | """批量情感分析结果数据类""" | 68 | """批量情感分析结果数据类""" |
| 69 | + | ||
| 59 | results: List[SentimentResult] | 70 | results: List[SentimentResult] |
| 60 | total_processed: int | 71 | total_processed: int |
| 61 | success_count: int | 72 | success_count: int |
| @@ -69,7 +80,7 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -69,7 +80,7 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 69 | 多语言情感分析器 | 80 | 多语言情感分析器 |
| 70 | 封装WeiboMultilingualSentiment模型,为AI Agent提供情感分析功能 | 81 | 封装WeiboMultilingualSentiment模型,为AI Agent提供情感分析功能 |
| 71 | """ | 82 | """ |
| 72 | - | 83 | + |
| 73 | def __init__(self): | 84 | def __init__(self): |
| 74 | """初始化情感分析器""" | 85 | """初始化情感分析器""" |
| 75 | self.model = None | 86 | self.model = None |
| @@ -78,14 +89,14 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -78,14 +89,14 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 78 | self.is_initialized = False | 89 | self.is_initialized = False |
| 79 | self.is_disabled = False | 90 | self.is_disabled = False |
| 80 | self.disable_reason: Optional[str] = None | 91 | self.disable_reason: Optional[str] = None |
| 81 | - | 92 | + |
| 82 | # 情感标签映射(5级分类) | 93 | # 情感标签映射(5级分类) |
| 83 | self.sentiment_map = { | 94 | self.sentiment_map = { |
| 84 | - 0: "非常负面", | ||
| 85 | - 1: "负面", | ||
| 86 | - 2: "中性", | ||
| 87 | - 3: "正面", | ||
| 88 | - 4: "非常正面" | 95 | + 0: "非常负面", |
| 96 | + 1: "负面", | ||
| 97 | + 2: "中性", | ||
| 98 | + 3: "正面", | ||
| 99 | + 4: "非常正面", | ||
| 89 | } | 100 | } |
| 90 | 101 | ||
| 91 | if not SENTIMENT_ANALYSIS_ENABLED: | 102 | if not SENTIMENT_ANALYSIS_ENABLED: |
| @@ -96,9 +107,13 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -96,9 +107,13 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 96 | 107 | ||
| 97 | if self.is_disabled: | 108 | if self.is_disabled: |
| 98 | reason = self.disable_reason or "Sentiment analysis disabled." | 109 | reason = self.disable_reason or "Sentiment analysis disabled." |
| 99 | - print(f"WeiboMultilingualSentimentAnalyzer initialized but disabled: {reason}") | 110 | + print( |
| 111 | + f"WeiboMultilingualSentimentAnalyzer initialized but disabled: {reason}" | ||
| 112 | + ) | ||
| 100 | else: | 113 | else: |
| 101 | - print("WeiboMultilingualSentimentAnalyzer 已创建,调用 initialize() 来加载模型") | 114 | + print( |
| 115 | + "WeiboMultilingualSentimentAnalyzer 已创建,调用 initialize() 来加载模型" | ||
| 116 | + ) | ||
| 102 | 117 | ||
| 103 | def disable(self, reason: Optional[str] = None, drop_state: bool = False) -> None: | 118 | def disable(self, reason: Optional[str] = None, drop_state: bool = False) -> None: |
| 104 | """Disable sentiment analysis, optionally clearing loaded resources.""" | 119 | """Disable sentiment analysis, optionally clearing loaded resources.""" |
| @@ -130,14 +145,18 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -130,14 +145,18 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 130 | if torch.cuda.is_available(): | 145 | if torch.cuda.is_available(): |
| 131 | return torch.device("cuda") | 146 | return torch.device("cuda") |
| 132 | mps_backend = getattr(torch.backends, "mps", None) | 147 | mps_backend = getattr(torch.backends, "mps", None) |
| 133 | - if mps_backend and getattr(mps_backend, "is_available", lambda: False)() and getattr(mps_backend, "is_built", lambda: False)(): | 148 | + if ( |
| 149 | + mps_backend | ||
| 150 | + and getattr(mps_backend, "is_available", lambda: False)() | ||
| 151 | + and getattr(mps_backend, "is_built", lambda: False)() | ||
| 152 | + ): | ||
| 134 | return torch.device("mps") | 153 | return torch.device("mps") |
| 135 | return torch.device("cpu") | 154 | return torch.device("cpu") |
| 136 | - | 155 | + |
| 137 | def initialize(self) -> bool: | 156 | def initialize(self) -> bool: |
| 138 | """ | 157 | """ |
| 139 | 初始化模型和分词器 | 158 | 初始化模型和分词器 |
| 140 | - | 159 | + |
| 141 | Returns: | 160 | Returns: |
| 142 | 是否初始化成功 | 161 | 是否初始化成功 |
| 143 | """ | 162 | """ |
| @@ -155,31 +174,35 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -155,31 +174,35 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 155 | if self.is_initialized: | 174 | if self.is_initialized: |
| 156 | print("模型已经初始化,无需重复加载") | 175 | print("模型已经初始化,无需重复加载") |
| 157 | return True | 176 | return True |
| 158 | - | 177 | + |
| 159 | try: | 178 | try: |
| 160 | print("正在加载多语言情感分析模型...") | 179 | print("正在加载多语言情感分析模型...") |
| 161 | - | 180 | + |
| 162 | # 使用多语言情感分析模型 | 181 | # 使用多语言情感分析模型 |
| 163 | model_name = "tabularisai/multilingual-sentiment-analysis" | 182 | model_name = "tabularisai/multilingual-sentiment-analysis" |
| 164 | local_model_path = os.path.join(weibo_sentiment_path, "model") | 183 | local_model_path = os.path.join(weibo_sentiment_path, "model") |
| 165 | - | 184 | + |
| 166 | # 检查本地是否已有模型 | 185 | # 检查本地是否已有模型 |
| 167 | if os.path.exists(local_model_path): | 186 | if os.path.exists(local_model_path): |
| 168 | print("从本地加载模型...") | 187 | print("从本地加载模型...") |
| 169 | self.tokenizer = AutoTokenizer.from_pretrained(local_model_path) | 188 | self.tokenizer = AutoTokenizer.from_pretrained(local_model_path) |
| 170 | - self.model = AutoModelForSequenceClassification.from_pretrained(local_model_path) | 189 | + self.model = AutoModelForSequenceClassification.from_pretrained( |
| 190 | + local_model_path | ||
| 191 | + ) | ||
| 171 | else: | 192 | else: |
| 172 | print("首次使用,正在下载模型到本地...") | 193 | print("首次使用,正在下载模型到本地...") |
| 173 | # 下载并保存到本地 | 194 | # 下载并保存到本地 |
| 174 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) | 195 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 175 | - self.model = AutoModelForSequenceClassification.from_pretrained(model_name) | ||
| 176 | - | 196 | + self.model = AutoModelForSequenceClassification.from_pretrained( |
| 197 | + model_name | ||
| 198 | + ) | ||
| 199 | + | ||
| 177 | # 保存到本地 | 200 | # 保存到本地 |
| 178 | os.makedirs(local_model_path, exist_ok=True) | 201 | os.makedirs(local_model_path, exist_ok=True) |
| 179 | self.tokenizer.save_pretrained(local_model_path) | 202 | self.tokenizer.save_pretrained(local_model_path) |
| 180 | self.model.save_pretrained(local_model_path) | 203 | self.model.save_pretrained(local_model_path) |
| 181 | print(f"模型已保存到: {local_model_path}") | 204 | print(f"模型已保存到: {local_model_path}") |
| 182 | - | 205 | + |
| 183 | # 设置设备 | 206 | # 设置设备 |
| 184 | device = self._select_device() | 207 | device = self._select_device() |
| 185 | if device is None: | 208 | if device is None: |
| @@ -198,46 +221,46 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -198,46 +221,46 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 198 | print("检测到 Apple MPS 设备,已使用 MPS 进行推理。") | 221 | print("检测到 Apple MPS 设备,已使用 MPS 进行推理。") |
| 199 | else: | 222 | else: |
| 200 | print("未检测到 GPU,自动使用 CPU 进行推理。") | 223 | print("未检测到 GPU,自动使用 CPU 进行推理。") |
| 201 | - | 224 | + |
| 202 | print(f"模型加载成功! 使用设备: {self.device}") | 225 | print(f"模型加载成功! 使用设备: {self.device}") |
| 203 | print("支持语言: 中文、英文、西班牙文、阿拉伯文、日文、韩文等22种语言") | 226 | print("支持语言: 中文、英文、西班牙文、阿拉伯文、日文、韩文等22种语言") |
| 204 | print("情感等级: 非常负面、负面、中性、正面、非常正面") | 227 | print("情感等级: 非常负面、负面、中性、正面、非常正面") |
| 205 | - | 228 | + |
| 206 | return True | 229 | return True |
| 207 | - | 230 | + |
| 208 | except Exception as e: | 231 | except Exception as e: |
| 209 | error_message = f"模型加载失败: {e}" | 232 | error_message = f"模型加载失败: {e}" |
| 210 | print(error_message) | 233 | print(error_message) |
| 211 | print("请检查网络连接或模型文件") | 234 | print("请检查网络连接或模型文件") |
| 212 | self.disable(error_message, drop_state=True) | 235 | self.disable(error_message, drop_state=True) |
| 213 | return False | 236 | return False |
| 214 | - | 237 | + |
| 215 | def _preprocess_text(self, text: str) -> str: | 238 | def _preprocess_text(self, text: str) -> str: |
| 216 | """ | 239 | """ |
| 217 | 文本预处理 | 240 | 文本预处理 |
| 218 | - | 241 | + |
| 219 | Args: | 242 | Args: |
| 220 | text: 输入文本 | 243 | text: 输入文本 |
| 221 | - | 244 | + |
| 222 | Returns: | 245 | Returns: |
| 223 | 处理后的文本 | 246 | 处理后的文本 |
| 224 | """ | 247 | """ |
| 225 | # 基本文本清理 | 248 | # 基本文本清理 |
| 226 | if not text or not text.strip(): | 249 | if not text or not text.strip(): |
| 227 | return "" | 250 | return "" |
| 228 | - | 251 | + |
| 229 | # 去除多余空格 | 252 | # 去除多余空格 |
| 230 | - text = re.sub(r'\s+', ' ', text.strip()) | ||
| 231 | - | 253 | + text = re.sub(r"\s+", " ", text.strip()) |
| 254 | + | ||
| 232 | return text | 255 | return text |
| 233 | - | 256 | + |
| 234 | def analyze_single_text(self, text: str) -> SentimentResult: | 257 | def analyze_single_text(self, text: str) -> SentimentResult: |
| 235 | """ | 258 | """ |
| 236 | 对单个文本进行情感分析 | 259 | 对单个文本进行情感分析 |
| 237 | - | 260 | + |
| 238 | Args: | 261 | Args: |
| 239 | text: 要分析的文本 | 262 | text: 要分析的文本 |
| 240 | - | 263 | + |
| 241 | Returns: | 264 | Returns: |
| 242 | SentimentResult对象 | 265 | SentimentResult对象 |
| 243 | """ | 266 | """ |
| @@ -249,7 +272,7 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -249,7 +272,7 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 249 | probability_distribution={}, | 272 | probability_distribution={}, |
| 250 | success=False, | 273 | success=False, |
| 251 | error_message=self.disable_reason or "情感分析功能已禁用", | 274 | error_message=self.disable_reason or "情感分析功能已禁用", |
| 252 | - analysis_performed=False | 275 | + analysis_performed=False, |
| 253 | ) | 276 | ) |
| 254 | 277 | ||
| 255 | if not self.is_initialized: | 278 | if not self.is_initialized: |
| @@ -260,7 +283,7 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -260,7 +283,7 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 260 | probability_distribution={}, | 283 | probability_distribution={}, |
| 261 | success=False, | 284 | success=False, |
| 262 | error_message="模型未初始化,请先调用initialize() 方法", | 285 | error_message="模型未初始化,请先调用initialize() 方法", |
| 263 | - analysis_performed=False | 286 | + analysis_performed=False, |
| 264 | ) | 287 | ) |
| 265 | 288 | ||
| 266 | try: | 289 | try: |
| @@ -275,7 +298,7 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -275,7 +298,7 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 275 | probability_distribution={}, | 298 | probability_distribution={}, |
| 276 | success=False, | 299 | success=False, |
| 277 | error_message="输入文本为空或无效内容", | 300 | error_message="输入文本为空或无效内容", |
| 278 | - analysis_performed=False | 301 | + analysis_performed=False, |
| 279 | ) | 302 | ) |
| 280 | 303 | ||
| 281 | # 分词编码 | 304 | # 分词编码 |
| @@ -284,7 +307,7 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -284,7 +307,7 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 284 | max_length=512, | 307 | max_length=512, |
| 285 | padding=True, | 308 | padding=True, |
| 286 | truncation=True, | 309 | truncation=True, |
| 287 | - return_tensors='pt' | 310 | + return_tensors="pt", |
| 288 | ) | 311 | ) |
| 289 | 312 | ||
| 290 | # 转移到设备 | 313 | # 转移到设备 |
| @@ -311,7 +334,7 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -311,7 +334,7 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 311 | sentiment_label=label, | 334 | sentiment_label=label, |
| 312 | confidence=confidence, | 335 | confidence=confidence, |
| 313 | probability_distribution=prob_dist, | 336 | probability_distribution=prob_dist, |
| 314 | - success=True | 337 | + success=True, |
| 315 | ) | 338 | ) |
| 316 | 339 | ||
| 317 | except Exception as e: | 340 | except Exception as e: |
| @@ -322,17 +345,19 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -322,17 +345,19 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 322 | probability_distribution={}, | 345 | probability_distribution={}, |
| 323 | success=False, | 346 | success=False, |
| 324 | error_message=f"预测时发生错误: {str(e)}", | 347 | error_message=f"预测时发生错误: {str(e)}", |
| 325 | - analysis_performed=False | 348 | + analysis_performed=False, |
| 326 | ) | 349 | ) |
| 327 | 350 | ||
| 328 | - def analyze_batch(self, texts: List[str], show_progress: bool = True) -> BatchSentimentResult: | 351 | + def analyze_batch( |
| 352 | + self, texts: List[str], show_progress: bool = True | ||
| 353 | + ) -> BatchSentimentResult: | ||
| 329 | """ | 354 | """ |
| 330 | 批量情感分析 | 355 | 批量情感分析 |
| 331 | - | 356 | + |
| 332 | Args: | 357 | Args: |
| 333 | texts: 文本列表 | 358 | texts: 文本列表 |
| 334 | show_progress: 是否显示进度 | 359 | show_progress: 是否显示进度 |
| 335 | - | 360 | + |
| 336 | Returns: | 361 | Returns: |
| 337 | BatchSentimentResult对象 | 362 | BatchSentimentResult对象 |
| 338 | """ | 363 | """ |
| @@ -343,9 +368,9 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -343,9 +368,9 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 343 | success_count=0, | 368 | success_count=0, |
| 344 | failed_count=0, | 369 | failed_count=0, |
| 345 | average_confidence=0.0, | 370 | average_confidence=0.0, |
| 346 | - analysis_performed=not self.is_disabled and self.is_initialized | 371 | + analysis_performed=not self.is_disabled and self.is_initialized, |
| 347 | ) | 372 | ) |
| 348 | - | 373 | + |
| 349 | if self.is_disabled or not self.is_initialized: | 374 | if self.is_disabled or not self.is_initialized: |
| 350 | passthrough_results = [ | 375 | passthrough_results = [ |
| 351 | SentimentResult( | 376 | SentimentResult( |
| @@ -355,7 +380,7 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -355,7 +380,7 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 355 | probability_distribution={}, | 380 | probability_distribution={}, |
| 356 | success=False, | 381 | success=False, |
| 357 | error_message=self.disable_reason or "情感分析功能不可用", | 382 | error_message=self.disable_reason or "情感分析功能不可用", |
| 358 | - analysis_performed=False | 383 | + analysis_performed=False, |
| 359 | ) | 384 | ) |
| 360 | for text in texts | 385 | for text in texts |
| 361 | ] | 386 | ] |
| @@ -365,42 +390,44 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -365,42 +390,44 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 365 | success_count=0, | 390 | success_count=0, |
| 366 | failed_count=len(texts), | 391 | failed_count=len(texts), |
| 367 | average_confidence=0.0, | 392 | average_confidence=0.0, |
| 368 | - analysis_performed=False | 393 | + analysis_performed=False, |
| 369 | ) | 394 | ) |
| 370 | - | 395 | + |
| 371 | results = [] | 396 | results = [] |
| 372 | success_count = 0 | 397 | success_count = 0 |
| 373 | total_confidence = 0.0 | 398 | total_confidence = 0.0 |
| 374 | - | 399 | + |
| 375 | for i, text in enumerate(texts): | 400 | for i, text in enumerate(texts): |
| 376 | if show_progress and len(texts) > 1: | 401 | if show_progress and len(texts) > 1: |
| 377 | - print(f"处理进度: {i+1}/{len(texts)}") | ||
| 378 | - | 402 | + print(f"处理进度: {i + 1}/{len(texts)}") |
| 403 | + | ||
| 379 | result = self.analyze_single_text(text) | 404 | result = self.analyze_single_text(text) |
| 380 | results.append(result) | 405 | results.append(result) |
| 381 | - | 406 | + |
| 382 | if result.success: | 407 | if result.success: |
| 383 | success_count += 1 | 408 | success_count += 1 |
| 384 | total_confidence += result.confidence | 409 | total_confidence += result.confidence |
| 385 | - | ||
| 386 | - average_confidence = total_confidence / success_count if success_count > 0 else 0.0 | 410 | + |
| 411 | + average_confidence = ( | ||
| 412 | + total_confidence / success_count if success_count > 0 else 0.0 | ||
| 413 | + ) | ||
| 387 | failed_count = len(texts) - success_count | 414 | failed_count = len(texts) - success_count |
| 388 | - | 415 | + |
| 389 | return BatchSentimentResult( | 416 | return BatchSentimentResult( |
| 390 | results=results, | 417 | results=results, |
| 391 | total_processed=len(texts), | 418 | total_processed=len(texts), |
| 392 | success_count=success_count, | 419 | success_count=success_count, |
| 393 | failed_count=failed_count, | 420 | failed_count=failed_count, |
| 394 | average_confidence=average_confidence, | 421 | average_confidence=average_confidence, |
| 395 | - analysis_performed=True | 422 | + analysis_performed=True, |
| 396 | ) | 423 | ) |
| 397 | - | 424 | + |
| 398 | def _build_passthrough_analysis( | 425 | def _build_passthrough_analysis( |
| 399 | self, | 426 | self, |
| 400 | original_data: List[Dict[str, Any]], | 427 | original_data: List[Dict[str, Any]], |
| 401 | reason: str, | 428 | reason: str, |
| 402 | texts: Optional[List[str]] = None, | 429 | texts: Optional[List[str]] = None, |
| 403 | - results: Optional[List[SentimentResult]] = None | 430 | + results: Optional[List[SentimentResult]] = None, |
| 404 | ) -> Dict[str, Any]: | 431 | ) -> Dict[str, Any]: |
| 405 | """ | 432 | """ |
| 406 | 构建在情感分析不可用时的透传结果 | 433 | 构建在情感分析不可用时的透传结果 |
| @@ -416,33 +443,36 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -416,33 +443,36 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 416 | "sentiment_distribution": {}, | 443 | "sentiment_distribution": {}, |
| 417 | "high_confidence_results": [], | 444 | "high_confidence_results": [], |
| 418 | "summary": f"情感分析未执行:{reason}", | 445 | "summary": f"情感分析未执行:{reason}", |
| 419 | - "original_texts": original_data | 446 | + "original_texts": original_data, |
| 420 | } | 447 | } |
| 421 | } | 448 | } |
| 422 | - | 449 | + |
| 423 | if texts is not None: | 450 | if texts is not None: |
| 424 | response["sentiment_analysis"]["passthrough_texts"] = texts | 451 | response["sentiment_analysis"]["passthrough_texts"] = texts |
| 425 | - | 452 | + |
| 426 | if results is not None: | 453 | if results is not None: |
| 427 | response["sentiment_analysis"]["results"] = [ | 454 | response["sentiment_analysis"]["results"] = [ |
| 428 | result.__dict__ if isinstance(result, SentimentResult) else result | 455 | result.__dict__ if isinstance(result, SentimentResult) else result |
| 429 | for result in results | 456 | for result in results |
| 430 | ] | 457 | ] |
| 431 | - | 458 | + |
| 432 | return response | 459 | return response |
| 433 | - | ||
| 434 | - def analyze_query_results(self, query_results: List[Dict[str, Any]], | ||
| 435 | - text_field: str = "content", | ||
| 436 | - min_confidence: float = 0.5) -> Dict[str, Any]: | 460 | + |
| 461 | + def analyze_query_results( | ||
| 462 | + self, | ||
| 463 | + query_results: List[Dict[str, Any]], | ||
| 464 | + text_field: str = "content", | ||
| 465 | + min_confidence: float = 0.5, | ||
| 466 | + ) -> Dict[str, Any]: | ||
| 437 | """ | 467 | """ |
| 438 | 对查询结果进行情感分析 | 468 | 对查询结果进行情感分析 |
| 439 | 专门用于分析从MediaCrawlerDB返回的查询结果 | 469 | 专门用于分析从MediaCrawlerDB返回的查询结果 |
| 440 | - | 470 | + |
| 441 | Args: | 471 | Args: |
| 442 | query_results: 查询结果列表,每个元素包含文本内容 | 472 | query_results: 查询结果列表,每个元素包含文本内容 |
| 443 | text_field: 文本内容字段名,默认为"content" | 473 | text_field: 文本内容字段名,默认为"content" |
| 444 | min_confidence: 最小置信度阈值 | 474 | min_confidence: 最小置信度阈值 |
| 445 | - | 475 | + |
| 446 | Returns: | 476 | Returns: |
| 447 | 包含情感分析结果的字典 | 477 | 包含情感分析结果的字典 |
| 448 | """ | 478 | """ |
| @@ -452,14 +482,14 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -452,14 +482,14 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 452 | "total_analyzed": 0, | 482 | "total_analyzed": 0, |
| 453 | "sentiment_distribution": {}, | 483 | "sentiment_distribution": {}, |
| 454 | "high_confidence_results": [], | 484 | "high_confidence_results": [], |
| 455 | - "summary": "没有内容需要分析" | 485 | + "summary": "没有内容需要分析", |
| 456 | } | 486 | } |
| 457 | } | 487 | } |
| 458 | - | 488 | + |
| 459 | # 提取文本内容 | 489 | # 提取文本内容 |
| 460 | texts_to_analyze = [] | 490 | texts_to_analyze = [] |
| 461 | original_data = [] | 491 | original_data = [] |
| 462 | - | 492 | + |
| 463 | for item in query_results: | 493 | for item in query_results: |
| 464 | # 尝试多个可能的文本字段 | 494 | # 尝试多个可能的文本字段 |
| 465 | text_content = "" | 495 | text_content = "" |
| @@ -467,49 +497,52 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -467,49 +497,52 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 467 | if field in item and item[field]: | 497 | if field in item and item[field]: |
| 468 | text_content = str(item[field]) | 498 | text_content = str(item[field]) |
| 469 | break | 499 | break |
| 470 | - | 500 | + |
| 471 | if text_content.strip(): | 501 | if text_content.strip(): |
| 472 | texts_to_analyze.append(text_content) | 502 | texts_to_analyze.append(text_content) |
| 473 | original_data.append(item) | 503 | original_data.append(item) |
| 474 | - | 504 | + |
| 475 | if not texts_to_analyze: | 505 | if not texts_to_analyze: |
| 476 | return { | 506 | return { |
| 477 | "sentiment_analysis": { | 507 | "sentiment_analysis": { |
| 478 | "total_analyzed": 0, | 508 | "total_analyzed": 0, |
| 479 | "sentiment_distribution": {}, | 509 | "sentiment_distribution": {}, |
| 480 | "high_confidence_results": [], | 510 | "high_confidence_results": [], |
| 481 | - "summary": "查询结果中没有找到可分析的文本内容" | 511 | + "summary": "查询结果中没有找到可分析的文本内容", |
| 482 | } | 512 | } |
| 483 | } | 513 | } |
| 484 | - | 514 | + |
| 485 | if self.is_disabled: | 515 | if self.is_disabled: |
| 486 | return self._build_passthrough_analysis( | 516 | return self._build_passthrough_analysis( |
| 487 | original_data=original_data, | 517 | original_data=original_data, |
| 488 | reason=self.disable_reason or "情感分析模型不可用", | 518 | reason=self.disable_reason or "情感分析模型不可用", |
| 489 | - texts=texts_to_analyze | 519 | + texts=texts_to_analyze, |
| 490 | ) | 520 | ) |
| 491 | - | 521 | + |
| 492 | # 执行批量情感分析 | 522 | # 执行批量情感分析 |
| 493 | print(f"正在对{len(texts_to_analyze)}条内容进行情感分析...") | 523 | print(f"正在对{len(texts_to_analyze)}条内容进行情感分析...") |
| 494 | batch_result = self.analyze_batch(texts_to_analyze, show_progress=True) | 524 | batch_result = self.analyze_batch(texts_to_analyze, show_progress=True) |
| 495 | - | 525 | + |
| 496 | if not batch_result.analysis_performed: | 526 | if not batch_result.analysis_performed: |
| 497 | reason = self.disable_reason or "情感分析功能不可用" | 527 | reason = self.disable_reason or "情感分析功能不可用" |
| 498 | if batch_result.results: | 528 | if batch_result.results: |
| 499 | - candidate_error = next((r.error_message for r in batch_result.results if r.error_message), None) | 529 | + candidate_error = next( |
| 530 | + (r.error_message for r in batch_result.results if r.error_message), | ||
| 531 | + None, | ||
| 532 | + ) | ||
| 500 | if candidate_error: | 533 | if candidate_error: |
| 501 | reason = candidate_error | 534 | reason = candidate_error |
| 502 | return self._build_passthrough_analysis( | 535 | return self._build_passthrough_analysis( |
| 503 | original_data=original_data, | 536 | original_data=original_data, |
| 504 | reason=reason, | 537 | reason=reason, |
| 505 | texts=texts_to_analyze, | 538 | texts=texts_to_analyze, |
| 506 | - results=batch_result.results | 539 | + results=batch_result.results, |
| 507 | ) | 540 | ) |
| 508 | - | 541 | + |
| 509 | # 统计情感分布 | 542 | # 统计情感分布 |
| 510 | sentiment_distribution = {} | 543 | sentiment_distribution = {} |
| 511 | high_confidence_results = [] | 544 | high_confidence_results = [] |
| 512 | - | 545 | + |
| 513 | for result, original_item in zip(batch_result.results, original_data): | 546 | for result, original_item in zip(batch_result.results, original_data): |
| 514 | if result.success: | 547 | if result.success: |
| 515 | # 统计情感分布 | 548 | # 统计情感分布 |
| @@ -517,24 +550,28 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -517,24 +550,28 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 517 | if sentiment not in sentiment_distribution: | 550 | if sentiment not in sentiment_distribution: |
| 518 | sentiment_distribution[sentiment] = 0 | 551 | sentiment_distribution[sentiment] = 0 |
| 519 | sentiment_distribution[sentiment] += 1 | 552 | sentiment_distribution[sentiment] += 1 |
| 520 | - | 553 | + |
| 521 | # 收集高置信度结果 | 554 | # 收集高置信度结果 |
| 522 | if result.confidence >= min_confidence: | 555 | if result.confidence >= min_confidence: |
| 523 | - high_confidence_results.append({ | ||
| 524 | - "original_data": original_item, | ||
| 525 | - "sentiment": result.sentiment_label, | ||
| 526 | - "confidence": result.confidence, | ||
| 527 | - "text_preview": result.text[:100] + "..." if len(result.text) > 100 else result.text | ||
| 528 | - }) | ||
| 529 | - | 556 | + high_confidence_results.append( |
| 557 | + { | ||
| 558 | + "original_data": original_item, | ||
| 559 | + "sentiment": result.sentiment_label, | ||
| 560 | + "confidence": result.confidence, | ||
| 561 | + "text_preview": result.text[:100] + "..." | ||
| 562 | + if len(result.text) > 100 | ||
| 563 | + else result.text, | ||
| 564 | + } | ||
| 565 | + ) | ||
| 566 | + | ||
| 530 | # 生成情感分析摘要 | 567 | # 生成情感分析摘要 |
| 531 | total_analyzed = batch_result.success_count | 568 | total_analyzed = batch_result.success_count |
| 532 | if total_analyzed > 0: | 569 | if total_analyzed > 0: |
| 533 | dominant_sentiment = max(sentiment_distribution.items(), key=lambda x: x[1]) | 570 | dominant_sentiment = max(sentiment_distribution.items(), key=lambda x: x[1]) |
| 534 | - sentiment_summary = f"共分析{total_analyzed}条内容,主要情感倾向为'{dominant_sentiment[0]}'({dominant_sentiment[1]}条,占{dominant_sentiment[1]/total_analyzed*100:.1f}%)" | 571 | + sentiment_summary = f"共分析{total_analyzed}条内容,主要情感倾向为'{dominant_sentiment[0]}'({dominant_sentiment[1]}条,占{dominant_sentiment[1] / total_analyzed * 100:.1f}%)" |
| 535 | else: | 572 | else: |
| 536 | sentiment_summary = "情感分析失败" | 573 | sentiment_summary = "情感分析失败" |
| 537 | - | 574 | + |
| 538 | return { | 575 | return { |
| 539 | "sentiment_analysis": { | 576 | "sentiment_analysis": { |
| 540 | "total_analyzed": total_analyzed, | 577 | "total_analyzed": total_analyzed, |
| @@ -542,28 +579,46 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -542,28 +579,46 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 542 | "average_confidence": round(batch_result.average_confidence, 4), | 579 | "average_confidence": round(batch_result.average_confidence, 4), |
| 543 | "sentiment_distribution": sentiment_distribution, | 580 | "sentiment_distribution": sentiment_distribution, |
| 544 | "high_confidence_results": high_confidence_results, # 返回所有高置信度结果,不做限制 | 581 | "high_confidence_results": high_confidence_results, # 返回所有高置信度结果,不做限制 |
| 545 | - "summary": sentiment_summary | 582 | + "summary": sentiment_summary, |
| 546 | } | 583 | } |
| 547 | } | 584 | } |
| 548 | - | 585 | + |
| 549 | def get_model_info(self) -> Dict[str, Any]: | 586 | def get_model_info(self) -> Dict[str, Any]: |
| 550 | """ | 587 | """ |
| 551 | 获取模型信息 | 588 | 获取模型信息 |
| 552 | - | 589 | + |
| 553 | Returns: | 590 | Returns: |
| 554 | 模型信息字典 | 591 | 模型信息字典 |
| 555 | """ | 592 | """ |
| 556 | return { | 593 | return { |
| 557 | "model_name": "tabularisai/multilingual-sentiment-analysis", | 594 | "model_name": "tabularisai/multilingual-sentiment-analysis", |
| 558 | "supported_languages": [ | 595 | "supported_languages": [ |
| 559 | - "中文", "英文", "西班牙文", "阿拉伯文", "日文", "韩文", | ||
| 560 | - "德文", "法文", "意大利文", "葡萄牙文", "俄文", "荷兰文", | ||
| 561 | - "波兰文", "土耳其文", "丹麦文", "希腊文", "芬兰文", | ||
| 562 | - "瑞典文", "挪威文", "匈牙利文", "捷克文", "保加利亚文" | 596 | + "中文", |
| 597 | + "英文", | ||
| 598 | + "西班牙文", | ||
| 599 | + "阿拉伯文", | ||
| 600 | + "日文", | ||
| 601 | + "韩文", | ||
| 602 | + "德文", | ||
| 603 | + "法文", | ||
| 604 | + "意大利文", | ||
| 605 | + "葡萄牙文", | ||
| 606 | + "俄文", | ||
| 607 | + "荷兰文", | ||
| 608 | + "波兰文", | ||
| 609 | + "土耳其文", | ||
| 610 | + "丹麦文", | ||
| 611 | + "希腊文", | ||
| 612 | + "芬兰文", | ||
| 613 | + "瑞典文", | ||
| 614 | + "挪威文", | ||
| 615 | + "匈牙利文", | ||
| 616 | + "捷克文", | ||
| 617 | + "保加利亚文", | ||
| 563 | ], | 618 | ], |
| 564 | "sentiment_levels": list(self.sentiment_map.values()), | 619 | "sentiment_levels": list(self.sentiment_map.values()), |
| 565 | "is_initialized": self.is_initialized, | 620 | "is_initialized": self.is_initialized, |
| 566 | - "device": str(self.device) if self.device else "未设置" | 621 | + "device": str(self.device) if self.device else "未设置", |
| 567 | } | 622 | } |
| 568 | 623 | ||
| 569 | 624 | ||
| @@ -576,20 +631,23 @@ def enable_sentiment_analysis() -> bool: | @@ -576,20 +631,23 @@ def enable_sentiment_analysis() -> bool: | ||
| 576 | return multilingual_sentiment_analyzer.enable() | 631 | return multilingual_sentiment_analyzer.enable() |
| 577 | 632 | ||
| 578 | 633 | ||
| 579 | -def disable_sentiment_analysis(reason: Optional[str] = None, drop_state: bool = False) -> None: | 634 | +def disable_sentiment_analysis( |
| 635 | + reason: Optional[str] = None, drop_state: bool = False | ||
| 636 | +) -> None: | ||
| 580 | """Public helper to disable sentiment analysis at runtime.""" | 637 | """Public helper to disable sentiment analysis at runtime.""" |
| 581 | multilingual_sentiment_analyzer.disable(reason=reason, drop_state=drop_state) | 638 | multilingual_sentiment_analyzer.disable(reason=reason, drop_state=drop_state) |
| 582 | 639 | ||
| 583 | 640 | ||
| 584 | -def analyze_sentiment(text_or_texts: Union[str, List[str]], | ||
| 585 | - initialize_if_needed: bool = True) -> Union[SentimentResult, BatchSentimentResult]: | 641 | +def analyze_sentiment( |
| 642 | + text_or_texts: Union[str, List[str]], initialize_if_needed: bool = True | ||
| 643 | +) -> Union[SentimentResult, BatchSentimentResult]: | ||
| 586 | """ | 644 | """ |
| 587 | 便捷的情感分析函数 | 645 | 便捷的情感分析函数 |
| 588 | - | 646 | + |
| 589 | Args: | 647 | Args: |
| 590 | text_or_texts: 单个文本或文本列表 | 648 | text_or_texts: 单个文本或文本列表 |
| 591 | initialize_if_needed: 如果模型未初始化,是否自动初始化 | 649 | initialize_if_needed: 如果模型未初始化,是否自动初始化 |
| 592 | - | 650 | + |
| 593 | Returns: | 651 | Returns: |
| 594 | SentimentResult或BatchSentimentResult | 652 | SentimentResult或BatchSentimentResult |
| 595 | """ | 653 | """ |
| @@ -599,7 +657,7 @@ def analyze_sentiment(text_or_texts: Union[str, List[str]], | @@ -599,7 +657,7 @@ def analyze_sentiment(text_or_texts: Union[str, List[str]], | ||
| 599 | and not multilingual_sentiment_analyzer.is_disabled | 657 | and not multilingual_sentiment_analyzer.is_disabled |
| 600 | ): | 658 | ): |
| 601 | multilingual_sentiment_analyzer.initialize() | 659 | multilingual_sentiment_analyzer.initialize() |
| 602 | - | 660 | + |
| 603 | if isinstance(text_or_texts, str): | 661 | if isinstance(text_or_texts, str): |
| 604 | return multilingual_sentiment_analyzer.analyze_single_text(text_or_texts) | 662 | return multilingual_sentiment_analyzer.analyze_single_text(text_or_texts) |
| 605 | else: | 663 | else: |
| @@ -610,24 +668,30 @@ def analyze_sentiment(text_or_texts: Union[str, List[str]], | @@ -610,24 +668,30 @@ def analyze_sentiment(text_or_texts: Union[str, List[str]], | ||
| 610 | if __name__ == "__main__": | 668 | if __name__ == "__main__": |
| 611 | # 测试代码 | 669 | # 测试代码 |
| 612 | analyzer = WeiboMultilingualSentimentAnalyzer() | 670 | analyzer = WeiboMultilingualSentimentAnalyzer() |
| 613 | - | 671 | + |
| 614 | if analyzer.initialize(): | 672 | if analyzer.initialize(): |
| 615 | # 测试单个文本 | 673 | # 测试单个文本 |
| 616 | result = analyzer.analyze_single_text("今天天气真好,心情特别棒!") | 674 | result = analyzer.analyze_single_text("今天天气真好,心情特别棒!") |
| 617 | - print(f"单个文本分析: {result.sentiment_label} (置信度: {result.confidence:.4f})") | ||
| 618 | - | 675 | + print( |
| 676 | + f"单个文本分析: {result.sentiment_label} (置信度: {result.confidence:.4f})" | ||
| 677 | + ) | ||
| 678 | + | ||
| 619 | # 测试批量文本 | 679 | # 测试批量文本 |
| 620 | test_texts = [ | 680 | test_texts = [ |
| 621 | "这家餐厅的菜味道非常棒!", | 681 | "这家餐厅的菜味道非常棒!", |
| 622 | "服务态度太差了,很失望", | 682 | "服务态度太差了,很失望", |
| 623 | "I absolutely love this product!", | 683 | "I absolutely love this product!", |
| 624 | - "The customer service was disappointing." | 684 | + "The customer service was disappointing.", |
| 625 | ] | 685 | ] |
| 626 | - | 686 | + |
| 627 | batch_result = analyzer.analyze_batch(test_texts) | 687 | batch_result = analyzer.analyze_batch(test_texts) |
| 628 | - print(f"\n批量分析: 成功 {batch_result.success_count}/{batch_result.total_processed}") | ||
| 629 | - | 688 | + print( |
| 689 | + f"\n批量分析: 成功 {batch_result.success_count}/{batch_result.total_processed}" | ||
| 690 | + ) | ||
| 691 | + | ||
| 630 | for result in batch_result.results: | 692 | for result in batch_result.results: |
| 631 | - print(f"'{result.text[:30]}...' -> {result.sentiment_label} ({result.confidence:.4f})") | 693 | + print( |
| 694 | + f"'{result.text[:30]}...' -> {result.sentiment_label} ({result.confidence:.4f})" | ||
| 695 | + ) | ||
| 632 | else: | 696 | else: |
| 633 | print("模型初始化失败,无法进行测试") | 697 | print("模型初始化失败,无法进行测试") |
-
Please register or login to post a comment