戒酒的李白

Add BERT-CTM model with save/load and error handling

@@ -8,41 +8,87 @@ from contextualized_topic_models.utils.data_preparation import TopicModelDataPre @@ -8,41 +8,87 @@ from contextualized_topic_models.utils.data_preparation import TopicModelDataPre
8 from contextualized_topic_models.models.ctm import CombinedTM 8 from contextualized_topic_models.models.ctm import CombinedTM
9 9
10 class BERT_CTM_Model: 10 class BERT_CTM_Model:
11 - def __init__(self, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=50): 11 + def __init__(self, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=50, device=None):
  12 + # 确定设备 (CPU/GPU)
  13 + self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
  14 +
  15 + # 检查模型路径是否存在
  16 + if not os.path.exists(bert_model_path):
  17 + raise ValueError(f"BERT模型路径不存在: {bert_model_path}")
  18 + if not os.path.exists(ctm_tokenizer_path):
  19 + raise ValueError(f"CTM分词器路径不存在: {ctm_tokenizer_path}")
  20 +
12 # 加载BERT模型和tokenizer 21 # 加载BERT模型和tokenizer
13 self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) 22 self.tokenizer = BertTokenizer.from_pretrained(bert_model_path)
14 - self.model = BertModel.from_pretrained(bert_model_path) 23 + self.model = BertModel.from_pretrained(bert_model_path).to(self.device)
15 24
16 # 创建CTM数据预处理对象 25 # 创建CTM数据预处理对象
17 self.tp = TopicModelDataPreparation(ctm_tokenizer_path) 26 self.tp = TopicModelDataPreparation(ctm_tokenizer_path)
18 self.n_components = n_components 27 self.n_components = n_components
19 self.num_epochs = num_epochs 28 self.num_epochs = num_epochs
  29 + self.ctm_model = None
20 30
21 def get_bert_embeddings(self, texts): 31 def get_bert_embeddings(self, texts):
22 """使用BERT模型批量生成文本的嵌入向量""" 32 """使用BERT模型批量生成文本的嵌入向量"""
23 embeddings = [] 33 embeddings = []
24 for text in tqdm(texts, desc="Processing texts with BERT"): 34 for text in tqdm(texts, desc="Processing texts with BERT"):
25 - inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80) 35 + inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80).to(self.device)
26 with torch.no_grad(): 36 with torch.no_grad():
27 outputs = self.model(**inputs) 37 outputs = self.model(**inputs)
28 embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size] 38 embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size]
29 return np.vstack(embeddings) 39 return np.vstack(embeddings)
30 - 40 +
31 def chinese_tokenize(self, text): 41 def chinese_tokenize(self, text):
32 """使用jieba对中文文本进行分词""" 42 """使用jieba对中文文本进行分词"""
33 return " ".join(jieba.cut(text)) 43 return " ".join(jieba.cut(text))
34 44
35 def train_ctm(self, texts): 45 def train_ctm(self, texts):
36 """训练CTM模型""" 46 """训练CTM模型"""
37 - bow_texts = [self.chinese_tokenize(text) for text in texts]  
38 - training_dataset = self.tp.fit(text_for_contextual=texts, text_for_bow=bow_texts) 47 + try:
  48 + # 分词并准备BOW文本
  49 + bow_texts = [self.chinese_tokenize(text) for text in texts]
  50 + training_dataset = self.tp.fit(text_for_contextual=texts, text_for_bow=bow_texts)
  51 +
  52 + # 训练CTM
  53 + self.ctm_model = CombinedTM(bow_size=len(self.tp.vocab), contextual_size=768,
  54 + n_components=self.n_components, num_epochs=self.num_epochs)
  55 + self.ctm_model.fit(training_dataset)
  56 + print("CTM模型训练完成")
  57 + except Exception as e:
  58 + print(f"训练CTM模型时发生错误: {e}")
  59 +
  60 + def save_model(self, path):
  61 + """保存训练后的CTM模型"""
  62 + if self.ctm_model:
  63 + self.ctm_model.save(path)
  64 + print(f"CTM模型已保存至: {path}")
  65 + else:
  66 + print("未找到已训练的CTM模型,无法保存")
39 67
40 - # 训练CTM  
41 - ctm = CombinedTM(bow_size=len(self.tp.vocab), contextual_size=768, n_components=self.n_components, num_epochs=self.num_epochs)  
42 - ctm.fit(training_dataset)  
43 - print("CTM模型训练完成") 68 + def load_model(self, path):
  69 + """加载已保存的CTM模型"""
  70 + if os.path.exists(path):
  71 + self.ctm_model = CombinedTM.load(path)
  72 + print(f"CTM模型已加载自: {path}")
  73 + else:
  74 + print(f"无法加载模型,路径不存在: {path}")
44 75
45 if __name__ == "__main__": 76 if __name__ == "__main__":
46 - model = BERT_CTM_Model('./bert_model', './sentence_bert_model') 77 + # 设定BERT和CTM模型的路径
  78 + bert_model_path = './bert_model'
  79 + ctm_tokenizer_path = './sentence_bert_model'
  80 +
  81 + # 初始化模型
  82 + model = BERT_CTM_Model(bert_model_path, ctm_tokenizer_path)
  83 +
  84 + # 示例文本
47 texts = ["这是第一个文本", "这是第二个文本"] 85 texts = ["这是第一个文本", "这是第二个文本"]
  86 +
  87 + # 训练CTM模型
48 model.train_ctm(texts) 88 model.train_ctm(texts)
  89 +
  90 + # 保存CTM模型
  91 + model.save_model('./trained_ctm_model')
  92 +
  93 + # 加载CTM模型
  94 + model.load_model('./trained_ctm_model')