Showing
1 changed file
with
21 additions
and
5 deletions
| @@ -4,13 +4,20 @@ import torch | @@ -4,13 +4,20 @@ import torch | ||
| 4 | from tqdm import tqdm | 4 | from tqdm import tqdm |
| 5 | import numpy as np | 5 | import numpy as np |
| 6 | import jieba | 6 | import jieba |
| 7 | +from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation | ||
| 8 | +from contextualized_topic_models.models.ctm import CombinedTM | ||
| 7 | 9 | ||
| 8 | class BERT_CTM_Model: | 10 | class BERT_CTM_Model: |
| 9 | - def __init__(self, bert_model_path): | 11 | + def __init__(self, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=50): |
| 10 | # 加载BERT模型和tokenizer | 12 | # 加载BERT模型和tokenizer |
| 11 | self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) | 13 | self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) |
| 12 | self.model = BertModel.from_pretrained(bert_model_path) | 14 | self.model = BertModel.from_pretrained(bert_model_path) |
| 13 | 15 | ||
| 16 | + # 创建CTM数据预处理对象 | ||
| 17 | + self.tp = TopicModelDataPreparation(ctm_tokenizer_path) | ||
| 18 | + self.n_components = n_components | ||
| 19 | + self.num_epochs = num_epochs | ||
| 20 | + | ||
| 14 | def get_bert_embeddings(self, texts): | 21 | def get_bert_embeddings(self, texts): |
| 15 | """使用BERT模型批量生成文本的嵌入向量""" | 22 | """使用BERT模型批量生成文本的嵌入向量""" |
| 16 | embeddings = [] | 23 | embeddings = [] |
| @@ -25,8 +32,17 @@ class BERT_CTM_Model: | @@ -25,8 +32,17 @@ class BERT_CTM_Model: | ||
| 25 | """使用jieba对中文文本进行分词""" | 32 | """使用jieba对中文文本进行分词""" |
| 26 | return " ".join(jieba.cut(text)) | 33 | return " ".join(jieba.cut(text)) |
| 27 | 34 | ||
| 35 | + def train_ctm(self, texts): | ||
| 36 | + """训练CTM模型""" | ||
| 37 | + bow_texts = [self.chinese_tokenize(text) for text in texts] | ||
| 38 | + training_dataset = self.tp.fit(text_for_contextual=texts, text_for_bow=bow_texts) | ||
| 39 | + | ||
| 40 | + # 训练CTM | ||
| 41 | + ctm = CombinedTM(bow_size=len(self.tp.vocab), contextual_size=768, n_components=self.n_components, num_epochs=self.num_epochs) | ||
| 42 | + ctm.fit(training_dataset) | ||
| 43 | + print("CTM模型训练完成") | ||
| 44 | + | ||
| 28 | if __name__ == "__main__": | 45 | if __name__ == "__main__": |
| 29 | - model = BERT_CTM_Model('./bert_model') | ||
| 30 | - text = "这是一个测试文本" | ||
| 31 | - tokenized_text = model.chinese_tokenize(text) | ||
| 32 | - print(tokenized_text) | 46 | + model = BERT_CTM_Model('./bert_model', './sentence_bert_model') |
| 47 | + texts = ["这是第一个文本", "这是第二个文本"] | ||
| 48 | + model.train_ctm(texts) |
-
Please register or login to post a comment