Updated the topic prediction script to support Top-K predictions and added an in…
…teractive selection feature for optional Chinese foundation models.
Showing
4 changed files
with
110 additions
and
24 deletions
| @@ -185,6 +185,7 @@ WeiboSentiment_Finetuned/BertChinese-Lora/model/ | @@ -185,6 +185,7 @@ WeiboSentiment_Finetuned/BertChinese-Lora/model/ | ||
| 185 | WeiboMultilingualSentiment/model/ | 185 | WeiboMultilingualSentiment/model/ |
| 186 | WeiboSentiment_MachineLearning/model/chinese_wwm_pytorch/ | 186 | WeiboSentiment_MachineLearning/model/chinese_wwm_pytorch/ |
| 187 | WeiboSentiment_SmallQwen/models/ | 187 | WeiboSentiment_SmallQwen/models/ |
| 188 | +BertTopicDetection_Finetuned/model/ | ||
| 188 | 189 | ||
| 189 | # LoRA 和 Adapter 权重 | 190 | # LoRA 和 Adapter 权重 |
| 190 | */adapter_model.safetensors | 191 | */adapter_model.safetensors |
| @@ -63,6 +63,23 @@ python train.py \ | @@ -63,6 +63,23 @@ python train.py \ | ||
| 63 | - 支持早停(默认耐心 5 次评估),并在评估/保存策略一致时自动回滚到最佳模型; | 63 | - 支持早停(默认耐心 5 次评估),并在评估/保存策略一致时自动回滚到最佳模型; |
| 64 | - 分词器、权重与 `label_map.json` 保存到 `model/bert-chinese-classifier/`。 | 64 | - 分词器、权重与 `label_map.json` 保存到 `model/bert-chinese-classifier/`。 |
| 65 | 65 | ||
| 66 | +### 可选中文基座模型(训练前交互选择) | ||
| 67 | + | ||
| 68 | +默认基座:`google-bert/bert-base-chinese`。启动训练时,若终端可交互,程序会提示从下列选项中选择(或输入任意 Hugging Face 模型 ID): | ||
| 69 | + | ||
| 70 | +1) `google-bert/bert-base-chinese` | ||
| 71 | +2) `hfl/chinese-roberta-wwm-ext-large` | ||
| 72 | +3) `hfl/chinese-macbert-large` | ||
| 73 | +4) `IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese` | ||
| 74 | +5) `IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese` | ||
| 75 | +6) `Langboat/mengzi-bert-base` | ||
| 76 | +7) `BAAI/bge-base-zh`(更适合检索式/对比学习范式) | ||
| 77 | +8) `nghuyong/ernie-3.0-base-zh` | ||
| 78 | + | ||
| 79 | +说明: | ||
| 80 | +- 非交互环境(如调度系统)或设置 `NON_INTERACTIVE=1` 时,会直接使用命令行参数 `--pretrained_name` 指定的模型(默认为 `google-bert/bert-base-chinese`)。 | ||
| 81 | +- 选择后,基础模型将下载/缓存至 `model/` 目录,统一管理。 | ||
| 82 | + | ||
| 66 | ### 预测 | 83 | ### 预测 |
| 67 | 84 | ||
| 68 | 单条: | 85 | 单条: |
| @@ -96,3 +113,14 @@ python predict.py --interactive --model_root ./model --finetuned_subdir bert-chi | @@ -96,3 +113,14 @@ python predict.py --interactive --model_root ./model --finetuned_subdir bert-chi | ||
| 96 | - 单卡稳定运行:默认仅使用一张 GPU,可通过 `--gpu` 指定;脚本会清理分布式环境变量。 | 113 | - 单卡稳定运行:默认仅使用一张 GPU,可通过 `--gpu` 指定;脚本会清理分布式环境变量。 |
| 97 | 114 | ||
| 98 | 115 | ||
| 116 | +### 作者说明(关于超大规模多分类) | ||
| 117 | + | ||
| 118 | +- 当话题类别达到上万级时,直接在编码器后接单一线性分类头(大 softmax)往往受限:长尾类别难学、语义稀疏、新增话题无法增量适配、上线后需频繁重训。 | ||
| 119 | +- 改进思路(推荐优先级): | ||
| 120 | + - 检索式/双塔范式(文本 vs. 话题名称/描述 对比学习)+ 近邻检索 + 小头重排,天然支持增量扩类与快速更新; | ||
| 121 | + - 分层分类(先粗分再细分),显著降低单头难度与计算; | ||
| 122 | + - 文本-标签联合建模(使用标签描述),提升近义话题的可迁移性; | ||
| 123 | + - 训练细节:class-balanced/focal/label smoothing、sampled softmax、对比预训练等。 | ||
| 124 | +- 重要声明:本目录使用的“静态分类头微调”仅作为备选与学习参考。对于英文/多语微短文场景,话题变化极快,传统静态分类器难以及时覆盖,我们的工作重点在 `TopicGPT` 等生成式/自监督话题发现与动态体系构建方向;本实现旨在提供一个可运行的基线与工程示例。 | ||
| 125 | + | ||
| 126 | + |
| @@ -3,7 +3,7 @@ import sys | @@ -3,7 +3,7 @@ import sys | ||
| 3 | import json | 3 | import json |
| 4 | import re | 4 | import re |
| 5 | import argparse | 5 | import argparse |
| 6 | -from typing import Dict, Tuple | 6 | +from typing import Dict, Tuple, List |
| 7 | 7 | ||
| 8 | # ========== 单卡锁定(在导入 torch/transformers 前执行) ========== | 8 | # ========== 单卡锁定(在导入 torch/transformers 前执行) ========== |
| 9 | def _extract_gpu_arg(argv, default: str = "0") -> str: | 9 | def _extract_gpu_arg(argv, default: str = "0") -> str: |
| @@ -109,14 +109,8 @@ def load_finetuned(model_root: str, subdir: str) -> Tuple[str, Dict[int, str]]: | @@ -109,14 +109,8 @@ def load_finetuned(model_root: str, subdir: str) -> Tuple[str, Dict[int, str]]: | ||
| 109 | return finetuned_path, id2label | 109 | return finetuned_path, id2label |
| 110 | 110 | ||
| 111 | 111 | ||
| 112 | -def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str, float]: | ||
| 113 | - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| 114 | - tokenizer = AutoTokenizer.from_pretrained(model_dir) | ||
| 115 | - model = AutoModelForSequenceClassification.from_pretrained(model_dir) | ||
| 116 | - model.to(device) | ||
| 117 | - model.eval() | ||
| 118 | - | ||
| 119 | - processed = preprocess_text(text) | 112 | +def predict_topk(model: AutoModelForSequenceClassification, tokenizer: AutoTokenizer, device: torch.device, text: str, max_length: int = 128, top_k: int = 3) -> List[Tuple[str, float]]: |
| 113 | + processed = preprocess_text(text or "") | ||
| 120 | encoded = tokenizer( | 114 | encoded = tokenizer( |
| 121 | processed, | 115 | processed, |
| 122 | max_length=max_length, | 116 | max_length=max_length, |
| @@ -130,12 +124,17 @@ def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str, | @@ -130,12 +124,17 @@ def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str, | ||
| 130 | with torch.no_grad(): | 124 | with torch.no_grad(): |
| 131 | outputs = model(input_ids=input_ids, attention_mask=attention_mask) | 125 | outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
| 132 | logits = outputs.logits | 126 | logits = outputs.logits |
| 133 | - probs = torch.softmax(logits, dim=-1) | ||
| 134 | - pred = int(torch.argmax(probs, dim=-1).item()) | ||
| 135 | - conf = float(probs[0, pred].item()) | ||
| 136 | - id2label = getattr(model.config, "id2label", None) | ||
| 137 | - label_name = id2label.get(pred, str(pred)) if isinstance(id2label, dict) else str(pred) | ||
| 138 | - return label_name, conf | 127 | + probs = torch.softmax(logits, dim=-1)[0] |
| 128 | + k = min(top_k, probs.shape[-1]) | ||
| 129 | + confs, idxs = torch.topk(probs, k) | ||
| 130 | + id2label = getattr(model.config, "id2label", {}) if isinstance(getattr(model.config, "id2label", None), dict) else {} | ||
| 131 | + results: List[Tuple[str, float]] = [] | ||
| 132 | + for i in range(k): | ||
| 133 | + idx = int(idxs[i].item()) | ||
| 134 | + conf = float(confs[i].item()) | ||
| 135 | + label_name = id2label.get(idx, str(idx)) | ||
| 136 | + results.append((label_name, conf)) | ||
| 137 | + return results | ||
| 139 | 138 | ||
| 140 | 139 | ||
| 141 | def main() -> None: | 140 | def main() -> None: |
| @@ -149,13 +148,21 @@ def main() -> None: | @@ -149,13 +148,21 @@ def main() -> None: | ||
| 149 | ensure_base_model_local(args.pretrained_name, model_root) | 148 | ensure_base_model_local(args.pretrained_name, model_root) |
| 150 | 149 | ||
| 151 | finetuned_dir, _ = load_finetuned(model_root, args.finetuned_subdir) | 150 | finetuned_dir, _ = load_finetuned(model_root, args.finetuned_subdir) |
| 151 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| 152 | + tokenizer = AutoTokenizer.from_pretrained(finetuned_dir) | ||
| 153 | + model = AutoModelForSequenceClassification.from_pretrained(finetuned_dir) | ||
| 154 | + model.to(device) | ||
| 155 | + model.eval() | ||
| 152 | 156 | ||
| 153 | if args.text is not None: | 157 | if args.text is not None: |
| 154 | - label, conf = predict_once(finetuned_dir, args.text, args.max_length) | ||
| 155 | - print(f"预测结果: {label} (置信度: {conf:.4f})") | 158 | + topk = predict_topk(model, tokenizer, device, args.text, args.max_length, top_k=3) |
| 159 | + print("Top-3 预测:") | ||
| 160 | + for rank, (label, conf) in enumerate(topk, 1): | ||
| 161 | + print(f"{rank}. {label} (p={conf:.4f})") | ||
| 156 | return | 162 | return |
| 157 | 163 | ||
| 158 | - if args.interactive: | 164 | + # 默认进入交互模式(未显式指定 --text 且未显式关闭交互) |
| 165 | + if args.interactive or (args.text is None): | ||
| 159 | print("进入交互模式。输入 'q' 退出。") | 166 | print("进入交互模式。输入 'q' 退出。") |
| 160 | while True: | 167 | while True: |
| 161 | try: | 168 | try: |
| @@ -166,11 +173,13 @@ def main() -> None: | @@ -166,11 +173,13 @@ def main() -> None: | ||
| 166 | break | 173 | break |
| 167 | if not text: | 174 | if not text: |
| 168 | continue | 175 | continue |
| 169 | - label, conf = predict_once(finetuned_dir, text, args.max_length) | ||
| 170 | - print(f"预测结果: {label} (置信度: {conf:.4f})") | 176 | + topk = predict_topk(model, tokenizer, device, text, args.max_length, top_k=3) |
| 177 | + print("Top-3 预测:") | ||
| 178 | + for rank, (label, conf) in enumerate(topk, 1): | ||
| 179 | + print(f"{rank}. {label} (p={conf:.4f})") | ||
| 171 | return | 180 | return |
| 172 | - | ||
| 173 | - print("未提供 --text 或 --interactive,什么也没有发生。") | 181 | + # 理论上不会到达这里 |
| 182 | + print("未提供输入。") | ||
| 174 | 183 | ||
| 175 | 184 | ||
| 176 | if __name__ == "__main__": | 185 | if __name__ == "__main__": |
| @@ -52,6 +52,52 @@ try: | @@ -52,6 +52,52 @@ try: | ||
| 52 | except Exception: # pragma: no cover | 52 | except Exception: # pragma: no cover |
| 53 | EarlyStoppingCallback = None # type: ignore | 53 | EarlyStoppingCallback = None # type: ignore |
| 54 | 54 | ||
| 55 | +# 预置可选中文基座模型(可扩展) | ||
| 56 | +BACKBONE_CANDIDATES: List[Tuple[str, str]] = [ | ||
| 57 | + ("1) google-bert/bert-base-chinese", "google-bert/bert-base-chinese"), | ||
| 58 | + ("2) hfl/chinese-roberta-wwm-ext-large", "hfl/chinese-roberta-wwm-ext-large"), | ||
| 59 | + ("3) hfl/chinese-macbert-large", "hfl/chinese-macbert-large"), | ||
| 60 | + ("4) IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese", "IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese"), | ||
| 61 | + ("5) IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese", "IDEA-CCNL/Erlangshen-DeBERTa-v3-Base-Chinese"), | ||
| 62 | + ("6) Langboat/mengzi-bert-base", "Langboat/mengzi-bert-base"), | ||
| 63 | + ("7) BAAI/bge-base-zh", "BAAI/bge-base-zh"), | ||
| 64 | + ("8) nghuyong/ernie-3.0-base-zh", "nghuyong/ernie-3.0-base-zh"), | ||
| 65 | +] | ||
| 66 | + | ||
| 67 | + | ||
| 68 | +def prompt_backbone_interactive(current_id: str) -> str: | ||
| 69 | + """交互式选择基座模型。 | ||
| 70 | + | ||
| 71 | + - 当处于非交互环境(stdin 非 TTY)或设置了环境变量 NON_INTERACTIVE=1 时,直接返回 current_id。 | ||
| 72 | + - 用户可输入序号选择预置项,或直接输入任意 Hugging Face 模型 ID。 | ||
| 73 | + - 空回车使用当前默认。 | ||
| 74 | + """ | ||
| 75 | + if os.environ.get("NON_INTERACTIVE", "0") == "1": | ||
| 76 | + return current_id | ||
| 77 | + try: | ||
| 78 | + if not sys.stdin.isatty(): | ||
| 79 | + return current_id | ||
| 80 | + except Exception: | ||
| 81 | + return current_id | ||
| 82 | + | ||
| 83 | + print("\n可选中文基座模型(直接回车使用默认):") | ||
| 84 | + for label, hf_id in BACKBONE_CANDIDATES: | ||
| 85 | + print(f" {label}") | ||
| 86 | + print(f"当前默认: {current_id}") | ||
| 87 | + choice = input("请输入序号或直接粘贴模型ID(回车沿用默认): ").strip() | ||
| 88 | + if not choice: | ||
| 89 | + return current_id | ||
| 90 | + # 数字选项 | ||
| 91 | + if choice.isdigit(): | ||
| 92 | + idx = int(choice) | ||
| 93 | + for label, hf_id in BACKBONE_CANDIDATES: | ||
| 94 | + if label.startswith(f"{idx})"): | ||
| 95 | + return hf_id | ||
| 96 | + print("未找到该序号,沿用默认。") | ||
| 97 | + return current_id | ||
| 98 | + # 自定义 HF 模型 ID | ||
| 99 | + return choice | ||
| 100 | + | ||
| 55 | 101 | ||
| 56 | def preprocess_text(text: str) -> str: | 102 | def preprocess_text(text: str) -> str: |
| 57 | if text is None: | 103 | if text is None: |
| @@ -191,7 +237,7 @@ def parse_args() -> argparse.Namespace: | @@ -191,7 +237,7 @@ def parse_args() -> argparse.Namespace: | ||
| 191 | parser.add_argument("--text_col", type=str, default="auto", help="文本列名,默认自动识别") | 237 | parser.add_argument("--text_col", type=str, default="auto", help="文本列名,默认自动识别") |
| 192 | parser.add_argument("--label_col", type=str, default="auto", help="标签列名,默认自动识别") | 238 | parser.add_argument("--label_col", type=str, default="auto", help="标签列名,默认自动识别") |
| 193 | parser.add_argument("--model_root", type=str, default="./model", help="本地模型根目录") | 239 | parser.add_argument("--model_root", type=str, default="./model", help="本地模型根目录") |
| 194 | - parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese") | 240 | + parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese", help="Hugging Face 模型ID;留空则进入交互选择") |
| 195 | parser.add_argument("--save_subdir", type=str, default="bert-chinese-classifier") | 241 | parser.add_argument("--save_subdir", type=str, default="bert-chinese-classifier") |
| 196 | parser.add_argument("--max_length", type=int, default=128) | 242 | parser.add_argument("--max_length", type=int, default=128) |
| 197 | parser.add_argument("--batch_size", type=int, default=64) | 243 | parser.add_argument("--batch_size", type=int, default=64) |
| @@ -216,8 +262,10 @@ def main() -> None: | @@ -216,8 +262,10 @@ def main() -> None: | ||
| 216 | model_root = args.model_root if os.path.isabs(args.model_root) else os.path.join(script_dir, args.model_root) | 262 | model_root = args.model_root if os.path.isabs(args.model_root) else os.path.join(script_dir, args.model_root) |
| 217 | os.makedirs(model_root, exist_ok=True) | 263 | os.makedirs(model_root, exist_ok=True) |
| 218 | 264 | ||
| 265 | + # 交互式选择基座模型(若允许交互且未通过环境禁用) | ||
| 266 | + selected_model_id = prompt_backbone_interactive(args.pretrained_name) | ||
| 219 | # 确保基础模型就绪 | 267 | # 确保基础模型就绪 |
| 220 | - base_dir, tokenizer = ensure_base_model_local(args.pretrained_name, model_root) | 268 | + base_dir, tokenizer = ensure_base_model_local(selected_model_id, model_root) |
| 221 | print(f"[Info] 使用基础模型目录: {base_dir}") | 269 | print(f"[Info] 使用基础模型目录: {base_dir}") |
| 222 | 270 | ||
| 223 | # 读取数据 | 271 | # 读取数据 |
-
Please register or login to post a comment