戒酒的李白

Batch processing text embedding tests

1 import os 1 import os
2 from transformers.models.bert import BertTokenizer, BertModel 2 from transformers.models.bert import BertTokenizer, BertModel
3 import torch 3 import torch
  4 +from tqdm import tqdm
  5 +import numpy as np
4 6
5 class BERT_CTM_Model: 7 class BERT_CTM_Model:
6 def __init__(self, bert_model_path): 8 def __init__(self, bert_model_path):
@@ -8,15 +10,18 @@ class BERT_CTM_Model: @@ -8,15 +10,18 @@ class BERT_CTM_Model:
8 self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) 10 self.tokenizer = BertTokenizer.from_pretrained(bert_model_path)
9 self.model = BertModel.from_pretrained(bert_model_path) 11 self.model = BertModel.from_pretrained(bert_model_path)
10 12
11 - def get_bert_embeddings(self, text):  
12 - """使用BERT模型生成文本的嵌入向量"""  
13 - inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80)  
14 - with torch.no_grad():  
15 - outputs = self.model(**inputs)  
16 - return outputs.last_hidden_state.cpu().numpy() # [batch_size, sequence_length, hidden_size] 13 + def get_bert_embeddings(self, texts):
  14 + """使用BERT模型批量生成文本的嵌入向量"""
  15 + embeddings = []
  16 + for text in tqdm(texts, desc="Processing texts with BERT"):
  17 + inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80)
  18 + with torch.no_grad():
  19 + outputs = self.model(**inputs)
  20 + embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size]
  21 + return np.vstack(embeddings)
17 22
18 if __name__ == "__main__": 23 if __name__ == "__main__":
19 model = BERT_CTM_Model('./bert_model') 24 model = BERT_CTM_Model('./bert_model')
20 - text = "这是一个测试文本"  
21 - embedding = model.get_bert_embeddings(text)  
22 - print(embedding.shape) 25 + texts = ["这是第一个文本", "这是第二个文本"]
  26 + embeddings = model.get_bert_embeddings(texts)
  27 + print(embeddings.shape)