Toggle navigation
Toggle navigation
This project
Loading...
Sign in
万朱浩
/
Venue-Ops
Go to a project
Toggle navigation
Projects
Groups
Snippets
Help
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
戒酒的李白
2025-08-09 19:42:09 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
d726941d9597580df7783a96ab6a6bf883ab0f39
d726941d
1 parent
4373f31d
Fix label mapping and boost training with batch_size=64.
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
2 deletions
BertTopicDetection_Finetuned/train.py
BertTopicDetection_Finetuned/train.py
View file @
d726941
...
...
@@ -29,6 +29,7 @@ os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID")
# 清理可能由外部启动器注入的分布式环境变量,避免误触多卡/分布式
for
_k
in
[
"RANK"
,
"LOCAL_RANK"
,
"WORLD_SIZE"
]:
os
.
environ
.
pop
(
_k
,
None
)
os
.
environ
.
setdefault
(
"TOKENIZERS_PARALLELISM"
,
"false"
)
import
numpy
as
np
import
torch
...
...
@@ -240,11 +241,22 @@ def main() -> None:
text_col
,
label_col
=
autodetect_columns
(
train_df
,
args
.
text_col
,
args
.
label_col
)
print
(
f
"[Info] 文本列: {text_col} | 标签列: {label_col}"
)
# 标签映射
label2id
,
id2label
=
build_label_mappings
(
train_df
,
label_col
)
# 标签映射(使用 训练集∪验证集 的并集,避免验证集中出现新标签导致报错)
combined_labels_df
=
pd
.
concat
([
train_df
[[
label_col
]],
valid_df
[[
label_col
]]],
ignore_index
=
True
)
label2id
,
id2label
=
build_label_mappings
(
combined_labels_df
,
label_col
)
if
len
(
label2id
)
<
2
:
raise
ValueError
(
"标签类别数少于 2,无法训练分类模型。"
)
print
(
f
"[Info] 标签类别数: {len(label2id)}"
)
# 提示验证集中未出现在训练集的标签数量
try
:
train_label_set
=
set
(
str
(
x
)
for
x
in
train_df
[
label_col
]
.
dropna
()
.
astype
(
str
)
.
tolist
())
valid_label_set
=
set
(
str
(
x
)
for
x
in
valid_df
[
label_col
]
.
dropna
()
.
astype
(
str
)
.
tolist
())
unseen_in_train
=
sorted
(
valid_label_set
-
train_label_set
)
if
unseen_in_train
:
preview
=
", "
.
join
(
unseen_in_train
[:
10
])
print
(
f
"[Warn] 验证集中存在 {len(unseen_in_train)} 个训练未出现的标签(已纳入映射以避免报错)。示例: {preview} ..."
)
except
Exception
:
pass
# 数据集
train_dataset
=
TextClassificationDataset
(
train_df
,
tokenizer
,
text_col
,
label_col
,
label2id
,
args
.
max_length
)
...
...
Please
register
or
login
to post a comment