Showing
1 changed file
with
52 additions
and
6 deletions
| @@ -8,21 +8,31 @@ from contextualized_topic_models.utils.data_preparation import TopicModelDataPre | @@ -8,21 +8,31 @@ from contextualized_topic_models.utils.data_preparation import TopicModelDataPre | ||
| 8 | from contextualized_topic_models.models.ctm import CombinedTM | 8 | from contextualized_topic_models.models.ctm import CombinedTM |
| 9 | 9 | ||
| 10 | class BERT_CTM_Model: | 10 | class BERT_CTM_Model: |
| 11 | - def __init__(self, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=50): | 11 | + def __init__(self, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=50, device=None): |
| 12 | + # 确定设备 (CPU/GPU) | ||
| 13 | + self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | ||
| 14 | + | ||
| 15 | + # 检查模型路径是否存在 | ||
| 16 | + if not os.path.exists(bert_model_path): | ||
| 17 | + raise ValueError(f"BERT模型路径不存在: {bert_model_path}") | ||
| 18 | + if not os.path.exists(ctm_tokenizer_path): | ||
| 19 | + raise ValueError(f"CTM分词器路径不存在: {ctm_tokenizer_path}") | ||
| 20 | + | ||
| 12 | # 加载BERT模型和tokenizer | 21 | # 加载BERT模型和tokenizer |
| 13 | self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) | 22 | self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) |
| 14 | - self.model = BertModel.from_pretrained(bert_model_path) | 23 | + self.model = BertModel.from_pretrained(bert_model_path).to(self.device) |
| 15 | 24 | ||
| 16 | # 创建CTM数据预处理对象 | 25 | # 创建CTM数据预处理对象 |
| 17 | self.tp = TopicModelDataPreparation(ctm_tokenizer_path) | 26 | self.tp = TopicModelDataPreparation(ctm_tokenizer_path) |
| 18 | self.n_components = n_components | 27 | self.n_components = n_components |
| 19 | self.num_epochs = num_epochs | 28 | self.num_epochs = num_epochs |
| 29 | + self.ctm_model = None | ||
| 20 | 30 | ||
| 21 | def get_bert_embeddings(self, texts): | 31 | def get_bert_embeddings(self, texts): |
| 22 | """使用BERT模型批量生成文本的嵌入向量""" | 32 | """使用BERT模型批量生成文本的嵌入向量""" |
| 23 | embeddings = [] | 33 | embeddings = [] |
| 24 | for text in tqdm(texts, desc="Processing texts with BERT"): | 34 | for text in tqdm(texts, desc="Processing texts with BERT"): |
| 25 | - inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80) | 35 | + inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80).to(self.device) |
| 26 | with torch.no_grad(): | 36 | with torch.no_grad(): |
| 27 | outputs = self.model(**inputs) | 37 | outputs = self.model(**inputs) |
| 28 | embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size] | 38 | embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size] |
| @@ -34,15 +44,51 @@ class BERT_CTM_Model: | @@ -34,15 +44,51 @@ class BERT_CTM_Model: | ||
| 34 | 44 | ||
| 35 | def train_ctm(self, texts): | 45 | def train_ctm(self, texts): |
| 36 | """训练CTM模型""" | 46 | """训练CTM模型""" |
| 47 | + try: | ||
| 48 | + # 分词并准备BOW文本 | ||
| 37 | bow_texts = [self.chinese_tokenize(text) for text in texts] | 49 | bow_texts = [self.chinese_tokenize(text) for text in texts] |
| 38 | training_dataset = self.tp.fit(text_for_contextual=texts, text_for_bow=bow_texts) | 50 | training_dataset = self.tp.fit(text_for_contextual=texts, text_for_bow=bow_texts) |
| 39 | 51 | ||
| 40 | # 训练CTM | 52 | # 训练CTM |
| 41 | - ctm = CombinedTM(bow_size=len(self.tp.vocab), contextual_size=768, n_components=self.n_components, num_epochs=self.num_epochs) | ||
| 42 | - ctm.fit(training_dataset) | 53 | + self.ctm_model = CombinedTM(bow_size=len(self.tp.vocab), contextual_size=768, |
| 54 | + n_components=self.n_components, num_epochs=self.num_epochs) | ||
| 55 | + self.ctm_model.fit(training_dataset) | ||
| 43 | print("CTM模型训练完成") | 56 | print("CTM模型训练完成") |
| 57 | + except Exception as e: | ||
| 58 | + print(f"训练CTM模型时发生错误: {e}") | ||
| 59 | + | ||
| 60 | + def save_model(self, path): | ||
| 61 | + """保存训练后的CTM模型""" | ||
| 62 | + if self.ctm_model: | ||
| 63 | + self.ctm_model.save(path) | ||
| 64 | + print(f"CTM模型已保存至: {path}") | ||
| 65 | + else: | ||
| 66 | + print("未找到已训练的CTM模型,无法保存") | ||
| 67 | + | ||
| 68 | + def load_model(self, path): | ||
| 69 | + """加载已保存的CTM模型""" | ||
| 70 | + if os.path.exists(path): | ||
| 71 | + self.ctm_model = CombinedTM.load(path) | ||
| 72 | + print(f"CTM模型已加载自: {path}") | ||
| 73 | + else: | ||
| 74 | + print(f"无法加载模型,路径不存在: {path}") | ||
| 44 | 75 | ||
| 45 | if __name__ == "__main__": | 76 | if __name__ == "__main__": |
| 46 | - model = BERT_CTM_Model('./bert_model', './sentence_bert_model') | 77 | + # 设定BERT和CTM模型的路径 |
| 78 | + bert_model_path = './bert_model' | ||
| 79 | + ctm_tokenizer_path = './sentence_bert_model' | ||
| 80 | + | ||
| 81 | + # 初始化模型 | ||
| 82 | + model = BERT_CTM_Model(bert_model_path, ctm_tokenizer_path) | ||
| 83 | + | ||
| 84 | + # 示例文本 | ||
| 47 | texts = ["这是第一个文本", "这是第二个文本"] | 85 | texts = ["这是第一个文本", "这是第二个文本"] |
| 86 | + | ||
| 87 | + # 训练CTM模型 | ||
| 48 | model.train_ctm(texts) | 88 | model.train_ctm(texts) |
| 89 | + | ||
| 90 | + # 保存CTM模型 | ||
| 91 | + model.save_model('./trained_ctm_model') | ||
| 92 | + | ||
| 93 | + # 加载CTM模型 | ||
| 94 | + model.load_model('./trained_ctm_model') |
-
Please register or login to post a comment