predict.py
5.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from __future__ import annotations
import argparse
from pathlib import Path
MODEL_NAME = "tabularisai/multilingual-sentiment-analysis"
SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_MODEL_DIR = SCRIPT_DIR / "model"
SENTIMENT_LABELS = {
0: "非常负面",
1: "负面",
2: "中性",
3: "正面",
4: "非常正面",
}
DEMO_TEXTS = (
("今天天气真好,心情特别棒。", "zh"),
("服务态度太差了,很失望。", "zh"),
("I absolutely love this product.", "en"),
("The customer service was disappointing.", "en"),
("Este lugar es increible y muy acogedor.", "es"),
("El servicio fue terrible y muy lento.", "es"),
("このレストランの料理は本当に美味しいです。", "ja"),
("このホテルのサービスにはがっかりしました。", "ja"),
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run the multilingual sentiment analysis research demo."
)
parser.add_argument(
"--text",
help="Single text to analyze. If omitted, the script enters interactive mode.",
)
parser.add_argument(
"--lang",
default="auto",
help="Compatibility-only language hint for CLI parity. Default: auto.",
)
parser.add_argument(
"--demo",
action="store_true",
help="Run a built-in multilingual demo set and exit.",
)
parser.add_argument(
"--model-dir",
type=Path,
default=DEFAULT_MODEL_DIR,
help="Directory used to cache the Hugging Face model locally.",
)
parser.add_argument(
"--device",
choices=("auto", "cpu", "cuda"),
default="auto",
help="Execution device. Default: auto.",
)
return parser.parse_args()
def import_runtime_dependencies():
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
return torch, AutoTokenizer, AutoModelForSequenceClassification
def choose_device(device_name: str):
torch, _, _ = import_runtime_dependencies()
if device_name == "cpu":
return torch.device("cpu")
if device_name == "cuda":
if not torch.cuda.is_available():
raise RuntimeError("CUDA requested but not available.")
return torch.device("cuda")
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def has_local_model(model_dir: Path) -> bool:
return (model_dir / "config.json").exists()
def load_model(model_dir: Path):
_, AutoTokenizer, AutoModelForSequenceClassification = import_runtime_dependencies()
model_dir = model_dir.resolve()
if has_local_model(model_dir):
source = str(model_dir)
downloaded = False
else:
source = MODEL_NAME
downloaded = True
tokenizer = AutoTokenizer.from_pretrained(source)
model = AutoModelForSequenceClassification.from_pretrained(source)
if downloaded:
model_dir.mkdir(parents=True, exist_ok=True)
tokenizer.save_pretrained(model_dir)
model.save_pretrained(model_dir)
return tokenizer, model, downloaded
def predict_text(tokenizer, model, device, text: str):
torch, _, _ = import_runtime_dependencies()
inputs = tokenizer(
text,
max_length=512,
padding=True,
truncation=True,
return_tensors="pt",
)
inputs = {key: value.to(device) for key, value in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.softmax(outputs.logits, dim=1)[0].cpu().tolist()
prediction = int(max(range(len(probabilities)), key=probabilities.__getitem__))
return SENTIMENT_LABELS[prediction], probabilities[prediction], probabilities
def print_prediction(tokenizer, model, device, text: str, language_hint: str = "auto") -> None:
label, confidence, probabilities = predict_text(tokenizer, model, device, text)
print(f"文本: {text}")
print(f"语言提示: {language_hint}")
print(f"预测结果: {label} (置信度: {confidence:.4f})")
print("详细概率分布:")
for index, score in enumerate(probabilities):
print(f" - {SENTIMENT_LABELS[index]}: {score:.4f}")
def run_demo(tokenizer, model, device) -> None:
print("运行内置多语言示例...")
for text, language_hint in DEMO_TEXTS:
print_prediction(tokenizer, model, device, text, language_hint)
print("")
def interactive_loop(tokenizer, model, device) -> None:
print("进入交互模式。输入 `demo` 运行示例,输入 `q` 退出。")
while True:
text = input("请输入文本: ").strip()
if text.lower() == "q":
return
if not text:
print("输入不能为空,请重新输入。")
continue
if text.lower() == "demo":
run_demo(tokenizer, model, device)
continue
print_prediction(tokenizer, model, device, text)
print("")
def main() -> None:
args = parse_args()
device = choose_device(args.device)
tokenizer, model, downloaded = load_model(args.model_dir)
model.to(device)
model.eval()
if downloaded:
print(f"首次运行已下载模型到: {args.model_dir.resolve()}")
else:
print(f"已从本地缓存加载模型: {args.model_dir.resolve()}")
print(f"当前执行设备: {device}")
if args.demo:
run_demo(tokenizer, model, device)
return
if args.text:
print_prediction(tokenizer, model, device, args.text, args.lang)
return
interactive_loop(tokenizer, model, device)
if __name__ == "__main__":
main()