ghmark675
Committed by 666ghj

fix(sentiment_analyzer): fix type warning from pyright

... ... @@ -142,6 +142,7 @@ class WeiboMultilingualSentimentAnalyzer:
"""Select the best available torch device."""
if not TORCH_AVAILABLE:
return None
assert torch is not None
if torch.cuda.is_available():
return torch.device("cuda")
mps_backend = getattr(torch.backends, "mps", None)
... ... @@ -177,6 +178,8 @@ class WeiboMultilingualSentimentAnalyzer:
try:
print("正在加载多语言情感分析模型...")
assert AutoTokenizer is not None
assert AutoModelForSequenceClassification is not None
# 使用多语言情感分析模型
model_name = "tabularisai/multilingual-sentiment-analysis"
... ... @@ -300,7 +303,7 @@ class WeiboMultilingualSentimentAnalyzer:
error_message="输入文本为空或无效内容",
analysis_performed=False,
)
assert self.tokenizer is not None
# 分词编码
inputs = self.tokenizer(
processed_text,
... ... @@ -314,11 +317,13 @@ class WeiboMultilingualSentimentAnalyzer:
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# 预测
assert torch is not None
assert self.model is not None
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
prediction = torch.argmax(probabilities, dim=1).item()
prediction = int(torch.argmax(probabilities, dim=1).item())
# 构建结果
confidence = probabilities[0][prediction].item()
... ...