Toggle navigation
Toggle navigation
This project
Loading...
Sign in
万朱浩
/
Venue-Ops
Go to a project
Toggle navigation
Projects
Groups
Snippets
Help
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
戒酒的李白
2024-10-01 15:50:57 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
f2ebb56509e0a4fd99412dbaa843cf4ed9856e3e
f2ebb565
1 parent
91a192e3
Added inference function for the model
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
1 deletions
model_pro/BERT_CTM.py
model_pro/BERT_CTM.py
View file @
f2ebb56
...
...
@@ -35,7 +35,7 @@ class BERT_CTM_Model:
inputs
=
self
.
tokenizer
(
text
,
return_tensors
=
"pt"
,
padding
=
"max_length"
,
truncation
=
True
,
max_length
=
80
)
.
to
(
self
.
device
)
with
torch
.
no_grad
():
outputs
=
self
.
model
(
**
inputs
)
embeddings
.
append
(
outputs
.
last_hidden_state
.
cpu
()
.
numpy
())
# [batch_size, sequence_length
, hidden_size]
embeddings
.
append
(
outputs
.
last_hidden_state
[:,
0
,
:]
.
cpu
()
.
numpy
())
# [batch_size
, hidden_size]
return
np
.
vstack
(
embeddings
)
def
chinese_tokenize
(
self
,
text
):
...
...
@@ -57,6 +57,20 @@ class BERT_CTM_Model:
except
Exception
as
e
:
print
(
f
"训练CTM模型时发生错误: {e}"
)
def
predict
(
self
,
texts
):
"""使用训练好的CTM模型预测新文本的主题分布"""
if
not
self
.
ctm_model
:
raise
ValueError
(
"模型尚未训练或加载,无法进行预测"
)
try
:
bow_texts
=
[
self
.
chinese_tokenize
(
text
)
for
text
in
texts
]
testing_dataset
=
self
.
tp
.
transform
(
text_for_contextual
=
texts
,
text_for_bow
=
bow_texts
)
topic_distributions
=
self
.
ctm_model
.
get_doc_topic_distribution
(
testing_dataset
)
return
topic_distributions
except
Exception
as
e
:
print
(
f
"预测主题时发生错误: {e}"
)
return
None
def
save_model
(
self
,
path
):
"""保存训练后的CTM模型"""
if
self
.
ctm_model
:
...
...
@@ -92,3 +106,12 @@ if __name__ == "__main__":
# 加载CTM模型
model
.
load_model
(
'./trained_ctm_model'
)
# 预测新文本的主题分布
new_texts
=
[
"这是一个新的文本"
,
"另外一个新文本"
]
topic_distributions
=
model
.
predict
(
new_texts
)
# 输出预测结果
if
topic_distributions
is
not
None
:
for
idx
,
distribution
in
enumerate
(
topic_distributions
):
print
(
f
"文本 {idx+1} 的主题分布: {distribution}"
)
...
...
Please
register
or
login
to post a comment