predict.py
6.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
import sys
import json
import re
import argparse
from typing import Dict, Tuple
# ========== 单卡锁定(在导入 torch/transformers 前执行) ==========
def _extract_gpu_arg(argv, default: str = "0") -> str:
for i, arg in enumerate(argv):
if arg.startswith("--gpu="):
return arg.split("=", 1)[1]
if arg == "--gpu" and i + 1 < len(argv):
return argv[i + 1]
return default
env_vis = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
try:
gpu_to_use = _extract_gpu_arg(sys.argv, default="0")
except Exception:
gpu_to_use = "0"
if (not env_vis) or ("," in env_vis):
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_to_use
os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
for _k in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]:
os.environ.pop(_k, None)
import torch
from transformers import (
AutoTokenizer,
AutoModel,
AutoModelForSequenceClassification,
)
def preprocess_text(text: str) -> str:
if text is None:
return ""
text = str(text)
text = re.sub(r"\{%.+?%\}", " ", text)
text = re.sub(r"@.+?( |$)", " ", text)
text = re.sub(r"【.+?】", " ", text)
text = re.sub(r"\u200b", " ", text)
text = re.sub(
r"[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\U00002600-\U000027BF\U0001f900-\U0001f9ff\U0001f018-\U0001f270\U0000231a-\U0000231b\U0000238d-\U0000238d\U000024c2-\U0001f251]+",
"",
text,
)
text = re.sub(r"\s+", " ", text)
return text.strip()
def ensure_base_model_local(model_name_or_path: str, local_model_root: str) -> Tuple[str, AutoTokenizer]:
os.makedirs(local_model_root, exist_ok=True)
base_dir = os.path.join(local_model_root, "bert-base-chinese")
def is_ready(path: str) -> bool:
return os.path.isdir(path) and os.path.isfile(os.path.join(path, "config.json"))
if is_ready(base_dir):
tokenizer = AutoTokenizer.from_pretrained(base_dir)
return base_dir, tokenizer
# 本机缓存
try:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, local_files_only=True)
base = AutoModel.from_pretrained(model_name_or_path, local_files_only=True)
os.makedirs(base_dir, exist_ok=True)
tokenizer.save_pretrained(base_dir)
base.save_pretrained(base_dir)
return base_dir, tokenizer
except Exception:
pass
# 远程下载
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
base = AutoModel.from_pretrained(model_name_or_path)
os.makedirs(base_dir, exist_ok=True)
tokenizer.save_pretrained(base_dir)
base.save_pretrained(base_dir)
return base_dir, tokenizer
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="使用本地/缓存/远程加载的中文 BERT 分类模型进行预测")
parser.add_argument("--model_root", type=str, default="./model", help="本地模型根目录")
parser.add_argument("--finetuned_subdir", type=str, default="bert-chinese-classifier", help="微调结果子目录")
parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese", help="预训练模型名称或路径")
parser.add_argument("--text", type=str, default=None, help="直接输入一条要预测的文本")
parser.add_argument("--interactive", action="store_true", help="进入交互式预测模式")
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--gpu", type=str, default=os.environ.get("CUDA_VISIBLE_DEVICES", "0"), help="指定单卡 GPU,如 0 或 1")
return parser.parse_args()
def load_finetuned(model_root: str, subdir: str) -> Tuple[str, Dict[int, str]]:
finetuned_path = os.path.join(model_root, subdir)
if not os.path.isdir(finetuned_path):
raise FileNotFoundError(
f"未找到微调模型目录: {finetuned_path},请先运行训练脚本。"
)
label_map_path = os.path.join(finetuned_path, "label_map.json")
id2label = None
if os.path.isfile(label_map_path):
with open(label_map_path, "r", encoding="utf-8") as f:
data = json.load(f)
id2label = {int(k): str(v) for k, v in data.get("id2label", {}).items()}
return finetuned_path, id2label
def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str, float]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
model.to(device)
model.eval()
processed = preprocess_text(text)
encoded = tokenizer(
processed,
max_length=max_length,
truncation=True,
padding="max_length",
return_tensors="pt",
)
input_ids = encoded["input_ids"].to(device)
attention_mask = encoded["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
pred = int(torch.argmax(probs, dim=-1).item())
conf = float(probs[0, pred].item())
id2label = getattr(model.config, "id2label", None)
label_name = id2label.get(pred, str(pred)) if isinstance(id2label, dict) else str(pred)
return label_name, conf
def main() -> None:
args = parse_args()
script_dir = os.path.dirname(os.path.abspath(__file__))
model_root = args.model_root if os.path.isabs(args.model_root) else os.path.join(script_dir, args.model_root)
os.makedirs(model_root, exist_ok=True)
# 确保基础模型在本地
ensure_base_model_local(args.pretrained_name, model_root)
finetuned_dir, _ = load_finetuned(model_root, args.finetuned_subdir)
if args.text is not None:
label, conf = predict_once(finetuned_dir, args.text, args.max_length)
print(f"预测结果: {label} (置信度: {conf:.4f})")
return
if args.interactive:
print("进入交互模式。输入 'q' 退出。")
while True:
try:
text = input("请输入文本: ").strip()
except EOFError:
break
if text.lower() == "q":
break
if not text:
continue
label, conf = predict_once(finetuned_dir, text, args.max_length)
print(f"预测结果: {label} (置信度: {conf:.4f})")
return
print("未提供 --text 或 --interactive,什么也没有发生。")
if __name__ == "__main__":
main()