戒酒的李白

Implement the get_bert_ctm_embeddings function and embedding generation and loading logic

  1 +import os
  2 +import numpy as np
  3 +from BERT_CTM import BERT_CTM_Model # 假设BERT_CTM模型在这个文件中
  4 +
  5 +# BERT_CTM 嵌入生成和加载函数
  6 +def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=20, save_path=None):
  7 + """
  8 + 获取或生成 BERT+CTM 嵌入,并保存到文件。
  9 +
  10 + :param texts: 需要嵌入的文本
  11 + :param bert_model_path: BERT 模型的路径
  12 + :param ctm_tokenizer_path: CTM tokenizer 的路径
  13 + :param n_components: 生成的主题数量
  14 + :param num_epochs: 训练的epoch数
  15 + :param save_path: 嵌入保存路径
  16 + :return: 生成或加载的嵌入
  17 + """
  18 + # 检查是否已经存在保存的嵌入文件
  19 + if save_path and os.path.exists(save_path):
  20 + print(f"从文件 {save_path} 加载嵌入...")
  21 + embeddings = np.load(save_path)
  22 + else:
  23 + print("生成 BERT+CTM 嵌入...")
  24 + bert_ctm_model = BERT_CTM_Model(
  25 + bert_model_path=bert_model_path,
  26 + ctm_tokenizer_path=ctm_tokenizer_path,
  27 + n_components=n_components,
  28 + num_epochs=num_epochs
  29 + )
  30 + embeddings = bert_ctm_model.train(texts) # 生成嵌入
  31 +
  32 + # 保存嵌入到文件
  33 + if save_path:
  34 + print(f"保存嵌入到文件 {save_path}...")
  35 + np.save(save_path, embeddings)
  36 +
  37 + return embeddings
  38 +
  39 +
  40 +if __name__ == "__main__":
  41 + # 示例调用
  42 + sample_texts = ["This is a test text.", "Another example of text data."]
  43 + bert_model_path = './bert_model'
  44 + ctm_tokenizer_path = './sentence_bert_model'
  45 + save_path = 'sample_embeddings.npy'
  46 +
  47 + # 生成或加载 BERT+CTM 嵌入
  48 + embeddings = get_bert_ctm_embeddings(sample_texts, bert_model_path, ctm_tokenizer_path, save_path=save_path)
  49 +
  50 + # 打印嵌入形状
  51 + print(f"嵌入形状: {embeddings.shape}")