Toggle navigation
Toggle navigation
This project
Loading...
Sign in
万朱浩
/
Venue-Ops
Go to a project
Toggle navigation
Projects
Groups
Snippets
Help
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
戒酒的李白
2024-09-30 13:38:56 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
91a192e3790aed11c37ea1be78b089c8bd34013f
91a192e3
1 parent
f0f43c8e
Add BERT-CTM model with save/load and error handling
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
11 deletions
model_pro/BERT_CTM.py
model_pro/BERT_CTM.py
View file @
91a192e
...
...
@@ -8,41 +8,87 @@ from contextualized_topic_models.utils.data_preparation import TopicModelDataPre
from
contextualized_topic_models.models.ctm
import
CombinedTM
class
BERT_CTM_Model
:
def
__init__
(
self
,
bert_model_path
,
ctm_tokenizer_path
,
n_components
=
12
,
num_epochs
=
50
):
def
__init__
(
self
,
bert_model_path
,
ctm_tokenizer_path
,
n_components
=
12
,
num_epochs
=
50
,
device
=
None
):
# 确定设备 (CPU/GPU)
self
.
device
=
device
if
device
else
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
# 检查模型路径是否存在
if
not
os
.
path
.
exists
(
bert_model_path
):
raise
ValueError
(
f
"BERT模型路径不存在: {bert_model_path}"
)
if
not
os
.
path
.
exists
(
ctm_tokenizer_path
):
raise
ValueError
(
f
"CTM分词器路径不存在: {ctm_tokenizer_path}"
)
# 加载BERT模型和tokenizer
self
.
tokenizer
=
BertTokenizer
.
from_pretrained
(
bert_model_path
)
self
.
model
=
BertModel
.
from_pretrained
(
bert_model_path
)
self
.
model
=
BertModel
.
from_pretrained
(
bert_model_path
)
.
to
(
self
.
device
)
# 创建CTM数据预处理对象
self
.
tp
=
TopicModelDataPreparation
(
ctm_tokenizer_path
)
self
.
n_components
=
n_components
self
.
num_epochs
=
num_epochs
self
.
ctm_model
=
None
def
get_bert_embeddings
(
self
,
texts
):
"""使用BERT模型批量生成文本的嵌入向量"""
embeddings
=
[]
for
text
in
tqdm
(
texts
,
desc
=
"Processing texts with BERT"
):
inputs
=
self
.
tokenizer
(
text
,
return_tensors
=
"pt"
,
padding
=
"max_length"
,
truncation
=
True
,
max_length
=
80
)
inputs
=
self
.
tokenizer
(
text
,
return_tensors
=
"pt"
,
padding
=
"max_length"
,
truncation
=
True
,
max_length
=
80
)
.
to
(
self
.
device
)
with
torch
.
no_grad
():
outputs
=
self
.
model
(
**
inputs
)
embeddings
.
append
(
outputs
.
last_hidden_state
.
cpu
()
.
numpy
())
# [batch_size, sequence_length, hidden_size]
return
np
.
vstack
(
embeddings
)
def
chinese_tokenize
(
self
,
text
):
"""使用jieba对中文文本进行分词"""
return
" "
.
join
(
jieba
.
cut
(
text
))
def
train_ctm
(
self
,
texts
):
"""训练CTM模型"""
bow_texts
=
[
self
.
chinese_tokenize
(
text
)
for
text
in
texts
]
training_dataset
=
self
.
tp
.
fit
(
text_for_contextual
=
texts
,
text_for_bow
=
bow_texts
)
try
:
# 分词并准备BOW文本
bow_texts
=
[
self
.
chinese_tokenize
(
text
)
for
text
in
texts
]
training_dataset
=
self
.
tp
.
fit
(
text_for_contextual
=
texts
,
text_for_bow
=
bow_texts
)
# 训练CTM
self
.
ctm_model
=
CombinedTM
(
bow_size
=
len
(
self
.
tp
.
vocab
),
contextual_size
=
768
,
n_components
=
self
.
n_components
,
num_epochs
=
self
.
num_epochs
)
self
.
ctm_model
.
fit
(
training_dataset
)
print
(
"CTM模型训练完成"
)
except
Exception
as
e
:
print
(
f
"训练CTM模型时发生错误: {e}"
)
def
save_model
(
self
,
path
):
"""保存训练后的CTM模型"""
if
self
.
ctm_model
:
self
.
ctm_model
.
save
(
path
)
print
(
f
"CTM模型已保存至: {path}"
)
else
:
print
(
"未找到已训练的CTM模型,无法保存"
)
# 训练CTM
ctm
=
CombinedTM
(
bow_size
=
len
(
self
.
tp
.
vocab
),
contextual_size
=
768
,
n_components
=
self
.
n_components
,
num_epochs
=
self
.
num_epochs
)
ctm
.
fit
(
training_dataset
)
print
(
"CTM模型训练完成"
)
def
load_model
(
self
,
path
):
"""加载已保存的CTM模型"""
if
os
.
path
.
exists
(
path
):
self
.
ctm_model
=
CombinedTM
.
load
(
path
)
print
(
f
"CTM模型已加载自: {path}"
)
else
:
print
(
f
"无法加载模型,路径不存在: {path}"
)
if
__name__
==
"__main__"
:
model
=
BERT_CTM_Model
(
'./bert_model'
,
'./sentence_bert_model'
)
# 设定BERT和CTM模型的路径
bert_model_path
=
'./bert_model'
ctm_tokenizer_path
=
'./sentence_bert_model'
# 初始化模型
model
=
BERT_CTM_Model
(
bert_model_path
,
ctm_tokenizer_path
)
# 示例文本
texts
=
[
"这是第一个文本"
,
"这是第二个文本"
]
# 训练CTM模型
model
.
train_ctm
(
texts
)
# 保存CTM模型
model
.
save_model
(
'./trained_ctm_model'
)
# 加载CTM模型
model
.
load_model
(
'./trained_ctm_model'
)
...
...
Please
register
or
login
to post a comment