戒酒的李白

The integration process and a complete use example are given

  1 +import torch
  2 +import numpy as np
  3 +from transformers.models.bert import BertTokenizer, BertModel
  4 +from MHA import MultiHeadAttentionLayer
  5 +from classifier import FinalClassifier
  6 +
  7 +
  8 +# 加载BERT模型并生成嵌入
  9 +def get_sentence_embeddings(sentences, bert_model_path, max_length=80):
  10 + """使用BERT生成多个句子的嵌入"""
  11 + tokenizer = BertTokenizer.from_pretrained(bert_model_path)
  12 + model = BertModel.from_pretrained(bert_model_path)
  13 +
  14 + embeddings = []
  15 + for sentence in sentences:
  16 + inputs = tokenizer(sentence, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length)
  17 + with torch.no_grad():
  18 + outputs = model(**inputs)
  19 + embedding = outputs.last_hidden_state.cpu().numpy()
  20 + embeddings.append(embedding)
  21 +
  22 + return np.vstack(embeddings) # 保持多句子输出格式一致
  23 +
  24 +
  25 +# 加载已经训练好的模型
  26 +def load_model(model_path):
  27 + print(f"加载模型 {model_path}...")
  28 + model = torch.load(model_path)
  29 + model.eval() # 设置为评估模式
  30 + return model
  31 +
  32 +
  33 +# 多句子的预测函数
  34 +def predict_sentences(sentences, model, bert_model_path, max_length=80):
  35 + # 检查是否为单个句子输入,如果是,将其包装为列表
  36 + if isinstance(sentences, str):
  37 + sentences = [sentences]
  38 +
  39 + # 生成句子的BERT嵌入
  40 + embeddings = get_sentence_embeddings(sentences, bert_model_path, max_length)
  41 +
  42 + # 转换为Tensor
  43 + embedding_tensors = torch.tensor(embeddings, dtype=torch.float32).squeeze(1) # 修改squeeze以适应多个句子
  44 +
  45 + # 检查嵌入维度是否符合注意力层要求
  46 + embed_size = embedding_tensors.size(-1)
  47 + num_heads = 12
  48 + if embed_size % num_heads != 0:
  49 + raise ValueError(f"嵌入维度 {embed_size} 无法被注意力头数量 {num_heads} 整除")
  50 +
  51 + # 加载多头注意力机制
  52 + attention_model = MultiHeadAttentionLayer(embed_size=embed_size, num_heads=num_heads)
  53 +
  54 + predictions = []
  55 + with torch.no_grad():
  56 + for embedding_tensor in embedding_tensors:
  57 + attention_output = attention_model(embedding_tensor.unsqueeze(0), embedding_tensor.unsqueeze(0),
  58 + embedding_tensor.unsqueeze(0))
  59 + outputs = model(attention_output)
  60 + outputs = torch.mean(outputs, dim=1)
  61 + _, predicted = torch.max(outputs, 1) # 获取预测的类别
  62 + predictions.append(predicted.item())
  63 +
  64 + return predictions
  65 +
  66 +
  67 +if __name__ == "__main__":
  68 + # 加载已经训练好的模型
  69 + model_path = './final_model.pt'
  70 + model = load_model(model_path)
  71 +
  72 + # 需要预测的句子,可以输入单个句子或多个句子
  73 + sentences = ["这是一条待预测的句子",
  74 + "他在你面前骂黑鬼 印度屎屁尿背后就会根人家骂你中国猴子,这可能不是种族歧视这是素质太低",
  75 + "完美女朋友",
  76 + "在美国的亚裔就是一盘散沙。日裔看不起韩裔 韩裔仇视日裔 港澳台裔看不起大陆裔,大陆裔里面又歧视福建裔"] # 可以替换为单个句子或多个句子
  77 +
  78 + # BERT模型路径
  79 + bert_model_path = './bert_model'
  80 +
  81 + # 对句子进行预测
  82 + predicted_labels = predict_sentences(sentences, model, bert_model_path)
  83 +
  84 + # 根据预测的label输出对应的文本
  85 + for i, label in enumerate(predicted_labels):
  86 + if label == 1:
  87 + print(f"句子: '{sentences[i]}' 预测结果: 不良言论")
  88 + elif label == 0:
  89 + print(f"句子: '{sentences[i]}' 预测结果: 正常言论")
  90 + else:
  91 + print(f"句子: '{sentences[i]}' 未知标签: {label}")