戒酒的李白

Integrated training function of CTM model

@@ -4,13 +4,20 @@ import torch @@ -4,13 +4,20 @@ import torch
4 from tqdm import tqdm 4 from tqdm import tqdm
5 import numpy as np 5 import numpy as np
6 import jieba 6 import jieba
  7 +from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation
  8 +from contextualized_topic_models.models.ctm import CombinedTM
7 9
8 class BERT_CTM_Model: 10 class BERT_CTM_Model:
9 - def __init__(self, bert_model_path): 11 + def __init__(self, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=50):
10 # 加载BERT模型和tokenizer 12 # 加载BERT模型和tokenizer
11 self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) 13 self.tokenizer = BertTokenizer.from_pretrained(bert_model_path)
12 self.model = BertModel.from_pretrained(bert_model_path) 14 self.model = BertModel.from_pretrained(bert_model_path)
13 15
  16 + # 创建CTM数据预处理对象
  17 + self.tp = TopicModelDataPreparation(ctm_tokenizer_path)
  18 + self.n_components = n_components
  19 + self.num_epochs = num_epochs
  20 +
14 def get_bert_embeddings(self, texts): 21 def get_bert_embeddings(self, texts):
15 """使用BERT模型批量生成文本的嵌入向量""" 22 """使用BERT模型批量生成文本的嵌入向量"""
16 embeddings = [] 23 embeddings = []
@@ -25,8 +32,17 @@ class BERT_CTM_Model: @@ -25,8 +32,17 @@ class BERT_CTM_Model:
25 """使用jieba对中文文本进行分词""" 32 """使用jieba对中文文本进行分词"""
26 return " ".join(jieba.cut(text)) 33 return " ".join(jieba.cut(text))
27 34
  35 + def train_ctm(self, texts):
  36 + """训练CTM模型"""
  37 + bow_texts = [self.chinese_tokenize(text) for text in texts]
  38 + training_dataset = self.tp.fit(text_for_contextual=texts, text_for_bow=bow_texts)
  39 +
  40 + # 训练CTM
  41 + ctm = CombinedTM(bow_size=len(self.tp.vocab), contextual_size=768, n_components=self.n_components, num_epochs=self.num_epochs)
  42 + ctm.fit(training_dataset)
  43 + print("CTM模型训练完成")
  44 +
28 if __name__ == "__main__": 45 if __name__ == "__main__":
29 - model = BERT_CTM_Model('./bert_model')  
30 - text = "这是一个测试文本"  
31 - tokenized_text = model.chinese_tokenize(text)  
32 - print(tokenized_text) 46 + model = BERT_CTM_Model('./bert_model', './sentence_bert_model')
  47 + texts = ["这是第一个文本", "这是第二个文本"]
  48 + model.train_ctm(texts)