戒酒的李白

Updated the topic prediction script to support Top-K predictions and added an in…

…teractive selection feature for optional Chinese foundation models.
@@ -185,6 +185,7 @@ WeiboSentiment_Finetuned/BertChinese-Lora/model/ @@ -185,6 +185,7 @@ WeiboSentiment_Finetuned/BertChinese-Lora/model/
185 WeiboMultilingualSentiment/model/ 185 WeiboMultilingualSentiment/model/
186 WeiboSentiment_MachineLearning/model/chinese_wwm_pytorch/ 186 WeiboSentiment_MachineLearning/model/chinese_wwm_pytorch/
187 WeiboSentiment_SmallQwen/models/ 187 WeiboSentiment_SmallQwen/models/
  188 +BertTopicDetection_Finetuned/model/
188 189
189 # LoRA 和 Adapter 权重 190 # LoRA 和 Adapter 权重
190 */adapter_model.safetensors 191 */adapter_model.safetensors
@@ -63,6 +63,23 @@ python train.py \ @@ -63,6 +63,23 @@ python train.py \
63 - 支持早停(默认耐心 5 次评估),并在评估/保存策略一致时自动回滚到最佳模型; 63 - 支持早停(默认耐心 5 次评估),并在评估/保存策略一致时自动回滚到最佳模型;
64 - 分词器、权重与 `label_map.json` 保存到 `model/bert-chinese-classifier/` 64 - 分词器、权重与 `label_map.json` 保存到 `model/bert-chinese-classifier/`
65 65
  66 +### 可选中文基座模型(训练前交互选择)
  67 +
  68 +默认基座:`google-bert/bert-base-chinese`。启动训练时,若终端可交互,程序会提示从下列选项中选择(或输入任意 Hugging Face 模型 ID):
  69 +
  70 +1) `google-bert/bert-base-chinese`
  71 +2) `hfl/chinese-roberta-wwm-ext-large`
  72 +3) `hfl/chinese-macbert-large`
  73 +4) `IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese`
  74 +5) `IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese`
  75 +6) `Langboat/mengzi-bert-base`
  76 +7) `BAAI/bge-base-zh`(更适合检索式/对比学习范式)
  77 +8) `nghuyong/ernie-3.0-base-zh`
  78 +
  79 +说明:
  80 +- 非交互环境(如调度系统)或设置 `NON_INTERACTIVE=1` 时,会直接使用命令行参数 `--pretrained_name` 指定的模型(默认为 `google-bert/bert-base-chinese`)。
  81 +- 选择后,基础模型将下载/缓存至 `model/` 目录,统一管理。
  82 +
66 ### 预测 83 ### 预测
67 84
68 单条: 85 单条:
@@ -96,3 +113,14 @@ python predict.py --interactive --model_root ./model --finetuned_subdir bert-chi @@ -96,3 +113,14 @@ python predict.py --interactive --model_root ./model --finetuned_subdir bert-chi
96 - 单卡稳定运行:默认仅使用一张 GPU,可通过 `--gpu` 指定;脚本会清理分布式环境变量。 113 - 单卡稳定运行:默认仅使用一张 GPU,可通过 `--gpu` 指定;脚本会清理分布式环境变量。
97 114
98 115
  116 +### 作者说明(关于超大规模多分类)
  117 +
  118 +- 当话题类别达到上万级时,直接在编码器后接单一线性分类头(大 softmax)往往受限:长尾类别难学、语义稀疏、新增话题无法增量适配、上线后需频繁重训。
  119 +- 改进思路(推荐优先级):
  120 + - 检索式/双塔范式(文本 vs. 话题名称/描述 对比学习)+ 近邻检索 + 小头重排,天然支持增量扩类与快速更新;
  121 + - 分层分类(先粗分再细分),显著降低单头难度与计算;
  122 + - 文本-标签联合建模(使用标签描述),提升近义话题的可迁移性;
  123 + - 训练细节:class-balanced/focal/label smoothing、sampled softmax、对比预训练等。
  124 +- 重要声明:本目录使用的“静态分类头微调”仅作为备选与学习参考。对于英文/多语微短文场景,话题变化极快,传统静态分类器难以及时覆盖,我们的工作重点在 `TopicGPT` 等生成式/自监督话题发现与动态体系构建方向;本实现旨在提供一个可运行的基线与工程示例。
  125 +
  126 +
@@ -3,7 +3,7 @@ import sys @@ -3,7 +3,7 @@ import sys
3 import json 3 import json
4 import re 4 import re
5 import argparse 5 import argparse
6 -from typing import Dict, Tuple 6 +from typing import Dict, Tuple, List
7 7
8 # ========== 单卡锁定(在导入 torch/transformers 前执行) ========== 8 # ========== 单卡锁定(在导入 torch/transformers 前执行) ==========
9 def _extract_gpu_arg(argv, default: str = "0") -> str: 9 def _extract_gpu_arg(argv, default: str = "0") -> str:
@@ -109,14 +109,8 @@ def load_finetuned(model_root: str, subdir: str) -> Tuple[str, Dict[int, str]]: @@ -109,14 +109,8 @@ def load_finetuned(model_root: str, subdir: str) -> Tuple[str, Dict[int, str]]:
109 return finetuned_path, id2label 109 return finetuned_path, id2label
110 110
111 111
112 -def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str, float]:  
113 - device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
114 - tokenizer = AutoTokenizer.from_pretrained(model_dir)  
115 - model = AutoModelForSequenceClassification.from_pretrained(model_dir)  
116 - model.to(device)  
117 - model.eval()  
118 -  
119 - processed = preprocess_text(text) 112 +def predict_topk(model: AutoModelForSequenceClassification, tokenizer: AutoTokenizer, device: torch.device, text: str, max_length: int = 128, top_k: int = 3) -> List[Tuple[str, float]]:
  113 + processed = preprocess_text(text or "")
120 encoded = tokenizer( 114 encoded = tokenizer(
121 processed, 115 processed,
122 max_length=max_length, 116 max_length=max_length,
@@ -130,12 +124,17 @@ def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str, @@ -130,12 +124,17 @@ def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str,
130 with torch.no_grad(): 124 with torch.no_grad():
131 outputs = model(input_ids=input_ids, attention_mask=attention_mask) 125 outputs = model(input_ids=input_ids, attention_mask=attention_mask)
132 logits = outputs.logits 126 logits = outputs.logits
133 - probs = torch.softmax(logits, dim=-1)  
134 - pred = int(torch.argmax(probs, dim=-1).item())  
135 - conf = float(probs[0, pred].item())  
136 - id2label = getattr(model.config, "id2label", None)  
137 - label_name = id2label.get(pred, str(pred)) if isinstance(id2label, dict) else str(pred)  
138 - return label_name, conf 127 + probs = torch.softmax(logits, dim=-1)[0]
  128 + k = min(top_k, probs.shape[-1])
  129 + confs, idxs = torch.topk(probs, k)
  130 + id2label = getattr(model.config, "id2label", {}) if isinstance(getattr(model.config, "id2label", None), dict) else {}
  131 + results: List[Tuple[str, float]] = []
  132 + for i in range(k):
  133 + idx = int(idxs[i].item())
  134 + conf = float(confs[i].item())
  135 + label_name = id2label.get(idx, str(idx))
  136 + results.append((label_name, conf))
  137 + return results
139 138
140 139
141 def main() -> None: 140 def main() -> None:
@@ -149,13 +148,21 @@ def main() -> None: @@ -149,13 +148,21 @@ def main() -> None:
149 ensure_base_model_local(args.pretrained_name, model_root) 148 ensure_base_model_local(args.pretrained_name, model_root)
150 149
151 finetuned_dir, _ = load_finetuned(model_root, args.finetuned_subdir) 150 finetuned_dir, _ = load_finetuned(model_root, args.finetuned_subdir)
  151 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  152 + tokenizer = AutoTokenizer.from_pretrained(finetuned_dir)
  153 + model = AutoModelForSequenceClassification.from_pretrained(finetuned_dir)
  154 + model.to(device)
  155 + model.eval()
152 156
153 if args.text is not None: 157 if args.text is not None:
154 - label, conf = predict_once(finetuned_dir, args.text, args.max_length)  
155 - print(f"预测结果: {label} (置信度: {conf:.4f})") 158 + topk = predict_topk(model, tokenizer, device, args.text, args.max_length, top_k=3)
  159 + print("Top-3 预测:")
  160 + for rank, (label, conf) in enumerate(topk, 1):
  161 + print(f"{rank}. {label} (p={conf:.4f})")
156 return 162 return
157 163
158 - if args.interactive: 164 + # 默认进入交互模式(未显式指定 --text 且未显式关闭交互)
  165 + if args.interactive or (args.text is None):
159 print("进入交互模式。输入 'q' 退出。") 166 print("进入交互模式。输入 'q' 退出。")
160 while True: 167 while True:
161 try: 168 try:
@@ -166,11 +173,13 @@ def main() -> None: @@ -166,11 +173,13 @@ def main() -> None:
166 break 173 break
167 if not text: 174 if not text:
168 continue 175 continue
169 - label, conf = predict_once(finetuned_dir, text, args.max_length)  
170 - print(f"预测结果: {label} (置信度: {conf:.4f})") 176 + topk = predict_topk(model, tokenizer, device, text, args.max_length, top_k=3)
  177 + print("Top-3 预测:")
  178 + for rank, (label, conf) in enumerate(topk, 1):
  179 + print(f"{rank}. {label} (p={conf:.4f})")
171 return 180 return
172 -  
173 - print("未提供 --text 或 --interactive,什么也没有发生。") 181 + # 理论上不会到达这里
  182 + print("未提供输入。")
174 183
175 184
176 if __name__ == "__main__": 185 if __name__ == "__main__":
@@ -52,6 +52,52 @@ try: @@ -52,6 +52,52 @@ try:
52 except Exception: # pragma: no cover 52 except Exception: # pragma: no cover
53 EarlyStoppingCallback = None # type: ignore 53 EarlyStoppingCallback = None # type: ignore
54 54
  55 +# 预置可选中文基座模型(可扩展)
  56 +BACKBONE_CANDIDATES: List[Tuple[str, str]] = [
  57 + ("1) google-bert/bert-base-chinese", "google-bert/bert-base-chinese"),
  58 + ("2) hfl/chinese-roberta-wwm-ext-large", "hfl/chinese-roberta-wwm-ext-large"),
  59 + ("3) hfl/chinese-macbert-large", "hfl/chinese-macbert-large"),
  60 + ("4) IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese", "IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese"),
  61 + ("5) IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese", "IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese"),
  62 + ("6) Langboat/mengzi-bert-base", "Langboat/mengzi-bert-base"),
  63 + ("7) BAAI/bge-base-zh", "BAAI/bge-base-zh"),
  64 + ("8) nghuyong/ernie-3.0-base-zh", "nghuyong/ernie-3.0-base-zh"),
  65 +]
  66 +
  67 +
  68 +def prompt_backbone_interactive(current_id: str) -> str:
  69 + """交互式选择基座模型。
  70 +
  71 + - 当处于非交互环境(stdin 非 TTY)或设置了环境变量 NON_INTERACTIVE=1 时,直接返回 current_id。
  72 + - 用户可输入序号选择预置项,或直接输入任意 Hugging Face 模型 ID。
  73 + - 空回车使用当前默认。
  74 + """
  75 + if os.environ.get("NON_INTERACTIVE", "0") == "1":
  76 + return current_id
  77 + try:
  78 + if not sys.stdin.isatty():
  79 + return current_id
  80 + except Exception:
  81 + return current_id
  82 +
  83 + print("\n可选中文基座模型(直接回车使用默认):")
  84 + for label, hf_id in BACKBONE_CANDIDATES:
  85 + print(f" {label}")
  86 + print(f"当前默认: {current_id}")
  87 + choice = input("请输入序号或直接粘贴模型ID(回车沿用默认): ").strip()
  88 + if not choice:
  89 + return current_id
  90 + # 数字选项
  91 + if choice.isdigit():
  92 + idx = int(choice)
  93 + for label, hf_id in BACKBONE_CANDIDATES:
  94 + if label.startswith(f"{idx})"):
  95 + return hf_id
  96 + print("未找到该序号,沿用默认。")
  97 + return current_id
  98 + # 自定义 HF 模型 ID
  99 + return choice
  100 +
55 101
56 def preprocess_text(text: str) -> str: 102 def preprocess_text(text: str) -> str:
57 if text is None: 103 if text is None:
@@ -191,7 +237,7 @@ def parse_args() -> argparse.Namespace: @@ -191,7 +237,7 @@ def parse_args() -> argparse.Namespace:
191 parser.add_argument("--text_col", type=str, default="auto", help="文本列名,默认自动识别") 237 parser.add_argument("--text_col", type=str, default="auto", help="文本列名,默认自动识别")
192 parser.add_argument("--label_col", type=str, default="auto", help="标签列名,默认自动识别") 238 parser.add_argument("--label_col", type=str, default="auto", help="标签列名,默认自动识别")
193 parser.add_argument("--model_root", type=str, default="./model", help="本地模型根目录") 239 parser.add_argument("--model_root", type=str, default="./model", help="本地模型根目录")
194 - parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese") 240 + parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese", help="Hugging Face 模型ID;留空则进入交互选择")
195 parser.add_argument("--save_subdir", type=str, default="bert-chinese-classifier") 241 parser.add_argument("--save_subdir", type=str, default="bert-chinese-classifier")
196 parser.add_argument("--max_length", type=int, default=128) 242 parser.add_argument("--max_length", type=int, default=128)
197 parser.add_argument("--batch_size", type=int, default=64) 243 parser.add_argument("--batch_size", type=int, default=64)
@@ -216,8 +262,10 @@ def main() -> None: @@ -216,8 +262,10 @@ def main() -> None:
216 model_root = args.model_root if os.path.isabs(args.model_root) else os.path.join(script_dir, args.model_root) 262 model_root = args.model_root if os.path.isabs(args.model_root) else os.path.join(script_dir, args.model_root)
217 os.makedirs(model_root, exist_ok=True) 263 os.makedirs(model_root, exist_ok=True)
218 264
  265 + # 交互式选择基座模型(若允许交互且未通过环境禁用)
  266 + selected_model_id = prompt_backbone_interactive(args.pretrained_name)
219 # 确保基础模型就绪 267 # 确保基础模型就绪
220 - base_dir, tokenizer = ensure_base_model_local(args.pretrained_name, model_root) 268 + base_dir, tokenizer = ensure_base_model_local(selected_model_id, model_root)
221 print(f"[Info] 使用基础模型目录: {base_dir}") 269 print(f"[Info] 使用基础模型目录: {base_dir}")
222 270
223 # 读取数据 271 # 读取数据