juanboy

predict.demo built

  1 +import torch
  2 +import pandas as pd
  3 +import numpy as np
  4 +from torch.utils.data import DataLoader, TensorDataset
  5 +from tqdm import tqdm
  6 +import os
  7 +import sys
  8 +import json
  9 +import chardet # 导入 chardet
  10 +
  11 +# 导入您定义的模型和模块
  12 +from MHA import MultiHeadAttentionLayer
  13 +from classifier import FinalClassifier
  14 +from BERT_CTM import BERT_CTM_Model
  15 +
  16 +# 设置设备
  17 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  18 +
  19 +
  20 +def detect_file_encoding(file_path, num_bytes=10000):
  21 + """
  22 + 使用 chardet 检测文件的编码。
  23 +
  24 + :param file_path: 文件路径
  25 + :param num_bytes: 用于检测的字节数
  26 + :return: 检测到的编码
  27 + """
  28 + with open(file_path, 'rb') as f:
  29 + rawdata = f.read(num_bytes)
  30 + result = chardet.detect(rawdata)
  31 + encoding = result['encoding']
  32 + confidence = result['confidence']
  33 + print(f"Detected encoding: {encoding} with confidence {confidence}")
  34 + return encoding
  35 +
  36 +
  37 +def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=20):
  38 + # 创建BERT_CTM_Model实例
  39 + bert_ctm_model = BERT_CTM_Model(
  40 + bert_model_path=bert_model_path,
  41 + ctm_tokenizer_path=ctm_tokenizer_path,
  42 + n_components=n_components,
  43 + num_epochs=num_epochs
  44 + )
  45 + # 加载已保存的CTM模型
  46 + bert_ctm_model.load_model()
  47 + # 获取嵌入
  48 + embeddings = bert_ctm_model.get_bert_embeddings(texts)
  49 + return embeddings
  50 +
  51 +
  52 +def prepare_dataloader(features, batch_size):
  53 + tensor_x = torch.tensor(features, dtype=torch.float32)
  54 + dataset = TensorDataset(tensor_x)
  55 + return DataLoader(dataset, batch_size=batch_size, shuffle=False)
  56 +
  57 +
  58 +def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_tokenizer_path, stats_output_path,
  59 + batch_size=128,
  60 + num_classes=2):
  61 + try:
  62 + # 加载模型
  63 + # 修改这里,设置 weights_only=True 以消除 FutureWarning
  64 + checkpoint = torch.load(model_save_path, map_location=device, weights_only=False)
  65 + classifier_model = FinalClassifier(input_dim=768, num_classes=num_classes)
  66 + classifier_model.load_state_dict(checkpoint['classifier_model_state_dict'])
  67 + classifier_model.to(device)
  68 + classifier_model.eval()
  69 +
  70 + attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
  71 + attention_model.load_state_dict(checkpoint['attention_model_state_dict'])
  72 + attention_model.to(device)
  73 + attention_model.eval()
  74 +
  75 + # 检测文件编码
  76 + encoding = detect_file_encoding(input_data_path)
  77 +
  78 + # 读取输入数据
  79 + data = pd.read_csv(input_data_path, encoding=encoding)
  80 + texts = data['TEXT'].tolist()
  81 +
  82 + # 生成嵌入
  83 + print("Generating embeddings...")
  84 + embeddings = get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path)
  85 +
  86 + # 准备DataLoader
  87 + data_loader = prepare_dataloader(embeddings, batch_size)
  88 +
  89 + # 存储预测结果
  90 + all_predictions = []
  91 +
  92 + with torch.no_grad():
  93 + for batch in tqdm(data_loader, desc="Predicting"):
  94 + batch_x = batch[0].to(device)
  95 + batch_x = torch.mean(batch_x, dim=1)
  96 + attention_output = attention_model(batch_x, batch_x, batch_x)
  97 + outputs = classifier_model(attention_output)
  98 + outputs = torch.mean(outputs, dim=1)
  99 + _, predicted = torch.max(outputs, 1)
  100 + all_predictions.extend(predicted.cpu().numpy())
  101 +
  102 + # 保存预测结果
  103 + data['Predicted_Label'] = all_predictions
  104 + data.to_csv(output_path, index=False, encoding='utf-8')
  105 + print(f"Predictions saved to {output_path}")
  106 +
  107 + # 统计标签的个数和占比
  108 + label_counts = data['Predicted_Label'].value_counts()
  109 + total_count = len(data)
  110 + stats = {}
  111 + for label, count in label_counts.items():
  112 + label_name = "良好" if label == 0 else "不良"
  113 + percentage = (count / total_count) * 100
  114 + stats[label_name] = {
  115 + 'count': count,
  116 + 'percentage': f"{percentage:.2f}%"
  117 + }
  118 + print(f"Label: {label_name}, Count: {count}, Percentage: {percentage:.2f}%")
  119 +
  120 + # 将统计信息保存到 JSON 文件
  121 + with open(stats_output_path, 'w', encoding='utf-8') as f:
  122 + json.dump(stats, f, ensure_ascii=False)
  123 +
  124 + return True # 成功执行
  125 + except Exception as e:
  126 + print(f"Error during prediction: {e}")
  127 + return False # 执行失败
  128 +
  129 +
  130 +if __name__ == "__main__":
  131 + if len(sys.argv) != 3:
  132 + print("Usage: python using_example.py <input_data_path> <stats_output_path>")
  133 + sys.exit(1)
  134 +
  135 + input_data_path = sys.argv[1]
  136 + stats_output_path = sys.argv[2]
  137 + # 定义路径
  138 + model_save_path = 'BCAT/final_model.pt'
  139 + output_path = 'BCAT/predictions.csv' # 保存预测结果的文件
  140 + bert_model_path = 'BCAT/bert_model'
  141 + ctm_tokenizer_path = 'BCAT/sentence_bert_model'
  142 +
  143 + # 执行预测
  144 + success = predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_tokenizer_path,
  145 + stats_output_path)
  146 +
  147 + if success:
  148 + sys.exit(0) # 成功
  149 + else:
  150 + sys.exit(1) # 失败