戒酒的李白

Fix label mapping and boost training with batch_size=64.

... ... @@ -29,6 +29,7 @@ os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
# 清理可能由外部启动器注入的分布式环境变量,避免误触多卡/分布式
for _k in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]:
os.environ.pop(_k, None)
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
import numpy as np
import torch
... ... @@ -240,11 +241,22 @@ def main() -> None:
text_col, label_col = autodetect_columns(train_df, args.text_col, args.label_col)
print(f"[Info] 文本列: {text_col} | 标签列: {label_col}")
# 标签映射
label2id, id2label = build_label_mappings(train_df, label_col)
# 标签映射(使用 训练集∪验证集 的并集,避免验证集中出现新标签导致报错)
combined_labels_df = pd.concat([train_df[[label_col]], valid_df[[label_col]]], ignore_index=True)
label2id, id2label = build_label_mappings(combined_labels_df, label_col)
if len(label2id) < 2:
raise ValueError("标签类别数少于 2,无法训练分类模型。")
print(f"[Info] 标签类别数: {len(label2id)}")
# 提示验证集中未出现在训练集的标签数量
try:
train_label_set = set(str(x) for x in train_df[label_col].dropna().astype(str).tolist())
valid_label_set = set(str(x) for x in valid_df[label_col].dropna().astype(str).tolist())
unseen_in_train = sorted(valid_label_set - train_label_set)
if unseen_in_train:
preview = ", ".join(unseen_in_train[:10])
print(f"[Warn] 验证集中存在 {len(unseen_in_train)} 个训练未出现的标签(已纳入映射以避免报错)。示例: {preview} ...")
except Exception:
pass
# 数据集
train_dataset = TextClassificationDataset(train_df, tokenizer, text_col, label_col, label2id, args.max_length)
... ...