Optimize model loading and prediction performance, implement the singleton patte…
…rn, and provide comprehensive error handling and error messages, along with confidence level display.
Showing
3 changed files
with
158 additions
and
120 deletions
| @@ -13,9 +13,83 @@ from model_pro.MHA import MultiHeadAttentionLayer | @@ -13,9 +13,83 @@ from model_pro.MHA import MultiHeadAttentionLayer | ||
| 13 | from model_pro.classifier import FinalClassifier | 13 | from model_pro.classifier import FinalClassifier |
| 14 | from model_pro.BERT_CTM import BERT_CTM_Model | 14 | from model_pro.BERT_CTM import BERT_CTM_Model |
| 15 | 15 | ||
| 16 | -# 设置设备 | ||
| 17 | -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| 18 | - | 16 | +class ModelManager: |
| 17 | + _instance = None | ||
| 18 | + _initialized = False | ||
| 19 | + | ||
| 20 | + def __new__(cls): | ||
| 21 | + if cls._instance is None: | ||
| 22 | + cls._instance = super(ModelManager, cls).__new__(cls) | ||
| 23 | + return cls._instance | ||
| 24 | + | ||
| 25 | + def __init__(self): | ||
| 26 | + if not self._initialized: | ||
| 27 | + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| 28 | + self.classifier_model = None | ||
| 29 | + self.attention_model = None | ||
| 30 | + self.bert_ctm_model = None | ||
| 31 | + self._initialized = True | ||
| 32 | + | ||
| 33 | + def load_models(self, model_save_path, bert_model_path, ctm_tokenizer_path): | ||
| 34 | + """加载所有需要的模型""" | ||
| 35 | + try: | ||
| 36 | + if self.classifier_model is None: | ||
| 37 | + self.classifier_model = torch.load(model_save_path, map_location=self.device) | ||
| 38 | + self.classifier_model.eval() | ||
| 39 | + | ||
| 40 | + if self.attention_model is None: | ||
| 41 | + self.attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) | ||
| 42 | + self.attention_model.to(self.device) | ||
| 43 | + self.attention_model.eval() | ||
| 44 | + | ||
| 45 | + if self.bert_ctm_model is None: | ||
| 46 | + self.bert_ctm_model = BERT_CTM_Model( | ||
| 47 | + bert_model_path=bert_model_path, | ||
| 48 | + ctm_tokenizer_path=ctm_tokenizer_path | ||
| 49 | + ) | ||
| 50 | + return True | ||
| 51 | + except Exception as e: | ||
| 52 | + print(f"模型加载失败: {e}") | ||
| 53 | + return False | ||
| 54 | + | ||
| 55 | + def predict_batch(self, texts, batch_size=32): | ||
| 56 | + """批量预测文本情感""" | ||
| 57 | + try: | ||
| 58 | + all_predictions = [] | ||
| 59 | + all_probabilities = [] | ||
| 60 | + | ||
| 61 | + # 分批处理文本 | ||
| 62 | + for i in range(0, len(texts), batch_size): | ||
| 63 | + batch_texts = texts[i:i + batch_size] | ||
| 64 | + | ||
| 65 | + # 获取文本嵌入 | ||
| 66 | + embeddings = self.bert_ctm_model.get_bert_embeddings(batch_texts) | ||
| 67 | + | ||
| 68 | + # 转换为tensor | ||
| 69 | + batch_x = torch.tensor(embeddings, dtype=torch.float32).to(self.device) | ||
| 70 | + batch_x = torch.mean(batch_x, dim=1) | ||
| 71 | + | ||
| 72 | + with torch.no_grad(): | ||
| 73 | + # 使用注意力机制 | ||
| 74 | + attention_output = self.attention_model(batch_x, batch_x, batch_x) | ||
| 75 | + # 获取分类结果 | ||
| 76 | + outputs = self.classifier_model(attention_output) | ||
| 77 | + outputs = torch.mean(outputs, dim=1) | ||
| 78 | + # 获取预测概率 | ||
| 79 | + probabilities = torch.softmax(outputs, dim=1) | ||
| 80 | + # 获取预测标签 | ||
| 81 | + _, predicted = torch.max(outputs, 1) | ||
| 82 | + | ||
| 83 | + all_predictions.extend(predicted.cpu().numpy()) | ||
| 84 | + all_probabilities.extend(probabilities.cpu().numpy()) | ||
| 85 | + | ||
| 86 | + return all_predictions, all_probabilities | ||
| 87 | + except Exception as e: | ||
| 88 | + print(f"预测过程中出现错误: {e}") | ||
| 89 | + return None, None | ||
| 90 | + | ||
| 91 | +# 创建全局的模型管理器实例 | ||
| 92 | +model_manager = ModelManager() | ||
| 19 | 93 | ||
| 20 | def detect_file_encoding(file_path, num_bytes=10000): | 94 | def detect_file_encoding(file_path, num_bytes=10000): |
| 21 | """ | 95 | """ |
| @@ -59,12 +133,8 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ | @@ -59,12 +133,8 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ | ||
| 59 | try: | 133 | try: |
| 60 | # 加载模型 | 134 | # 加载模型 |
| 61 | print("加载模型...") | 135 | print("加载模型...") |
| 62 | - classifier_model = torch.load(model_save_path, map_location=device) | ||
| 63 | - classifier_model.eval() | ||
| 64 | - | ||
| 65 | - attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) | ||
| 66 | - attention_model.to(device) | ||
| 67 | - attention_model.eval() | 136 | + if not model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path): |
| 137 | + return False | ||
| 68 | 138 | ||
| 69 | # 检测文件编码 | 139 | # 检测文件编码 |
| 70 | encoding = detect_file_encoding(input_data_path) | 140 | encoding = detect_file_encoding(input_data_path) |
| @@ -88,14 +158,14 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ | @@ -88,14 +158,14 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ | ||
| 88 | print("开始预测...") | 158 | print("开始预测...") |
| 89 | with torch.no_grad(): | 159 | with torch.no_grad(): |
| 90 | for batch in tqdm(data_loader, desc="预测进度"): | 160 | for batch in tqdm(data_loader, desc="预测进度"): |
| 91 | - batch_x = batch[0].to(device) | 161 | + batch_x = batch[0].to(model_manager.device) |
| 92 | batch_x = torch.mean(batch_x, dim=1) | 162 | batch_x = torch.mean(batch_x, dim=1) |
| 93 | 163 | ||
| 94 | # 使用注意力机制 | 164 | # 使用注意力机制 |
| 95 | - attention_output = attention_model(batch_x, batch_x, batch_x) | 165 | + attention_output = model_manager.attention_model(batch_x, batch_x, batch_x) |
| 96 | 166 | ||
| 97 | # 获取分类结果 | 167 | # 获取分类结果 |
| 98 | - outputs = classifier_model(attention_output) | 168 | + outputs = model_manager.classifier_model(attention_output) |
| 99 | outputs = torch.mean(outputs, dim=1) | 169 | outputs = torch.mean(outputs, dim=1) |
| 100 | 170 | ||
| 101 | # 获取预测概率 | 171 | # 获取预测概率 |
| @@ -2,9 +2,7 @@ from utils.getPublicData import * # Import utility functions for data retrieval | @@ -2,9 +2,7 @@ from utils.getPublicData import * # Import utility functions for data retrieval | ||
| 2 | from utils.mynlp import SnowNLP # Import SnowNLP for sentiment analysis | 2 | from utils.mynlp import SnowNLP # Import SnowNLP for sentiment analysis |
| 3 | from collections import Counter # Import Counter for counting occurrences | 3 | from collections import Counter # Import Counter for counting occurrences |
| 4 | import torch | 4 | import torch |
| 5 | -from model_pro.MHA import MultiHeadAttentionLayer | ||
| 6 | -from model_pro.classifier import FinalClassifier | ||
| 7 | -from model_pro.BERT_CTM import BERT_CTM_Model | 5 | +from BCAT_front.predict import model_manager |
| 8 | 6 | ||
| 9 | articleList = getAllArticleData() # Retrieve all article data | 7 | articleList = getAllArticleData() # Retrieve all article data |
| 10 | commentList = getAllCommentsData() # Retrieve all comment data | 8 | commentList = getAllCommentsData() # Retrieve all comment data |
| @@ -12,47 +10,27 @@ commentList = getAllCommentsData() # Retrieve all comment data | @@ -12,47 +10,27 @@ commentList = getAllCommentsData() # Retrieve all comment data | ||
| 12 | # 设置设备 | 10 | # 设置设备 |
| 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 14 | 12 | ||
| 15 | -# 加载模型(全局变量,避免重复加载) | 13 | +# 设置模型路径 |
| 16 | model_save_path = 'model_pro/final_model.pt' | 14 | model_save_path = 'model_pro/final_model.pt' |
| 17 | bert_model_path = 'model_pro/bert_model' | 15 | bert_model_path = 'model_pro/bert_model' |
| 18 | ctm_tokenizer_path = 'model_pro/sentence_bert_model' | 16 | ctm_tokenizer_path = 'model_pro/sentence_bert_model' |
| 19 | 17 | ||
| 18 | +# 初始化模型 | ||
| 20 | try: | 19 | try: |
| 21 | - classifier_model = torch.load(model_save_path, map_location=device) | ||
| 22 | - classifier_model.eval() | ||
| 23 | - attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) | ||
| 24 | - attention_model.to(device) | ||
| 25 | - attention_model.eval() | ||
| 26 | - bert_ctm_model = BERT_CTM_Model( | ||
| 27 | - bert_model_path=bert_model_path, | ||
| 28 | - ctm_tokenizer_path=ctm_tokenizer_path | ||
| 29 | - ) | 20 | + model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path) |
| 30 | except Exception as e: | 21 | except Exception as e: |
| 31 | print(f"模型加载失败: {e}") | 22 | print(f"模型加载失败: {e}") |
| 32 | 23 | ||
| 33 | def predict_sentiment(texts): | 24 | def predict_sentiment(texts): |
| 34 | """使用改进版模型预测情感""" | 25 | """使用改进版模型预测情感""" |
| 35 | try: | 26 | try: |
| 36 | - # 获取文本嵌入 | ||
| 37 | - embeddings = bert_ctm_model.get_bert_embeddings(texts) | ||
| 38 | - | ||
| 39 | - # 转换为tensor | ||
| 40 | - batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device) | ||
| 41 | - batch_x = torch.mean(batch_x, dim=1) | ||
| 42 | - | ||
| 43 | - with torch.no_grad(): | ||
| 44 | - # 使用注意力机制 | ||
| 45 | - attention_output = attention_model(batch_x, batch_x, batch_x) | ||
| 46 | - # 获取分类结果 | ||
| 47 | - outputs = classifier_model(attention_output) | ||
| 48 | - outputs = torch.mean(outputs, dim=1) | ||
| 49 | - # 获取预测标签 | ||
| 50 | - _, predicted = torch.max(outputs, 1) | ||
| 51 | - | ||
| 52 | - return predicted.cpu().numpy() | 27 | + predictions, probabilities = model_manager.predict_batch(texts) |
| 28 | + if predictions is not None: | ||
| 29 | + return predictions, probabilities | ||
| 30 | + return None, None | ||
| 53 | except Exception as e: | 31 | except Exception as e: |
| 54 | print(f"预测过程中出现错误: {e}") | 32 | print(f"预测过程中出现错误: {e}") |
| 55 | - return None | 33 | + return None, None |
| 56 | 34 | ||
| 57 | def getTypeList(): | 35 | def getTypeList(): |
| 58 | # Return a list of unique article types | 36 | # Return a list of unique article types |
| @@ -194,15 +172,23 @@ def getYuQingCharDataTwo(model_type='pro'): | @@ -194,15 +172,23 @@ def getYuQingCharDataTwo(model_type='pro'): | ||
| 194 | article_sentiments.append('不良') | 172 | article_sentiments.append('不良') |
| 195 | else: | 173 | else: |
| 196 | # 使用改进模型 | 174 | # 使用改进模型 |
| 197 | - comment_predictions = predict_sentiment(comment_texts) | 175 | + comment_predictions, comment_probs = predict_sentiment(comment_texts) |
| 198 | if comment_predictions is not None: | 176 | if comment_predictions is not None: |
| 199 | - comment_sentiments = ['良好' if pred == 0 else '不良' for pred in comment_predictions] | 177 | + comment_sentiments = [] |
| 178 | + for pred, prob in zip(comment_predictions, comment_probs): | ||
| 179 | + label = '良好' if pred == 0 else '不良' | ||
| 180 | + confidence = prob[pred] | ||
| 181 | + comment_sentiments.append(f"{label} ({confidence:.2%})") | ||
| 200 | else: | 182 | else: |
| 201 | comment_sentiments = [] | 183 | comment_sentiments = [] |
| 202 | 184 | ||
| 203 | - article_predictions = predict_sentiment(article_texts) | 185 | + article_predictions, article_probs = predict_sentiment(article_texts) |
| 204 | if article_predictions is not None: | 186 | if article_predictions is not None: |
| 205 | - article_sentiments = ['良好' if pred == 0 else '不良' for pred in article_predictions] | 187 | + article_sentiments = [] |
| 188 | + for pred, prob in zip(article_predictions, article_probs): | ||
| 189 | + label = '良好' if pred == 0 else '不良' | ||
| 190 | + confidence = prob[pred] | ||
| 191 | + article_sentiments.append(f"{label} ({confidence:.2%})") | ||
| 206 | else: | 192 | else: |
| 207 | article_sentiments = [] | 193 | article_sentiments = [] |
| 208 | 194 |
| 1 | -from flask import Flask, session, render_template, redirect, Blueprint, request | 1 | +from flask import Flask, session, render_template, redirect, Blueprint, request, jsonify |
| 2 | from utils.mynlp import SnowNLP | 2 | from utils.mynlp import SnowNLP |
| 3 | from utils.getHomePageData import * | 3 | from utils.getHomePageData import * |
| 4 | from utils.getHotWordPageData import * | 4 | from utils.getHotWordPageData import * |
| @@ -9,9 +9,7 @@ from utils.getTopicPageData import * | @@ -9,9 +9,7 @@ from utils.getTopicPageData import * | ||
| 9 | from utils.yuqingpredict import * | 9 | from utils.yuqingpredict import * |
| 10 | from utils.logger import app_logger as logging | 10 | from utils.logger import app_logger as logging |
| 11 | import torch | 11 | import torch |
| 12 | -from model_pro.MHA import MultiHeadAttentionLayer | ||
| 13 | -from model_pro.classifier import FinalClassifier | ||
| 14 | -from model_pro.BERT_CTM import BERT_CTM_Model | 12 | +from BCAT_front.predict import model_manager |
| 15 | 13 | ||
| 16 | pb = Blueprint('page', | 14 | pb = Blueprint('page', |
| 17 | __name__, | 15 | __name__, |
| @@ -21,47 +19,26 @@ pb = Blueprint('page', | @@ -21,47 +19,26 @@ pb = Blueprint('page', | ||
| 21 | # 设置设备 | 19 | # 设置设备 |
| 22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 23 | 21 | ||
| 24 | -# 加载模型(全局变量,避免重复加载) | 22 | +# 设置模型路径 |
| 25 | model_save_path = 'model_pro/final_model.pt' | 23 | model_save_path = 'model_pro/final_model.pt' |
| 26 | bert_model_path = 'model_pro/bert_model' | 24 | bert_model_path = 'model_pro/bert_model' |
| 27 | ctm_tokenizer_path = 'model_pro/sentence_bert_model' | 25 | ctm_tokenizer_path = 'model_pro/sentence_bert_model' |
| 28 | 26 | ||
| 27 | +# 初始化模型 | ||
| 29 | try: | 28 | try: |
| 30 | - classifier_model = torch.load(model_save_path, map_location=device) | ||
| 31 | - classifier_model.eval() | ||
| 32 | - attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) | ||
| 33 | - attention_model.to(device) | ||
| 34 | - attention_model.eval() | ||
| 35 | - bert_ctm_model = BERT_CTM_Model( | ||
| 36 | - bert_model_path=bert_model_path, | ||
| 37 | - ctm_tokenizer_path=ctm_tokenizer_path | ||
| 38 | - ) | 29 | + model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path) |
| 39 | except Exception as e: | 30 | except Exception as e: |
| 40 | - print(f"模型加载失败: {e}") | 31 | + logging.error(f"模型加载失败: {e}") |
| 41 | 32 | ||
| 42 | def predict_sentiment(text): | 33 | def predict_sentiment(text): |
| 43 | """使用改进版模型预测单个文本的情感""" | 34 | """使用改进版模型预测单个文本的情感""" |
| 44 | try: | 35 | try: |
| 45 | - # 获取文本嵌入 | ||
| 46 | - embeddings = bert_ctm_model.get_bert_embeddings([text]) | ||
| 47 | - | ||
| 48 | - # 转换为tensor | ||
| 49 | - batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device) | ||
| 50 | - batch_x = torch.mean(batch_x, dim=1) | ||
| 51 | - | ||
| 52 | - with torch.no_grad(): | ||
| 53 | - # 使用注意力机制 | ||
| 54 | - attention_output = attention_model(batch_x, batch_x, batch_x) | ||
| 55 | - # 获取分类结果 | ||
| 56 | - outputs = classifier_model(attention_output) | ||
| 57 | - outputs = torch.mean(outputs, dim=1) | ||
| 58 | - # 获取预测标签和概率 | ||
| 59 | - probabilities = torch.softmax(outputs, dim=1) | ||
| 60 | - _, predicted = torch.max(outputs, 1) | ||
| 61 | - | ||
| 62 | - return predicted.item(), probabilities[0][predicted.item()].item() | 36 | + predictions, probabilities = model_manager.predict_batch([text]) |
| 37 | + if predictions is not None and len(predictions) > 0: | ||
| 38 | + return predictions[0], probabilities[0][predictions[0]] | ||
| 39 | + return None, None | ||
| 63 | except Exception as e: | 40 | except Exception as e: |
| 64 | - print(f"预测过程中出现错误: {e}") | 41 | + logging.error(f"预测过程中出现错误: {e}") |
| 65 | return None, None | 42 | return None, None |
| 66 | 43 | ||
| 67 | @pb.route('/home') | 44 | @pb.route('/home') |
| @@ -218,46 +195,51 @@ def yuqingChar(): | @@ -218,46 +195,51 @@ def yuqingChar(): | ||
| 218 | 195 | ||
| 219 | @pb.route('/yuqingpredict') | 196 | @pb.route('/yuqingpredict') |
| 220 | def yuqingpredict(): | 197 | def yuqingpredict(): |
| 221 | - username = session.get('username') | ||
| 222 | - TopicList = getAllTopicData() | ||
| 223 | - defaultTopic = TopicList[0][0] | ||
| 224 | - if request.args.get('Topic'): | ||
| 225 | - defaultTopic = request.args.get('Topic') | ||
| 226 | - TopicLen = getTopicLen(defaultTopic) | ||
| 227 | - X, Y = getTopicCreatedAtandpredictData(defaultTopic) | ||
| 228 | - | ||
| 229 | - # 获取模型选择参数 | ||
| 230 | - model_type = request.args.get('model', 'pro') # 默认使用改进模型 | ||
| 231 | - | ||
| 232 | - if model_type == 'basic': | ||
| 233 | - # 使用基础模型(SnowNLP) | ||
| 234 | - value = SnowNLP(defaultTopic).sentiments | ||
| 235 | - if value == 0.5: | ||
| 236 | - sentences = '中性' | ||
| 237 | - elif value > 0.5: | ||
| 238 | - sentences = '正面' | ||
| 239 | - elif value < 0.5: | ||
| 240 | - sentences = '负面' | ||
| 241 | - else: | ||
| 242 | - # 使用改进模型 | ||
| 243 | - predicted_label, confidence = predict_sentiment(defaultTopic) | ||
| 244 | - if predicted_label is not None: | ||
| 245 | - sentences = '良好' if predicted_label == 0 else '不良' | ||
| 246 | - sentences = f"{sentences} (置信度: {confidence:.2f})" | 198 | + try: |
| 199 | + username = session.get('username') | ||
| 200 | + TopicList = getAllTopicData() | ||
| 201 | + defaultTopic = TopicList[0][0] | ||
| 202 | + if request.args.get('Topic'): | ||
| 203 | + defaultTopic = request.args.get('Topic') | ||
| 204 | + TopicLen = getTopicLen(defaultTopic) | ||
| 205 | + X, Y = getTopicCreatedAtandpredictData(defaultTopic) | ||
| 206 | + | ||
| 207 | + # 获取模型选择参数 | ||
| 208 | + model_type = request.args.get('model', 'pro') # 默认使用改进模型 | ||
| 209 | + | ||
| 210 | + if model_type == 'basic': | ||
| 211 | + # 使用基础模型(SnowNLP) | ||
| 212 | + value = SnowNLP(defaultTopic).sentiments | ||
| 213 | + if value == 0.5: | ||
| 214 | + sentences = '中性' | ||
| 215 | + elif value > 0.5: | ||
| 216 | + sentences = '正面' | ||
| 217 | + elif value < 0.5: | ||
| 218 | + sentences = '负面' | ||
| 247 | else: | 219 | else: |
| 248 | - sentences = '预测失败' | ||
| 249 | - | ||
| 250 | - comments = getCommentFilterDataTopic(defaultTopic) | ||
| 251 | - return render_template('yuqingpredict.html', | ||
| 252 | - username=username, | ||
| 253 | - hotWordList=TopicList, | ||
| 254 | - defaultHotWord=defaultTopic, | ||
| 255 | - hotWordLen=TopicLen, | ||
| 256 | - sentences=sentences, | ||
| 257 | - xData=X, | ||
| 258 | - yData=Y, | ||
| 259 | - comments=comments, | ||
| 260 | - model_type=model_type) | 220 | + # 使用改进模型 |
| 221 | + predicted_label, confidence = predict_sentiment(defaultTopic) | ||
| 222 | + if predicted_label is not None: | ||
| 223 | + sentences = '良好' if predicted_label == 0 else '不良' | ||
| 224 | + sentences = f"{sentences} (置信度: {confidence:.2%})" | ||
| 225 | + else: | ||
| 226 | + sentences = '预测失败,请稍后重试' | ||
| 227 | + logging.error(f"预测失败,话题: {defaultTopic}") | ||
| 228 | + | ||
| 229 | + comments = getCommentFilterDataTopic(defaultTopic) | ||
| 230 | + return render_template('yuqingpredict.html', | ||
| 231 | + username=username, | ||
| 232 | + hotWordList=TopicList, | ||
| 233 | + defaultHotWord=defaultTopic, | ||
| 234 | + hotWordLen=TopicLen, | ||
| 235 | + sentences=sentences, | ||
| 236 | + xData=X, | ||
| 237 | + yData=Y, | ||
| 238 | + comments=comments, | ||
| 239 | + model_type=model_type) | ||
| 240 | + except Exception as e: | ||
| 241 | + logging.error(f"舆情预测页面渲染失败: {e}") | ||
| 242 | + return render_template('error.html', error_message="加载舆情预测页面失败,请稍后重试") | ||
| 261 | 243 | ||
| 262 | 244 | ||
| 263 | @pb.route('/articleCloud') | 245 | @pb.route('/articleCloud') |
-
Please register or login to post a comment