戒酒的李白

The BERT_CTM module is finally completed

1 import os 1 import os
2 -from transformers.models.bert import BertTokenizer, BertModel  
3 -import torch 2 +os.environ["TOKENIZERS_PARALLELISM"] = "false"
  3 +import pandas as pd
4 from tqdm import tqdm 4 from tqdm import tqdm
  5 +from transformers.models.bert import BertTokenizer, BertModel
  6 +from contextualized_topic_models.models.ctm import CombinedTM
  7 +from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation
5 import numpy as np 8 import numpy as np
  9 +import torch
6 import jieba 10 import jieba
7 -from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation  
8 -from contextualized_topic_models.models.ctm import CombinedTM 11 +import pickle # 用于保存和加载模型
9 12
10 class BERT_CTM_Model: 13 class BERT_CTM_Model:
11 - def __init__(self, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=50, device=None):  
12 - # 确定设备 (CPU/GPU)  
13 - self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")  
14 -  
15 - # 检查模型路径是否存在  
16 - if not os.path.exists(bert_model_path):  
17 - raise ValueError(f"BERT模型路径不存在: {bert_model_path}")  
18 - if not os.path.exists(ctm_tokenizer_path):  
19 - raise ValueError(f"CTM分词器路径不存在: {ctm_tokenizer_path}")  
20 - 14 + def __init__(self, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=50, model_save_path='./ctm_model'):
  15 + self.bert_model_path = bert_model_path
  16 + self.ctm_tokenizer_path = ctm_tokenizer_path
  17 + self.n_components = n_components
  18 + self.num_epochs = num_epochs
  19 + self.model_save_path = model_save_path
21 # 加载BERT模型和tokenizer 20 # 加载BERT模型和tokenizer
22 - self.tokenizer = BertTokenizer.from_pretrained(bert_model_path)  
23 - self.model = BertModel.from_pretrained(bert_model_path).to(self.device) 21 + self.tokenizer = BertTokenizer.from_pretrained(self.bert_model_path)
  22 + self.model = BertModel.from_pretrained(self.bert_model_path)
24 23
25 # 创建CTM数据预处理对象 24 # 创建CTM数据预处理对象
26 - self.tp = TopicModelDataPreparation(ctm_tokenizer_path)  
27 - self.n_components = n_components  
28 - self.num_epochs = num_epochs  
29 - self.ctm_model = None 25 + self.tp = TopicModelDataPreparation(self.ctm_tokenizer_path)
30 26
  27 + def chinese_tokenize(self, text):
  28 + """使用jieba对中文文本进行分词"""
  29 + return " ".join(jieba.cut(text))
  30 +
31 def get_bert_embeddings(self, texts): 31 def get_bert_embeddings(self, texts):
32 - """使用BERT模型批量生成文本的嵌入向量""" 32 + """使用BERT模型生成文本的嵌入向量"""
33 embeddings = [] 33 embeddings = []
34 for text in tqdm(texts, desc="Processing texts with BERT"): 34 for text in tqdm(texts, desc="Processing texts with BERT"):
35 - inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80).to(self.device) 35 + inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80)
36 with torch.no_grad(): 36 with torch.no_grad():
37 outputs = self.model(**inputs) 37 outputs = self.model(**inputs)
38 - embeddings.append(outputs.last_hidden_state[:, 0, :].cpu().numpy()) # [batch_size, hidden_size] 38 + embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size]
39 return np.vstack(embeddings) 39 return np.vstack(embeddings)
40 40
41 - def chinese_tokenize(self, text):  
42 - """使用jieba对中文文本进行分词"""  
43 - return " ".join(jieba.cut(text))  
44 -  
45 - def train_ctm(self, texts):  
46 - """训练CTM模型"""  
47 - try:  
48 - # 分词并准备BOW文本  
49 - bow_texts = [self.chinese_tokenize(text) for text in texts]  
50 - training_dataset = self.tp.fit(text_for_contextual=texts, text_for_bow=bow_texts)  
51 -  
52 - # 训练CTM  
53 - self.ctm_model = CombinedTM(bow_size=len(self.tp.vocab), contextual_size=768,  
54 - n_components=self.n_components, num_epochs=self.num_epochs)  
55 - self.ctm_model.fit(training_dataset)  
56 - print("CTM模型训练完成")  
57 - except Exception as e:  
58 - print(f"训练CTM模型时发生错误: {e}")  
59 -  
60 - def predict(self, texts):  
61 - """使用训练好的CTM模型预测新文本的主题分布"""  
62 - if not self.ctm_model:  
63 - raise ValueError("模型尚未训练或加载,无法进行预测")  
64 -  
65 - try:  
66 - bow_texts = [self.chinese_tokenize(text) for text in texts]  
67 - testing_dataset = self.tp.transform(text_for_contextual=texts, text_for_bow=bow_texts)  
68 - topic_distributions = self.ctm_model.get_doc_topic_distribution(testing_dataset)  
69 - return topic_distributions  
70 - except Exception as e:  
71 - print(f"预测主题时发生错误: {e}")  
72 - return None  
73 -  
74 - def save_model(self, path):  
75 - """保存训练后的CTM模型"""  
76 - if self.ctm_model:  
77 - self.ctm_model.save(path)  
78 - print(f"CTM模型已保存至: {path}")  
79 - else:  
80 - print("未找到已训练的CTM模型,无法保存")  
81 -  
82 - def load_model(self, path):  
83 - """加载已保存的CTM模型"""  
84 - if os.path.exists(path):  
85 - self.ctm_model = CombinedTM.load(path)  
86 - print(f"CTM模型已加载自: {path}")  
87 - else:  
88 - print(f"无法加载模型,路径不存在: {path}") 41 + def save_model(self, ctm):
  42 + """保存CTM模型、词袋和BoW的vectorizer"""
  43 + os.makedirs(self.model_save_path, exist_ok=True)
  44 + with open(f"{self.model_save_path}/ctm_model.pkl", 'wb') as f:
  45 + pickle.dump(ctm, f)
  46 + with open(f"{self.model_save_path}/vocab.pkl", 'wb') as f:
  47 + pickle.dump(self.tp.vocab, f)
  48 + with open(f"{self.model_save_path}/vectorizer.pkl", 'wb') as f: # 保存BoW的vectorizer
  49 + pickle.dump(self.tp.vectorizer, f)
  50 + print(f"CTM模型和词袋保存到: {self.model_save_path}")
  51 +
  52 + def load_model(self):
  53 + """加载CTM模型、词袋和BoW的vectorizer"""
  54 + with open(f"{self.model_save_path}/ctm_model.pkl", 'rb') as f:
  55 + ctm = pickle.load(f)
  56 + with open(f"{self.model_save_path}/vocab.pkl", 'rb') as f:
  57 + self.tp.vocab = pickle.load(f)
  58 + with open(f"{self.model_save_path}/vectorizer.pkl", 'rb') as f: # 加载BoW的vectorizer
  59 + self.tp.vectorizer = pickle.load(f)
  60 + print(f"CTM模型、词袋和vectorizer加载成功")
  61 + return ctm
  62 +
  63 + def train(self, csv_file):
  64 + """训练BERT + CTM模型并保存最终的特征向量和标签"""
  65 + # 读取CSV文件中的文本和标签
  66 + data = pd.read_csv(csv_file)
  67 + texts = data['TEXT'].tolist()
  68 + labels = data['label'].tolist()
  69 +
  70 + # Step 1: 获取BERT的嵌入向量
  71 + print("Extracting BERT embeddings...")
  72 + bert_embeddings = self.get_bert_embeddings(texts) # [batch_size, sequence_length, hidden_size]
  73 +
  74 + # Step 2: 准备CTM数据
  75 + print("Preparing data for CTM using training set...")
  76 + bow_texts = [self.chinese_tokenize(text) for text in texts]
  77 + training_dataset = self.tp.fit(text_for_contextual=texts, text_for_bow=bow_texts)
  78 +
  79 + # Step 3: 替换BERT嵌入
  80 + training_dataset._X = bert_embeddings[:, 0, :] # 只使用第一个token的向量用于CTM
  81 +
  82 + # Step 4: 训练CTM模型
  83 + print("Training CTM model...")
  84 + ctm = CombinedTM(bow_size=len(self.tp.vocab), contextual_size=768, n_components=self.n_components, num_epochs=self.num_epochs)
  85 + ctm.fit(train_dataset=training_dataset, verbose=True)
  86 +
  87 + # Step 5: 保存CTM模型和词袋
  88 + self.save_model(ctm)
  89 +
  90 + # Step 6: 获取CTM的特征向量
  91 + print("Generating CTM features...")
  92 + ctm_features = ctm.get_doc_topic_distribution(training_dataset) # [batch_size, n_components]
  93 +
  94 + # Step 7: 将CTM特征扩展为与BERT的sequence长度一致
  95 + sequence_length = bert_embeddings.shape[1]
  96 + ctm_features_expanded = np.repeat(ctm_features[:, np.newaxis, :], sequence_length, axis=1) # [batch_size, sequence_length, n_components]
  97 +
  98 + # Step 8: 拼接BERT嵌入和CTM特征
  99 + final_embeddings = np.concatenate([bert_embeddings, ctm_features_expanded], axis=-1) # [batch_size, sequence_length, hidden_size + n_components]
  100 +
  101 + return bert_embeddings
89 102
90 if __name__ == "__main__": 103 if __name__ == "__main__":
91 - # 设定BERT和CTM模型的路径  
92 - bert_model_path = './bert_model'  
93 - ctm_tokenizer_path = './sentence_bert_model'  
94 -  
95 - # 初始化模型  
96 - model = BERT_CTM_Model(bert_model_path, ctm_tokenizer_path)  
97 -  
98 - # 示例文本  
99 - texts = ["这是第一个文本", "这是第二个文本"]  
100 -  
101 - # 训练CTM模型  
102 - model.train_ctm(texts)  
103 -  
104 - # 保存CTM模型  
105 - model.save_model('./trained_ctm_model')  
106 -  
107 - # 加载CTM模型  
108 - model.load_model('./trained_ctm_model')  
109 -  
110 - # 预测新文本的主题分布  
111 - new_texts = ["这是一个新的文本", "另外一个新文本"]  
112 - topic_distributions = model.predict(new_texts)  
113 -  
114 - # 输出预测结果  
115 - if topic_distributions is not None:  
116 - for idx, distribution in enumerate(topic_distributions):  
117 - print(f"文本 {idx+1} 的主题分布: {distribution}") 104 + # 创建BERT_CTM_Model实例
  105 + model = BERT_CTM_Model(
  106 + bert_model_path='./bert_model', # BERT模型的路径
  107 + ctm_tokenizer_path='./sentence_bert_model', # CTM分词器的路径
  108 + n_components=12, # 主题数量
  109 + num_epochs=50, # 训练轮次
  110 + model_save_path='./ctm_model', # 保存路径
  111 + )
  112 +
  113 + # 传入CSV文件路径进行训练
  114 + model.train("./train.csv")