Showing
1 changed file
with
22 additions
and
0 deletions
model_pro/BERT_CTM.py
0 → 100644
| 1 | +import os | ||
| 2 | +from transformers.models.bert import BertTokenizer, BertModel | ||
| 3 | +import torch | ||
| 4 | + | ||
| 5 | +class BERT_CTM_Model: | ||
| 6 | + def __init__(self, bert_model_path): | ||
| 7 | + # 加载BERT模型和tokenizer | ||
| 8 | + self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) | ||
| 9 | + self.model = BertModel.from_pretrained(bert_model_path) | ||
| 10 | + | ||
| 11 | + def get_bert_embeddings(self, text): | ||
| 12 | + """使用BERT模型生成文本的嵌入向量""" | ||
| 13 | + inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80) | ||
| 14 | + with torch.no_grad(): | ||
| 15 | + outputs = self.model(**inputs) | ||
| 16 | + return outputs.last_hidden_state.cpu().numpy() # [batch_size, sequence_length, hidden_size] | ||
| 17 | + | ||
| 18 | +if __name__ == "__main__": | ||
| 19 | + model = BERT_CTM_Model('./bert_model') | ||
| 20 | + text = "这是一个测试文本" | ||
| 21 | + embedding = model.get_bert_embeddings(text) | ||
| 22 | + print(embedding.shape) |
-
Please register or login to post a comment