戒酒的李白

Fix label mapping and boost training with batch_size=64.

@@ -29,6 +29,7 @@ os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") @@ -29,6 +29,7 @@ os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
29 # 清理可能由外部启动器注入的分布式环境变量,避免误触多卡/分布式 29 # 清理可能由外部启动器注入的分布式环境变量,避免误触多卡/分布式
30 for _k in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]: 30 for _k in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]:
31 os.environ.pop(_k, None) 31 os.environ.pop(_k, None)
  32 +os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
32 33
33 import numpy as np 34 import numpy as np
34 import torch 35 import torch
@@ -240,11 +241,22 @@ def main() -> None: @@ -240,11 +241,22 @@ def main() -> None:
240 text_col, label_col = autodetect_columns(train_df, args.text_col, args.label_col) 241 text_col, label_col = autodetect_columns(train_df, args.text_col, args.label_col)
241 print(f"[Info] 文本列: {text_col} | 标签列: {label_col}") 242 print(f"[Info] 文本列: {text_col} | 标签列: {label_col}")
242 243
243 - # 标签映射  
244 - label2id, id2label = build_label_mappings(train_df, label_col) 244 + # 标签映射(使用 训练集∪验证集 的并集,避免验证集中出现新标签导致报错)
  245 + combined_labels_df = pd.concat([train_df[[label_col]], valid_df[[label_col]]], ignore_index=True)
  246 + label2id, id2label = build_label_mappings(combined_labels_df, label_col)
245 if len(label2id) < 2: 247 if len(label2id) < 2:
246 raise ValueError("标签类别数少于 2,无法训练分类模型。") 248 raise ValueError("标签类别数少于 2,无法训练分类模型。")
247 print(f"[Info] 标签类别数: {len(label2id)}") 249 print(f"[Info] 标签类别数: {len(label2id)}")
  250 + # 提示验证集中未出现在训练集的标签数量
  251 + try:
  252 + train_label_set = set(str(x) for x in train_df[label_col].dropna().astype(str).tolist())
  253 + valid_label_set = set(str(x) for x in valid_df[label_col].dropna().astype(str).tolist())
  254 + unseen_in_train = sorted(valid_label_set - train_label_set)
  255 + if unseen_in_train:
  256 + preview = ", ".join(unseen_in_train[:10])
  257 + print(f"[Warn] 验证集中存在 {len(unseen_in_train)} 个训练未出现的标签(已纳入映射以避免报错)。示例: {preview} ...")
  258 + except Exception:
  259 + pass
248 260
249 # 数据集 261 # 数据集
250 train_dataset = TextClassificationDataset(train_df, tokenizer, text_col, label_col, label2id, args.max_length) 262 train_dataset = TextClassificationDataset(train_df, tokenizer, text_col, label_col, label2id, args.max_length)