Showing
1 changed file
with
95 additions
and
98 deletions
| 1 | import os | 1 | import os |
| 2 | -from transformers.models.bert import BertTokenizer, BertModel | ||
| 3 | -import torch | 2 | +os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| 3 | +import pandas as pd | ||
| 4 | from tqdm import tqdm | 4 | from tqdm import tqdm |
| 5 | +from transformers.models.bert import BertTokenizer, BertModel | ||
| 6 | +from contextualized_topic_models.models.ctm import CombinedTM | ||
| 7 | +from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation | ||
| 5 | import numpy as np | 8 | import numpy as np |
| 9 | +import torch | ||
| 6 | import jieba | 10 | import jieba |
| 7 | -from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation | ||
| 8 | -from contextualized_topic_models.models.ctm import CombinedTM | 11 | +import pickle # 用于保存和加载模型 |
| 9 | 12 | ||
| 10 | class BERT_CTM_Model: | 13 | class BERT_CTM_Model: |
| 11 | - def __init__(self, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=50, device=None): | ||
| 12 | - # 确定设备 (CPU/GPU) | ||
| 13 | - self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | ||
| 14 | - | ||
| 15 | - # 检查模型路径是否存在 | ||
| 16 | - if not os.path.exists(bert_model_path): | ||
| 17 | - raise ValueError(f"BERT模型路径不存在: {bert_model_path}") | ||
| 18 | - if not os.path.exists(ctm_tokenizer_path): | ||
| 19 | - raise ValueError(f"CTM分词器路径不存在: {ctm_tokenizer_path}") | ||
| 20 | - | 14 | + def __init__(self, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=50, model_save_path='./ctm_model'): |
| 15 | + self.bert_model_path = bert_model_path | ||
| 16 | + self.ctm_tokenizer_path = ctm_tokenizer_path | ||
| 17 | + self.n_components = n_components | ||
| 18 | + self.num_epochs = num_epochs | ||
| 19 | + self.model_save_path = model_save_path | ||
| 21 | # 加载BERT模型和tokenizer | 20 | # 加载BERT模型和tokenizer |
| 22 | - self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) | ||
| 23 | - self.model = BertModel.from_pretrained(bert_model_path).to(self.device) | 21 | + self.tokenizer = BertTokenizer.from_pretrained(self.bert_model_path) |
| 22 | + self.model = BertModel.from_pretrained(self.bert_model_path) | ||
| 24 | 23 | ||
| 25 | # 创建CTM数据预处理对象 | 24 | # 创建CTM数据预处理对象 |
| 26 | - self.tp = TopicModelDataPreparation(ctm_tokenizer_path) | ||
| 27 | - self.n_components = n_components | ||
| 28 | - self.num_epochs = num_epochs | ||
| 29 | - self.ctm_model = None | 25 | + self.tp = TopicModelDataPreparation(self.ctm_tokenizer_path) |
| 30 | 26 | ||
| 27 | + def chinese_tokenize(self, text): | ||
| 28 | + """使用jieba对中文文本进行分词""" | ||
| 29 | + return " ".join(jieba.cut(text)) | ||
| 30 | + | ||
| 31 | def get_bert_embeddings(self, texts): | 31 | def get_bert_embeddings(self, texts): |
| 32 | - """使用BERT模型批量生成文本的嵌入向量""" | 32 | + """使用BERT模型生成文本的嵌入向量""" |
| 33 | embeddings = [] | 33 | embeddings = [] |
| 34 | for text in tqdm(texts, desc="Processing texts with BERT"): | 34 | for text in tqdm(texts, desc="Processing texts with BERT"): |
| 35 | - inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80).to(self.device) | 35 | + inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80) |
| 36 | with torch.no_grad(): | 36 | with torch.no_grad(): |
| 37 | outputs = self.model(**inputs) | 37 | outputs = self.model(**inputs) |
| 38 | - embeddings.append(outputs.last_hidden_state[:, 0, :].cpu().numpy()) # [batch_size, hidden_size] | 38 | + embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size] |
| 39 | return np.vstack(embeddings) | 39 | return np.vstack(embeddings) |
| 40 | 40 | ||
| 41 | - def chinese_tokenize(self, text): | ||
| 42 | - """使用jieba对中文文本进行分词""" | ||
| 43 | - return " ".join(jieba.cut(text)) | ||
| 44 | - | ||
| 45 | - def train_ctm(self, texts): | ||
| 46 | - """训练CTM模型""" | ||
| 47 | - try: | ||
| 48 | - # 分词并准备BOW文本 | ||
| 49 | - bow_texts = [self.chinese_tokenize(text) for text in texts] | ||
| 50 | - training_dataset = self.tp.fit(text_for_contextual=texts, text_for_bow=bow_texts) | ||
| 51 | - | ||
| 52 | - # 训练CTM | ||
| 53 | - self.ctm_model = CombinedTM(bow_size=len(self.tp.vocab), contextual_size=768, | ||
| 54 | - n_components=self.n_components, num_epochs=self.num_epochs) | ||
| 55 | - self.ctm_model.fit(training_dataset) | ||
| 56 | - print("CTM模型训练完成") | ||
| 57 | - except Exception as e: | ||
| 58 | - print(f"训练CTM模型时发生错误: {e}") | ||
| 59 | - | ||
| 60 | - def predict(self, texts): | ||
| 61 | - """使用训练好的CTM模型预测新文本的主题分布""" | ||
| 62 | - if not self.ctm_model: | ||
| 63 | - raise ValueError("模型尚未训练或加载,无法进行预测") | ||
| 64 | - | ||
| 65 | - try: | ||
| 66 | - bow_texts = [self.chinese_tokenize(text) for text in texts] | ||
| 67 | - testing_dataset = self.tp.transform(text_for_contextual=texts, text_for_bow=bow_texts) | ||
| 68 | - topic_distributions = self.ctm_model.get_doc_topic_distribution(testing_dataset) | ||
| 69 | - return topic_distributions | ||
| 70 | - except Exception as e: | ||
| 71 | - print(f"预测主题时发生错误: {e}") | ||
| 72 | - return None | ||
| 73 | - | ||
| 74 | - def save_model(self, path): | ||
| 75 | - """保存训练后的CTM模型""" | ||
| 76 | - if self.ctm_model: | ||
| 77 | - self.ctm_model.save(path) | ||
| 78 | - print(f"CTM模型已保存至: {path}") | ||
| 79 | - else: | ||
| 80 | - print("未找到已训练的CTM模型,无法保存") | ||
| 81 | - | ||
| 82 | - def load_model(self, path): | ||
| 83 | - """加载已保存的CTM模型""" | ||
| 84 | - if os.path.exists(path): | ||
| 85 | - self.ctm_model = CombinedTM.load(path) | ||
| 86 | - print(f"CTM模型已加载自: {path}") | ||
| 87 | - else: | ||
| 88 | - print(f"无法加载模型,路径不存在: {path}") | 41 | + def save_model(self, ctm): |
| 42 | + """保存CTM模型、词袋和BoW的vectorizer""" | ||
| 43 | + os.makedirs(self.model_save_path, exist_ok=True) | ||
| 44 | + with open(f"{self.model_save_path}/ctm_model.pkl", 'wb') as f: | ||
| 45 | + pickle.dump(ctm, f) | ||
| 46 | + with open(f"{self.model_save_path}/vocab.pkl", 'wb') as f: | ||
| 47 | + pickle.dump(self.tp.vocab, f) | ||
| 48 | + with open(f"{self.model_save_path}/vectorizer.pkl", 'wb') as f: # 保存BoW的vectorizer | ||
| 49 | + pickle.dump(self.tp.vectorizer, f) | ||
| 50 | + print(f"CTM模型和词袋保存到: {self.model_save_path}") | ||
| 51 | + | ||
| 52 | + def load_model(self): | ||
| 53 | + """加载CTM模型、词袋和BoW的vectorizer""" | ||
| 54 | + with open(f"{self.model_save_path}/ctm_model.pkl", 'rb') as f: | ||
| 55 | + ctm = pickle.load(f) | ||
| 56 | + with open(f"{self.model_save_path}/vocab.pkl", 'rb') as f: | ||
| 57 | + self.tp.vocab = pickle.load(f) | ||
| 58 | + with open(f"{self.model_save_path}/vectorizer.pkl", 'rb') as f: # 加载BoW的vectorizer | ||
| 59 | + self.tp.vectorizer = pickle.load(f) | ||
| 60 | + print(f"CTM模型、词袋和vectorizer加载成功") | ||
| 61 | + return ctm | ||
| 62 | + | ||
| 63 | + def train(self, csv_file): | ||
| 64 | + """训练BERT + CTM模型并保存最终的特征向量和标签""" | ||
| 65 | + # 读取CSV文件中的文本和标签 | ||
| 66 | + data = pd.read_csv(csv_file) | ||
| 67 | + texts = data['TEXT'].tolist() | ||
| 68 | + labels = data['label'].tolist() | ||
| 69 | + | ||
| 70 | + # Step 1: 获取BERT的嵌入向量 | ||
| 71 | + print("Extracting BERT embeddings...") | ||
| 72 | + bert_embeddings = self.get_bert_embeddings(texts) # [batch_size, sequence_length, hidden_size] | ||
| 73 | + | ||
| 74 | + # Step 2: 准备CTM数据 | ||
| 75 | + print("Preparing data for CTM using training set...") | ||
| 76 | + bow_texts = [self.chinese_tokenize(text) for text in texts] | ||
| 77 | + training_dataset = self.tp.fit(text_for_contextual=texts, text_for_bow=bow_texts) | ||
| 78 | + | ||
| 79 | + # Step 3: 替换BERT嵌入 | ||
| 80 | + training_dataset._X = bert_embeddings[:, 0, :] # 只使用第一个token的向量用于CTM | ||
| 81 | + | ||
| 82 | + # Step 4: 训练CTM模型 | ||
| 83 | + print("Training CTM model...") | ||
| 84 | + ctm = CombinedTM(bow_size=len(self.tp.vocab), contextual_size=768, n_components=self.n_components, num_epochs=self.num_epochs) | ||
| 85 | + ctm.fit(train_dataset=training_dataset, verbose=True) | ||
| 86 | + | ||
| 87 | + # Step 5: 保存CTM模型和词袋 | ||
| 88 | + self.save_model(ctm) | ||
| 89 | + | ||
| 90 | + # Step 6: 获取CTM的特征向量 | ||
| 91 | + print("Generating CTM features...") | ||
| 92 | + ctm_features = ctm.get_doc_topic_distribution(training_dataset) # [batch_size, n_components] | ||
| 93 | + | ||
| 94 | + # Step 7: 将CTM特征扩展为与BERT的sequence长度一致 | ||
| 95 | + sequence_length = bert_embeddings.shape[1] | ||
| 96 | + ctm_features_expanded = np.repeat(ctm_features[:, np.newaxis, :], sequence_length, axis=1) # [batch_size, sequence_length, n_components] | ||
| 97 | + | ||
| 98 | + # Step 8: 拼接BERT嵌入和CTM特征 | ||
| 99 | + final_embeddings = np.concatenate([bert_embeddings, ctm_features_expanded], axis=-1) # [batch_size, sequence_length, hidden_size + n_components] | ||
| 100 | + | ||
| 101 | + return bert_embeddings | ||
| 89 | 102 | ||
| 90 | if __name__ == "__main__": | 103 | if __name__ == "__main__": |
| 91 | - # 设定BERT和CTM模型的路径 | ||
| 92 | - bert_model_path = './bert_model' | ||
| 93 | - ctm_tokenizer_path = './sentence_bert_model' | ||
| 94 | - | ||
| 95 | - # 初始化模型 | ||
| 96 | - model = BERT_CTM_Model(bert_model_path, ctm_tokenizer_path) | ||
| 97 | - | ||
| 98 | - # 示例文本 | ||
| 99 | - texts = ["这是第一个文本", "这是第二个文本"] | ||
| 100 | - | ||
| 101 | - # 训练CTM模型 | ||
| 102 | - model.train_ctm(texts) | ||
| 103 | - | ||
| 104 | - # 保存CTM模型 | ||
| 105 | - model.save_model('./trained_ctm_model') | ||
| 106 | - | ||
| 107 | - # 加载CTM模型 | ||
| 108 | - model.load_model('./trained_ctm_model') | ||
| 109 | - | ||
| 110 | - # 预测新文本的主题分布 | ||
| 111 | - new_texts = ["这是一个新的文本", "另外一个新文本"] | ||
| 112 | - topic_distributions = model.predict(new_texts) | ||
| 113 | - | ||
| 114 | - # 输出预测结果 | ||
| 115 | - if topic_distributions is not None: | ||
| 116 | - for idx, distribution in enumerate(topic_distributions): | ||
| 117 | - print(f"文本 {idx+1} 的主题分布: {distribution}") | 104 | + # 创建BERT_CTM_Model实例 |
| 105 | + model = BERT_CTM_Model( | ||
| 106 | + bert_model_path='./bert_model', # BERT模型的路径 | ||
| 107 | + ctm_tokenizer_path='./sentence_bert_model', # CTM分词器的路径 | ||
| 108 | + n_components=12, # 主题数量 | ||
| 109 | + num_epochs=50, # 训练轮次 | ||
| 110 | + model_save_path='./ctm_model', # 保存路径 | ||
| 111 | + ) | ||
| 112 | + | ||
| 113 | + # 传入CSV文件路径进行训练 | ||
| 114 | + model.train("./train.csv") |
-
Please register or login to post a comment