戒酒的李白

Added inference function for the model

@@ -35,7 +35,7 @@ class BERT_CTM_Model: @@ -35,7 +35,7 @@ class BERT_CTM_Model:
35 inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80).to(self.device) 35 inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80).to(self.device)
36 with torch.no_grad(): 36 with torch.no_grad():
37 outputs = self.model(**inputs) 37 outputs = self.model(**inputs)
38 - embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size] 38 + embeddings.append(outputs.last_hidden_state[:, 0, :].cpu().numpy()) # [batch_size, hidden_size]
39 return np.vstack(embeddings) 39 return np.vstack(embeddings)
40 40
41 def chinese_tokenize(self, text): 41 def chinese_tokenize(self, text):
@@ -57,6 +57,20 @@ class BERT_CTM_Model: @@ -57,6 +57,20 @@ class BERT_CTM_Model:
57 except Exception as e: 57 except Exception as e:
58 print(f"训练CTM模型时发生错误: {e}") 58 print(f"训练CTM模型时发生错误: {e}")
59 59
  60 + def predict(self, texts):
  61 + """使用训练好的CTM模型预测新文本的主题分布"""
  62 + if not self.ctm_model:
  63 + raise ValueError("模型尚未训练或加载,无法进行预测")
  64 +
  65 + try:
  66 + bow_texts = [self.chinese_tokenize(text) for text in texts]
  67 + testing_dataset = self.tp.transform(text_for_contextual=texts, text_for_bow=bow_texts)
  68 + topic_distributions = self.ctm_model.get_doc_topic_distribution(testing_dataset)
  69 + return topic_distributions
  70 + except Exception as e:
  71 + print(f"预测主题时发生错误: {e}")
  72 + return None
  73 +
60 def save_model(self, path): 74 def save_model(self, path):
61 """保存训练后的CTM模型""" 75 """保存训练后的CTM模型"""
62 if self.ctm_model: 76 if self.ctm_model:
@@ -92,3 +106,12 @@ if __name__ == "__main__": @@ -92,3 +106,12 @@ if __name__ == "__main__":
92 106
93 # 加载CTM模型 107 # 加载CTM模型
94 model.load_model('./trained_ctm_model') 108 model.load_model('./trained_ctm_model')
  109 +
  110 + # 预测新文本的主题分布
  111 + new_texts = ["这是一个新的文本", "另外一个新文本"]
  112 + topic_distributions = model.predict(new_texts)
  113 +
  114 + # 输出预测结果
  115 + if topic_distributions is not None:
  116 + for idx, distribution in enumerate(topic_distributions):
  117 + print(f"文本 {idx+1} 的主题分布: {distribution}")