Showing
1 changed file
with
14 additions
and
9 deletions
| 1 | import os | 1 | import os |
| 2 | from transformers.models.bert import BertTokenizer, BertModel | 2 | from transformers.models.bert import BertTokenizer, BertModel |
| 3 | import torch | 3 | import torch |
| 4 | +from tqdm import tqdm | ||
| 5 | +import numpy as np | ||
| 4 | 6 | ||
| 5 | class BERT_CTM_Model: | 7 | class BERT_CTM_Model: |
| 6 | def __init__(self, bert_model_path): | 8 | def __init__(self, bert_model_path): |
| @@ -8,15 +10,18 @@ class BERT_CTM_Model: | @@ -8,15 +10,18 @@ class BERT_CTM_Model: | ||
| 8 | self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) | 10 | self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) |
| 9 | self.model = BertModel.from_pretrained(bert_model_path) | 11 | self.model = BertModel.from_pretrained(bert_model_path) |
| 10 | 12 | ||
| 11 | - def get_bert_embeddings(self, text): | ||
| 12 | - """使用BERT模型生成文本的嵌入向量""" | ||
| 13 | - inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80) | ||
| 14 | - with torch.no_grad(): | ||
| 15 | - outputs = self.model(**inputs) | ||
| 16 | - return outputs.last_hidden_state.cpu().numpy() # [batch_size, sequence_length, hidden_size] | 13 | + def get_bert_embeddings(self, texts): |
| 14 | + """使用BERT模型批量生成文本的嵌入向量""" | ||
| 15 | + embeddings = [] | ||
| 16 | + for text in tqdm(texts, desc="Processing texts with BERT"): | ||
| 17 | + inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80) | ||
| 18 | + with torch.no_grad(): | ||
| 19 | + outputs = self.model(**inputs) | ||
| 20 | + embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size] | ||
| 21 | + return np.vstack(embeddings) | ||
| 17 | 22 | ||
| 18 | if __name__ == "__main__": | 23 | if __name__ == "__main__": |
| 19 | model = BERT_CTM_Model('./bert_model') | 24 | model = BERT_CTM_Model('./bert_model') |
| 20 | - text = "这是一个测试文本" | ||
| 21 | - embedding = model.get_bert_embeddings(text) | ||
| 22 | - print(embedding.shape) | 25 | + texts = ["这是第一个文本", "这是第二个文本"] |
| 26 | + embeddings = model.get_bert_embeddings(texts) | ||
| 27 | + print(embeddings.shape) |
-
Please register or login to post a comment