戒酒的李白

Optimize model loading and prediction performance, implement the singleton patte…

…rn, and provide comprehensive error handling and error messages, along with confidence level display.
@@ -13,9 +13,83 @@ from model_pro.MHA import MultiHeadAttentionLayer @@ -13,9 +13,83 @@ from model_pro.MHA import MultiHeadAttentionLayer
13 from model_pro.classifier import FinalClassifier 13 from model_pro.classifier import FinalClassifier
14 from model_pro.BERT_CTM import BERT_CTM_Model 14 from model_pro.BERT_CTM import BERT_CTM_Model
15 15
16 -# 设置设备  
17 -device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
18 - 16 +class ModelManager:
  17 + _instance = None
  18 + _initialized = False
  19 +
  20 + def __new__(cls):
  21 + if cls._instance is None:
  22 + cls._instance = super(ModelManager, cls).__new__(cls)
  23 + return cls._instance
  24 +
  25 + def __init__(self):
  26 + if not self._initialized:
  27 + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  28 + self.classifier_model = None
  29 + self.attention_model = None
  30 + self.bert_ctm_model = None
  31 + self._initialized = True
  32 +
  33 + def load_models(self, model_save_path, bert_model_path, ctm_tokenizer_path):
  34 + """加载所有需要的模型"""
  35 + try:
  36 + if self.classifier_model is None:
  37 + self.classifier_model = torch.load(model_save_path, map_location=self.device)
  38 + self.classifier_model.eval()
  39 +
  40 + if self.attention_model is None:
  41 + self.attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
  42 + self.attention_model.to(self.device)
  43 + self.attention_model.eval()
  44 +
  45 + if self.bert_ctm_model is None:
  46 + self.bert_ctm_model = BERT_CTM_Model(
  47 + bert_model_path=bert_model_path,
  48 + ctm_tokenizer_path=ctm_tokenizer_path
  49 + )
  50 + return True
  51 + except Exception as e:
  52 + print(f"模型加载失败: {e}")
  53 + return False
  54 +
  55 + def predict_batch(self, texts, batch_size=32):
  56 + """批量预测文本情感"""
  57 + try:
  58 + all_predictions = []
  59 + all_probabilities = []
  60 +
  61 + # 分批处理文本
  62 + for i in range(0, len(texts), batch_size):
  63 + batch_texts = texts[i:i + batch_size]
  64 +
  65 + # 获取文本嵌入
  66 + embeddings = self.bert_ctm_model.get_bert_embeddings(batch_texts)
  67 +
  68 + # 转换为tensor
  69 + batch_x = torch.tensor(embeddings, dtype=torch.float32).to(self.device)
  70 + batch_x = torch.mean(batch_x, dim=1)
  71 +
  72 + with torch.no_grad():
  73 + # 使用注意力机制
  74 + attention_output = self.attention_model(batch_x, batch_x, batch_x)
  75 + # 获取分类结果
  76 + outputs = self.classifier_model(attention_output)
  77 + outputs = torch.mean(outputs, dim=1)
  78 + # 获取预测概率
  79 + probabilities = torch.softmax(outputs, dim=1)
  80 + # 获取预测标签
  81 + _, predicted = torch.max(outputs, 1)
  82 +
  83 + all_predictions.extend(predicted.cpu().numpy())
  84 + all_probabilities.extend(probabilities.cpu().numpy())
  85 +
  86 + return all_predictions, all_probabilities
  87 + except Exception as e:
  88 + print(f"预测过程中出现错误: {e}")
  89 + return None, None
  90 +
  91 +# 创建全局的模型管理器实例
  92 +model_manager = ModelManager()
19 93
20 def detect_file_encoding(file_path, num_bytes=10000): 94 def detect_file_encoding(file_path, num_bytes=10000):
21 """ 95 """
@@ -59,12 +133,8 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ @@ -59,12 +133,8 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_
59 try: 133 try:
60 # 加载模型 134 # 加载模型
61 print("加载模型...") 135 print("加载模型...")
62 - classifier_model = torch.load(model_save_path, map_location=device)  
63 - classifier_model.eval()  
64 -  
65 - attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)  
66 - attention_model.to(device)  
67 - attention_model.eval() 136 + if not model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path):
  137 + return False
68 138
69 # 检测文件编码 139 # 检测文件编码
70 encoding = detect_file_encoding(input_data_path) 140 encoding = detect_file_encoding(input_data_path)
@@ -88,14 +158,14 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ @@ -88,14 +158,14 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_
88 print("开始预测...") 158 print("开始预测...")
89 with torch.no_grad(): 159 with torch.no_grad():
90 for batch in tqdm(data_loader, desc="预测进度"): 160 for batch in tqdm(data_loader, desc="预测进度"):
91 - batch_x = batch[0].to(device) 161 + batch_x = batch[0].to(model_manager.device)
92 batch_x = torch.mean(batch_x, dim=1) 162 batch_x = torch.mean(batch_x, dim=1)
93 163
94 # 使用注意力机制 164 # 使用注意力机制
95 - attention_output = attention_model(batch_x, batch_x, batch_x) 165 + attention_output = model_manager.attention_model(batch_x, batch_x, batch_x)
96 166
97 # 获取分类结果 167 # 获取分类结果
98 - outputs = classifier_model(attention_output) 168 + outputs = model_manager.classifier_model(attention_output)
99 outputs = torch.mean(outputs, dim=1) 169 outputs = torch.mean(outputs, dim=1)
100 170
101 # 获取预测概率 171 # 获取预测概率
@@ -2,9 +2,7 @@ from utils.getPublicData import * # Import utility functions for data retrieval @@ -2,9 +2,7 @@ from utils.getPublicData import * # Import utility functions for data retrieval
2 from utils.mynlp import SnowNLP # Import SnowNLP for sentiment analysis 2 from utils.mynlp import SnowNLP # Import SnowNLP for sentiment analysis
3 from collections import Counter # Import Counter for counting occurrences 3 from collections import Counter # Import Counter for counting occurrences
4 import torch 4 import torch
5 -from model_pro.MHA import MultiHeadAttentionLayer  
6 -from model_pro.classifier import FinalClassifier  
7 -from model_pro.BERT_CTM import BERT_CTM_Model 5 +from BCAT_front.predict import model_manager
8 6
9 articleList = getAllArticleData() # Retrieve all article data 7 articleList = getAllArticleData() # Retrieve all article data
10 commentList = getAllCommentsData() # Retrieve all comment data 8 commentList = getAllCommentsData() # Retrieve all comment data
@@ -12,47 +10,27 @@ commentList = getAllCommentsData() # Retrieve all comment data @@ -12,47 +10,27 @@ commentList = getAllCommentsData() # Retrieve all comment data
12 # 设置设备 10 # 设置设备
13 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14 12
15 -# 加载模型(全局变量,避免重复加载) 13 +# 设置模型路径
16 model_save_path = 'model_pro/final_model.pt' 14 model_save_path = 'model_pro/final_model.pt'
17 bert_model_path = 'model_pro/bert_model' 15 bert_model_path = 'model_pro/bert_model'
18 ctm_tokenizer_path = 'model_pro/sentence_bert_model' 16 ctm_tokenizer_path = 'model_pro/sentence_bert_model'
19 17
  18 +# 初始化模型
20 try: 19 try:
21 - classifier_model = torch.load(model_save_path, map_location=device)  
22 - classifier_model.eval()  
23 - attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)  
24 - attention_model.to(device)  
25 - attention_model.eval()  
26 - bert_ctm_model = BERT_CTM_Model(  
27 - bert_model_path=bert_model_path,  
28 - ctm_tokenizer_path=ctm_tokenizer_path  
29 - ) 20 + model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path)
30 except Exception as e: 21 except Exception as e:
31 print(f"模型加载失败: {e}") 22 print(f"模型加载失败: {e}")
32 23
33 def predict_sentiment(texts): 24 def predict_sentiment(texts):
34 """使用改进版模型预测情感""" 25 """使用改进版模型预测情感"""
35 try: 26 try:
36 - # 获取文本嵌入  
37 - embeddings = bert_ctm_model.get_bert_embeddings(texts)  
38 -  
39 - # 转换为tensor  
40 - batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device)  
41 - batch_x = torch.mean(batch_x, dim=1)  
42 -  
43 - with torch.no_grad():  
44 - # 使用注意力机制  
45 - attention_output = attention_model(batch_x, batch_x, batch_x)  
46 - # 获取分类结果  
47 - outputs = classifier_model(attention_output)  
48 - outputs = torch.mean(outputs, dim=1)  
49 - # 获取预测标签  
50 - _, predicted = torch.max(outputs, 1)  
51 -  
52 - return predicted.cpu().numpy() 27 + predictions, probabilities = model_manager.predict_batch(texts)
  28 + if predictions is not None:
  29 + return predictions, probabilities
  30 + return None, None
53 except Exception as e: 31 except Exception as e:
54 print(f"预测过程中出现错误: {e}") 32 print(f"预测过程中出现错误: {e}")
55 - return None 33 + return None, None
56 34
57 def getTypeList(): 35 def getTypeList():
58 # Return a list of unique article types 36 # Return a list of unique article types
@@ -194,15 +172,23 @@ def getYuQingCharDataTwo(model_type='pro'): @@ -194,15 +172,23 @@ def getYuQingCharDataTwo(model_type='pro'):
194 article_sentiments.append('不良') 172 article_sentiments.append('不良')
195 else: 173 else:
196 # 使用改进模型 174 # 使用改进模型
197 - comment_predictions = predict_sentiment(comment_texts) 175 + comment_predictions, comment_probs = predict_sentiment(comment_texts)
198 if comment_predictions is not None: 176 if comment_predictions is not None:
199 - comment_sentiments = ['良好' if pred == 0 else '不良' for pred in comment_predictions] 177 + comment_sentiments = []
  178 + for pred, prob in zip(comment_predictions, comment_probs):
  179 + label = '良好' if pred == 0 else '不良'
  180 + confidence = prob[pred]
  181 + comment_sentiments.append(f"{label} ({confidence:.2%})")
200 else: 182 else:
201 comment_sentiments = [] 183 comment_sentiments = []
202 184
203 - article_predictions = predict_sentiment(article_texts) 185 + article_predictions, article_probs = predict_sentiment(article_texts)
204 if article_predictions is not None: 186 if article_predictions is not None:
205 - article_sentiments = ['良好' if pred == 0 else '不良' for pred in article_predictions] 187 + article_sentiments = []
  188 + for pred, prob in zip(article_predictions, article_probs):
  189 + label = '良好' if pred == 0 else '不良'
  190 + confidence = prob[pred]
  191 + article_sentiments.append(f"{label} ({confidence:.2%})")
206 else: 192 else:
207 article_sentiments = [] 193 article_sentiments = []
208 194
1 -from flask import Flask, session, render_template, redirect, Blueprint, request 1 +from flask import Flask, session, render_template, redirect, Blueprint, request, jsonify
2 from utils.mynlp import SnowNLP 2 from utils.mynlp import SnowNLP
3 from utils.getHomePageData import * 3 from utils.getHomePageData import *
4 from utils.getHotWordPageData import * 4 from utils.getHotWordPageData import *
@@ -9,9 +9,7 @@ from utils.getTopicPageData import * @@ -9,9 +9,7 @@ from utils.getTopicPageData import *
9 from utils.yuqingpredict import * 9 from utils.yuqingpredict import *
10 from utils.logger import app_logger as logging 10 from utils.logger import app_logger as logging
11 import torch 11 import torch
12 -from model_pro.MHA import MultiHeadAttentionLayer  
13 -from model_pro.classifier import FinalClassifier  
14 -from model_pro.BERT_CTM import BERT_CTM_Model 12 +from BCAT_front.predict import model_manager
15 13
16 pb = Blueprint('page', 14 pb = Blueprint('page',
17 __name__, 15 __name__,
@@ -21,47 +19,26 @@ pb = Blueprint('page', @@ -21,47 +19,26 @@ pb = Blueprint('page',
21 # 设置设备 19 # 设置设备
22 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23 21
24 -# 加载模型(全局变量,避免重复加载) 22 +# 设置模型路径
25 model_save_path = 'model_pro/final_model.pt' 23 model_save_path = 'model_pro/final_model.pt'
26 bert_model_path = 'model_pro/bert_model' 24 bert_model_path = 'model_pro/bert_model'
27 ctm_tokenizer_path = 'model_pro/sentence_bert_model' 25 ctm_tokenizer_path = 'model_pro/sentence_bert_model'
28 26
  27 +# 初始化模型
29 try: 28 try:
30 - classifier_model = torch.load(model_save_path, map_location=device)  
31 - classifier_model.eval()  
32 - attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)  
33 - attention_model.to(device)  
34 - attention_model.eval()  
35 - bert_ctm_model = BERT_CTM_Model(  
36 - bert_model_path=bert_model_path,  
37 - ctm_tokenizer_path=ctm_tokenizer_path  
38 - ) 29 + model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path)
39 except Exception as e: 30 except Exception as e:
40 - print(f"模型加载失败: {e}") 31 + logging.error(f"模型加载失败: {e}")
41 32
42 def predict_sentiment(text): 33 def predict_sentiment(text):
43 """使用改进版模型预测单个文本的情感""" 34 """使用改进版模型预测单个文本的情感"""
44 try: 35 try:
45 - # 获取文本嵌入  
46 - embeddings = bert_ctm_model.get_bert_embeddings([text])  
47 -  
48 - # 转换为tensor  
49 - batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device)  
50 - batch_x = torch.mean(batch_x, dim=1)  
51 -  
52 - with torch.no_grad():  
53 - # 使用注意力机制  
54 - attention_output = attention_model(batch_x, batch_x, batch_x)  
55 - # 获取分类结果  
56 - outputs = classifier_model(attention_output)  
57 - outputs = torch.mean(outputs, dim=1)  
58 - # 获取预测标签和概率  
59 - probabilities = torch.softmax(outputs, dim=1)  
60 - _, predicted = torch.max(outputs, 1)  
61 -  
62 - return predicted.item(), probabilities[0][predicted.item()].item() 36 + predictions, probabilities = model_manager.predict_batch([text])
  37 + if predictions is not None and len(predictions) > 0:
  38 + return predictions[0], probabilities[0][predictions[0]]
  39 + return None, None
63 except Exception as e: 40 except Exception as e:
64 - print(f"预测过程中出现错误: {e}") 41 + logging.error(f"预测过程中出现错误: {e}")
65 return None, None 42 return None, None
66 43
67 @pb.route('/home') 44 @pb.route('/home')
@@ -218,46 +195,51 @@ def yuqingChar(): @@ -218,46 +195,51 @@ def yuqingChar():
218 195
219 @pb.route('/yuqingpredict') 196 @pb.route('/yuqingpredict')
220 def yuqingpredict(): 197 def yuqingpredict():
221 - username = session.get('username')  
222 - TopicList = getAllTopicData()  
223 - defaultTopic = TopicList[0][0]  
224 - if request.args.get('Topic'):  
225 - defaultTopic = request.args.get('Topic')  
226 - TopicLen = getTopicLen(defaultTopic)  
227 - X, Y = getTopicCreatedAtandpredictData(defaultTopic)  
228 -  
229 - # 获取模型选择参数  
230 - model_type = request.args.get('model', 'pro') # 默认使用改进模型  
231 -  
232 - if model_type == 'basic':  
233 - # 使用基础模型(SnowNLP)  
234 - value = SnowNLP(defaultTopic).sentiments  
235 - if value == 0.5:  
236 - sentences = '中性'  
237 - elif value > 0.5:  
238 - sentences = '正面'  
239 - elif value < 0.5:  
240 - sentences = '负面'  
241 - else:  
242 - # 使用改进模型  
243 - predicted_label, confidence = predict_sentiment(defaultTopic)  
244 - if predicted_label is not None:  
245 - sentences = '良好' if predicted_label == 0 else '不良'  
246 - sentences = f"{sentences} (置信度: {confidence:.2f})" 198 + try:
  199 + username = session.get('username')
  200 + TopicList = getAllTopicData()
  201 + defaultTopic = TopicList[0][0]
  202 + if request.args.get('Topic'):
  203 + defaultTopic = request.args.get('Topic')
  204 + TopicLen = getTopicLen(defaultTopic)
  205 + X, Y = getTopicCreatedAtandpredictData(defaultTopic)
  206 +
  207 + # 获取模型选择参数
  208 + model_type = request.args.get('model', 'pro') # 默认使用改进模型
  209 +
  210 + if model_type == 'basic':
  211 + # 使用基础模型(SnowNLP)
  212 + value = SnowNLP(defaultTopic).sentiments
  213 + if value == 0.5:
  214 + sentences = '中性'
  215 + elif value > 0.5:
  216 + sentences = '正面'
  217 + elif value < 0.5:
  218 + sentences = '负面'
247 else: 219 else:
248 - sentences = '预测失败'  
249 -  
250 - comments = getCommentFilterDataTopic(defaultTopic)  
251 - return render_template('yuqingpredict.html',  
252 - username=username,  
253 - hotWordList=TopicList,  
254 - defaultHotWord=defaultTopic,  
255 - hotWordLen=TopicLen,  
256 - sentences=sentences,  
257 - xData=X,  
258 - yData=Y,  
259 - comments=comments,  
260 - model_type=model_type) 220 + # 使用改进模型
  221 + predicted_label, confidence = predict_sentiment(defaultTopic)
  222 + if predicted_label is not None:
  223 + sentences = '良好' if predicted_label == 0 else '不良'
  224 + sentences = f"{sentences} (置信度: {confidence:.2%})"
  225 + else:
  226 + sentences = '预测失败,请稍后重试'
  227 + logging.error(f"预测失败,话题: {defaultTopic}")
  228 +
  229 + comments = getCommentFilterDataTopic(defaultTopic)
  230 + return render_template('yuqingpredict.html',
  231 + username=username,
  232 + hotWordList=TopicList,
  233 + defaultHotWord=defaultTopic,
  234 + hotWordLen=TopicLen,
  235 + sentences=sentences,
  236 + xData=X,
  237 + yData=Y,
  238 + comments=comments,
  239 + model_type=model_type)
  240 + except Exception as e:
  241 + logging.error(f"舆情预测页面渲染失败: {e}")
  242 + return render_template('error.html', error_message="加载舆情预测页面失败,请稍后重试")
261 243
262 244
263 @pb.route('/articleCloud') 245 @pb.route('/articleCloud')