Toggle navigation
Toggle navigation
This project
Loading...
Sign in
万朱浩
/
Venue-Ops
Go to a project
Toggle navigation
Projects
Groups
Snippets
Help
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
戒酒的李白
2024-09-30 00:38:42 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
48af69dace269653fbdbe589ef3bbdb9f7a15a55
48af69da
1 parent
8c0479a9
Batch processing text embedding tests
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
6 deletions
model_pro/BERT_CTM.py
model_pro/BERT_CTM.py
View file @
48af69d
import
os
from
transformers.models.bert
import
BertTokenizer
,
BertModel
import
torch
from
tqdm
import
tqdm
import
numpy
as
np
class
BERT_CTM_Model
:
def
__init__
(
self
,
bert_model_path
):
...
...
@@ -8,15 +10,18 @@ class BERT_CTM_Model:
self
.
tokenizer
=
BertTokenizer
.
from_pretrained
(
bert_model_path
)
self
.
model
=
BertModel
.
from_pretrained
(
bert_model_path
)
def
get_bert_embeddings
(
self
,
text
):
"""使用BERT模型生成文本的嵌入向量"""
def
get_bert_embeddings
(
self
,
texts
):
"""使用BERT模型批量生成文本的嵌入向量"""
embeddings
=
[]
for
text
in
tqdm
(
texts
,
desc
=
"Processing texts with BERT"
):
inputs
=
self
.
tokenizer
(
text
,
return_tensors
=
"pt"
,
padding
=
"max_length"
,
truncation
=
True
,
max_length
=
80
)
with
torch
.
no_grad
():
outputs
=
self
.
model
(
**
inputs
)
return
outputs
.
last_hidden_state
.
cpu
()
.
numpy
()
# [batch_size, sequence_length, hidden_size]
embeddings
.
append
(
outputs
.
last_hidden_state
.
cpu
()
.
numpy
())
# [batch_size, sequence_length, hidden_size]
return
np
.
vstack
(
embeddings
)
if
__name__
==
"__main__"
:
model
=
BERT_CTM_Model
(
'./bert_model'
)
text
=
"这是一个测试文本"
embedding
=
model
.
get_bert_embeddings
(
text
)
print
(
embedding
.
shape
)
texts
=
[
"这是第一个文本"
,
"这是第二个文本"
]
embeddings
=
model
.
get_bert_embeddings
(
texts
)
print
(
embeddings
.
shape
)
...
...
Please
register
or
login
to post a comment