戒酒的李白

Updated the topic prediction script to support Top-K predictions and added an in…

…teractive selection feature for optional Chinese foundation models.
... ... @@ -185,6 +185,7 @@ WeiboSentiment_Finetuned/BertChinese-Lora/model/
WeiboMultilingualSentiment/model/
WeiboSentiment_MachineLearning/model/chinese_wwm_pytorch/
WeiboSentiment_SmallQwen/models/
BertTopicDetection_Finetuned/model/
# LoRA 和 Adapter 权重
*/adapter_model.safetensors
... ...
... ... @@ -63,6 +63,23 @@ python train.py \
- 支持早停(默认耐心 5 次评估),并在评估/保存策略一致时自动回滚到最佳模型;
- 分词器、权重与 `label_map.json` 保存到 `model/bert-chinese-classifier/`
### 可选中文基座模型(训练前交互选择)
默认基座:`google-bert/bert-base-chinese`。启动训练时,若终端可交互,程序会提示从下列选项中选择(或输入任意 Hugging Face 模型 ID):
1) `google-bert/bert-base-chinese`
2) `hfl/chinese-roberta-wwm-ext-large`
3) `hfl/chinese-macbert-large`
4) `IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese`
5) `IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese`
6) `Langboat/mengzi-bert-base`
7) `BAAI/bge-base-zh`(更适合检索式/对比学习范式)
8) `nghuyong/ernie-3.0-base-zh`
说明:
- 非交互环境(如调度系统)或设置 `NON_INTERACTIVE=1` 时,会直接使用命令行参数 `--pretrained_name` 指定的模型(默认为 `google-bert/bert-base-chinese`)。
- 选择后,基础模型将下载/缓存至 `model/` 目录,统一管理。
### 预测
单条:
... ... @@ -96,3 +113,14 @@ python predict.py --interactive --model_root ./model --finetuned_subdir bert-chi
- 单卡稳定运行:默认仅使用一张 GPU,可通过 `--gpu` 指定;脚本会清理分布式环境变量。
### 作者说明(关于超大规模多分类)
- 当话题类别达到上万级时,直接在编码器后接单一线性分类头(大 softmax)往往受限:长尾类别难学、语义稀疏、新增话题无法增量适配、上线后需频繁重训。
- 改进思路(推荐优先级):
- 检索式/双塔范式(文本 vs. 话题名称/描述 对比学习)+ 近邻检索 + 小头重排,天然支持增量扩类与快速更新;
- 分层分类(先粗分再细分),显著降低单头难度与计算;
- 文本-标签联合建模(使用标签描述),提升近义话题的可迁移性;
- 训练细节:class-balanced/focal/label smoothing、sampled softmax、对比预训练等。
- 重要声明:本目录使用的“静态分类头微调”仅作为备选与学习参考。对于英文/多语微短文场景,话题变化极快,传统静态分类器难以及时覆盖,我们的工作重点在 `TopicGPT` 等生成式/自监督话题发现与动态体系构建方向;本实现旨在提供一个可运行的基线与工程示例。
... ...
... ... @@ -3,7 +3,7 @@ import sys
import json
import re
import argparse
from typing import Dict, Tuple
from typing import Dict, Tuple, List
# ========== 单卡锁定(在导入 torch/transformers 前执行) ==========
def _extract_gpu_arg(argv, default: str = "0") -> str:
... ... @@ -109,14 +109,8 @@ def load_finetuned(model_root: str, subdir: str) -> Tuple[str, Dict[int, str]]:
return finetuned_path, id2label
def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str, float]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
model.to(device)
model.eval()
processed = preprocess_text(text)
def predict_topk(model: AutoModelForSequenceClassification, tokenizer: AutoTokenizer, device: torch.device, text: str, max_length: int = 128, top_k: int = 3) -> List[Tuple[str, float]]:
processed = preprocess_text(text or "")
encoded = tokenizer(
processed,
max_length=max_length,
... ... @@ -130,12 +124,17 @@ def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str,
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
pred = int(torch.argmax(probs, dim=-1).item())
conf = float(probs[0, pred].item())
id2label = getattr(model.config, "id2label", None)
label_name = id2label.get(pred, str(pred)) if isinstance(id2label, dict) else str(pred)
return label_name, conf
probs = torch.softmax(logits, dim=-1)[0]
k = min(top_k, probs.shape[-1])
confs, idxs = torch.topk(probs, k)
id2label = getattr(model.config, "id2label", {}) if isinstance(getattr(model.config, "id2label", None), dict) else {}
results: List[Tuple[str, float]] = []
for i in range(k):
idx = int(idxs[i].item())
conf = float(confs[i].item())
label_name = id2label.get(idx, str(idx))
results.append((label_name, conf))
return results
def main() -> None:
... ... @@ -149,13 +148,21 @@ def main() -> None:
ensure_base_model_local(args.pretrained_name, model_root)
finetuned_dir, _ = load_finetuned(model_root, args.finetuned_subdir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(finetuned_dir)
model = AutoModelForSequenceClassification.from_pretrained(finetuned_dir)
model.to(device)
model.eval()
if args.text is not None:
label, conf = predict_once(finetuned_dir, args.text, args.max_length)
print(f"预测结果: {label} (置信度: {conf:.4f})")
topk = predict_topk(model, tokenizer, device, args.text, args.max_length, top_k=3)
print("Top-3 预测:")
for rank, (label, conf) in enumerate(topk, 1):
print(f"{rank}. {label} (p={conf:.4f})")
return
if args.interactive:
# 默认进入交互模式(未显式指定 --text 且未显式关闭交互)
if args.interactive or (args.text is None):
print("进入交互模式。输入 'q' 退出。")
while True:
try:
... ... @@ -166,11 +173,13 @@ def main() -> None:
break
if not text:
continue
label, conf = predict_once(finetuned_dir, text, args.max_length)
print(f"预测结果: {label} (置信度: {conf:.4f})")
topk = predict_topk(model, tokenizer, device, text, args.max_length, top_k=3)
print("Top-3 预测:")
for rank, (label, conf) in enumerate(topk, 1):
print(f"{rank}. {label} (p={conf:.4f})")
return
print("未提供 --text 或 --interactive,什么也没有发生。")
# 理论上不会到达这里
print("未提供输入。")
if __name__ == "__main__":
... ...
... ... @@ -52,6 +52,52 @@ try:
except Exception: # pragma: no cover
EarlyStoppingCallback = None # type: ignore
# 预置可选中文基座模型(可扩展)
BACKBONE_CANDIDATES: List[Tuple[str, str]] = [
("1) google-bert/bert-base-chinese", "google-bert/bert-base-chinese"),
("2) hfl/chinese-roberta-wwm-ext-large", "hfl/chinese-roberta-wwm-ext-large"),
("3) hfl/chinese-macbert-large", "hfl/chinese-macbert-large"),
("4) IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese", "IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese"),
("5) IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese", "IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese"),
("6) Langboat/mengzi-bert-base", "Langboat/mengzi-bert-base"),
("7) BAAI/bge-base-zh", "BAAI/bge-base-zh"),
("8) nghuyong/ernie-3.0-base-zh", "nghuyong/ernie-3.0-base-zh"),
]
def prompt_backbone_interactive(current_id: str) -> str:
"""交互式选择基座模型。
- 当处于非交互环境(stdin 非 TTY)或设置了环境变量 NON_INTERACTIVE=1 时,直接返回 current_id。
- 用户可输入序号选择预置项,或直接输入任意 Hugging Face 模型 ID。
- 空回车使用当前默认。
"""
if os.environ.get("NON_INTERACTIVE", "0") == "1":
return current_id
try:
if not sys.stdin.isatty():
return current_id
except Exception:
return current_id
print("\n可选中文基座模型(直接回车使用默认):")
for label, hf_id in BACKBONE_CANDIDATES:
print(f" {label}")
print(f"当前默认: {current_id}")
choice = input("请输入序号或直接粘贴模型ID(回车沿用默认): ").strip()
if not choice:
return current_id
# 数字选项
if choice.isdigit():
idx = int(choice)
for label, hf_id in BACKBONE_CANDIDATES:
if label.startswith(f"{idx})"):
return hf_id
print("未找到该序号,沿用默认。")
return current_id
# 自定义 HF 模型 ID
return choice
def preprocess_text(text: str) -> str:
if text is None:
... ... @@ -191,7 +237,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--text_col", type=str, default="auto", help="文本列名,默认自动识别")
parser.add_argument("--label_col", type=str, default="auto", help="标签列名,默认自动识别")
parser.add_argument("--model_root", type=str, default="./model", help="本地模型根目录")
parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese")
parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese", help="Hugging Face 模型ID;留空则进入交互选择")
parser.add_argument("--save_subdir", type=str, default="bert-chinese-classifier")
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--batch_size", type=int, default=64)
... ... @@ -216,8 +262,10 @@ def main() -> None:
model_root = args.model_root if os.path.isabs(args.model_root) else os.path.join(script_dir, args.model_root)
os.makedirs(model_root, exist_ok=True)
# 交互式选择基座模型(若允许交互且未通过环境禁用)
selected_model_id = prompt_backbone_interactive(args.pretrained_name)
# 确保基础模型就绪
base_dir, tokenizer = ensure_base_model_local(args.pretrained_name, model_root)
base_dir, tokenizer = ensure_base_model_local(selected_model_id, model_root)
print(f"[Info] 使用基础模型目录: {base_dir}")
# 读取数据
... ...