Showing
1 changed file
with
91 additions
and
0 deletions
model_pro/using_example.py
0 → 100644
| 1 | +import torch | ||
| 2 | +import numpy as np | ||
| 3 | +from transformers.models.bert import BertTokenizer, BertModel | ||
| 4 | +from MHA import MultiHeadAttentionLayer | ||
| 5 | +from classifier import FinalClassifier | ||
| 6 | + | ||
| 7 | + | ||
| 8 | +# 加载BERT模型并生成嵌入 | ||
| 9 | +def get_sentence_embeddings(sentences, bert_model_path, max_length=80): | ||
| 10 | + """使用BERT生成多个句子的嵌入""" | ||
| 11 | + tokenizer = BertTokenizer.from_pretrained(bert_model_path) | ||
| 12 | + model = BertModel.from_pretrained(bert_model_path) | ||
| 13 | + | ||
| 14 | + embeddings = [] | ||
| 15 | + for sentence in sentences: | ||
| 16 | + inputs = tokenizer(sentence, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length) | ||
| 17 | + with torch.no_grad(): | ||
| 18 | + outputs = model(**inputs) | ||
| 19 | + embedding = outputs.last_hidden_state.cpu().numpy() | ||
| 20 | + embeddings.append(embedding) | ||
| 21 | + | ||
| 22 | + return np.vstack(embeddings) # 保持多句子输出格式一致 | ||
| 23 | + | ||
| 24 | + | ||
| 25 | +# 加载已经训练好的模型 | ||
| 26 | +def load_model(model_path): | ||
| 27 | + print(f"加载模型 {model_path}...") | ||
| 28 | + model = torch.load(model_path) | ||
| 29 | + model.eval() # 设置为评估模式 | ||
| 30 | + return model | ||
| 31 | + | ||
| 32 | + | ||
| 33 | +# 多句子的预测函数 | ||
| 34 | +def predict_sentences(sentences, model, bert_model_path, max_length=80): | ||
| 35 | + # 检查是否为单个句子输入,如果是,将其包装为列表 | ||
| 36 | + if isinstance(sentences, str): | ||
| 37 | + sentences = [sentences] | ||
| 38 | + | ||
| 39 | + # 生成句子的BERT嵌入 | ||
| 40 | + embeddings = get_sentence_embeddings(sentences, bert_model_path, max_length) | ||
| 41 | + | ||
| 42 | + # 转换为Tensor | ||
| 43 | + embedding_tensors = torch.tensor(embeddings, dtype=torch.float32).squeeze(1) # 修改squeeze以适应多个句子 | ||
| 44 | + | ||
| 45 | + # 检查嵌入维度是否符合注意力层要求 | ||
| 46 | + embed_size = embedding_tensors.size(-1) | ||
| 47 | + num_heads = 12 | ||
| 48 | + if embed_size % num_heads != 0: | ||
| 49 | + raise ValueError(f"嵌入维度 {embed_size} 无法被注意力头数量 {num_heads} 整除") | ||
| 50 | + | ||
| 51 | + # 加载多头注意力机制 | ||
| 52 | + attention_model = MultiHeadAttentionLayer(embed_size=embed_size, num_heads=num_heads) | ||
| 53 | + | ||
| 54 | + predictions = [] | ||
| 55 | + with torch.no_grad(): | ||
| 56 | + for embedding_tensor in embedding_tensors: | ||
| 57 | + attention_output = attention_model(embedding_tensor.unsqueeze(0), embedding_tensor.unsqueeze(0), | ||
| 58 | + embedding_tensor.unsqueeze(0)) | ||
| 59 | + outputs = model(attention_output) | ||
| 60 | + outputs = torch.mean(outputs, dim=1) | ||
| 61 | + _, predicted = torch.max(outputs, 1) # 获取预测的类别 | ||
| 62 | + predictions.append(predicted.item()) | ||
| 63 | + | ||
| 64 | + return predictions | ||
| 65 | + | ||
| 66 | + | ||
| 67 | +if __name__ == "__main__": | ||
| 68 | + # 加载已经训练好的模型 | ||
| 69 | + model_path = './final_model.pt' | ||
| 70 | + model = load_model(model_path) | ||
| 71 | + | ||
| 72 | + # 需要预测的句子,可以输入单个句子或多个句子 | ||
| 73 | + sentences = ["这是一条待预测的句子", | ||
| 74 | + "他在你面前骂黑鬼 印度屎屁尿背后就会根人家骂你中国猴子,这可能不是种族歧视这是素质太低", | ||
| 75 | + "完美女朋友", | ||
| 76 | + "在美国的亚裔就是一盘散沙。日裔看不起韩裔 韩裔仇视日裔 港澳台裔看不起大陆裔,大陆裔里面又歧视福建裔"] # 可以替换为单个句子或多个句子 | ||
| 77 | + | ||
| 78 | + # BERT模型路径 | ||
| 79 | + bert_model_path = './bert_model' | ||
| 80 | + | ||
| 81 | + # 对句子进行预测 | ||
| 82 | + predicted_labels = predict_sentences(sentences, model, bert_model_path) | ||
| 83 | + | ||
| 84 | + # 根据预测的label输出对应的文本 | ||
| 85 | + for i, label in enumerate(predicted_labels): | ||
| 86 | + if label == 1: | ||
| 87 | + print(f"句子: '{sentences[i]}' 预测结果: 不良言论") | ||
| 88 | + elif label == 0: | ||
| 89 | + print(f"句子: '{sentences[i]}' 预测结果: 正常言论") | ||
| 90 | + else: | ||
| 91 | + print(f"句子: '{sentences[i]}' 未知标签: {label}") |
-
Please register or login to post a comment