Showing
1 changed file
with
24 additions
and
1 deletions
| @@ -35,7 +35,7 @@ class BERT_CTM_Model: | @@ -35,7 +35,7 @@ class BERT_CTM_Model: | ||
| 35 | inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80).to(self.device) | 35 | inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80).to(self.device) |
| 36 | with torch.no_grad(): | 36 | with torch.no_grad(): |
| 37 | outputs = self.model(**inputs) | 37 | outputs = self.model(**inputs) |
| 38 | - embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size] | 38 | + embeddings.append(outputs.last_hidden_state[:, 0, :].cpu().numpy()) # [batch_size, hidden_size] |
| 39 | return np.vstack(embeddings) | 39 | return np.vstack(embeddings) |
| 40 | 40 | ||
| 41 | def chinese_tokenize(self, text): | 41 | def chinese_tokenize(self, text): |
| @@ -57,6 +57,20 @@ class BERT_CTM_Model: | @@ -57,6 +57,20 @@ class BERT_CTM_Model: | ||
| 57 | except Exception as e: | 57 | except Exception as e: |
| 58 | print(f"训练CTM模型时发生错误: {e}") | 58 | print(f"训练CTM模型时发生错误: {e}") |
| 59 | 59 | ||
| 60 | + def predict(self, texts): | ||
| 61 | + """使用训练好的CTM模型预测新文本的主题分布""" | ||
| 62 | + if not self.ctm_model: | ||
| 63 | + raise ValueError("模型尚未训练或加载,无法进行预测") | ||
| 64 | + | ||
| 65 | + try: | ||
| 66 | + bow_texts = [self.chinese_tokenize(text) for text in texts] | ||
| 67 | + testing_dataset = self.tp.transform(text_for_contextual=texts, text_for_bow=bow_texts) | ||
| 68 | + topic_distributions = self.ctm_model.get_doc_topic_distribution(testing_dataset) | ||
| 69 | + return topic_distributions | ||
| 70 | + except Exception as e: | ||
| 71 | + print(f"预测主题时发生错误: {e}") | ||
| 72 | + return None | ||
| 73 | + | ||
| 60 | def save_model(self, path): | 74 | def save_model(self, path): |
| 61 | """保存训练后的CTM模型""" | 75 | """保存训练后的CTM模型""" |
| 62 | if self.ctm_model: | 76 | if self.ctm_model: |
| @@ -92,3 +106,12 @@ if __name__ == "__main__": | @@ -92,3 +106,12 @@ if __name__ == "__main__": | ||
| 92 | 106 | ||
| 93 | # 加载CTM模型 | 107 | # 加载CTM模型 |
| 94 | model.load_model('./trained_ctm_model') | 108 | model.load_model('./trained_ctm_model') |
| 109 | + | ||
| 110 | + # 预测新文本的主题分布 | ||
| 111 | + new_texts = ["这是一个新的文本", "另外一个新文本"] | ||
| 112 | + topic_distributions = model.predict(new_texts) | ||
| 113 | + | ||
| 114 | + # 输出预测结果 | ||
| 115 | + if topic_distributions is not None: | ||
| 116 | + for idx, distribution in enumerate(topic_distributions): | ||
| 117 | + print(f"文本 {idx+1} 的主题分布: {distribution}") |
-
Please register or login to post a comment