Toggle navigation
Toggle navigation
This project
Loading...
Sign in
万朱浩
/
Venue-Ops
Go to a project
Toggle navigation
Projects
Groups
Snippets
Help
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
戒酒的李白
2024-09-30 00:14:40 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
8c0479a978c62c98560df0689e8ec865a62930e1
8c0479a9
1 parent
701926ba
Test the BERT model for Chinese simulation embedding
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
0 deletions
model_pro/BERT_CTM.py
model_pro/BERT_CTM.py
0 → 100644
View file @
8c0479a
import
os
from
transformers.models.bert
import
BertTokenizer
,
BertModel
import
torch
class
BERT_CTM_Model
:
def
__init__
(
self
,
bert_model_path
):
# 加载BERT模型和tokenizer
self
.
tokenizer
=
BertTokenizer
.
from_pretrained
(
bert_model_path
)
self
.
model
=
BertModel
.
from_pretrained
(
bert_model_path
)
def
get_bert_embeddings
(
self
,
text
):
"""使用BERT模型生成文本的嵌入向量"""
inputs
=
self
.
tokenizer
(
text
,
return_tensors
=
"pt"
,
padding
=
"max_length"
,
truncation
=
True
,
max_length
=
80
)
with
torch
.
no_grad
():
outputs
=
self
.
model
(
**
inputs
)
return
outputs
.
last_hidden_state
.
cpu
()
.
numpy
()
# [batch_size, sequence_length, hidden_size]
if
__name__
==
"__main__"
:
model
=
BERT_CTM_Model
(
'./bert_model'
)
text
=
"这是一个测试文本"
embedding
=
model
.
get_bert_embeddings
(
text
)
print
(
embedding
.
shape
)
...
...
Please
register
or
login
to post a comment