Toggle navigation
Toggle navigation
This project
Loading...
Sign in
万朱浩
/
Venue-Ops
Go to a project
Toggle navigation
Projects
Groups
Snippets
Help
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
戒酒的李白
2024-09-30 09:24:08 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
f0f43c8e985236887bfdd3a5cda4f2abd5bf6aa6
f0f43c8e
1 parent
5108ae12
Integrated training function of CTM model
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
5 deletions
model_pro/BERT_CTM.py
model_pro/BERT_CTM.py
View file @
f0f43c8
...
...
@@ -4,13 +4,20 @@ import torch
from
tqdm
import
tqdm
import
numpy
as
np
import
jieba
from
contextualized_topic_models.utils.data_preparation
import
TopicModelDataPreparation
from
contextualized_topic_models.models.ctm
import
CombinedTM
class
BERT_CTM_Model
:
def
__init__
(
self
,
bert_model_path
):
def
__init__
(
self
,
bert_model_path
,
ctm_tokenizer_path
,
n_components
=
12
,
num_epochs
=
50
):
# 加载BERT模型和tokenizer
self
.
tokenizer
=
BertTokenizer
.
from_pretrained
(
bert_model_path
)
self
.
model
=
BertModel
.
from_pretrained
(
bert_model_path
)
# 创建CTM数据预处理对象
self
.
tp
=
TopicModelDataPreparation
(
ctm_tokenizer_path
)
self
.
n_components
=
n_components
self
.
num_epochs
=
num_epochs
def
get_bert_embeddings
(
self
,
texts
):
"""使用BERT模型批量生成文本的嵌入向量"""
embeddings
=
[]
...
...
@@ -25,8 +32,17 @@ class BERT_CTM_Model:
"""使用jieba对中文文本进行分词"""
return
" "
.
join
(
jieba
.
cut
(
text
))
def
train_ctm
(
self
,
texts
):
"""训练CTM模型"""
bow_texts
=
[
self
.
chinese_tokenize
(
text
)
for
text
in
texts
]
training_dataset
=
self
.
tp
.
fit
(
text_for_contextual
=
texts
,
text_for_bow
=
bow_texts
)
# 训练CTM
ctm
=
CombinedTM
(
bow_size
=
len
(
self
.
tp
.
vocab
),
contextual_size
=
768
,
n_components
=
self
.
n_components
,
num_epochs
=
self
.
num_epochs
)
ctm
.
fit
(
training_dataset
)
print
(
"CTM模型训练完成"
)
if
__name__
==
"__main__"
:
model
=
BERT_CTM_Model
(
'./bert_model'
)
text
=
"这是一个测试文本"
tokenized_text
=
model
.
chinese_tokenize
(
text
)
print
(
tokenized_text
)
model
=
BERT_CTM_Model
(
'./bert_model'
,
'./sentence_bert_model'
)
texts
=
[
"这是第一个文本"
,
"这是第二个文本"
]
model
.
train_ctm
(
texts
)
...
...
Please
register
or
login
to post a comment