Implement the get_bert_ctm_embeddings function and embedding generation and loading logic
Showing
1 changed file
with
51 additions
and
0 deletions
model_pro/BCAT.py
0 → 100644
| 1 | +import os | ||
| 2 | +import numpy as np | ||
| 3 | +from BERT_CTM import BERT_CTM_Model # 假设BERT_CTM模型在这个文件中 | ||
| 4 | + | ||
| 5 | +# BERT_CTM 嵌入生成和加载函数 | ||
| 6 | +def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=20, save_path=None): | ||
| 7 | + """ | ||
| 8 | + 获取或生成 BERT+CTM 嵌入,并保存到文件。 | ||
| 9 | + | ||
| 10 | + :param texts: 需要嵌入的文本 | ||
| 11 | + :param bert_model_path: BERT 模型的路径 | ||
| 12 | + :param ctm_tokenizer_path: CTM tokenizer 的路径 | ||
| 13 | + :param n_components: 生成的主题数量 | ||
| 14 | + :param num_epochs: 训练的epoch数 | ||
| 15 | + :param save_path: 嵌入保存路径 | ||
| 16 | + :return: 生成或加载的嵌入 | ||
| 17 | + """ | ||
| 18 | + # 检查是否已经存在保存的嵌入文件 | ||
| 19 | + if save_path and os.path.exists(save_path): | ||
| 20 | + print(f"从文件 {save_path} 加载嵌入...") | ||
| 21 | + embeddings = np.load(save_path) | ||
| 22 | + else: | ||
| 23 | + print("生成 BERT+CTM 嵌入...") | ||
| 24 | + bert_ctm_model = BERT_CTM_Model( | ||
| 25 | + bert_model_path=bert_model_path, | ||
| 26 | + ctm_tokenizer_path=ctm_tokenizer_path, | ||
| 27 | + n_components=n_components, | ||
| 28 | + num_epochs=num_epochs | ||
| 29 | + ) | ||
| 30 | + embeddings = bert_ctm_model.train(texts) # 生成嵌入 | ||
| 31 | + | ||
| 32 | + # 保存嵌入到文件 | ||
| 33 | + if save_path: | ||
| 34 | + print(f"保存嵌入到文件 {save_path}...") | ||
| 35 | + np.save(save_path, embeddings) | ||
| 36 | + | ||
| 37 | + return embeddings | ||
| 38 | + | ||
| 39 | + | ||
| 40 | +if __name__ == "__main__": | ||
| 41 | + # 示例调用 | ||
| 42 | + sample_texts = ["This is a test text.", "Another example of text data."] | ||
| 43 | + bert_model_path = './bert_model' | ||
| 44 | + ctm_tokenizer_path = './sentence_bert_model' | ||
| 45 | + save_path = 'sample_embeddings.npy' | ||
| 46 | + | ||
| 47 | + # 生成或加载 BERT+CTM 嵌入 | ||
| 48 | + embeddings = get_bert_ctm_embeddings(sample_texts, bert_model_path, ctm_tokenizer_path, save_path=save_path) | ||
| 49 | + | ||
| 50 | + # 打印嵌入形状 | ||
| 51 | + print(f"嵌入形状: {embeddings.shape}") |
-
Please register or login to post a comment