Train and prediction script for a topic classification model based on bert-chinese.
Showing
8 changed files
with
810 additions
and
0 deletions
BertTopicDetection_Finetuned/README.md
0 → 100644
| 1 | +## 话题分类(BERT 中文基座) | ||
| 2 | + | ||
| 3 | +本目录提供一个使用 `google-bert/bert-base-chinese` 的中文话题分类实现: | ||
| 4 | +- 自动处理本地/缓存/远程三段式加载逻辑; | ||
| 5 | +- `train.py` 进行微调训练;`predict.py` 进行单条或交互式预测; | ||
| 6 | +- 所有模型与权重统一保存至本目录的 `model/`。 | ||
| 7 | + | ||
| 8 | +参考模型卡片: [google-bert/bert-base-chinese](https://huggingface.co/google-bert/bert-base-chinese) | ||
| 9 | + | ||
| 10 | +### 数据集亮点 | ||
| 11 | + | ||
| 12 | +- 约 **410 万**条预过滤高质量问题与回复; | ||
| 13 | +- 每个问题对应一个“【话题】”,覆盖 **约 2.8 万**个多样主题; | ||
| 14 | +- 从 **1400 万**原始问答中筛选,保留至少 **3 个点赞以上**的答案,确保内容质量与有趣度; | ||
| 15 | +- 除了问题、话题与一个或多个回复外,每个回复还带有点赞数、回复 ID、回复者标签; | ||
| 16 | +- 数据清洗去重后划分三部分:示例划分训练集约 **412 万**、验证/测试若干(可按需调整)。 | ||
| 17 | + | ||
| 18 | +> 实际训练时,请以 `dataset/` 下的 CSV 为准;脚本会自动识别常见列名或允许通过命令参数显式指定。 | ||
| 19 | + | ||
| 20 | +### 目录结构 | ||
| 21 | + | ||
| 22 | +``` | ||
| 23 | +BertTopicDetection_Finetuned/ | ||
| 24 | + ├─ dataset/ # 已放置数据 | ||
| 25 | + ├─ model/ # 训练生成;亦缓存基础 BERT | ||
| 26 | + ├─ train.py | ||
| 27 | + ├─ predict.py | ||
| 28 | + └─ README.md | ||
| 29 | +``` | ||
| 30 | + | ||
| 31 | +### 环境 | ||
| 32 | + | ||
| 33 | +``` | ||
| 34 | +pip install torch transformers scikit-learn pandas | ||
| 35 | +``` | ||
| 36 | + | ||
| 37 | +或使用你既有的 Conda 环境。 | ||
| 38 | + | ||
| 39 | +### 数据格式 | ||
| 40 | + | ||
| 41 | +CSV 至少包含文本列与标签列,脚本会尝试自动识别: | ||
| 42 | +- 文本列候选:`text`/`content`/`sentence`/`title`/`desc`/`question` | ||
| 43 | +- 标签列候选:`label`/`labels`/`category`/`topic`/`class` | ||
| 44 | + | ||
| 45 | +如需显式指定,请使用 `--text_col` 与 `--label_col`。 | ||
| 46 | + | ||
| 47 | +### 训练 | ||
| 48 | + | ||
| 49 | +``` | ||
| 50 | +python train.py \ | ||
| 51 | + --train_file ./dataset/web_text_zh_train.csv \ | ||
| 52 | + --valid_file ./dataset/web_text_zh_valid.csv \ | ||
| 53 | + --text_col auto \ | ||
| 54 | + --label_col auto \ | ||
| 55 | + --model_root ./model \ | ||
| 56 | + --save_subdir bert-chinese-classifier \ | ||
| 57 | + --num_epochs 10 --batch_size 16 --learning_rate 2e-5 --fp16 | ||
| 58 | +``` | ||
| 59 | + | ||
| 60 | +要点: | ||
| 61 | +- 首次运行会检查 `model/bert-base-chinese`;若无则尝试本机缓存,再不行则自动下载并保存; | ||
| 62 | +- 训练过程按步评估与保存(默认每 1/4 个 epoch),最多保留 5 个最近 checkpoint(可通过环境变量 `SAVE_TOTAL_LIMIT` 调整); | ||
| 63 | +- 支持早停(默认耐心 5 次评估),并在评估/保存策略一致时自动回滚到最佳模型; | ||
| 64 | +- 分词器、权重与 `label_map.json` 保存到 `model/bert-chinese-classifier/`。 | ||
| 65 | + | ||
| 66 | +### 预测 | ||
| 67 | + | ||
| 68 | +单条: | ||
| 69 | +``` | ||
| 70 | +python predict.py --text "这条微博讨论的是哪个话题?" --model_root ./model --finetuned_subdir bert-chinese-classifier | ||
| 71 | +``` | ||
| 72 | + | ||
| 73 | +交互: | ||
| 74 | +``` | ||
| 75 | +python predict.py --interactive --model_root ./model --finetuned_subdir bert-chinese-classifier | ||
| 76 | +``` | ||
| 77 | + | ||
| 78 | +示例输出: | ||
| 79 | +``` | ||
| 80 | +预测结果: 体育-足球 (置信度: 0.9412) | ||
| 81 | +``` | ||
| 82 | + | ||
| 83 | +### 说明 | ||
| 84 | + | ||
| 85 | +- 训练与预测均内置简易中文文本清洗。 | ||
| 86 | +- 标签集合以训练集为准,脚本自动生成并保存 `label_map.json`。 | ||
| 87 | + | ||
| 88 | +### 训练策略(简述) | ||
| 89 | + | ||
| 90 | +- 基座:`google-bert/bert-base-chinese`;分类头维度=训练集唯一标签数。 | ||
| 91 | +- 学习率与正则:`lr=2e-5`,`weight_decay=0.01`,可在大型数据上微调到 `1e-5~3e-5`。 | ||
| 92 | +- 序列长度与批量:`max_length=128`,`batch_size=16`;若截断严重可升至 256(成本上升)。 | ||
| 93 | +- Warmup:若环境支持,使用 `warmup_ratio=0.1`;否则回退 `warmup_steps=0`。 | ||
| 94 | +- 评估/保存:按 `--eval_fraction` 折算步数(默认 0.25),`save_total_limit=5` 限制磁盘占用。 | ||
| 95 | +- 早停:监控加权 F1(越大越好),默认耐心 5、改善阈值 0.0。 | ||
| 96 | +- 单卡稳定运行:默认仅使用一张 GPU,可通过 `--gpu` 指定;脚本会清理分布式环境变量。 | ||
| 97 | + | ||
| 98 | + |
| 1 | +{ | ||
| 2 | + "architectures": [ | ||
| 3 | + "BertModel" | ||
| 4 | + ], | ||
| 5 | + "attention_probs_dropout_prob": 0.1, | ||
| 6 | + "classifier_dropout": null, | ||
| 7 | + "directionality": "bidi", | ||
| 8 | + "hidden_act": "gelu", | ||
| 9 | + "hidden_dropout_prob": 0.1, | ||
| 10 | + "hidden_size": 768, | ||
| 11 | + "initializer_range": 0.02, | ||
| 12 | + "intermediate_size": 3072, | ||
| 13 | + "layer_norm_eps": 1e-12, | ||
| 14 | + "max_position_embeddings": 512, | ||
| 15 | + "model_type": "bert", | ||
| 16 | + "num_attention_heads": 12, | ||
| 17 | + "num_hidden_layers": 12, | ||
| 18 | + "pad_token_id": 0, | ||
| 19 | + "pooler_fc_size": 768, | ||
| 20 | + "pooler_num_attention_heads": 12, | ||
| 21 | + "pooler_num_fc_layers": 3, | ||
| 22 | + "pooler_size_per_head": 128, | ||
| 23 | + "pooler_type": "first_token_transform", | ||
| 24 | + "position_embedding_type": "absolute", | ||
| 25 | + "torch_dtype": "float32", | ||
| 26 | + "transformers_version": "4.51.3", | ||
| 27 | + "type_vocab_size": 2, | ||
| 28 | + "use_cache": true, | ||
| 29 | + "vocab_size": 21128 | ||
| 30 | +} |
This diff could not be displayed because it is too large.
| 1 | +{ | ||
| 2 | + "added_tokens_decoder": { | ||
| 3 | + "0": { | ||
| 4 | + "content": "[PAD]", | ||
| 5 | + "lstrip": false, | ||
| 6 | + "normalized": false, | ||
| 7 | + "rstrip": false, | ||
| 8 | + "single_word": false, | ||
| 9 | + "special": true | ||
| 10 | + }, | ||
| 11 | + "100": { | ||
| 12 | + "content": "[UNK]", | ||
| 13 | + "lstrip": false, | ||
| 14 | + "normalized": false, | ||
| 15 | + "rstrip": false, | ||
| 16 | + "single_word": false, | ||
| 17 | + "special": true | ||
| 18 | + }, | ||
| 19 | + "101": { | ||
| 20 | + "content": "[CLS]", | ||
| 21 | + "lstrip": false, | ||
| 22 | + "normalized": false, | ||
| 23 | + "rstrip": false, | ||
| 24 | + "single_word": false, | ||
| 25 | + "special": true | ||
| 26 | + }, | ||
| 27 | + "102": { | ||
| 28 | + "content": "[SEP]", | ||
| 29 | + "lstrip": false, | ||
| 30 | + "normalized": false, | ||
| 31 | + "rstrip": false, | ||
| 32 | + "single_word": false, | ||
| 33 | + "special": true | ||
| 34 | + }, | ||
| 35 | + "103": { | ||
| 36 | + "content": "[MASK]", | ||
| 37 | + "lstrip": false, | ||
| 38 | + "normalized": false, | ||
| 39 | + "rstrip": false, | ||
| 40 | + "single_word": false, | ||
| 41 | + "special": true | ||
| 42 | + } | ||
| 43 | + }, | ||
| 44 | + "clean_up_tokenization_spaces": false, | ||
| 45 | + "cls_token": "[CLS]", | ||
| 46 | + "do_lower_case": false, | ||
| 47 | + "extra_special_tokens": {}, | ||
| 48 | + "mask_token": "[MASK]", | ||
| 49 | + "model_max_length": 512, | ||
| 50 | + "pad_token": "[PAD]", | ||
| 51 | + "sep_token": "[SEP]", | ||
| 52 | + "strip_accents": null, | ||
| 53 | + "tokenize_chinese_chars": true, | ||
| 54 | + "tokenizer_class": "BertTokenizer", | ||
| 55 | + "unk_token": "[UNK]" | ||
| 56 | +} |
This diff could not be displayed because it is too large.
BertTopicDetection_Finetuned/predict.py
0 → 100644
| 1 | +import os | ||
| 2 | +import sys | ||
| 3 | +import json | ||
| 4 | +import re | ||
| 5 | +import argparse | ||
| 6 | +from typing import Dict, Tuple | ||
| 7 | + | ||
| 8 | +# ========== 单卡锁定(在导入 torch/transformers 前执行) ========== | ||
| 9 | +def _extract_gpu_arg(argv, default: str = "0") -> str: | ||
| 10 | + for i, arg in enumerate(argv): | ||
| 11 | + if arg.startswith("--gpu="): | ||
| 12 | + return arg.split("=", 1)[1] | ||
| 13 | + if arg == "--gpu" and i + 1 < len(argv): | ||
| 14 | + return argv[i + 1] | ||
| 15 | + return default | ||
| 16 | + | ||
| 17 | +env_vis = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip() | ||
| 18 | +try: | ||
| 19 | + gpu_to_use = _extract_gpu_arg(sys.argv, default="0") | ||
| 20 | +except Exception: | ||
| 21 | + gpu_to_use = "0" | ||
| 22 | +if (not env_vis) or ("," in env_vis): | ||
| 23 | + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_to_use | ||
| 24 | +os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") | ||
| 25 | + | ||
| 26 | +for _k in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]: | ||
| 27 | + os.environ.pop(_k, None) | ||
| 28 | + | ||
| 29 | +import torch | ||
| 30 | +from transformers import ( | ||
| 31 | + AutoTokenizer, | ||
| 32 | + AutoModel, | ||
| 33 | + AutoModelForSequenceClassification, | ||
| 34 | +) | ||
| 35 | + | ||
| 36 | + | ||
| 37 | +def preprocess_text(text: str) -> str: | ||
| 38 | + if text is None: | ||
| 39 | + return "" | ||
| 40 | + text = str(text) | ||
| 41 | + text = re.sub(r"\{%.+?%\}", " ", text) | ||
| 42 | + text = re.sub(r"@.+?( |$)", " ", text) | ||
| 43 | + text = re.sub(r"【.+?】", " ", text) | ||
| 44 | + text = re.sub(r"\u200b", " ", text) | ||
| 45 | + text = re.sub( | ||
| 46 | + r"[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\U00002600-\U000027BF\U0001f900-\U0001f9ff\U0001f018-\U0001f270\U0000231a-\U0000231b\U0000238d-\U0000238d\U000024c2-\U0001f251]+", | ||
| 47 | + "", | ||
| 48 | + text, | ||
| 49 | + ) | ||
| 50 | + text = re.sub(r"\s+", " ", text) | ||
| 51 | + return text.strip() | ||
| 52 | + | ||
| 53 | + | ||
| 54 | +def ensure_base_model_local(model_name_or_path: str, local_model_root: str) -> Tuple[str, AutoTokenizer]: | ||
| 55 | + os.makedirs(local_model_root, exist_ok=True) | ||
| 56 | + base_dir = os.path.join(local_model_root, "bert-base-chinese") | ||
| 57 | + | ||
| 58 | + def is_ready(path: str) -> bool: | ||
| 59 | + return os.path.isdir(path) and os.path.isfile(os.path.join(path, "config.json")) | ||
| 60 | + | ||
| 61 | + if is_ready(base_dir): | ||
| 62 | + tokenizer = AutoTokenizer.from_pretrained(base_dir) | ||
| 63 | + return base_dir, tokenizer | ||
| 64 | + | ||
| 65 | + # 本机缓存 | ||
| 66 | + try: | ||
| 67 | + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, local_files_only=True) | ||
| 68 | + base = AutoModel.from_pretrained(model_name_or_path, local_files_only=True) | ||
| 69 | + os.makedirs(base_dir, exist_ok=True) | ||
| 70 | + tokenizer.save_pretrained(base_dir) | ||
| 71 | + base.save_pretrained(base_dir) | ||
| 72 | + return base_dir, tokenizer | ||
| 73 | + except Exception: | ||
| 74 | + pass | ||
| 75 | + | ||
| 76 | + # 远程下载 | ||
| 77 | + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | ||
| 78 | + base = AutoModel.from_pretrained(model_name_or_path) | ||
| 79 | + os.makedirs(base_dir, exist_ok=True) | ||
| 80 | + tokenizer.save_pretrained(base_dir) | ||
| 81 | + base.save_pretrained(base_dir) | ||
| 82 | + return base_dir, tokenizer | ||
| 83 | + | ||
| 84 | + | ||
| 85 | +def parse_args() -> argparse.Namespace: | ||
| 86 | + parser = argparse.ArgumentParser(description="使用本地/缓存/远程加载的中文 BERT 分类模型进行预测") | ||
| 87 | + parser.add_argument("--model_root", type=str, default="./model", help="本地模型根目录") | ||
| 88 | + parser.add_argument("--finetuned_subdir", type=str, default="bert-chinese-classifier", help="微调结果子目录") | ||
| 89 | + parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese", help="预训练模型名称或路径") | ||
| 90 | + parser.add_argument("--text", type=str, default=None, help="直接输入一条要预测的文本") | ||
| 91 | + parser.add_argument("--interactive", action="store_true", help="进入交互式预测模式") | ||
| 92 | + parser.add_argument("--max_length", type=int, default=128) | ||
| 93 | + parser.add_argument("--gpu", type=str, default=os.environ.get("CUDA_VISIBLE_DEVICES", "0"), help="指定单卡 GPU,如 0 或 1") | ||
| 94 | + return parser.parse_args() | ||
| 95 | + | ||
| 96 | + | ||
| 97 | +def load_finetuned(model_root: str, subdir: str) -> Tuple[str, Dict[int, str]]: | ||
| 98 | + finetuned_path = os.path.join(model_root, subdir) | ||
| 99 | + if not os.path.isdir(finetuned_path): | ||
| 100 | + raise FileNotFoundError( | ||
| 101 | + f"未找到微调模型目录: {finetuned_path},请先运行训练脚本。" | ||
| 102 | + ) | ||
| 103 | + label_map_path = os.path.join(finetuned_path, "label_map.json") | ||
| 104 | + id2label = None | ||
| 105 | + if os.path.isfile(label_map_path): | ||
| 106 | + with open(label_map_path, "r", encoding="utf-8") as f: | ||
| 107 | + data = json.load(f) | ||
| 108 | + id2label = {int(k): str(v) for k, v in data.get("id2label", {}).items()} | ||
| 109 | + return finetuned_path, id2label | ||
| 110 | + | ||
| 111 | + | ||
| 112 | +def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str, float]: | ||
| 113 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| 114 | + tokenizer = AutoTokenizer.from_pretrained(model_dir) | ||
| 115 | + model = AutoModelForSequenceClassification.from_pretrained(model_dir) | ||
| 116 | + model.to(device) | ||
| 117 | + model.eval() | ||
| 118 | + | ||
| 119 | + processed = preprocess_text(text) | ||
| 120 | + encoded = tokenizer( | ||
| 121 | + processed, | ||
| 122 | + max_length=max_length, | ||
| 123 | + truncation=True, | ||
| 124 | + padding="max_length", | ||
| 125 | + return_tensors="pt", | ||
| 126 | + ) | ||
| 127 | + input_ids = encoded["input_ids"].to(device) | ||
| 128 | + attention_mask = encoded["attention_mask"].to(device) | ||
| 129 | + | ||
| 130 | + with torch.no_grad(): | ||
| 131 | + outputs = model(input_ids=input_ids, attention_mask=attention_mask) | ||
| 132 | + logits = outputs.logits | ||
| 133 | + probs = torch.softmax(logits, dim=-1) | ||
| 134 | + pred = int(torch.argmax(probs, dim=-1).item()) | ||
| 135 | + conf = float(probs[0, pred].item()) | ||
| 136 | + id2label = getattr(model.config, "id2label", None) | ||
| 137 | + label_name = id2label.get(pred, str(pred)) if isinstance(id2label, dict) else str(pred) | ||
| 138 | + return label_name, conf | ||
| 139 | + | ||
| 140 | + | ||
| 141 | +def main() -> None: | ||
| 142 | + args = parse_args() | ||
| 143 | + | ||
| 144 | + script_dir = os.path.dirname(os.path.abspath(__file__)) | ||
| 145 | + model_root = args.model_root if os.path.isabs(args.model_root) else os.path.join(script_dir, args.model_root) | ||
| 146 | + os.makedirs(model_root, exist_ok=True) | ||
| 147 | + | ||
| 148 | + # 确保基础模型在本地 | ||
| 149 | + ensure_base_model_local(args.pretrained_name, model_root) | ||
| 150 | + | ||
| 151 | + finetuned_dir, _ = load_finetuned(model_root, args.finetuned_subdir) | ||
| 152 | + | ||
| 153 | + if args.text is not None: | ||
| 154 | + label, conf = predict_once(finetuned_dir, args.text, args.max_length) | ||
| 155 | + print(f"预测结果: {label} (置信度: {conf:.4f})") | ||
| 156 | + return | ||
| 157 | + | ||
| 158 | + if args.interactive: | ||
| 159 | + print("进入交互模式。输入 'q' 退出。") | ||
| 160 | + while True: | ||
| 161 | + try: | ||
| 162 | + text = input("请输入文本: ").strip() | ||
| 163 | + except EOFError: | ||
| 164 | + break | ||
| 165 | + if text.lower() == "q": | ||
| 166 | + break | ||
| 167 | + if not text: | ||
| 168 | + continue | ||
| 169 | + label, conf = predict_once(finetuned_dir, text, args.max_length) | ||
| 170 | + print(f"预测结果: {label} (置信度: {conf:.4f})") | ||
| 171 | + return | ||
| 172 | + | ||
| 173 | + print("未提供 --text 或 --interactive,什么也没有发生。") | ||
| 174 | + | ||
| 175 | + | ||
| 176 | +if __name__ == "__main__": | ||
| 177 | + main() | ||
| 178 | + | ||
| 179 | + |
BertTopicDetection_Finetuned/train.py
0 → 100644
| 1 | +import os | ||
| 2 | +import sys | ||
| 3 | +import json | ||
| 4 | +import re | ||
| 5 | +import argparse | ||
| 6 | +import math | ||
| 7 | +import inspect | ||
| 8 | +from typing import Dict, List, Optional, Tuple | ||
| 9 | + | ||
| 10 | +# ========== 单卡锁定(在导入 torch/transformers 前执行) ========== | ||
| 11 | +def _extract_gpu_arg(argv: List[str], default: str = "0") -> str: | ||
| 12 | + for i, arg in enumerate(argv): | ||
| 13 | + if arg.startswith("--gpu="): | ||
| 14 | + return arg.split("=", 1)[1] | ||
| 15 | + if arg == "--gpu" and i + 1 < len(argv): | ||
| 16 | + return argv[i + 1] | ||
| 17 | + return default | ||
| 18 | + | ||
| 19 | +env_vis = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip() | ||
| 20 | +try: | ||
| 21 | + gpu_to_use = _extract_gpu_arg(sys.argv, default="0") | ||
| 22 | +except Exception: | ||
| 23 | + gpu_to_use = "0" | ||
| 24 | +# 若未设置或暴露了多卡,则强制只暴露单卡(默认0)以确保直接运行稳定 | ||
| 25 | +if (not env_vis) or ("," in env_vis): | ||
| 26 | + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_to_use | ||
| 27 | +os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") | ||
| 28 | + | ||
| 29 | +# 清理可能由外部启动器注入的分布式环境变量,避免误触多卡/分布式 | ||
| 30 | +for _k in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]: | ||
| 31 | + os.environ.pop(_k, None) | ||
| 32 | + | ||
| 33 | +import numpy as np | ||
| 34 | +import torch | ||
| 35 | +from torch.utils.data import Dataset | ||
| 36 | +from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support | ||
| 37 | +import pandas as pd | ||
| 38 | + | ||
| 39 | +from transformers import ( | ||
| 40 | + AutoTokenizer, | ||
| 41 | + AutoModel, | ||
| 42 | + AutoModelForSequenceClassification, | ||
| 43 | + AutoConfig, | ||
| 44 | + DataCollatorWithPadding, | ||
| 45 | + Trainer, | ||
| 46 | + TrainingArguments, | ||
| 47 | + set_seed, | ||
| 48 | +) | ||
| 49 | +try: | ||
| 50 | + from transformers import EarlyStoppingCallback # type: ignore | ||
| 51 | +except Exception: # pragma: no cover | ||
| 52 | + EarlyStoppingCallback = None # type: ignore | ||
| 53 | + | ||
| 54 | + | ||
| 55 | +def preprocess_text(text: str) -> str: | ||
| 56 | + if text is None: | ||
| 57 | + return "" | ||
| 58 | + text = str(text) | ||
| 59 | + text = re.sub(r"\{%.+?%\}", " ", text) | ||
| 60 | + text = re.sub(r"@.+?( |$)", " ", text) | ||
| 61 | + text = re.sub(r"【.+?】", " ", text) | ||
| 62 | + text = re.sub(r"\u200b", " ", text) | ||
| 63 | + text = re.sub( | ||
| 64 | + r"[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\U00002600-\U000027BF\U0001f900-\U0001f9ff\U0001f018-\U0001f270\U0000231a-\U0000231b\U0000238d-\U0000238d\U000024c2-\U0001f251]+", | ||
| 65 | + "", | ||
| 66 | + text, | ||
| 67 | + ) | ||
| 68 | + text = re.sub(r"\s+", " ", text) | ||
| 69 | + return text.strip() | ||
| 70 | + | ||
| 71 | + | ||
| 72 | +def ensure_base_model_local(model_name_or_path: str, local_model_root: str) -> Tuple[str, AutoTokenizer]: | ||
| 73 | + os.makedirs(local_model_root, exist_ok=True) | ||
| 74 | + base_dir = os.path.join(local_model_root, "bert-base-chinese") | ||
| 75 | + | ||
| 76 | + def is_ready(path: str) -> bool: | ||
| 77 | + return os.path.isdir(path) and os.path.isfile(os.path.join(path, "config.json")) | ||
| 78 | + | ||
| 79 | + # 1) 本地现成 | ||
| 80 | + if is_ready(base_dir): | ||
| 81 | + tokenizer = AutoTokenizer.from_pretrained(base_dir) | ||
| 82 | + return base_dir, tokenizer | ||
| 83 | + | ||
| 84 | + # 2) 本机缓存 | ||
| 85 | + try: | ||
| 86 | + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, local_files_only=True) | ||
| 87 | + base = AutoModel.from_pretrained(model_name_or_path, local_files_only=True) | ||
| 88 | + os.makedirs(base_dir, exist_ok=True) | ||
| 89 | + tokenizer.save_pretrained(base_dir) | ||
| 90 | + base.save_pretrained(base_dir) | ||
| 91 | + return base_dir, tokenizer | ||
| 92 | + except Exception: | ||
| 93 | + pass | ||
| 94 | + | ||
| 95 | + # 3) 远程下载 | ||
| 96 | + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | ||
| 97 | + base = AutoModel.from_pretrained(model_name_or_path) | ||
| 98 | + os.makedirs(base_dir, exist_ok=True) | ||
| 99 | + tokenizer.save_pretrained(base_dir) | ||
| 100 | + base.save_pretrained(base_dir) | ||
| 101 | + return base_dir, tokenizer | ||
| 102 | + | ||
| 103 | + | ||
| 104 | +class TextClassificationDataset(Dataset): | ||
| 105 | + def __init__( | ||
| 106 | + self, | ||
| 107 | + dataframe: pd.DataFrame, | ||
| 108 | + tokenizer: AutoTokenizer, | ||
| 109 | + text_column: str, | ||
| 110 | + label_column: str, | ||
| 111 | + label2id: Dict[str, int], | ||
| 112 | + max_length: int, | ||
| 113 | + ) -> None: | ||
| 114 | + self.dataframe = dataframe.reset_index(drop=True) | ||
| 115 | + self.tokenizer = tokenizer | ||
| 116 | + self.text_column = text_column | ||
| 117 | + self.label_column = label_column | ||
| 118 | + self.label2id = label2id | ||
| 119 | + self.max_length = max_length | ||
| 120 | + | ||
| 121 | + def __len__(self) -> int: | ||
| 122 | + return len(self.dataframe) | ||
| 123 | + | ||
| 124 | + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | ||
| 125 | + row = self.dataframe.iloc[idx] | ||
| 126 | + text = preprocess_text(row[self.text_column]) | ||
| 127 | + encoding = self.tokenizer( | ||
| 128 | + text, | ||
| 129 | + max_length=self.max_length, | ||
| 130 | + truncation=True, | ||
| 131 | + padding=False, | ||
| 132 | + return_tensors="pt", | ||
| 133 | + ) | ||
| 134 | + item = {k: v.squeeze(0) for k, v in encoding.items()} | ||
| 135 | + if self.label_column in row and pd.notna(row[self.label_column]): | ||
| 136 | + label_str = str(row[self.label_column]) | ||
| 137 | + item["labels"] = torch.tensor(self.label2id[label_str], dtype=torch.long) | ||
| 138 | + return item | ||
| 139 | + | ||
| 140 | + | ||
| 141 | +def build_label_mappings(train_df: pd.DataFrame, label_column: str) -> Tuple[Dict[str, int], Dict[int, str]]: | ||
| 142 | + labels: List[str] = [str(x) for x in train_df[label_column].dropna().astype(str).tolist()] | ||
| 143 | + unique_sorted = sorted(set(labels)) | ||
| 144 | + label2id = {label: i for i, label in enumerate(unique_sorted)} | ||
| 145 | + id2label = {i: label for label, i in label2id.items()} | ||
| 146 | + return label2id, id2label | ||
| 147 | + | ||
| 148 | + | ||
| 149 | +def compute_metrics_fn(eval_pred) -> Dict[str, float]: | ||
| 150 | + logits, labels = eval_pred | ||
| 151 | + preds = np.argmax(logits, axis=-1) | ||
| 152 | + precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="weighted", zero_division=0) | ||
| 153 | + acc = accuracy_score(labels, preds) | ||
| 154 | + return { | ||
| 155 | + "accuracy": float(acc), | ||
| 156 | + "precision": float(precision), | ||
| 157 | + "recall": float(recall), | ||
| 158 | + "f1": float(f1), | ||
| 159 | + } | ||
| 160 | + | ||
| 161 | + | ||
| 162 | +def autodetect_columns(df: pd.DataFrame, text_col: str, label_col: str) -> Tuple[str, str]: | ||
| 163 | + if text_col != "auto" and label_col != "auto": | ||
| 164 | + return text_col, label_col | ||
| 165 | + candidates_text = ["text", "content", "sentence", "title", "desc", "question"] | ||
| 166 | + candidates_label = ["label", "labels", "category", "topic", "class"] | ||
| 167 | + t = text_col | ||
| 168 | + l = label_col | ||
| 169 | + if text_col == "auto": | ||
| 170 | + for name in candidates_text: | ||
| 171 | + if name in df.columns: | ||
| 172 | + t = name | ||
| 173 | + break | ||
| 174 | + if label_col == "auto": | ||
| 175 | + for name in candidates_label: | ||
| 176 | + if name in df.columns: | ||
| 177 | + l = name | ||
| 178 | + break | ||
| 179 | + if t == "auto" or l == "auto": | ||
| 180 | + raise ValueError( | ||
| 181 | + f"无法自动识别列名,请显式传入 --text_col 与 --label_col。现有列: {list(df.columns)}" | ||
| 182 | + ) | ||
| 183 | + return t, l | ||
| 184 | + | ||
| 185 | + | ||
| 186 | +def parse_args() -> argparse.Namespace: | ||
| 187 | + parser = argparse.ArgumentParser(description="使用 google-bert/bert-base-chinese 在本目录数据集上进行文本分类微调") | ||
| 188 | + parser.add_argument("--train_file", type=str, default="./dataset/web_text_zh_train.csv") | ||
| 189 | + parser.add_argument("--valid_file", type=str, default="./dataset/web_text_zh_valid.csv") | ||
| 190 | + parser.add_argument("--text_col", type=str, default="auto", help="文本列名,默认自动识别") | ||
| 191 | + parser.add_argument("--label_col", type=str, default="auto", help="标签列名,默认自动识别") | ||
| 192 | + parser.add_argument("--model_root", type=str, default="./model", help="本地模型根目录") | ||
| 193 | + parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese") | ||
| 194 | + parser.add_argument("--save_subdir", type=str, default="bert-chinese-classifier") | ||
| 195 | + parser.add_argument("--max_length", type=int, default=128) | ||
| 196 | + parser.add_argument("--batch_size", type=int, default=64) | ||
| 197 | + parser.add_argument("--num_epochs", type=int, default=10) | ||
| 198 | + parser.add_argument("--learning_rate", type=float, default=2e-5) | ||
| 199 | + parser.add_argument("--weight_decay", type=float, default=0.01) | ||
| 200 | + parser.add_argument("--warmup_ratio", type=float, default=0.1) | ||
| 201 | + parser.add_argument("--seed", type=int, default=42) | ||
| 202 | + parser.add_argument("--fp16", action="store_true") | ||
| 203 | + parser.add_argument("--gpu", type=str, default=os.environ.get("CUDA_VISIBLE_DEVICES", "0"), help="指定单卡 GPU,如 0 或 1") | ||
| 204 | + parser.add_argument("--eval_fraction", type=float, default=0.25, help="每多少个 epoch 做一次评估与保存,例如 0.25 表示每四分之一个 epoch") | ||
| 205 | + parser.add_argument("--early_stop_patience", type=int, default=5, help="早停耐心(以评估轮次计)") | ||
| 206 | + parser.add_argument("--early_stop_threshold", type=float, default=0.0, help="早停最小改善阈值(与 metric_for_best_model 同单位)") | ||
| 207 | + return parser.parse_args() | ||
| 208 | + | ||
| 209 | + | ||
| 210 | +def main() -> None: | ||
| 211 | + args = parse_args() | ||
| 212 | + set_seed(args.seed) | ||
| 213 | + | ||
| 214 | + script_dir = os.path.dirname(os.path.abspath(__file__)) | ||
| 215 | + model_root = args.model_root if os.path.isabs(args.model_root) else os.path.join(script_dir, args.model_root) | ||
| 216 | + os.makedirs(model_root, exist_ok=True) | ||
| 217 | + | ||
| 218 | + # 确保基础模型就绪 | ||
| 219 | + base_dir, tokenizer = ensure_base_model_local(args.pretrained_name, model_root) | ||
| 220 | + print(f"[Info] 使用基础模型目录: {base_dir}") | ||
| 221 | + | ||
| 222 | + # 读取数据 | ||
| 223 | + train_path = args.train_file if os.path.isabs(args.train_file) else os.path.join(script_dir, args.train_file) | ||
| 224 | + valid_path = args.valid_file if os.path.isabs(args.valid_file) else os.path.join(script_dir, args.valid_file) | ||
| 225 | + if not os.path.isfile(train_path): | ||
| 226 | + raise FileNotFoundError(f"训练集不存在: {train_path}") | ||
| 227 | + train_df = pd.read_csv(train_path) | ||
| 228 | + if not os.path.isfile(valid_path): | ||
| 229 | + # 若未提供或不存在验证集,自动切分 | ||
| 230 | + shuffled = train_df.sample(frac=1.0, random_state=args.seed).reset_index(drop=True) | ||
| 231 | + split_idx = int(len(shuffled) * 0.9) | ||
| 232 | + valid_df = shuffled.iloc[split_idx:].reset_index(drop=True) | ||
| 233 | + train_df = shuffled.iloc[:split_idx].reset_index(drop=True) | ||
| 234 | + else: | ||
| 235 | + valid_df = pd.read_csv(valid_path) | ||
| 236 | + print(f"[Info] 训练集: {train_path} | 样本数: {len(train_df)}") | ||
| 237 | + print(f"[Info] 验证集: {valid_path if os.path.isfile(valid_path) else '(从训练集切分)'} | 样本数: {len(valid_df)}") | ||
| 238 | + | ||
| 239 | + # 自动识别列名 | ||
| 240 | + text_col, label_col = autodetect_columns(train_df, args.text_col, args.label_col) | ||
| 241 | + print(f"[Info] 文本列: {text_col} | 标签列: {label_col}") | ||
| 242 | + | ||
| 243 | + # 标签映射 | ||
| 244 | + label2id, id2label = build_label_mappings(train_df, label_col) | ||
| 245 | + if len(label2id) < 2: | ||
| 246 | + raise ValueError("标签类别数少于 2,无法训练分类模型。") | ||
| 247 | + print(f"[Info] 标签类别数: {len(label2id)}") | ||
| 248 | + | ||
| 249 | + # 数据集 | ||
| 250 | + train_dataset = TextClassificationDataset(train_df, tokenizer, text_col, label_col, label2id, args.max_length) | ||
| 251 | + eval_dataset = TextClassificationDataset(valid_df, tokenizer, text_col, label_col, label2id, args.max_length) | ||
| 252 | + collator = DataCollatorWithPadding(tokenizer=tokenizer) | ||
| 253 | + | ||
| 254 | + # 模型 | ||
| 255 | + config = AutoConfig.from_pretrained( | ||
| 256 | + base_dir, | ||
| 257 | + num_labels=len(label2id), | ||
| 258 | + id2label={int(i): str(l) for i, l in id2label.items()}, | ||
| 259 | + label2id={str(l): int(i) for l, i in label2id.items()}, | ||
| 260 | + ) | ||
| 261 | + model = AutoModelForSequenceClassification.from_pretrained( | ||
| 262 | + base_dir, | ||
| 263 | + config=config, | ||
| 264 | + ignore_mismatched_sizes=True, | ||
| 265 | + ) | ||
| 266 | + | ||
| 267 | + # 训练参数 | ||
| 268 | + output_dir = os.path.join(model_root, args.save_subdir) | ||
| 269 | + os.makedirs(output_dir, exist_ok=True) | ||
| 270 | + # 训练参数(兼容不同 transformers 版本) | ||
| 271 | + args_dict = { | ||
| 272 | + "output_dir": output_dir, | ||
| 273 | + "per_device_train_batch_size": args.batch_size, | ||
| 274 | + "per_device_eval_batch_size": args.batch_size, | ||
| 275 | + "learning_rate": args.learning_rate, | ||
| 276 | + "weight_decay": args.weight_decay, | ||
| 277 | + "num_train_epochs": args.num_epochs, | ||
| 278 | + "logging_steps": 100, | ||
| 279 | + "fp16": args.fp16, | ||
| 280 | + "seed": args.seed, | ||
| 281 | + } | ||
| 282 | + | ||
| 283 | + sig = inspect.signature(TrainingArguments.__init__) | ||
| 284 | + allowed = set(sig.parameters.keys()) | ||
| 285 | + | ||
| 286 | + # 可选参数(仅在支持时添加,尽量简化与参考实现一致以提升兼容性) | ||
| 287 | + if "warmup_ratio" in allowed: | ||
| 288 | + args_dict["warmup_ratio"] = args.warmup_ratio | ||
| 289 | + if "report_to" in allowed: | ||
| 290 | + args_dict["report_to"] = [] | ||
| 291 | + # 评估/保存步进:按 eval_fraction 折算每个 epoch 的步数 | ||
| 292 | + steps_per_epoch = max(1, math.ceil(len(train_dataset) / max(1, args.batch_size))) | ||
| 293 | + eval_every_steps = max(1, math.ceil(steps_per_epoch * max(0.01, min(1.0, args.eval_fraction)))) | ||
| 294 | + # 策略式(新/旧版本字段名兼容) | ||
| 295 | + key_eval = "evaluation_strategy" if "evaluation_strategy" in allowed else ("eval_strategy" if "eval_strategy" in allowed else None) | ||
| 296 | + if key_eval: | ||
| 297 | + args_dict[key_eval] = "steps" | ||
| 298 | + if "save_strategy" in allowed: | ||
| 299 | + args_dict["save_strategy"] = "steps" | ||
| 300 | + if "eval_steps" in allowed: | ||
| 301 | + args_dict["eval_steps"] = eval_every_steps | ||
| 302 | + if "save_steps" in allowed: | ||
| 303 | + args_dict["save_steps"] = eval_every_steps | ||
| 304 | + if "save_total_limit" in allowed: | ||
| 305 | + args_dict["save_total_limit"] = 5 | ||
| 306 | + # 将日志步长与评估/保存步长对齐,减少刷屏 | ||
| 307 | + if "logging_steps" in allowed: | ||
| 308 | + args_dict["logging_steps"] = eval_every_steps | ||
| 309 | + # 最优模型回滚(仅当评估与保存策略一致时开启) | ||
| 310 | + if "metric_for_best_model" in allowed: | ||
| 311 | + args_dict["metric_for_best_model"] = "f1" | ||
| 312 | + if "greater_is_better" in allowed: | ||
| 313 | + args_dict["greater_is_better"] = True | ||
| 314 | + if "load_best_model_at_end" in allowed: | ||
| 315 | + eval_strat = args_dict.get("evaluation_strategy", args_dict.get("eval_strategy")) | ||
| 316 | + save_strat = args_dict.get("save_strategy") | ||
| 317 | + if eval_strat == save_strat and eval_strat in ("steps", "epoch"): | ||
| 318 | + args_dict["load_best_model_at_end"] = True | ||
| 319 | + | ||
| 320 | + # 兼容无 warmup_ratio 的版本:若支持 warmup_steps 则忽略比例 | ||
| 321 | + if "warmup_ratio" not in allowed and "warmup_steps" in allowed: | ||
| 322 | + # 不计算总步数,默认 0 | ||
| 323 | + args_dict["warmup_steps"] = 0 | ||
| 324 | + | ||
| 325 | + # 若不支持策略式参数:退化为每 eval_every_steps 步保存/评估 | ||
| 326 | + if "save_strategy" not in allowed and "save_steps" in allowed: | ||
| 327 | + args_dict["save_steps"] = eval_every_steps | ||
| 328 | + if ("evaluation_strategy" not in allowed and "eval_strategy" not in allowed) and "eval_steps" in allowed: | ||
| 329 | + args_dict["eval_steps"] = eval_every_steps | ||
| 330 | + | ||
| 331 | + # 如果支持 load_best_model_at_end,但无法同时设置评估/保存策略,则关闭它以避免报错 | ||
| 332 | + if "load_best_model_at_end" in allowed: | ||
| 333 | + want_load_best = args_dict.get("load_best_model_at_end", False) | ||
| 334 | + eval_set = args_dict.get("evaluation_strategy", None) | ||
| 335 | + save_set = args_dict.get("save_strategy", None) | ||
| 336 | + if want_load_best and (eval_set is None or save_set is None or eval_set != save_set): | ||
| 337 | + args_dict["load_best_model_at_end"] = False | ||
| 338 | + | ||
| 339 | + training_args = TrainingArguments(**args_dict) | ||
| 340 | + print("[Info] 训练参数要点:") | ||
| 341 | + print(f" epochs={args.num_epochs}, batch_size={args.batch_size}, lr={args.learning_rate}, weight_decay={args.weight_decay}") | ||
| 342 | + print(f" max_length={args.max_length}, seed={args.seed}, fp16={args.fp16}") | ||
| 343 | + if "warmup_ratio" in allowed and "warmup_ratio" in args_dict: | ||
| 344 | + print(f" warmup_ratio={args_dict['warmup_ratio']}") | ||
| 345 | + elif "warmup_steps" in allowed and "warmup_steps" in args_dict: | ||
| 346 | + print(f" warmup_steps={args_dict['warmup_steps']}") | ||
| 347 | + print(f" steps_per_epoch={steps_per_epoch}, eval_every_steps={eval_every_steps}") | ||
| 348 | + print(f" eval_strategy={args_dict.get('evaluation_strategy', args_dict.get('eval_strategy'))}, save_strategy={args_dict.get('save_strategy')}, logging_steps={args_dict.get('logging_steps')}") | ||
| 349 | + print(f" save_total_limit={args_dict.get('save_total_limit', 'n/a')}, load_best_model_at_end={args_dict.get('load_best_model_at_end', False)}") | ||
| 350 | + | ||
| 351 | + callbacks = [] | ||
| 352 | + if EarlyStoppingCallback is not None and (args_dict.get("evaluation_strategy") in ("steps", "epoch") or "eval_steps" in allowed): | ||
| 353 | + try: | ||
| 354 | + callbacks.append( | ||
| 355 | + EarlyStoppingCallback( | ||
| 356 | + early_stopping_patience=args.early_stop_patience, | ||
| 357 | + early_stopping_threshold=args.early_stop_threshold, | ||
| 358 | + ) | ||
| 359 | + ) | ||
| 360 | + except Exception: | ||
| 361 | + pass | ||
| 362 | + | ||
| 363 | + trainer = Trainer( | ||
| 364 | + model=model, | ||
| 365 | + args=training_args, | ||
| 366 | + train_dataset=train_dataset, | ||
| 367 | + eval_dataset=eval_dataset, | ||
| 368 | + tokenizer=tokenizer, | ||
| 369 | + data_collator=collator, | ||
| 370 | + compute_metrics=compute_metrics_fn, | ||
| 371 | + callbacks=callbacks, | ||
| 372 | + ) | ||
| 373 | + # 设备与 GPU 信息 | ||
| 374 | + try: | ||
| 375 | + device_cnt = torch.cuda.device_count() | ||
| 376 | + dev_name = torch.cuda.get_device_name(0) if device_cnt > 0 else "cpu" | ||
| 377 | + print(f"[Info] CUDA 可见设备数: {device_cnt}, 当前设备: {dev_name}, CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}") | ||
| 378 | + except Exception: | ||
| 379 | + pass | ||
| 380 | + | ||
| 381 | + print("[Info] 开始训练 ...") | ||
| 382 | + | ||
| 383 | + trainer.train() | ||
| 384 | + | ||
| 385 | + # 保存 | ||
| 386 | + tokenizer.save_pretrained(output_dir) | ||
| 387 | + trainer.model.config.id2label = {int(i): str(l) for i, l in id2label.items()} | ||
| 388 | + trainer.model.config.label2id = {str(l): int(i) for l, i in label2id.items()} | ||
| 389 | + trainer.save_model(output_dir) | ||
| 390 | + try: | ||
| 391 | + best_metric = getattr(trainer.state, "best_metric", None) | ||
| 392 | + best_ckpt = getattr(trainer.state, "best_model_checkpoint", None) | ||
| 393 | + if best_metric is not None and best_ckpt is not None: | ||
| 394 | + print(f"[Info] 最优模型: metric={best_metric:.6f} | checkpoint={best_ckpt}") | ||
| 395 | + except Exception: | ||
| 396 | + pass | ||
| 397 | + | ||
| 398 | + with open(os.path.join(output_dir, "label_map.json"), "w", encoding="utf-8") as f: | ||
| 399 | + json.dump( | ||
| 400 | + {"label2id": trainer.model.config.label2id, "id2label": trainer.model.config.id2label}, | ||
| 401 | + f, | ||
| 402 | + ensure_ascii=False, | ||
| 403 | + indent=2, | ||
| 404 | + ) | ||
| 405 | + | ||
| 406 | + # 训练曲线:可选保存训练与评估 loss | ||
| 407 | + try: | ||
| 408 | + import matplotlib.pyplot as plt # type: ignore | ||
| 409 | + logs = trainer.state.log_history | ||
| 410 | + t_steps, t_losses, e_steps, e_losses = [], [], [], [] | ||
| 411 | + step_counter = 0 | ||
| 412 | + for rec in logs: | ||
| 413 | + if "loss" in rec and "epoch" in rec: | ||
| 414 | + step_counter += 1 | ||
| 415 | + t_steps.append(step_counter) | ||
| 416 | + t_losses.append(rec["loss"]) | ||
| 417 | + if "eval_loss" in rec: | ||
| 418 | + e_steps.append(step_counter) | ||
| 419 | + e_losses.append(rec["eval_loss"]) | ||
| 420 | + if t_losses or e_losses: | ||
| 421 | + plt.figure(figsize=(8,4)) | ||
| 422 | + if t_losses: | ||
| 423 | + plt.plot(t_steps, t_losses, label="train_loss") | ||
| 424 | + if e_losses: | ||
| 425 | + plt.plot(e_steps, e_losses, label="eval_loss") | ||
| 426 | + plt.xlabel("training step (logged)") | ||
| 427 | + plt.ylabel("loss") | ||
| 428 | + plt.legend() | ||
| 429 | + plt.tight_layout() | ||
| 430 | + plt.savefig(os.path.join(output_dir, "training_curve.png")) | ||
| 431 | + except Exception: | ||
| 432 | + pass | ||
| 433 | + | ||
| 434 | + print(f"微调完成,模型已保存到: {output_dir}") | ||
| 435 | + | ||
| 436 | + | ||
| 437 | +if __name__ == "__main__": | ||
| 438 | + main() | ||
| 439 | + | ||
| 440 | + |
-
Please register or login to post a comment