Committed by
666ghj
fix(sentiment_analyzer): fix type warning from pyright
Showing
1 changed file
with
7 additions
and
2 deletions
| @@ -142,6 +142,7 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -142,6 +142,7 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 142 | """Select the best available torch device.""" | 142 | """Select the best available torch device.""" |
| 143 | if not TORCH_AVAILABLE: | 143 | if not TORCH_AVAILABLE: |
| 144 | return None | 144 | return None |
| 145 | + assert torch is not None | ||
| 145 | if torch.cuda.is_available(): | 146 | if torch.cuda.is_available(): |
| 146 | return torch.device("cuda") | 147 | return torch.device("cuda") |
| 147 | mps_backend = getattr(torch.backends, "mps", None) | 148 | mps_backend = getattr(torch.backends, "mps", None) |
| @@ -177,6 +178,8 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -177,6 +178,8 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 177 | 178 | ||
| 178 | try: | 179 | try: |
| 179 | print("正在加载多语言情感分析模型...") | 180 | print("正在加载多语言情感分析模型...") |
| 181 | + assert AutoTokenizer is not None | ||
| 182 | + assert AutoModelForSequenceClassification is not None | ||
| 180 | 183 | ||
| 181 | # 使用多语言情感分析模型 | 184 | # 使用多语言情感分析模型 |
| 182 | model_name = "tabularisai/multilingual-sentiment-analysis" | 185 | model_name = "tabularisai/multilingual-sentiment-analysis" |
| @@ -300,7 +303,7 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -300,7 +303,7 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 300 | error_message="输入文本为空或无效内容", | 303 | error_message="输入文本为空或无效内容", |
| 301 | analysis_performed=False, | 304 | analysis_performed=False, |
| 302 | ) | 305 | ) |
| 303 | - | 306 | + assert self.tokenizer is not None |
| 304 | # 分词编码 | 307 | # 分词编码 |
| 305 | inputs = self.tokenizer( | 308 | inputs = self.tokenizer( |
| 306 | processed_text, | 309 | processed_text, |
| @@ -314,11 +317,13 @@ class WeiboMultilingualSentimentAnalyzer: | @@ -314,11 +317,13 @@ class WeiboMultilingualSentimentAnalyzer: | ||
| 314 | inputs = {k: v.to(self.device) for k, v in inputs.items()} | 317 | inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| 315 | 318 | ||
| 316 | # 预测 | 319 | # 预测 |
| 320 | + assert torch is not None | ||
| 321 | + assert self.model is not None | ||
| 317 | with torch.no_grad(): | 322 | with torch.no_grad(): |
| 318 | outputs = self.model(**inputs) | 323 | outputs = self.model(**inputs) |
| 319 | logits = outputs.logits | 324 | logits = outputs.logits |
| 320 | probabilities = torch.softmax(logits, dim=1) | 325 | probabilities = torch.softmax(logits, dim=1) |
| 321 | - prediction = torch.argmax(probabilities, dim=1).item() | 326 | + prediction = int(torch.argmax(probabilities, dim=1).item()) |
| 322 | 327 | ||
| 323 | # 构建结果 | 328 | # 构建结果 |
| 324 | confidence = probabilities[0][prediction].item() | 329 | confidence = probabilities[0][prediction].item() |
-
Please register or login to post a comment