Toggle navigation
Toggle navigation
This project
Loading...
Sign in
万朱浩
/
Venue-Ops
Go to a project
Toggle navigation
Projects
Groups
Snippets
Help
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
戒酒的李白
2024-10-03 00:48:10 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
80aa0cfa9c8974f9cab92e166749c44e6aa84fe1
80aa0cfa
1 parent
4d91f30d
Implement the get_bert_ctm_embeddings function and embedding generation and loading logic
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
0 deletions
model_pro/BCAT.py
model_pro/BCAT.py
0 → 100644
View file @
80aa0cf
import
os
import
numpy
as
np
from
BERT_CTM
import
BERT_CTM_Model
# 假设BERT_CTM模型在这个文件中
# BERT_CTM 嵌入生成和加载函数
def
get_bert_ctm_embeddings
(
texts
,
bert_model_path
,
ctm_tokenizer_path
,
n_components
=
12
,
num_epochs
=
20
,
save_path
=
None
):
"""
获取或生成 BERT+CTM 嵌入,并保存到文件。
:param texts: 需要嵌入的文本
:param bert_model_path: BERT 模型的路径
:param ctm_tokenizer_path: CTM tokenizer 的路径
:param n_components: 生成的主题数量
:param num_epochs: 训练的epoch数
:param save_path: 嵌入保存路径
:return: 生成或加载的嵌入
"""
# 检查是否已经存在保存的嵌入文件
if
save_path
and
os
.
path
.
exists
(
save_path
):
print
(
f
"从文件 {save_path} 加载嵌入..."
)
embeddings
=
np
.
load
(
save_path
)
else
:
print
(
"生成 BERT+CTM 嵌入..."
)
bert_ctm_model
=
BERT_CTM_Model
(
bert_model_path
=
bert_model_path
,
ctm_tokenizer_path
=
ctm_tokenizer_path
,
n_components
=
n_components
,
num_epochs
=
num_epochs
)
embeddings
=
bert_ctm_model
.
train
(
texts
)
# 生成嵌入
# 保存嵌入到文件
if
save_path
:
print
(
f
"保存嵌入到文件 {save_path}..."
)
np
.
save
(
save_path
,
embeddings
)
return
embeddings
if
__name__
==
"__main__"
:
# 示例调用
sample_texts
=
[
"This is a test text."
,
"Another example of text data."
]
bert_model_path
=
'./bert_model'
ctm_tokenizer_path
=
'./sentence_bert_model'
save_path
=
'sample_embeddings.npy'
# 生成或加载 BERT+CTM 嵌入
embeddings
=
get_bert_ctm_embeddings
(
sample_texts
,
bert_model_path
,
ctm_tokenizer_path
,
save_path
=
save_path
)
# 打印嵌入形状
print
(
f
"嵌入形状: {embeddings.shape}"
)
...
...
Please
register
or
login
to post a comment