戒酒的李白

Enhanced public opinion prediction system by integrating LSTM model.

  1 +import torch
  2 +import torch.nn as nn
  3 +import torch.optim as optim
  4 +from torch.utils.data import Dataset, DataLoader
  5 +import numpy as np
  6 +import pandas as pd
  7 +from sklearn.model_selection import train_test_split
  8 +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
  9 +import jieba
  10 +from transformers import BertTokenizer
  11 +import logging
  12 +import os
  13 +
  14 +# 配置日志记录
  15 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  16 +logger = logging.getLogger('LSTM_model')
  17 +
  18 +class TextDataset(Dataset):
  19 + """文本数据集类,用于加载和预处理文本数据"""
  20 +
  21 + def __init__(self, texts, labels, tokenizer, max_length=128):
  22 + self.texts = texts
  23 + self.labels = labels
  24 + self.tokenizer = tokenizer
  25 + self.max_length = max_length
  26 +
  27 + def __len__(self):
  28 + return len(self.texts)
  29 +
  30 + def __getitem__(self, idx):
  31 + text = str(self.texts[idx])
  32 + label = self.labels[idx]
  33 +
  34 + # BERT分词并获得输入ID和注意力掩码
  35 + encoding = self.tokenizer.encode_plus(
  36 + text,
  37 + add_special_tokens=True,
  38 + max_length=self.max_length,
  39 + padding='max_length',
  40 + truncation=True,
  41 + return_attention_mask=True,
  42 + return_tensors='pt'
  43 + )
  44 +
  45 + return {
  46 + 'text': text,
  47 + 'input_ids': encoding['input_ids'].flatten(),
  48 + 'attention_mask': encoding['attention_mask'].flatten(),
  49 + 'label': torch.tensor(label, dtype=torch.long)
  50 + }
  51 +
  52 +class LSTMSentimentModel(nn.Module):
  53 + """基于LSTM的情感分析模型"""
  54 +
  55 + def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers=2,
  56 + bidirectional=True, dropout=0.5, pad_idx=0):
  57 + super().__init__()
  58 +
  59 + # 嵌入层
  60 + self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
  61 +
  62 + # LSTM层
  63 + self.lstm = nn.LSTM(
  64 + embedding_dim,
  65 + hidden_dim,
  66 + num_layers=n_layers,
  67 + bidirectional=bidirectional,
  68 + dropout=dropout if n_layers > 1 else 0,
  69 + batch_first=True
  70 + )
  71 +
  72 + # 全连接层,如果是双向LSTM,输入维度需要翻倍
  73 + self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
  74 +
  75 + # Dropout层
  76 + self.dropout = nn.Dropout(dropout)
  77 +
  78 + def forward(self, text, attention_mask=None):
  79 + # 文本通过嵌入层 [batch_size, seq_len] -> [batch_size, seq_len, embedding_dim]
  80 + embedded = self.embedding(text)
  81 +
  82 + # 应用dropout
  83 + embedded = self.dropout(embedded)
  84 +
  85 + # 通过LSTM [batch_size, seq_len, embedding_dim] -> [batch_size, seq_len, hidden_dim*2]
  86 + if attention_mask is not None:
  87 + # 创建打包的序列
  88 + lengths = attention_mask.sum(dim=1).to('cpu')
  89 + packed_embedded = nn.utils.rnn.pack_padded_sequence(
  90 + embedded, lengths, batch_first=True, enforce_sorted=False
  91 + )
  92 + packed_output, (hidden, cell) = self.lstm(packed_embedded)
  93 + # 解包序列
  94 + output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
  95 + else:
  96 + output, (hidden, cell) = self.lstm(embedded)
  97 +
  98 + # 如果是双向LSTM,需要拼接最后一层的前向和后向隐藏状态
  99 + if self.lstm.bidirectional:
  100 + hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
  101 + else:
  102 + hidden = hidden[-1]
  103 +
  104 + # 应用dropout
  105 + hidden = self.dropout(hidden)
  106 +
  107 + # 全连接层
  108 + return self.fc(hidden)
  109 +
  110 +class LSTMModelManager:
  111 + """LSTM模型管理类,用于训练、评估和预测"""
  112 +
  113 + def __init__(self, bert_model_path, model_save_path=None, vocab_size=30522,
  114 + embedding_dim=128, hidden_dim=256, output_dim=2, n_layers=2,
  115 + bidirectional=True, dropout=0.5):
  116 + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  117 + self.tokenizer = BertTokenizer.from_pretrained(bert_model_path)
  118 + self.vocab_size = vocab_size
  119 + self.model = LSTMSentimentModel(
  120 + vocab_size=vocab_size,
  121 + embedding_dim=embedding_dim,
  122 + hidden_dim=hidden_dim,
  123 + output_dim=output_dim,
  124 + n_layers=n_layers,
  125 + bidirectional=bidirectional,
  126 + dropout=dropout,
  127 + pad_idx=self.tokenizer.pad_token_id
  128 + ).to(self.device)
  129 +
  130 + self.model_save_path = model_save_path
  131 + if model_save_path and os.path.exists(model_save_path):
  132 + self.model.load_state_dict(torch.load(model_save_path, map_location=self.device))
  133 + logger.info(f"已从 {model_save_path} 加载模型")
  134 +
  135 + def train(self, train_texts, train_labels, val_texts=None, val_labels=None,
  136 + batch_size=32, learning_rate=2e-5, epochs=10, validation_split=0.2):
  137 + """训练模型"""
  138 + logger.info("开始训练模型...")
  139 +
  140 + # 如果没有提供验证集,从训练集中划分
  141 + if val_texts is None or val_labels is None:
  142 + train_texts, val_texts, train_labels, val_labels = train_test_split(
  143 + train_texts, train_labels, test_size=validation_split, random_state=42
  144 + )
  145 +
  146 + # 创建数据集和数据加载器
  147 + train_dataset = TextDataset(train_texts, train_labels, self.tokenizer)
  148 + val_dataset = TextDataset(val_texts, val_labels, self.tokenizer)
  149 +
  150 + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  151 + val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
  152 +
  153 + # 优化器和损失函数
  154 + optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
  155 + criterion = nn.CrossEntropyLoss()
  156 +
  157 + # 训练循环
  158 + best_val_loss = float('inf')
  159 + for epoch in range(epochs):
  160 + # 训练模式
  161 + self.model.train()
  162 + train_loss = 0
  163 + train_preds = []
  164 + train_labels_list = []
  165 +
  166 + for batch in train_dataloader:
  167 + # 获取数据
  168 + input_ids = batch['input_ids'].to(self.device)
  169 + attention_mask = batch['attention_mask'].to(self.device)
  170 + labels = batch['label'].to(self.device)
  171 +
  172 + # 前向传播
  173 + optimizer.zero_grad()
  174 + outputs = self.model(input_ids, attention_mask)
  175 +
  176 + # 计算损失
  177 + loss = criterion(outputs, labels)
  178 + train_loss += loss.item()
  179 +
  180 + # 反向传播
  181 + loss.backward()
  182 + optimizer.step()
  183 +
  184 + # 收集预测和标签
  185 + _, predicted = torch.max(outputs, 1)
  186 + train_preds.extend(predicted.cpu().numpy())
  187 + train_labels_list.extend(labels.cpu().numpy())
  188 +
  189 + # 计算训练集的评估指标
  190 + train_accuracy = accuracy_score(train_labels_list, train_preds)
  191 + train_f1 = f1_score(train_labels_list, train_preds, average='macro')
  192 +
  193 + # 验证模式
  194 + self.model.eval()
  195 + val_loss = 0
  196 + val_preds = []
  197 + val_labels_list = []
  198 +
  199 + with torch.no_grad():
  200 + for batch in val_dataloader:
  201 + input_ids = batch['input_ids'].to(self.device)
  202 + attention_mask = batch['attention_mask'].to(self.device)
  203 + labels = batch['label'].to(self.device)
  204 +
  205 + outputs = self.model(input_ids, attention_mask)
  206 + loss = criterion(outputs, labels)
  207 + val_loss += loss.item()
  208 +
  209 + _, predicted = torch.max(outputs, 1)
  210 + val_preds.extend(predicted.cpu().numpy())
  211 + val_labels_list.extend(labels.cpu().numpy())
  212 +
  213 + # 计算验证集的评估指标
  214 + val_accuracy = accuracy_score(val_labels_list, val_preds)
  215 + val_f1 = f1_score(val_labels_list, val_preds, average='macro')
  216 +
  217 + # 计算平均损失
  218 + train_loss /= len(train_dataloader)
  219 + val_loss /= len(val_dataloader)
  220 +
  221 + logger.info(f'Epoch {epoch+1}/{epochs} | '
  222 + f'Train Loss: {train_loss:.4f} | '
  223 + f'Train Acc: {train_accuracy:.4f} | '
  224 + f'Train F1: {train_f1:.4f} | '
  225 + f'Val Loss: {val_loss:.4f} | '
  226 + f'Val Acc: {val_accuracy:.4f} | '
  227 + f'Val F1: {val_f1:.4f}')
  228 +
  229 + # 保存最佳模型
  230 + if val_loss < best_val_loss and self.model_save_path:
  231 + best_val_loss = val_loss
  232 + torch.save(self.model.state_dict(), self.model_save_path)
  233 + logger.info(f"模型已保存到 {self.model_save_path}")
  234 +
  235 + # 如果有保存路径但没有保存过模型,保存最后一轮的模型
  236 + if self.model_save_path and best_val_loss == float('inf'):
  237 + torch.save(self.model.state_dict(), self.model_save_path)
  238 + logger.info(f"最终模型已保存到 {self.model_save_path}")
  239 +
  240 + return train_loss, val_loss, val_accuracy, val_f1
  241 +
  242 + def evaluate(self, test_texts, test_labels, batch_size=32):
  243 + """评估模型"""
  244 + logger.info("评估模型...")
  245 +
  246 + # 创建测试数据集和数据加载器
  247 + test_dataset = TextDataset(test_texts, test_labels, self.tokenizer)
  248 + test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
  249 +
  250 + # 设置为评估模式
  251 + self.model.eval()
  252 +
  253 + # 损失函数
  254 + criterion = nn.CrossEntropyLoss()
  255 + test_loss = 0
  256 + test_preds = []
  257 + test_probs = []
  258 + test_labels_list = []
  259 +
  260 + with torch.no_grad():
  261 + for batch in test_dataloader:
  262 + input_ids = batch['input_ids'].to(self.device)
  263 + attention_mask = batch['attention_mask'].to(self.device)
  264 + labels = batch['label'].to(self.device)
  265 +
  266 + outputs = self.model(input_ids, attention_mask)
  267 + loss = criterion(outputs, labels)
  268 + test_loss += loss.item()
  269 +
  270 + probs = torch.softmax(outputs, dim=1)
  271 + _, predicted = torch.max(outputs, 1)
  272 +
  273 + test_preds.extend(predicted.cpu().numpy())
  274 + test_probs.extend(probs.cpu().numpy())
  275 + test_labels_list.extend(labels.cpu().numpy())
  276 +
  277 + # 计算平均损失
  278 + test_loss /= len(test_dataloader)
  279 +
  280 + # 计算评估指标
  281 + accuracy = accuracy_score(test_labels_list, test_preds)
  282 + precision = precision_score(test_labels_list, test_preds, average='macro')
  283 + recall = recall_score(test_labels_list, test_preds, average='macro')
  284 + f1 = f1_score(test_labels_list, test_preds, average='macro')
  285 + conf_matrix = confusion_matrix(test_labels_list, test_preds)
  286 +
  287 + logger.info(f'Test Loss: {test_loss:.4f}')
  288 + logger.info(f'Accuracy: {accuracy:.4f}')
  289 + logger.info(f'Precision: {precision:.4f}')
  290 + logger.info(f'Recall: {recall:.4f}')
  291 + logger.info(f'F1 Score: {f1:.4f}')
  292 + logger.info(f'Confusion Matrix:\n{conf_matrix}')
  293 +
  294 + return {
  295 + 'loss': test_loss,
  296 + 'accuracy': accuracy,
  297 + 'precision': precision,
  298 + 'recall': recall,
  299 + 'f1': f1,
  300 + 'confusion_matrix': conf_matrix,
  301 + 'predictions': test_preds,
  302 + 'probabilities': test_probs
  303 + }
  304 +
  305 + def predict_batch(self, texts, batch_size=32):
  306 + """批量预测文本的情感"""
  307 + if not texts:
  308 + return None, None
  309 +
  310 + # 确保文本是列表格式
  311 + if isinstance(texts, str):
  312 + texts = [texts]
  313 +
  314 + # 创建数据集(没有标签,使用占位符)
  315 + dummy_labels = [0] * len(texts)
  316 + dataset = TextDataset(texts, dummy_labels, self.tokenizer)
  317 + dataloader = DataLoader(dataset, batch_size=batch_size)
  318 +
  319 + # 设置为评估模式
  320 + self.model.eval()
  321 +
  322 + all_preds = []
  323 + all_probs = []
  324 +
  325 + with torch.no_grad():
  326 + for batch in dataloader:
  327 + input_ids = batch['input_ids'].to(self.device)
  328 + attention_mask = batch['attention_mask'].to(self.device)
  329 +
  330 + outputs = self.model(input_ids, attention_mask)
  331 + probs = torch.softmax(outputs, dim=1)
  332 + _, predicted = torch.max(outputs, 1)
  333 +
  334 + all_preds.extend(predicted.cpu().numpy())
  335 + all_probs.extend(probs.cpu().numpy())
  336 +
  337 + return all_preds, all_probs
  338 +
  339 + def predict(self, text):
  340 + """预测单个文本的情感"""
  341 + predictions, probabilities = self.predict_batch([text])
  342 + if predictions is not None and len(predictions) > 0:
  343 + return predictions[0], probabilities[0]
  344 + return None, None
  345 +
  346 +# 创建全局模型实例
  347 +lstm_model_manager = LSTMModelManager(
  348 + bert_model_path='model_pro/bert_model',
  349 + model_save_path='model_pro/lstm_model.pt'
  350 +)
  351 +
  352 +# 测试代码
  353 +if __name__ == "__main__":
  354 + # 加载数据
  355 + train_data = pd.read_csv('model_pro/train.csv')
  356 + dev_data = pd.read_csv('model_pro/dev.csv')
  357 + test_data = pd.read_csv('model_pro/test.csv')
  358 +
  359 + # 处理数据
  360 + train_texts = train_data['text'].values
  361 + train_labels = train_data['label'].values
  362 +
  363 + dev_texts = dev_data['text'].values
  364 + dev_labels = dev_data['label'].values
  365 +
  366 + test_texts = test_data['text'].values
  367 + test_labels = test_data['label'].values
  368 +
  369 + # 训练模型
  370 + lstm_model_manager.train(
  371 + train_texts, train_labels,
  372 + val_texts=dev_texts, val_labels=dev_labels,
  373 + batch_size=32, epochs=5
  374 + )
  375 +
  376 + # 评估模型
  377 + results = lstm_model_manager.evaluate(test_texts, test_labels)
  378 +
  379 + # 测试预测功能
  380 + test_sentences = [
  381 + "这件事情做得非常好",
  382 + "服务太差了,态度恶劣",
  383 + "这个产品质量一般,但价格便宜",
  384 + "我对这家公司非常满意",
  385 + ]
  386 +
  387 + for sentence in test_sentences:
  388 + pred, prob = lstm_model_manager.predict(sentence)
  389 + label = '良好' if pred == 0 else '不良'
  390 + confidence = prob[pred]
  391 + print(f"句子: '{sentence}' 预测结果: {label} (置信度: {confidence:.2%})")
  1 +import torch
  2 +import os
  3 +import logging
  4 +from LSTM_model import lstm_model_manager
  5 +
  6 +# 配置日志记录
  7 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  8 +logger = logging.getLogger('lstm_predict')
  9 +
  10 +class LSTMPredictor:
  11 + """LSTM预测器,与当前系统的预测接口兼容"""
  12 +
  13 + def __init__(self):
  14 + self.model_loaded = False
  15 + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  16 + logger.info(f"初始化LSTM预测器,使用设备: {self.device}")
  17 +
  18 + def load_models(self, model_save_path, bert_model_path, tokenizer_path=None):
  19 + """
  20 + 加载模型,与当前系统的model_manager.load_models接口兼容
  21 +
  22 + 参数:
  23 + model_save_path: LSTM模型保存路径
  24 + bert_model_path: BERT模型路径
  25 + tokenizer_path: 分词器路径(LSTM模型中使用BERT的分词器,可忽略)
  26 + """
  27 + try:
  28 + # 检查模型文件是否存在
  29 + if not os.path.exists(model_save_path):
  30 + logger.warning(f"模型文件 {model_save_path} 不存在,需要先训练模型")
  31 + return False
  32 +
  33 + if not os.path.exists(bert_model_path):
  34 + logger.error(f"BERT模型路径 {bert_model_path} 不存在")
  35 + return False
  36 +
  37 + # 实际上我们在lstm_model_manager初始化时已经加载了模型,这里只是检查一下
  38 + if lstm_model_manager.model is not None:
  39 + self.model_loaded = True
  40 + logger.info("LSTM模型已加载成功")
  41 + return True
  42 + else:
  43 + logger.error("LSTM模型加载失败")
  44 + return False
  45 + except Exception as e:
  46 + logger.error(f"加载模型过程中出错: {e}")
  47 + return False
  48 +
  49 + def predict_batch(self, texts):
  50 + """
  51 + 批量预测文本的情感
  52 +
  53 + 参数:
  54 + texts: 文本列表
  55 +
  56 + 返回:
  57 + predictions: 预测结果列表(0表示良好,1表示不良)
  58 + probabilities: 预测概率列表
  59 + """
  60 + if not self.model_loaded and lstm_model_manager.model is None:
  61 + logger.error("模型未加载,无法进行预测")
  62 + return None, None
  63 +
  64 + if not texts:
  65 + logger.warning("未提供文本,无法进行预测")
  66 + return None, None
  67 +
  68 + try:
  69 + # 调用LSTM模型管理器的批量预测函数
  70 + predictions, probabilities = lstm_model_manager.predict_batch(texts)
  71 + return predictions, probabilities
  72 + except Exception as e:
  73 + logger.error(f"预测过程中出错: {e}")
  74 + return None, None
  75 +
  76 + def predict(self, text):
  77 + """
  78 + 预测单个文本的情感
  79 +
  80 + 参数:
  81 + text: 文本字符串
  82 +
  83 + 返回:
  84 + prediction: 预测结果(0表示良好,1表示不良)
  85 + probability: 预测概率
  86 + """
  87 + if not self.model_loaded and lstm_model_manager.model is None:
  88 + logger.error("模型未加载,无法进行预测")
  89 + return None, None
  90 +
  91 + if not text or len(text.strip()) == 0:
  92 + logger.warning("未提供文本或文本为空,无法进行预测")
  93 + return None, None
  94 +
  95 + try:
  96 + # 调用LSTM模型管理器的单个文本预测函数
  97 + prediction, probability = lstm_model_manager.predict(text)
  98 + return prediction, probability
  99 + except Exception as e:
  100 + logger.error(f"预测过程中出错: {e}")
  101 + return None, None
  102 +
  103 + def train_model(self, train_texts, train_labels, val_texts=None, val_labels=None,
  104 + batch_size=32, learning_rate=2e-5, epochs=10):
  105 + """
  106 + 训练模型
  107 +
  108 + 参数:
  109 + train_texts: 训练集文本
  110 + train_labels: 训练集标签
  111 + val_texts: 验证集文本
  112 + val_labels: 验证集标签
  113 + batch_size: 批次大小
  114 + learning_rate: 学习率
  115 + epochs: 训练轮数
  116 +
  117 + 返回:
  118 + 训练结果
  119 + """
  120 + try:
  121 + results = lstm_model_manager.train(
  122 + train_texts, train_labels, val_texts, val_labels,
  123 + batch_size, learning_rate, epochs
  124 + )
  125 + self.model_loaded = True
  126 + return results
  127 + except Exception as e:
  128 + logger.error(f"训练模型过程中出错: {e}")
  129 + return None
  130 +
  131 +# 创建全局预测器实例
  132 +lstm_predictor = LSTMPredictor()
  133 +
  134 +# 为了与现有代码兼容,提供一个与model_manager相同的predict_batch函数
  135 +def predict_batch(texts):
  136 + return lstm_predictor.predict_batch(texts)
  137 +
  138 +# 为了与现有代码兼容,提供一个与model_manager相同的load_models函数
  139 +def load_models(model_save_path, bert_model_path, tokenizer_path=None):
  140 + return lstm_predictor.load_models(model_save_path, bert_model_path, tokenizer_path)
  141 +
  142 +# 测试代码
  143 +if __name__ == "__main__":
  144 + # 加载模型
  145 + load_models(
  146 + model_save_path="model_pro/lstm_model.pt",
  147 + bert_model_path="model_pro/bert_model"
  148 + )
  149 +
  150 + # 测试预测功能
  151 + test_sentences = [
  152 + "这件事情做得非常好",
  153 + "服务太差了,态度恶劣",
  154 + "这个产品质量一般,但价格便宜",
  155 + "我对这家公司非常满意",
  156 + ]
  157 +
  158 + for sentence in test_sentences:
  159 + pred, prob = lstm_predictor.predict(sentence)
  160 + if pred is not None:
  161 + label = '良好' if pred == 0 else '不良'
  162 + confidence = prob[pred]
  163 + print(f"句子: '{sentence}' 预测结果: {label} (置信度: {confidence:.2%})")
  164 + else:
  165 + print(f"句子: '{sentence}' 预测失败")
@@ -20,6 +20,7 @@ from functools import wraps @@ -20,6 +20,7 @@ from functools import wraps
20 import bleach 20 import bleach
21 import re 21 import re
22 from datetime import datetime, timedelta 22 from datetime import datetime, timedelta
  23 +from model_pro.lstm_predict import lstm_predictor
23 24
24 pb = Blueprint('page', 25 pb = Blueprint('page',
25 __name__, 26 __name__,
@@ -75,12 +76,15 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -75,12 +76,15 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75 76
76 # 设置模型路径 77 # 设置模型路径
77 model_save_path = 'model_pro/final_model.pt' 78 model_save_path = 'model_pro/final_model.pt'
  79 +lstm_model_path = 'model_pro/lstm_model.pt'
78 bert_model_path = 'model_pro/bert_model' 80 bert_model_path = 'model_pro/bert_model'
79 ctm_tokenizer_path = 'model_pro/sentence_bert_model' 81 ctm_tokenizer_path = 'model_pro/sentence_bert_model'
80 82
81 # 初始化模型 83 # 初始化模型
82 try: 84 try:
83 model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path) 85 model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path)
  86 + # 同时初始化LSTM模型
  87 + lstm_predictor.load_models(lstm_model_path, bert_model_path)
84 except Exception as e: 88 except Exception as e:
85 logging.error(f"模型加载失败: {e}") 89 logging.error(f"模型加载失败: {e}")
86 90
@@ -315,7 +319,7 @@ def yuqingpredict(): @@ -315,7 +319,7 @@ def yuqingpredict():
315 X, Y = getTopicCreatedAtandpredictData(defaultTopic) 319 X, Y = getTopicCreatedAtandpredictData(defaultTopic)
316 320
317 model_type = sanitize_input(request.args.get('model', 'pro')) 321 model_type = sanitize_input(request.args.get('model', 'pro'))
318 - if model_type not in ['pro', 'basic']: 322 + if model_type not in ['pro', 'basic', 'lstm']:
319 return abort(400, "无效的模型类型") 323 return abort(400, "无效的模型类型")
320 324
321 # 尝试从缓存获取预测结果 325 # 尝试从缓存获取预测结果
@@ -333,6 +337,14 @@ def yuqingpredict(): @@ -333,6 +337,14 @@ def yuqingpredict():
333 sentences = '正面' 337 sentences = '正面'
334 elif value < 0.5: 338 elif value < 0.5:
335 sentences = '负面' 339 sentences = '负面'
  340 + elif model_type == 'lstm':
  341 + predicted_label, confidence = lstm_predictor.predict(defaultTopic)
  342 + if predicted_label is not None:
  343 + sentences = '良好' if predicted_label == 0 else '不良'
  344 + sentences = f"{sentences} (LSTM置信度: {confidence[predicted_label]:.2%})"
  345 + else:
  346 + sentences = 'LSTM预测失败,请稍后重试'
  347 + logging.error(f"LSTM预测失败,话题: {defaultTopic}")
336 else: 348 else:
337 predicted_label, confidence = predict_sentiment(defaultTopic) 349 predicted_label, confidence = predict_sentiment(defaultTopic)
338 if predicted_label is not None: 350 if predicted_label is not None:
@@ -165,23 +165,10 @@ @@ -165,23 +165,10 @@
165 <div class="col-lg-12"> 165 <div class="col-lg-12">
166 <div class="form-group"> 166 <div class="form-group">
167 <label for="modelSelect">选择分析模型:</label> 167 <label for="modelSelect">选择分析模型:</label>
168 - <select class="form-control" id="modelSelect" onchange="updateModel(this.value)">  
169 - <optgroup label="基础模型">  
170 - <option value="basic" {% if model_type == 'basic' %}selected{% endif %}>SnowNLP</option>  
171 - </optgroup>  
172 - <optgroup label="OpenAI 模型">  
173 - <option value="gpt-3.5-turbo" {% if model_type == 'gpt-3.5-turbo' %}selected{% endif %}>GPT-3.5-Turbo</option>  
174 - <option value="gpt-4" {% if model_type == 'gpt-4' %}selected{% endif %}>GPT-4</option>  
175 - </optgroup>  
176 - <optgroup label="Claude 模型">  
177 - <option value="claude-3-opus-20240229" {% if model_type == 'claude-3-opus-20240229' %}selected{% endif %}>Claude-3 Opus</option>  
178 - <option value="claude-3-sonnet-20240229" {% if model_type == 'claude-3-sonnet-20240229' %}selected{% endif %}>Claude-3 Sonnet</option>  
179 - <option value="claude-3-haiku-20240307" {% if model_type == 'claude-3-haiku-20240307' %}selected{% endif %}>Claude-3 Haiku</option>  
180 - </optgroup>  
181 - <optgroup label="DeepSeek 模型">  
182 - <option value="deepseek-chat" {% if model_type == 'deepseek-chat' %}selected{% endif %}>DeepSeek-V3</option>  
183 - <option value="deepseek-reasoner" {% if model_type == 'deepseek-reasoner' %}selected{% endif %}>DeepSeek-R1</option>  
184 - </optgroup> 168 + <select class="custom-select" onchange="updateModel(this.value)">
  169 + <option value="basic" {% if model_type == 'basic' %}selected{% endif %}>基础模型 (SnowNLP)</option>
  170 + <option value="pro" {% if model_type == 'pro' %}selected{% endif %}>进阶模型 (BERT+CTM)</option>
  171 + <option value="lstm" {% if model_type == 'lstm' %}selected{% endif %}>LSTM模型 (新增)</option>
185 </select> 172 </select>
186 </div> 173 </div>
187 </div> 174 </div>