戒酒的李白

Test the BERT model for Chinese simulation embedding

  1 +import os
  2 +from transformers.models.bert import BertTokenizer, BertModel
  3 +import torch
  4 +
  5 +class BERT_CTM_Model:
  6 + def __init__(self, bert_model_path):
  7 + # 加载BERT模型和tokenizer
  8 + self.tokenizer = BertTokenizer.from_pretrained(bert_model_path)
  9 + self.model = BertModel.from_pretrained(bert_model_path)
  10 +
  11 + def get_bert_embeddings(self, text):
  12 + """使用BERT模型生成文本的嵌入向量"""
  13 + inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80)
  14 + with torch.no_grad():
  15 + outputs = self.model(**inputs)
  16 + return outputs.last_hidden_state.cpu().numpy() # [batch_size, sequence_length, hidden_size]
  17 +
  18 +if __name__ == "__main__":
  19 + model = BERT_CTM_Model('./bert_model')
  20 + text = "这是一个测试文本"
  21 + embedding = model.get_bert_embeddings(text)
  22 + print(embedding.shape)