戒酒的李白

Train and prediction script for a topic classification model based on bert-chinese.

  1 +## 话题分类(BERT 中文基座)
  2 +
  3 +本目录提供一个使用 `google-bert/bert-base-chinese` 的中文话题分类实现:
  4 +- 自动处理本地/缓存/远程三段式加载逻辑;
  5 +- `train.py` 进行微调训练;`predict.py` 进行单条或交互式预测;
  6 +- 所有模型与权重统一保存至本目录的 `model/`
  7 +
  8 +参考模型卡片: [google-bert/bert-base-chinese](https://huggingface.co/google-bert/bert-base-chinese)
  9 +
  10 +### 数据集亮点
  11 +
  12 +-**410 万**条预过滤高质量问题与回复;
  13 +- 每个问题对应一个“【话题】”,覆盖 **约 2.8 万**个多样主题;
  14 +-**1400 万**原始问答中筛选,保留至少 **3 个点赞以上**的答案,确保内容质量与有趣度;
  15 +- 除了问题、话题与一个或多个回复外,每个回复还带有点赞数、回复 ID、回复者标签;
  16 +- 数据清洗去重后划分三部分:示例划分训练集约 **412 万**、验证/测试若干(可按需调整)。
  17 +
  18 +> 实际训练时,请以 `dataset/` 下的 CSV 为准;脚本会自动识别常见列名或允许通过命令参数显式指定。
  19 +
  20 +### 目录结构
  21 +
  22 +```
  23 +BertTopicDetection_Finetuned/
  24 + ├─ dataset/ # 已放置数据
  25 + ├─ model/ # 训练生成;亦缓存基础 BERT
  26 + ├─ train.py
  27 + ├─ predict.py
  28 + └─ README.md
  29 +```
  30 +
  31 +### 环境
  32 +
  33 +```
  34 +pip install torch transformers scikit-learn pandas
  35 +```
  36 +
  37 +或使用你既有的 Conda 环境。
  38 +
  39 +### 数据格式
  40 +
  41 +CSV 至少包含文本列与标签列,脚本会尝试自动识别:
  42 +- 文本列候选:`text`/`content`/`sentence`/`title`/`desc`/`question`
  43 +- 标签列候选:`label`/`labels`/`category`/`topic`/`class`
  44 +
  45 +如需显式指定,请使用 `--text_col` 与 `--label_col`
  46 +
  47 +### 训练
  48 +
  49 +```
  50 +python train.py \
  51 + --train_file ./dataset/web_text_zh_train.csv \
  52 + --valid_file ./dataset/web_text_zh_valid.csv \
  53 + --text_col auto \
  54 + --label_col auto \
  55 + --model_root ./model \
  56 + --save_subdir bert-chinese-classifier \
  57 + --num_epochs 10 --batch_size 16 --learning_rate 2e-5 --fp16
  58 +```
  59 +
  60 +要点:
  61 +- 首次运行会检查 `model/bert-base-chinese`;若无则尝试本机缓存,再不行则自动下载并保存;
  62 +- 训练过程按步评估与保存(默认每 1/4 个 epoch),最多保留 5 个最近 checkpoint(可通过环境变量 `SAVE_TOTAL_LIMIT` 调整);
  63 +- 支持早停(默认耐心 5 次评估),并在评估/保存策略一致时自动回滚到最佳模型;
  64 +- 分词器、权重与 `label_map.json` 保存到 `model/bert-chinese-classifier/`
  65 +
  66 +### 预测
  67 +
  68 +单条:
  69 +```
  70 +python predict.py --text "这条微博讨论的是哪个话题?" --model_root ./model --finetuned_subdir bert-chinese-classifier
  71 +```
  72 +
  73 +交互:
  74 +```
  75 +python predict.py --interactive --model_root ./model --finetuned_subdir bert-chinese-classifier
  76 +```
  77 +
  78 +示例输出:
  79 +```
  80 +预测结果: 体育-足球 (置信度: 0.9412)
  81 +```
  82 +
  83 +### 说明
  84 +
  85 +- 训练与预测均内置简易中文文本清洗。
  86 +- 标签集合以训练集为准,脚本自动生成并保存 `label_map.json`
  87 +
  88 +### 训练策略(简述)
  89 +
  90 +- 基座:`google-bert/bert-base-chinese`;分类头维度=训练集唯一标签数。
  91 +- 学习率与正则:`lr=2e-5`,`weight_decay=0.01`,可在大型数据上微调到 `1e-5~3e-5`
  92 +- 序列长度与批量:`max_length=128`,`batch_size=16`;若截断严重可升至 256(成本上升)。
  93 +- Warmup:若环境支持,使用 `warmup_ratio=0.1`;否则回退 `warmup_steps=0`
  94 +- 评估/保存:按 `--eval_fraction` 折算步数(默认 0.25),`save_total_limit=5` 限制磁盘占用。
  95 +- 早停:监控加权 F1(越大越好),默认耐心 5、改善阈值 0.0。
  96 +- 单卡稳定运行:默认仅使用一张 GPU,可通过 `--gpu` 指定;脚本会清理分布式环境变量。
  97 +
  98 +
  1 +{
  2 + "architectures": [
  3 + "BertModel"
  4 + ],
  5 + "attention_probs_dropout_prob": 0.1,
  6 + "classifier_dropout": null,
  7 + "directionality": "bidi",
  8 + "hidden_act": "gelu",
  9 + "hidden_dropout_prob": 0.1,
  10 + "hidden_size": 768,
  11 + "initializer_range": 0.02,
  12 + "intermediate_size": 3072,
  13 + "layer_norm_eps": 1e-12,
  14 + "max_position_embeddings": 512,
  15 + "model_type": "bert",
  16 + "num_attention_heads": 12,
  17 + "num_hidden_layers": 12,
  18 + "pad_token_id": 0,
  19 + "pooler_fc_size": 768,
  20 + "pooler_num_attention_heads": 12,
  21 + "pooler_num_fc_layers": 3,
  22 + "pooler_size_per_head": 128,
  23 + "pooler_type": "first_token_transform",
  24 + "position_embedding_type": "absolute",
  25 + "torch_dtype": "float32",
  26 + "transformers_version": "4.51.3",
  27 + "type_vocab_size": 2,
  28 + "use_cache": true,
  29 + "vocab_size": 21128
  30 +}
  1 +{
  2 + "cls_token": "[CLS]",
  3 + "mask_token": "[MASK]",
  4 + "pad_token": "[PAD]",
  5 + "sep_token": "[SEP]",
  6 + "unk_token": "[UNK]"
  7 +}
This diff could not be displayed because it is too large.
  1 +{
  2 + "added_tokens_decoder": {
  3 + "0": {
  4 + "content": "[PAD]",
  5 + "lstrip": false,
  6 + "normalized": false,
  7 + "rstrip": false,
  8 + "single_word": false,
  9 + "special": true
  10 + },
  11 + "100": {
  12 + "content": "[UNK]",
  13 + "lstrip": false,
  14 + "normalized": false,
  15 + "rstrip": false,
  16 + "single_word": false,
  17 + "special": true
  18 + },
  19 + "101": {
  20 + "content": "[CLS]",
  21 + "lstrip": false,
  22 + "normalized": false,
  23 + "rstrip": false,
  24 + "single_word": false,
  25 + "special": true
  26 + },
  27 + "102": {
  28 + "content": "[SEP]",
  29 + "lstrip": false,
  30 + "normalized": false,
  31 + "rstrip": false,
  32 + "single_word": false,
  33 + "special": true
  34 + },
  35 + "103": {
  36 + "content": "[MASK]",
  37 + "lstrip": false,
  38 + "normalized": false,
  39 + "rstrip": false,
  40 + "single_word": false,
  41 + "special": true
  42 + }
  43 + },
  44 + "clean_up_tokenization_spaces": false,
  45 + "cls_token": "[CLS]",
  46 + "do_lower_case": false,
  47 + "extra_special_tokens": {},
  48 + "mask_token": "[MASK]",
  49 + "model_max_length": 512,
  50 + "pad_token": "[PAD]",
  51 + "sep_token": "[SEP]",
  52 + "strip_accents": null,
  53 + "tokenize_chinese_chars": true,
  54 + "tokenizer_class": "BertTokenizer",
  55 + "unk_token": "[UNK]"
  56 +}
This diff could not be displayed because it is too large.
  1 +import os
  2 +import sys
  3 +import json
  4 +import re
  5 +import argparse
  6 +from typing import Dict, Tuple
  7 +
  8 +# ========== 单卡锁定(在导入 torch/transformers 前执行) ==========
  9 +def _extract_gpu_arg(argv, default: str = "0") -> str:
  10 + for i, arg in enumerate(argv):
  11 + if arg.startswith("--gpu="):
  12 + return arg.split("=", 1)[1]
  13 + if arg == "--gpu" and i + 1 < len(argv):
  14 + return argv[i + 1]
  15 + return default
  16 +
  17 +env_vis = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
  18 +try:
  19 + gpu_to_use = _extract_gpu_arg(sys.argv, default="0")
  20 +except Exception:
  21 + gpu_to_use = "0"
  22 +if (not env_vis) or ("," in env_vis):
  23 + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_to_use
  24 +os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
  25 +
  26 +for _k in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]:
  27 + os.environ.pop(_k, None)
  28 +
  29 +import torch
  30 +from transformers import (
  31 + AutoTokenizer,
  32 + AutoModel,
  33 + AutoModelForSequenceClassification,
  34 +)
  35 +
  36 +
  37 +def preprocess_text(text: str) -> str:
  38 + if text is None:
  39 + return ""
  40 + text = str(text)
  41 + text = re.sub(r"\{%.+?%\}", " ", text)
  42 + text = re.sub(r"@.+?( |$)", " ", text)
  43 + text = re.sub(r"【.+?】", " ", text)
  44 + text = re.sub(r"\u200b", " ", text)
  45 + text = re.sub(
  46 + r"[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\U00002600-\U000027BF\U0001f900-\U0001f9ff\U0001f018-\U0001f270\U0000231a-\U0000231b\U0000238d-\U0000238d\U000024c2-\U0001f251]+",
  47 + "",
  48 + text,
  49 + )
  50 + text = re.sub(r"\s+", " ", text)
  51 + return text.strip()
  52 +
  53 +
  54 +def ensure_base_model_local(model_name_or_path: str, local_model_root: str) -> Tuple[str, AutoTokenizer]:
  55 + os.makedirs(local_model_root, exist_ok=True)
  56 + base_dir = os.path.join(local_model_root, "bert-base-chinese")
  57 +
  58 + def is_ready(path: str) -> bool:
  59 + return os.path.isdir(path) and os.path.isfile(os.path.join(path, "config.json"))
  60 +
  61 + if is_ready(base_dir):
  62 + tokenizer = AutoTokenizer.from_pretrained(base_dir)
  63 + return base_dir, tokenizer
  64 +
  65 + # 本机缓存
  66 + try:
  67 + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, local_files_only=True)
  68 + base = AutoModel.from_pretrained(model_name_or_path, local_files_only=True)
  69 + os.makedirs(base_dir, exist_ok=True)
  70 + tokenizer.save_pretrained(base_dir)
  71 + base.save_pretrained(base_dir)
  72 + return base_dir, tokenizer
  73 + except Exception:
  74 + pass
  75 +
  76 + # 远程下载
  77 + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
  78 + base = AutoModel.from_pretrained(model_name_or_path)
  79 + os.makedirs(base_dir, exist_ok=True)
  80 + tokenizer.save_pretrained(base_dir)
  81 + base.save_pretrained(base_dir)
  82 + return base_dir, tokenizer
  83 +
  84 +
  85 +def parse_args() -> argparse.Namespace:
  86 + parser = argparse.ArgumentParser(description="使用本地/缓存/远程加载的中文 BERT 分类模型进行预测")
  87 + parser.add_argument("--model_root", type=str, default="./model", help="本地模型根目录")
  88 + parser.add_argument("--finetuned_subdir", type=str, default="bert-chinese-classifier", help="微调结果子目录")
  89 + parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese", help="预训练模型名称或路径")
  90 + parser.add_argument("--text", type=str, default=None, help="直接输入一条要预测的文本")
  91 + parser.add_argument("--interactive", action="store_true", help="进入交互式预测模式")
  92 + parser.add_argument("--max_length", type=int, default=128)
  93 + parser.add_argument("--gpu", type=str, default=os.environ.get("CUDA_VISIBLE_DEVICES", "0"), help="指定单卡 GPU,如 0 或 1")
  94 + return parser.parse_args()
  95 +
  96 +
  97 +def load_finetuned(model_root: str, subdir: str) -> Tuple[str, Dict[int, str]]:
  98 + finetuned_path = os.path.join(model_root, subdir)
  99 + if not os.path.isdir(finetuned_path):
  100 + raise FileNotFoundError(
  101 + f"未找到微调模型目录: {finetuned_path},请先运行训练脚本。"
  102 + )
  103 + label_map_path = os.path.join(finetuned_path, "label_map.json")
  104 + id2label = None
  105 + if os.path.isfile(label_map_path):
  106 + with open(label_map_path, "r", encoding="utf-8") as f:
  107 + data = json.load(f)
  108 + id2label = {int(k): str(v) for k, v in data.get("id2label", {}).items()}
  109 + return finetuned_path, id2label
  110 +
  111 +
  112 +def predict_once(model_dir: str, text: str, max_length: int = 128) -> Tuple[str, float]:
  113 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  114 + tokenizer = AutoTokenizer.from_pretrained(model_dir)
  115 + model = AutoModelForSequenceClassification.from_pretrained(model_dir)
  116 + model.to(device)
  117 + model.eval()
  118 +
  119 + processed = preprocess_text(text)
  120 + encoded = tokenizer(
  121 + processed,
  122 + max_length=max_length,
  123 + truncation=True,
  124 + padding="max_length",
  125 + return_tensors="pt",
  126 + )
  127 + input_ids = encoded["input_ids"].to(device)
  128 + attention_mask = encoded["attention_mask"].to(device)
  129 +
  130 + with torch.no_grad():
  131 + outputs = model(input_ids=input_ids, attention_mask=attention_mask)
  132 + logits = outputs.logits
  133 + probs = torch.softmax(logits, dim=-1)
  134 + pred = int(torch.argmax(probs, dim=-1).item())
  135 + conf = float(probs[0, pred].item())
  136 + id2label = getattr(model.config, "id2label", None)
  137 + label_name = id2label.get(pred, str(pred)) if isinstance(id2label, dict) else str(pred)
  138 + return label_name, conf
  139 +
  140 +
  141 +def main() -> None:
  142 + args = parse_args()
  143 +
  144 + script_dir = os.path.dirname(os.path.abspath(__file__))
  145 + model_root = args.model_root if os.path.isabs(args.model_root) else os.path.join(script_dir, args.model_root)
  146 + os.makedirs(model_root, exist_ok=True)
  147 +
  148 + # 确保基础模型在本地
  149 + ensure_base_model_local(args.pretrained_name, model_root)
  150 +
  151 + finetuned_dir, _ = load_finetuned(model_root, args.finetuned_subdir)
  152 +
  153 + if args.text is not None:
  154 + label, conf = predict_once(finetuned_dir, args.text, args.max_length)
  155 + print(f"预测结果: {label} (置信度: {conf:.4f})")
  156 + return
  157 +
  158 + if args.interactive:
  159 + print("进入交互模式。输入 'q' 退出。")
  160 + while True:
  161 + try:
  162 + text = input("请输入文本: ").strip()
  163 + except EOFError:
  164 + break
  165 + if text.lower() == "q":
  166 + break
  167 + if not text:
  168 + continue
  169 + label, conf = predict_once(finetuned_dir, text, args.max_length)
  170 + print(f"预测结果: {label} (置信度: {conf:.4f})")
  171 + return
  172 +
  173 + print("未提供 --text 或 --interactive,什么也没有发生。")
  174 +
  175 +
  176 +if __name__ == "__main__":
  177 + main()
  178 +
  179 +
  1 +import os
  2 +import sys
  3 +import json
  4 +import re
  5 +import argparse
  6 +import math
  7 +import inspect
  8 +from typing import Dict, List, Optional, Tuple
  9 +
  10 +# ========== 单卡锁定(在导入 torch/transformers 前执行) ==========
  11 +def _extract_gpu_arg(argv: List[str], default: str = "0") -> str:
  12 + for i, arg in enumerate(argv):
  13 + if arg.startswith("--gpu="):
  14 + return arg.split("=", 1)[1]
  15 + if arg == "--gpu" and i + 1 < len(argv):
  16 + return argv[i + 1]
  17 + return default
  18 +
  19 +env_vis = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
  20 +try:
  21 + gpu_to_use = _extract_gpu_arg(sys.argv, default="0")
  22 +except Exception:
  23 + gpu_to_use = "0"
  24 +# 若未设置或暴露了多卡,则强制只暴露单卡(默认0)以确保直接运行稳定
  25 +if (not env_vis) or ("," in env_vis):
  26 + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_to_use
  27 +os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
  28 +
  29 +# 清理可能由外部启动器注入的分布式环境变量,避免误触多卡/分布式
  30 +for _k in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]:
  31 + os.environ.pop(_k, None)
  32 +
  33 +import numpy as np
  34 +import torch
  35 +from torch.utils.data import Dataset
  36 +from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
  37 +import pandas as pd
  38 +
  39 +from transformers import (
  40 + AutoTokenizer,
  41 + AutoModel,
  42 + AutoModelForSequenceClassification,
  43 + AutoConfig,
  44 + DataCollatorWithPadding,
  45 + Trainer,
  46 + TrainingArguments,
  47 + set_seed,
  48 +)
  49 +try:
  50 + from transformers import EarlyStoppingCallback # type: ignore
  51 +except Exception: # pragma: no cover
  52 + EarlyStoppingCallback = None # type: ignore
  53 +
  54 +
  55 +def preprocess_text(text: str) -> str:
  56 + if text is None:
  57 + return ""
  58 + text = str(text)
  59 + text = re.sub(r"\{%.+?%\}", " ", text)
  60 + text = re.sub(r"@.+?( |$)", " ", text)
  61 + text = re.sub(r"【.+?】", " ", text)
  62 + text = re.sub(r"\u200b", " ", text)
  63 + text = re.sub(
  64 + r"[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\U00002600-\U000027BF\U0001f900-\U0001f9ff\U0001f018-\U0001f270\U0000231a-\U0000231b\U0000238d-\U0000238d\U000024c2-\U0001f251]+",
  65 + "",
  66 + text,
  67 + )
  68 + text = re.sub(r"\s+", " ", text)
  69 + return text.strip()
  70 +
  71 +
  72 +def ensure_base_model_local(model_name_or_path: str, local_model_root: str) -> Tuple[str, AutoTokenizer]:
  73 + os.makedirs(local_model_root, exist_ok=True)
  74 + base_dir = os.path.join(local_model_root, "bert-base-chinese")
  75 +
  76 + def is_ready(path: str) -> bool:
  77 + return os.path.isdir(path) and os.path.isfile(os.path.join(path, "config.json"))
  78 +
  79 + # 1) 本地现成
  80 + if is_ready(base_dir):
  81 + tokenizer = AutoTokenizer.from_pretrained(base_dir)
  82 + return base_dir, tokenizer
  83 +
  84 + # 2) 本机缓存
  85 + try:
  86 + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, local_files_only=True)
  87 + base = AutoModel.from_pretrained(model_name_or_path, local_files_only=True)
  88 + os.makedirs(base_dir, exist_ok=True)
  89 + tokenizer.save_pretrained(base_dir)
  90 + base.save_pretrained(base_dir)
  91 + return base_dir, tokenizer
  92 + except Exception:
  93 + pass
  94 +
  95 + # 3) 远程下载
  96 + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
  97 + base = AutoModel.from_pretrained(model_name_or_path)
  98 + os.makedirs(base_dir, exist_ok=True)
  99 + tokenizer.save_pretrained(base_dir)
  100 + base.save_pretrained(base_dir)
  101 + return base_dir, tokenizer
  102 +
  103 +
  104 +class TextClassificationDataset(Dataset):
  105 + def __init__(
  106 + self,
  107 + dataframe: pd.DataFrame,
  108 + tokenizer: AutoTokenizer,
  109 + text_column: str,
  110 + label_column: str,
  111 + label2id: Dict[str, int],
  112 + max_length: int,
  113 + ) -> None:
  114 + self.dataframe = dataframe.reset_index(drop=True)
  115 + self.tokenizer = tokenizer
  116 + self.text_column = text_column
  117 + self.label_column = label_column
  118 + self.label2id = label2id
  119 + self.max_length = max_length
  120 +
  121 + def __len__(self) -> int:
  122 + return len(self.dataframe)
  123 +
  124 + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
  125 + row = self.dataframe.iloc[idx]
  126 + text = preprocess_text(row[self.text_column])
  127 + encoding = self.tokenizer(
  128 + text,
  129 + max_length=self.max_length,
  130 + truncation=True,
  131 + padding=False,
  132 + return_tensors="pt",
  133 + )
  134 + item = {k: v.squeeze(0) for k, v in encoding.items()}
  135 + if self.label_column in row and pd.notna(row[self.label_column]):
  136 + label_str = str(row[self.label_column])
  137 + item["labels"] = torch.tensor(self.label2id[label_str], dtype=torch.long)
  138 + return item
  139 +
  140 +
  141 +def build_label_mappings(train_df: pd.DataFrame, label_column: str) -> Tuple[Dict[str, int], Dict[int, str]]:
  142 + labels: List[str] = [str(x) for x in train_df[label_column].dropna().astype(str).tolist()]
  143 + unique_sorted = sorted(set(labels))
  144 + label2id = {label: i for i, label in enumerate(unique_sorted)}
  145 + id2label = {i: label for label, i in label2id.items()}
  146 + return label2id, id2label
  147 +
  148 +
  149 +def compute_metrics_fn(eval_pred) -> Dict[str, float]:
  150 + logits, labels = eval_pred
  151 + preds = np.argmax(logits, axis=-1)
  152 + precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="weighted", zero_division=0)
  153 + acc = accuracy_score(labels, preds)
  154 + return {
  155 + "accuracy": float(acc),
  156 + "precision": float(precision),
  157 + "recall": float(recall),
  158 + "f1": float(f1),
  159 + }
  160 +
  161 +
  162 +def autodetect_columns(df: pd.DataFrame, text_col: str, label_col: str) -> Tuple[str, str]:
  163 + if text_col != "auto" and label_col != "auto":
  164 + return text_col, label_col
  165 + candidates_text = ["text", "content", "sentence", "title", "desc", "question"]
  166 + candidates_label = ["label", "labels", "category", "topic", "class"]
  167 + t = text_col
  168 + l = label_col
  169 + if text_col == "auto":
  170 + for name in candidates_text:
  171 + if name in df.columns:
  172 + t = name
  173 + break
  174 + if label_col == "auto":
  175 + for name in candidates_label:
  176 + if name in df.columns:
  177 + l = name
  178 + break
  179 + if t == "auto" or l == "auto":
  180 + raise ValueError(
  181 + f"无法自动识别列名,请显式传入 --text_col 与 --label_col。现有列: {list(df.columns)}"
  182 + )
  183 + return t, l
  184 +
  185 +
  186 +def parse_args() -> argparse.Namespace:
  187 + parser = argparse.ArgumentParser(description="使用 google-bert/bert-base-chinese 在本目录数据集上进行文本分类微调")
  188 + parser.add_argument("--train_file", type=str, default="./dataset/web_text_zh_train.csv")
  189 + parser.add_argument("--valid_file", type=str, default="./dataset/web_text_zh_valid.csv")
  190 + parser.add_argument("--text_col", type=str, default="auto", help="文本列名,默认自动识别")
  191 + parser.add_argument("--label_col", type=str, default="auto", help="标签列名,默认自动识别")
  192 + parser.add_argument("--model_root", type=str, default="./model", help="本地模型根目录")
  193 + parser.add_argument("--pretrained_name", type=str, default="google-bert/bert-base-chinese")
  194 + parser.add_argument("--save_subdir", type=str, default="bert-chinese-classifier")
  195 + parser.add_argument("--max_length", type=int, default=128)
  196 + parser.add_argument("--batch_size", type=int, default=64)
  197 + parser.add_argument("--num_epochs", type=int, default=10)
  198 + parser.add_argument("--learning_rate", type=float, default=2e-5)
  199 + parser.add_argument("--weight_decay", type=float, default=0.01)
  200 + parser.add_argument("--warmup_ratio", type=float, default=0.1)
  201 + parser.add_argument("--seed", type=int, default=42)
  202 + parser.add_argument("--fp16", action="store_true")
  203 + parser.add_argument("--gpu", type=str, default=os.environ.get("CUDA_VISIBLE_DEVICES", "0"), help="指定单卡 GPU,如 0 或 1")
  204 + parser.add_argument("--eval_fraction", type=float, default=0.25, help="每多少个 epoch 做一次评估与保存,例如 0.25 表示每四分之一个 epoch")
  205 + parser.add_argument("--early_stop_patience", type=int, default=5, help="早停耐心(以评估轮次计)")
  206 + parser.add_argument("--early_stop_threshold", type=float, default=0.0, help="早停最小改善阈值(与 metric_for_best_model 同单位)")
  207 + return parser.parse_args()
  208 +
  209 +
  210 +def main() -> None:
  211 + args = parse_args()
  212 + set_seed(args.seed)
  213 +
  214 + script_dir = os.path.dirname(os.path.abspath(__file__))
  215 + model_root = args.model_root if os.path.isabs(args.model_root) else os.path.join(script_dir, args.model_root)
  216 + os.makedirs(model_root, exist_ok=True)
  217 +
  218 + # 确保基础模型就绪
  219 + base_dir, tokenizer = ensure_base_model_local(args.pretrained_name, model_root)
  220 + print(f"[Info] 使用基础模型目录: {base_dir}")
  221 +
  222 + # 读取数据
  223 + train_path = args.train_file if os.path.isabs(args.train_file) else os.path.join(script_dir, args.train_file)
  224 + valid_path = args.valid_file if os.path.isabs(args.valid_file) else os.path.join(script_dir, args.valid_file)
  225 + if not os.path.isfile(train_path):
  226 + raise FileNotFoundError(f"训练集不存在: {train_path}")
  227 + train_df = pd.read_csv(train_path)
  228 + if not os.path.isfile(valid_path):
  229 + # 若未提供或不存在验证集,自动切分
  230 + shuffled = train_df.sample(frac=1.0, random_state=args.seed).reset_index(drop=True)
  231 + split_idx = int(len(shuffled) * 0.9)
  232 + valid_df = shuffled.iloc[split_idx:].reset_index(drop=True)
  233 + train_df = shuffled.iloc[:split_idx].reset_index(drop=True)
  234 + else:
  235 + valid_df = pd.read_csv(valid_path)
  236 + print(f"[Info] 训练集: {train_path} | 样本数: {len(train_df)}")
  237 + print(f"[Info] 验证集: {valid_path if os.path.isfile(valid_path) else '(从训练集切分)'} | 样本数: {len(valid_df)}")
  238 +
  239 + # 自动识别列名
  240 + text_col, label_col = autodetect_columns(train_df, args.text_col, args.label_col)
  241 + print(f"[Info] 文本列: {text_col} | 标签列: {label_col}")
  242 +
  243 + # 标签映射
  244 + label2id, id2label = build_label_mappings(train_df, label_col)
  245 + if len(label2id) < 2:
  246 + raise ValueError("标签类别数少于 2,无法训练分类模型。")
  247 + print(f"[Info] 标签类别数: {len(label2id)}")
  248 +
  249 + # 数据集
  250 + train_dataset = TextClassificationDataset(train_df, tokenizer, text_col, label_col, label2id, args.max_length)
  251 + eval_dataset = TextClassificationDataset(valid_df, tokenizer, text_col, label_col, label2id, args.max_length)
  252 + collator = DataCollatorWithPadding(tokenizer=tokenizer)
  253 +
  254 + # 模型
  255 + config = AutoConfig.from_pretrained(
  256 + base_dir,
  257 + num_labels=len(label2id),
  258 + id2label={int(i): str(l) for i, l in id2label.items()},
  259 + label2id={str(l): int(i) for l, i in label2id.items()},
  260 + )
  261 + model = AutoModelForSequenceClassification.from_pretrained(
  262 + base_dir,
  263 + config=config,
  264 + ignore_mismatched_sizes=True,
  265 + )
  266 +
  267 + # 训练参数
  268 + output_dir = os.path.join(model_root, args.save_subdir)
  269 + os.makedirs(output_dir, exist_ok=True)
  270 + # 训练参数(兼容不同 transformers 版本)
  271 + args_dict = {
  272 + "output_dir": output_dir,
  273 + "per_device_train_batch_size": args.batch_size,
  274 + "per_device_eval_batch_size": args.batch_size,
  275 + "learning_rate": args.learning_rate,
  276 + "weight_decay": args.weight_decay,
  277 + "num_train_epochs": args.num_epochs,
  278 + "logging_steps": 100,
  279 + "fp16": args.fp16,
  280 + "seed": args.seed,
  281 + }
  282 +
  283 + sig = inspect.signature(TrainingArguments.__init__)
  284 + allowed = set(sig.parameters.keys())
  285 +
  286 + # 可选参数(仅在支持时添加,尽量简化与参考实现一致以提升兼容性)
  287 + if "warmup_ratio" in allowed:
  288 + args_dict["warmup_ratio"] = args.warmup_ratio
  289 + if "report_to" in allowed:
  290 + args_dict["report_to"] = []
  291 + # 评估/保存步进:按 eval_fraction 折算每个 epoch 的步数
  292 + steps_per_epoch = max(1, math.ceil(len(train_dataset) / max(1, args.batch_size)))
  293 + eval_every_steps = max(1, math.ceil(steps_per_epoch * max(0.01, min(1.0, args.eval_fraction))))
  294 + # 策略式(新/旧版本字段名兼容)
  295 + key_eval = "evaluation_strategy" if "evaluation_strategy" in allowed else ("eval_strategy" if "eval_strategy" in allowed else None)
  296 + if key_eval:
  297 + args_dict[key_eval] = "steps"
  298 + if "save_strategy" in allowed:
  299 + args_dict["save_strategy"] = "steps"
  300 + if "eval_steps" in allowed:
  301 + args_dict["eval_steps"] = eval_every_steps
  302 + if "save_steps" in allowed:
  303 + args_dict["save_steps"] = eval_every_steps
  304 + if "save_total_limit" in allowed:
  305 + args_dict["save_total_limit"] = 5
  306 + # 将日志步长与评估/保存步长对齐,减少刷屏
  307 + if "logging_steps" in allowed:
  308 + args_dict["logging_steps"] = eval_every_steps
  309 + # 最优模型回滚(仅当评估与保存策略一致时开启)
  310 + if "metric_for_best_model" in allowed:
  311 + args_dict["metric_for_best_model"] = "f1"
  312 + if "greater_is_better" in allowed:
  313 + args_dict["greater_is_better"] = True
  314 + if "load_best_model_at_end" in allowed:
  315 + eval_strat = args_dict.get("evaluation_strategy", args_dict.get("eval_strategy"))
  316 + save_strat = args_dict.get("save_strategy")
  317 + if eval_strat == save_strat and eval_strat in ("steps", "epoch"):
  318 + args_dict["load_best_model_at_end"] = True
  319 +
  320 + # 兼容无 warmup_ratio 的版本:若支持 warmup_steps 则忽略比例
  321 + if "warmup_ratio" not in allowed and "warmup_steps" in allowed:
  322 + # 不计算总步数,默认 0
  323 + args_dict["warmup_steps"] = 0
  324 +
  325 + # 若不支持策略式参数:退化为每 eval_every_steps 步保存/评估
  326 + if "save_strategy" not in allowed and "save_steps" in allowed:
  327 + args_dict["save_steps"] = eval_every_steps
  328 + if ("evaluation_strategy" not in allowed and "eval_strategy" not in allowed) and "eval_steps" in allowed:
  329 + args_dict["eval_steps"] = eval_every_steps
  330 +
  331 + # 如果支持 load_best_model_at_end,但无法同时设置评估/保存策略,则关闭它以避免报错
  332 + if "load_best_model_at_end" in allowed:
  333 + want_load_best = args_dict.get("load_best_model_at_end", False)
  334 + eval_set = args_dict.get("evaluation_strategy", None)
  335 + save_set = args_dict.get("save_strategy", None)
  336 + if want_load_best and (eval_set is None or save_set is None or eval_set != save_set):
  337 + args_dict["load_best_model_at_end"] = False
  338 +
  339 + training_args = TrainingArguments(**args_dict)
  340 + print("[Info] 训练参数要点:")
  341 + print(f" epochs={args.num_epochs}, batch_size={args.batch_size}, lr={args.learning_rate}, weight_decay={args.weight_decay}")
  342 + print(f" max_length={args.max_length}, seed={args.seed}, fp16={args.fp16}")
  343 + if "warmup_ratio" in allowed and "warmup_ratio" in args_dict:
  344 + print(f" warmup_ratio={args_dict['warmup_ratio']}")
  345 + elif "warmup_steps" in allowed and "warmup_steps" in args_dict:
  346 + print(f" warmup_steps={args_dict['warmup_steps']}")
  347 + print(f" steps_per_epoch={steps_per_epoch}, eval_every_steps={eval_every_steps}")
  348 + print(f" eval_strategy={args_dict.get('evaluation_strategy', args_dict.get('eval_strategy'))}, save_strategy={args_dict.get('save_strategy')}, logging_steps={args_dict.get('logging_steps')}")
  349 + print(f" save_total_limit={args_dict.get('save_total_limit', 'n/a')}, load_best_model_at_end={args_dict.get('load_best_model_at_end', False)}")
  350 +
  351 + callbacks = []
  352 + if EarlyStoppingCallback is not None and (args_dict.get("evaluation_strategy") in ("steps", "epoch") or "eval_steps" in allowed):
  353 + try:
  354 + callbacks.append(
  355 + EarlyStoppingCallback(
  356 + early_stopping_patience=args.early_stop_patience,
  357 + early_stopping_threshold=args.early_stop_threshold,
  358 + )
  359 + )
  360 + except Exception:
  361 + pass
  362 +
  363 + trainer = Trainer(
  364 + model=model,
  365 + args=training_args,
  366 + train_dataset=train_dataset,
  367 + eval_dataset=eval_dataset,
  368 + tokenizer=tokenizer,
  369 + data_collator=collator,
  370 + compute_metrics=compute_metrics_fn,
  371 + callbacks=callbacks,
  372 + )
  373 + # 设备与 GPU 信息
  374 + try:
  375 + device_cnt = torch.cuda.device_count()
  376 + dev_name = torch.cuda.get_device_name(0) if device_cnt > 0 else "cpu"
  377 + print(f"[Info] CUDA 可见设备数: {device_cnt}, 当前设备: {dev_name}, CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}")
  378 + except Exception:
  379 + pass
  380 +
  381 + print("[Info] 开始训练 ...")
  382 +
  383 + trainer.train()
  384 +
  385 + # 保存
  386 + tokenizer.save_pretrained(output_dir)
  387 + trainer.model.config.id2label = {int(i): str(l) for i, l in id2label.items()}
  388 + trainer.model.config.label2id = {str(l): int(i) for l, i in label2id.items()}
  389 + trainer.save_model(output_dir)
  390 + try:
  391 + best_metric = getattr(trainer.state, "best_metric", None)
  392 + best_ckpt = getattr(trainer.state, "best_model_checkpoint", None)
  393 + if best_metric is not None and best_ckpt is not None:
  394 + print(f"[Info] 最优模型: metric={best_metric:.6f} | checkpoint={best_ckpt}")
  395 + except Exception:
  396 + pass
  397 +
  398 + with open(os.path.join(output_dir, "label_map.json"), "w", encoding="utf-8") as f:
  399 + json.dump(
  400 + {"label2id": trainer.model.config.label2id, "id2label": trainer.model.config.id2label},
  401 + f,
  402 + ensure_ascii=False,
  403 + indent=2,
  404 + )
  405 +
  406 + # 训练曲线:可选保存训练与评估 loss
  407 + try:
  408 + import matplotlib.pyplot as plt # type: ignore
  409 + logs = trainer.state.log_history
  410 + t_steps, t_losses, e_steps, e_losses = [], [], [], []
  411 + step_counter = 0
  412 + for rec in logs:
  413 + if "loss" in rec and "epoch" in rec:
  414 + step_counter += 1
  415 + t_steps.append(step_counter)
  416 + t_losses.append(rec["loss"])
  417 + if "eval_loss" in rec:
  418 + e_steps.append(step_counter)
  419 + e_losses.append(rec["eval_loss"])
  420 + if t_losses or e_losses:
  421 + plt.figure(figsize=(8,4))
  422 + if t_losses:
  423 + plt.plot(t_steps, t_losses, label="train_loss")
  424 + if e_losses:
  425 + plt.plot(e_steps, e_losses, label="eval_loss")
  426 + plt.xlabel("training step (logged)")
  427 + plt.ylabel("loss")
  428 + plt.legend()
  429 + plt.tight_layout()
  430 + plt.savefig(os.path.join(output_dir, "training_curve.png"))
  431 + except Exception:
  432 + pass
  433 +
  434 + print(f"微调完成,模型已保存到: {output_dir}")
  435 +
  436 +
  437 +if __name__ == "__main__":
  438 + main()
  439 +
  440 +