ghmark675
Committed by 666ghj

fix(sentiment_analyzer): fix type warning from pyright

@@ -142,6 +142,7 @@ class WeiboMultilingualSentimentAnalyzer: @@ -142,6 +142,7 @@ class WeiboMultilingualSentimentAnalyzer:
142 """Select the best available torch device.""" 142 """Select the best available torch device."""
143 if not TORCH_AVAILABLE: 143 if not TORCH_AVAILABLE:
144 return None 144 return None
  145 + assert torch is not None
145 if torch.cuda.is_available(): 146 if torch.cuda.is_available():
146 return torch.device("cuda") 147 return torch.device("cuda")
147 mps_backend = getattr(torch.backends, "mps", None) 148 mps_backend = getattr(torch.backends, "mps", None)
@@ -177,6 +178,8 @@ class WeiboMultilingualSentimentAnalyzer: @@ -177,6 +178,8 @@ class WeiboMultilingualSentimentAnalyzer:
177 178
178 try: 179 try:
179 print("正在加载多语言情感分析模型...") 180 print("正在加载多语言情感分析模型...")
  181 + assert AutoTokenizer is not None
  182 + assert AutoModelForSequenceClassification is not None
180 183
181 # 使用多语言情感分析模型 184 # 使用多语言情感分析模型
182 model_name = "tabularisai/multilingual-sentiment-analysis" 185 model_name = "tabularisai/multilingual-sentiment-analysis"
@@ -300,7 +303,7 @@ class WeiboMultilingualSentimentAnalyzer: @@ -300,7 +303,7 @@ class WeiboMultilingualSentimentAnalyzer:
300 error_message="输入文本为空或无效内容", 303 error_message="输入文本为空或无效内容",
301 analysis_performed=False, 304 analysis_performed=False,
302 ) 305 )
303 - 306 + assert self.tokenizer is not None
304 # 分词编码 307 # 分词编码
305 inputs = self.tokenizer( 308 inputs = self.tokenizer(
306 processed_text, 309 processed_text,
@@ -314,11 +317,13 @@ class WeiboMultilingualSentimentAnalyzer: @@ -314,11 +317,13 @@ class WeiboMultilingualSentimentAnalyzer:
314 inputs = {k: v.to(self.device) for k, v in inputs.items()} 317 inputs = {k: v.to(self.device) for k, v in inputs.items()}
315 318
316 # 预测 319 # 预测
  320 + assert torch is not None
  321 + assert self.model is not None
317 with torch.no_grad(): 322 with torch.no_grad():
318 outputs = self.model(**inputs) 323 outputs = self.model(**inputs)
319 logits = outputs.logits 324 logits = outputs.logits
320 probabilities = torch.softmax(logits, dim=1) 325 probabilities = torch.softmax(logits, dim=1)
321 - prediction = torch.argmax(probabilities, dim=1).item() 326 + prediction = int(torch.argmax(probabilities, dim=1).item())
322 327
323 # 构建结果 328 # 构建结果
324 confidence = probabilities[0][prediction].item() 329 confidence = probabilities[0][prediction].item()