Showing
1 changed file
with
14 additions
and
2 deletions
| @@ -29,6 +29,7 @@ os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") | @@ -29,6 +29,7 @@ os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") | ||
| 29 | # 清理可能由外部启动器注入的分布式环境变量,避免误触多卡/分布式 | 29 | # 清理可能由外部启动器注入的分布式环境变量,避免误触多卡/分布式 |
| 30 | for _k in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]: | 30 | for _k in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]: |
| 31 | os.environ.pop(_k, None) | 31 | os.environ.pop(_k, None) |
| 32 | +os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | ||
| 32 | 33 | ||
| 33 | import numpy as np | 34 | import numpy as np |
| 34 | import torch | 35 | import torch |
| @@ -240,11 +241,22 @@ def main() -> None: | @@ -240,11 +241,22 @@ def main() -> None: | ||
| 240 | text_col, label_col = autodetect_columns(train_df, args.text_col, args.label_col) | 241 | text_col, label_col = autodetect_columns(train_df, args.text_col, args.label_col) |
| 241 | print(f"[Info] 文本列: {text_col} | 标签列: {label_col}") | 242 | print(f"[Info] 文本列: {text_col} | 标签列: {label_col}") |
| 242 | 243 | ||
| 243 | - # 标签映射 | ||
| 244 | - label2id, id2label = build_label_mappings(train_df, label_col) | 244 | + # 标签映射(使用 训练集∪验证集 的并集,避免验证集中出现新标签导致报错) |
| 245 | + combined_labels_df = pd.concat([train_df[[label_col]], valid_df[[label_col]]], ignore_index=True) | ||
| 246 | + label2id, id2label = build_label_mappings(combined_labels_df, label_col) | ||
| 245 | if len(label2id) < 2: | 247 | if len(label2id) < 2: |
| 246 | raise ValueError("标签类别数少于 2,无法训练分类模型。") | 248 | raise ValueError("标签类别数少于 2,无法训练分类模型。") |
| 247 | print(f"[Info] 标签类别数: {len(label2id)}") | 249 | print(f"[Info] 标签类别数: {len(label2id)}") |
| 250 | + # 提示验证集中未出现在训练集的标签数量 | ||
| 251 | + try: | ||
| 252 | + train_label_set = set(str(x) for x in train_df[label_col].dropna().astype(str).tolist()) | ||
| 253 | + valid_label_set = set(str(x) for x in valid_df[label_col].dropna().astype(str).tolist()) | ||
| 254 | + unseen_in_train = sorted(valid_label_set - train_label_set) | ||
| 255 | + if unseen_in_train: | ||
| 256 | + preview = ", ".join(unseen_in_train[:10]) | ||
| 257 | + print(f"[Warn] 验证集中存在 {len(unseen_in_train)} 个训练未出现的标签(已纳入映射以避免报错)。示例: {preview} ...") | ||
| 258 | + except Exception: | ||
| 259 | + pass | ||
| 248 | 260 | ||
| 249 | # 数据集 | 261 | # 数据集 |
| 250 | train_dataset = TextClassificationDataset(train_df, tokenizer, text_col, label_col, label2id, args.max_length) | 262 | train_dataset = TextClassificationDataset(train_df, tokenizer, text_col, label_col, label2id, args.max_length) |
-
Please register or login to post a comment