BERT_CTM.py
5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import pandas as pd
from tqdm import tqdm
from transformers.models.bert import BertTokenizer, BertModel
from contextualized_topic_models.models.ctm import CombinedTM
from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation
from contextualized_topic_models.utils.preprocessing import WhiteSpacePreprocessing
import numpy as np
import torch
import jieba
import pickle # 用于保存和加载模型
from utils.logger import model_logger as logging
class BERT_CTM:
def __init__(self, model_save_path='model_pro/saved_models/ctm_model.pkl'):
self.model_save_path = model_save_path
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.bert_model = None
self.tokenizer = None
self.ctm_model = None
self.vocab = None
self.vectorizer = None
def save_model(self):
"""保存模型和词袋"""
try:
with open(self.model_save_path, 'wb') as f:
pickle.dump({
'ctm_model': self.ctm_model,
'vocab': self.vocab,
'vectorizer': self.vectorizer
}, f)
logging.info(f"CTM模型和词袋保存到: {self.model_save_path}")
except Exception as e:
logging.error(f"保存模型时发生错误: {e}")
def load_model(self):
"""加载模型和词袋"""
try:
with open(self.model_save_path, 'rb') as f:
saved_data = pickle.load(f)
self.ctm_model = saved_data['ctm_model']
self.vocab = saved_data['vocab']
self.vectorizer = saved_data['vectorizer']
logging.info("CTM模型、词袋和vectorizer加载成功")
except Exception as e:
logging.error(f"加载模型时发生错误: {e}")
raise
def train(self, texts, num_topics=10, num_epochs=100):
"""训练CTM模型"""
try:
# 初始化BERT
if not self.bert_model:
self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
self.bert_model = BertModel.from_pretrained('bert-base-chinese').to(self.device)
# 提取BERT嵌入
logging.info("正在提取BERT嵌入...")
embeddings = self._get_bert_embeddings(texts)
# 准备CTM数据
logging.info("正在准备CTM训练数据...")
preprocessor = WhiteSpacePreprocessing(texts)
dataset = TopicModelDataPreparation(embeddings)
# 训练CTM模型
logging.info("正在训练CTM模型...")
self.ctm_model = CombinedTM(
bow_size=len(preprocessor.vocab),
contextual_size=768, # BERT输出维度
n_components=num_topics,
num_epochs=num_epochs
)
self.ctm_model.fit(dataset)
# 保存词袋相关数据
self.vocab = preprocessor.vocab
self.vectorizer = preprocessor.vectorizer
# 保存模型
self.save_model()
logging.info("模型训练完成并保存")
except Exception as e:
logging.error(f"训练模型时发生错误: {e}")
raise
def _get_bert_embeddings(self, texts):
"""获取文本的BERT嵌入"""
embeddings = []
try:
for text in texts:
inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.bert_model(**inputs)
# 使用[CLS]标记的输出作为文档表示
embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
embeddings.append(embedding[0])
return np.array(embeddings)
except Exception as e:
logging.error(f"获取BERT嵌入时发生错误: {e}")
raise
def get_topics(self, num_words=10):
"""获取主题词"""
try:
if not self.ctm_model or not self.vocab:
raise ValueError("模型未训练或未加载")
topics = []
for topic_idx in range(self.ctm_model.n_components):
topic = self.ctm_model.get_topic_lists(top_n=num_words)[topic_idx]
topics.append(topic)
return topics
except Exception as e:
logging.error(f"获取主题词时发生错误: {e}")
raise
if __name__ == "__main__":
# 创建BERT_CTM实例
model = BERT_CTM(
model_save_path='model_pro/saved_models/ctm_model.pkl', # 保存路径
)
# 传入CSV文件路径进行训练
model.train("./train.csv")