Enhanced public opinion prediction system by integrating LSTM model.
Showing
4 changed files
with
573 additions
and
18 deletions
model_pro/LSTM_model.py
0 → 100644
| 1 | +import torch | ||
| 2 | +import torch.nn as nn | ||
| 3 | +import torch.optim as optim | ||
| 4 | +from torch.utils.data import Dataset, DataLoader | ||
| 5 | +import numpy as np | ||
| 6 | +import pandas as pd | ||
| 7 | +from sklearn.model_selection import train_test_split | ||
| 8 | +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix | ||
| 9 | +import jieba | ||
| 10 | +from transformers import BertTokenizer | ||
| 11 | +import logging | ||
| 12 | +import os | ||
| 13 | + | ||
| 14 | +# 配置日志记录 | ||
| 15 | +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | ||
| 16 | +logger = logging.getLogger('LSTM_model') | ||
| 17 | + | ||
| 18 | +class TextDataset(Dataset): | ||
| 19 | + """文本数据集类,用于加载和预处理文本数据""" | ||
| 20 | + | ||
| 21 | + def __init__(self, texts, labels, tokenizer, max_length=128): | ||
| 22 | + self.texts = texts | ||
| 23 | + self.labels = labels | ||
| 24 | + self.tokenizer = tokenizer | ||
| 25 | + self.max_length = max_length | ||
| 26 | + | ||
| 27 | + def __len__(self): | ||
| 28 | + return len(self.texts) | ||
| 29 | + | ||
| 30 | + def __getitem__(self, idx): | ||
| 31 | + text = str(self.texts[idx]) | ||
| 32 | + label = self.labels[idx] | ||
| 33 | + | ||
| 34 | + # BERT分词并获得输入ID和注意力掩码 | ||
| 35 | + encoding = self.tokenizer.encode_plus( | ||
| 36 | + text, | ||
| 37 | + add_special_tokens=True, | ||
| 38 | + max_length=self.max_length, | ||
| 39 | + padding='max_length', | ||
| 40 | + truncation=True, | ||
| 41 | + return_attention_mask=True, | ||
| 42 | + return_tensors='pt' | ||
| 43 | + ) | ||
| 44 | + | ||
| 45 | + return { | ||
| 46 | + 'text': text, | ||
| 47 | + 'input_ids': encoding['input_ids'].flatten(), | ||
| 48 | + 'attention_mask': encoding['attention_mask'].flatten(), | ||
| 49 | + 'label': torch.tensor(label, dtype=torch.long) | ||
| 50 | + } | ||
| 51 | + | ||
| 52 | +class LSTMSentimentModel(nn.Module): | ||
| 53 | + """基于LSTM的情感分析模型""" | ||
| 54 | + | ||
| 55 | + def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers=2, | ||
| 56 | + bidirectional=True, dropout=0.5, pad_idx=0): | ||
| 57 | + super().__init__() | ||
| 58 | + | ||
| 59 | + # 嵌入层 | ||
| 60 | + self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx) | ||
| 61 | + | ||
| 62 | + # LSTM层 | ||
| 63 | + self.lstm = nn.LSTM( | ||
| 64 | + embedding_dim, | ||
| 65 | + hidden_dim, | ||
| 66 | + num_layers=n_layers, | ||
| 67 | + bidirectional=bidirectional, | ||
| 68 | + dropout=dropout if n_layers > 1 else 0, | ||
| 69 | + batch_first=True | ||
| 70 | + ) | ||
| 71 | + | ||
| 72 | + # 全连接层,如果是双向LSTM,输入维度需要翻倍 | ||
| 73 | + self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim) | ||
| 74 | + | ||
| 75 | + # Dropout层 | ||
| 76 | + self.dropout = nn.Dropout(dropout) | ||
| 77 | + | ||
| 78 | + def forward(self, text, attention_mask=None): | ||
| 79 | + # 文本通过嵌入层 [batch_size, seq_len] -> [batch_size, seq_len, embedding_dim] | ||
| 80 | + embedded = self.embedding(text) | ||
| 81 | + | ||
| 82 | + # 应用dropout | ||
| 83 | + embedded = self.dropout(embedded) | ||
| 84 | + | ||
| 85 | + # 通过LSTM [batch_size, seq_len, embedding_dim] -> [batch_size, seq_len, hidden_dim*2] | ||
| 86 | + if attention_mask is not None: | ||
| 87 | + # 创建打包的序列 | ||
| 88 | + lengths = attention_mask.sum(dim=1).to('cpu') | ||
| 89 | + packed_embedded = nn.utils.rnn.pack_padded_sequence( | ||
| 90 | + embedded, lengths, batch_first=True, enforce_sorted=False | ||
| 91 | + ) | ||
| 92 | + packed_output, (hidden, cell) = self.lstm(packed_embedded) | ||
| 93 | + # 解包序列 | ||
| 94 | + output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True) | ||
| 95 | + else: | ||
| 96 | + output, (hidden, cell) = self.lstm(embedded) | ||
| 97 | + | ||
| 98 | + # 如果是双向LSTM,需要拼接最后一层的前向和后向隐藏状态 | ||
| 99 | + if self.lstm.bidirectional: | ||
| 100 | + hidden = torch.cat([hidden[-2], hidden[-1]], dim=1) | ||
| 101 | + else: | ||
| 102 | + hidden = hidden[-1] | ||
| 103 | + | ||
| 104 | + # 应用dropout | ||
| 105 | + hidden = self.dropout(hidden) | ||
| 106 | + | ||
| 107 | + # 全连接层 | ||
| 108 | + return self.fc(hidden) | ||
| 109 | + | ||
| 110 | +class LSTMModelManager: | ||
| 111 | + """LSTM模型管理类,用于训练、评估和预测""" | ||
| 112 | + | ||
| 113 | + def __init__(self, bert_model_path, model_save_path=None, vocab_size=30522, | ||
| 114 | + embedding_dim=128, hidden_dim=256, output_dim=2, n_layers=2, | ||
| 115 | + bidirectional=True, dropout=0.5): | ||
| 116 | + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
| 117 | + self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) | ||
| 118 | + self.vocab_size = vocab_size | ||
| 119 | + self.model = LSTMSentimentModel( | ||
| 120 | + vocab_size=vocab_size, | ||
| 121 | + embedding_dim=embedding_dim, | ||
| 122 | + hidden_dim=hidden_dim, | ||
| 123 | + output_dim=output_dim, | ||
| 124 | + n_layers=n_layers, | ||
| 125 | + bidirectional=bidirectional, | ||
| 126 | + dropout=dropout, | ||
| 127 | + pad_idx=self.tokenizer.pad_token_id | ||
| 128 | + ).to(self.device) | ||
| 129 | + | ||
| 130 | + self.model_save_path = model_save_path | ||
| 131 | + if model_save_path and os.path.exists(model_save_path): | ||
| 132 | + self.model.load_state_dict(torch.load(model_save_path, map_location=self.device)) | ||
| 133 | + logger.info(f"已从 {model_save_path} 加载模型") | ||
| 134 | + | ||
| 135 | + def train(self, train_texts, train_labels, val_texts=None, val_labels=None, | ||
| 136 | + batch_size=32, learning_rate=2e-5, epochs=10, validation_split=0.2): | ||
| 137 | + """训练模型""" | ||
| 138 | + logger.info("开始训练模型...") | ||
| 139 | + | ||
| 140 | + # 如果没有提供验证集,从训练集中划分 | ||
| 141 | + if val_texts is None or val_labels is None: | ||
| 142 | + train_texts, val_texts, train_labels, val_labels = train_test_split( | ||
| 143 | + train_texts, train_labels, test_size=validation_split, random_state=42 | ||
| 144 | + ) | ||
| 145 | + | ||
| 146 | + # 创建数据集和数据加载器 | ||
| 147 | + train_dataset = TextDataset(train_texts, train_labels, self.tokenizer) | ||
| 148 | + val_dataset = TextDataset(val_texts, val_labels, self.tokenizer) | ||
| 149 | + | ||
| 150 | + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | ||
| 151 | + val_dataloader = DataLoader(val_dataset, batch_size=batch_size) | ||
| 152 | + | ||
| 153 | + # 优化器和损失函数 | ||
| 154 | + optimizer = optim.Adam(self.model.parameters(), lr=learning_rate) | ||
| 155 | + criterion = nn.CrossEntropyLoss() | ||
| 156 | + | ||
| 157 | + # 训练循环 | ||
| 158 | + best_val_loss = float('inf') | ||
| 159 | + for epoch in range(epochs): | ||
| 160 | + # 训练模式 | ||
| 161 | + self.model.train() | ||
| 162 | + train_loss = 0 | ||
| 163 | + train_preds = [] | ||
| 164 | + train_labels_list = [] | ||
| 165 | + | ||
| 166 | + for batch in train_dataloader: | ||
| 167 | + # 获取数据 | ||
| 168 | + input_ids = batch['input_ids'].to(self.device) | ||
| 169 | + attention_mask = batch['attention_mask'].to(self.device) | ||
| 170 | + labels = batch['label'].to(self.device) | ||
| 171 | + | ||
| 172 | + # 前向传播 | ||
| 173 | + optimizer.zero_grad() | ||
| 174 | + outputs = self.model(input_ids, attention_mask) | ||
| 175 | + | ||
| 176 | + # 计算损失 | ||
| 177 | + loss = criterion(outputs, labels) | ||
| 178 | + train_loss += loss.item() | ||
| 179 | + | ||
| 180 | + # 反向传播 | ||
| 181 | + loss.backward() | ||
| 182 | + optimizer.step() | ||
| 183 | + | ||
| 184 | + # 收集预测和标签 | ||
| 185 | + _, predicted = torch.max(outputs, 1) | ||
| 186 | + train_preds.extend(predicted.cpu().numpy()) | ||
| 187 | + train_labels_list.extend(labels.cpu().numpy()) | ||
| 188 | + | ||
| 189 | + # 计算训练集的评估指标 | ||
| 190 | + train_accuracy = accuracy_score(train_labels_list, train_preds) | ||
| 191 | + train_f1 = f1_score(train_labels_list, train_preds, average='macro') | ||
| 192 | + | ||
| 193 | + # 验证模式 | ||
| 194 | + self.model.eval() | ||
| 195 | + val_loss = 0 | ||
| 196 | + val_preds = [] | ||
| 197 | + val_labels_list = [] | ||
| 198 | + | ||
| 199 | + with torch.no_grad(): | ||
| 200 | + for batch in val_dataloader: | ||
| 201 | + input_ids = batch['input_ids'].to(self.device) | ||
| 202 | + attention_mask = batch['attention_mask'].to(self.device) | ||
| 203 | + labels = batch['label'].to(self.device) | ||
| 204 | + | ||
| 205 | + outputs = self.model(input_ids, attention_mask) | ||
| 206 | + loss = criterion(outputs, labels) | ||
| 207 | + val_loss += loss.item() | ||
| 208 | + | ||
| 209 | + _, predicted = torch.max(outputs, 1) | ||
| 210 | + val_preds.extend(predicted.cpu().numpy()) | ||
| 211 | + val_labels_list.extend(labels.cpu().numpy()) | ||
| 212 | + | ||
| 213 | + # 计算验证集的评估指标 | ||
| 214 | + val_accuracy = accuracy_score(val_labels_list, val_preds) | ||
| 215 | + val_f1 = f1_score(val_labels_list, val_preds, average='macro') | ||
| 216 | + | ||
| 217 | + # 计算平均损失 | ||
| 218 | + train_loss /= len(train_dataloader) | ||
| 219 | + val_loss /= len(val_dataloader) | ||
| 220 | + | ||
| 221 | + logger.info(f'Epoch {epoch+1}/{epochs} | ' | ||
| 222 | + f'Train Loss: {train_loss:.4f} | ' | ||
| 223 | + f'Train Acc: {train_accuracy:.4f} | ' | ||
| 224 | + f'Train F1: {train_f1:.4f} | ' | ||
| 225 | + f'Val Loss: {val_loss:.4f} | ' | ||
| 226 | + f'Val Acc: {val_accuracy:.4f} | ' | ||
| 227 | + f'Val F1: {val_f1:.4f}') | ||
| 228 | + | ||
| 229 | + # 保存最佳模型 | ||
| 230 | + if val_loss < best_val_loss and self.model_save_path: | ||
| 231 | + best_val_loss = val_loss | ||
| 232 | + torch.save(self.model.state_dict(), self.model_save_path) | ||
| 233 | + logger.info(f"模型已保存到 {self.model_save_path}") | ||
| 234 | + | ||
| 235 | + # 如果有保存路径但没有保存过模型,保存最后一轮的模型 | ||
| 236 | + if self.model_save_path and best_val_loss == float('inf'): | ||
| 237 | + torch.save(self.model.state_dict(), self.model_save_path) | ||
| 238 | + logger.info(f"最终模型已保存到 {self.model_save_path}") | ||
| 239 | + | ||
| 240 | + return train_loss, val_loss, val_accuracy, val_f1 | ||
| 241 | + | ||
| 242 | + def evaluate(self, test_texts, test_labels, batch_size=32): | ||
| 243 | + """评估模型""" | ||
| 244 | + logger.info("评估模型...") | ||
| 245 | + | ||
| 246 | + # 创建测试数据集和数据加载器 | ||
| 247 | + test_dataset = TextDataset(test_texts, test_labels, self.tokenizer) | ||
| 248 | + test_dataloader = DataLoader(test_dataset, batch_size=batch_size) | ||
| 249 | + | ||
| 250 | + # 设置为评估模式 | ||
| 251 | + self.model.eval() | ||
| 252 | + | ||
| 253 | + # 损失函数 | ||
| 254 | + criterion = nn.CrossEntropyLoss() | ||
| 255 | + test_loss = 0 | ||
| 256 | + test_preds = [] | ||
| 257 | + test_probs = [] | ||
| 258 | + test_labels_list = [] | ||
| 259 | + | ||
| 260 | + with torch.no_grad(): | ||
| 261 | + for batch in test_dataloader: | ||
| 262 | + input_ids = batch['input_ids'].to(self.device) | ||
| 263 | + attention_mask = batch['attention_mask'].to(self.device) | ||
| 264 | + labels = batch['label'].to(self.device) | ||
| 265 | + | ||
| 266 | + outputs = self.model(input_ids, attention_mask) | ||
| 267 | + loss = criterion(outputs, labels) | ||
| 268 | + test_loss += loss.item() | ||
| 269 | + | ||
| 270 | + probs = torch.softmax(outputs, dim=1) | ||
| 271 | + _, predicted = torch.max(outputs, 1) | ||
| 272 | + | ||
| 273 | + test_preds.extend(predicted.cpu().numpy()) | ||
| 274 | + test_probs.extend(probs.cpu().numpy()) | ||
| 275 | + test_labels_list.extend(labels.cpu().numpy()) | ||
| 276 | + | ||
| 277 | + # 计算平均损失 | ||
| 278 | + test_loss /= len(test_dataloader) | ||
| 279 | + | ||
| 280 | + # 计算评估指标 | ||
| 281 | + accuracy = accuracy_score(test_labels_list, test_preds) | ||
| 282 | + precision = precision_score(test_labels_list, test_preds, average='macro') | ||
| 283 | + recall = recall_score(test_labels_list, test_preds, average='macro') | ||
| 284 | + f1 = f1_score(test_labels_list, test_preds, average='macro') | ||
| 285 | + conf_matrix = confusion_matrix(test_labels_list, test_preds) | ||
| 286 | + | ||
| 287 | + logger.info(f'Test Loss: {test_loss:.4f}') | ||
| 288 | + logger.info(f'Accuracy: {accuracy:.4f}') | ||
| 289 | + logger.info(f'Precision: {precision:.4f}') | ||
| 290 | + logger.info(f'Recall: {recall:.4f}') | ||
| 291 | + logger.info(f'F1 Score: {f1:.4f}') | ||
| 292 | + logger.info(f'Confusion Matrix:\n{conf_matrix}') | ||
| 293 | + | ||
| 294 | + return { | ||
| 295 | + 'loss': test_loss, | ||
| 296 | + 'accuracy': accuracy, | ||
| 297 | + 'precision': precision, | ||
| 298 | + 'recall': recall, | ||
| 299 | + 'f1': f1, | ||
| 300 | + 'confusion_matrix': conf_matrix, | ||
| 301 | + 'predictions': test_preds, | ||
| 302 | + 'probabilities': test_probs | ||
| 303 | + } | ||
| 304 | + | ||
| 305 | + def predict_batch(self, texts, batch_size=32): | ||
| 306 | + """批量预测文本的情感""" | ||
| 307 | + if not texts: | ||
| 308 | + return None, None | ||
| 309 | + | ||
| 310 | + # 确保文本是列表格式 | ||
| 311 | + if isinstance(texts, str): | ||
| 312 | + texts = [texts] | ||
| 313 | + | ||
| 314 | + # 创建数据集(没有标签,使用占位符) | ||
| 315 | + dummy_labels = [0] * len(texts) | ||
| 316 | + dataset = TextDataset(texts, dummy_labels, self.tokenizer) | ||
| 317 | + dataloader = DataLoader(dataset, batch_size=batch_size) | ||
| 318 | + | ||
| 319 | + # 设置为评估模式 | ||
| 320 | + self.model.eval() | ||
| 321 | + | ||
| 322 | + all_preds = [] | ||
| 323 | + all_probs = [] | ||
| 324 | + | ||
| 325 | + with torch.no_grad(): | ||
| 326 | + for batch in dataloader: | ||
| 327 | + input_ids = batch['input_ids'].to(self.device) | ||
| 328 | + attention_mask = batch['attention_mask'].to(self.device) | ||
| 329 | + | ||
| 330 | + outputs = self.model(input_ids, attention_mask) | ||
| 331 | + probs = torch.softmax(outputs, dim=1) | ||
| 332 | + _, predicted = torch.max(outputs, 1) | ||
| 333 | + | ||
| 334 | + all_preds.extend(predicted.cpu().numpy()) | ||
| 335 | + all_probs.extend(probs.cpu().numpy()) | ||
| 336 | + | ||
| 337 | + return all_preds, all_probs | ||
| 338 | + | ||
| 339 | + def predict(self, text): | ||
| 340 | + """预测单个文本的情感""" | ||
| 341 | + predictions, probabilities = self.predict_batch([text]) | ||
| 342 | + if predictions is not None and len(predictions) > 0: | ||
| 343 | + return predictions[0], probabilities[0] | ||
| 344 | + return None, None | ||
| 345 | + | ||
| 346 | +# 创建全局模型实例 | ||
| 347 | +lstm_model_manager = LSTMModelManager( | ||
| 348 | + bert_model_path='model_pro/bert_model', | ||
| 349 | + model_save_path='model_pro/lstm_model.pt' | ||
| 350 | +) | ||
| 351 | + | ||
| 352 | +# 测试代码 | ||
| 353 | +if __name__ == "__main__": | ||
| 354 | + # 加载数据 | ||
| 355 | + train_data = pd.read_csv('model_pro/train.csv') | ||
| 356 | + dev_data = pd.read_csv('model_pro/dev.csv') | ||
| 357 | + test_data = pd.read_csv('model_pro/test.csv') | ||
| 358 | + | ||
| 359 | + # 处理数据 | ||
| 360 | + train_texts = train_data['text'].values | ||
| 361 | + train_labels = train_data['label'].values | ||
| 362 | + | ||
| 363 | + dev_texts = dev_data['text'].values | ||
| 364 | + dev_labels = dev_data['label'].values | ||
| 365 | + | ||
| 366 | + test_texts = test_data['text'].values | ||
| 367 | + test_labels = test_data['label'].values | ||
| 368 | + | ||
| 369 | + # 训练模型 | ||
| 370 | + lstm_model_manager.train( | ||
| 371 | + train_texts, train_labels, | ||
| 372 | + val_texts=dev_texts, val_labels=dev_labels, | ||
| 373 | + batch_size=32, epochs=5 | ||
| 374 | + ) | ||
| 375 | + | ||
| 376 | + # 评估模型 | ||
| 377 | + results = lstm_model_manager.evaluate(test_texts, test_labels) | ||
| 378 | + | ||
| 379 | + # 测试预测功能 | ||
| 380 | + test_sentences = [ | ||
| 381 | + "这件事情做得非常好", | ||
| 382 | + "服务太差了,态度恶劣", | ||
| 383 | + "这个产品质量一般,但价格便宜", | ||
| 384 | + "我对这家公司非常满意", | ||
| 385 | + ] | ||
| 386 | + | ||
| 387 | + for sentence in test_sentences: | ||
| 388 | + pred, prob = lstm_model_manager.predict(sentence) | ||
| 389 | + label = '良好' if pred == 0 else '不良' | ||
| 390 | + confidence = prob[pred] | ||
| 391 | + print(f"句子: '{sentence}' 预测结果: {label} (置信度: {confidence:.2%})") |
model_pro/lstm_predict.py
0 → 100644
| 1 | +import torch | ||
| 2 | +import os | ||
| 3 | +import logging | ||
| 4 | +from LSTM_model import lstm_model_manager | ||
| 5 | + | ||
| 6 | +# 配置日志记录 | ||
| 7 | +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | ||
| 8 | +logger = logging.getLogger('lstm_predict') | ||
| 9 | + | ||
| 10 | +class LSTMPredictor: | ||
| 11 | + """LSTM预测器,与当前系统的预测接口兼容""" | ||
| 12 | + | ||
| 13 | + def __init__(self): | ||
| 14 | + self.model_loaded = False | ||
| 15 | + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| 16 | + logger.info(f"初始化LSTM预测器,使用设备: {self.device}") | ||
| 17 | + | ||
| 18 | + def load_models(self, model_save_path, bert_model_path, tokenizer_path=None): | ||
| 19 | + """ | ||
| 20 | + 加载模型,与当前系统的model_manager.load_models接口兼容 | ||
| 21 | + | ||
| 22 | + 参数: | ||
| 23 | + model_save_path: LSTM模型保存路径 | ||
| 24 | + bert_model_path: BERT模型路径 | ||
| 25 | + tokenizer_path: 分词器路径(LSTM模型中使用BERT的分词器,可忽略) | ||
| 26 | + """ | ||
| 27 | + try: | ||
| 28 | + # 检查模型文件是否存在 | ||
| 29 | + if not os.path.exists(model_save_path): | ||
| 30 | + logger.warning(f"模型文件 {model_save_path} 不存在,需要先训练模型") | ||
| 31 | + return False | ||
| 32 | + | ||
| 33 | + if not os.path.exists(bert_model_path): | ||
| 34 | + logger.error(f"BERT模型路径 {bert_model_path} 不存在") | ||
| 35 | + return False | ||
| 36 | + | ||
| 37 | + # 实际上我们在lstm_model_manager初始化时已经加载了模型,这里只是检查一下 | ||
| 38 | + if lstm_model_manager.model is not None: | ||
| 39 | + self.model_loaded = True | ||
| 40 | + logger.info("LSTM模型已加载成功") | ||
| 41 | + return True | ||
| 42 | + else: | ||
| 43 | + logger.error("LSTM模型加载失败") | ||
| 44 | + return False | ||
| 45 | + except Exception as e: | ||
| 46 | + logger.error(f"加载模型过程中出错: {e}") | ||
| 47 | + return False | ||
| 48 | + | ||
| 49 | + def predict_batch(self, texts): | ||
| 50 | + """ | ||
| 51 | + 批量预测文本的情感 | ||
| 52 | + | ||
| 53 | + 参数: | ||
| 54 | + texts: 文本列表 | ||
| 55 | + | ||
| 56 | + 返回: | ||
| 57 | + predictions: 预测结果列表(0表示良好,1表示不良) | ||
| 58 | + probabilities: 预测概率列表 | ||
| 59 | + """ | ||
| 60 | + if not self.model_loaded and lstm_model_manager.model is None: | ||
| 61 | + logger.error("模型未加载,无法进行预测") | ||
| 62 | + return None, None | ||
| 63 | + | ||
| 64 | + if not texts: | ||
| 65 | + logger.warning("未提供文本,无法进行预测") | ||
| 66 | + return None, None | ||
| 67 | + | ||
| 68 | + try: | ||
| 69 | + # 调用LSTM模型管理器的批量预测函数 | ||
| 70 | + predictions, probabilities = lstm_model_manager.predict_batch(texts) | ||
| 71 | + return predictions, probabilities | ||
| 72 | + except Exception as e: | ||
| 73 | + logger.error(f"预测过程中出错: {e}") | ||
| 74 | + return None, None | ||
| 75 | + | ||
| 76 | + def predict(self, text): | ||
| 77 | + """ | ||
| 78 | + 预测单个文本的情感 | ||
| 79 | + | ||
| 80 | + 参数: | ||
| 81 | + text: 文本字符串 | ||
| 82 | + | ||
| 83 | + 返回: | ||
| 84 | + prediction: 预测结果(0表示良好,1表示不良) | ||
| 85 | + probability: 预测概率 | ||
| 86 | + """ | ||
| 87 | + if not self.model_loaded and lstm_model_manager.model is None: | ||
| 88 | + logger.error("模型未加载,无法进行预测") | ||
| 89 | + return None, None | ||
| 90 | + | ||
| 91 | + if not text or len(text.strip()) == 0: | ||
| 92 | + logger.warning("未提供文本或文本为空,无法进行预测") | ||
| 93 | + return None, None | ||
| 94 | + | ||
| 95 | + try: | ||
| 96 | + # 调用LSTM模型管理器的单个文本预测函数 | ||
| 97 | + prediction, probability = lstm_model_manager.predict(text) | ||
| 98 | + return prediction, probability | ||
| 99 | + except Exception as e: | ||
| 100 | + logger.error(f"预测过程中出错: {e}") | ||
| 101 | + return None, None | ||
| 102 | + | ||
| 103 | + def train_model(self, train_texts, train_labels, val_texts=None, val_labels=None, | ||
| 104 | + batch_size=32, learning_rate=2e-5, epochs=10): | ||
| 105 | + """ | ||
| 106 | + 训练模型 | ||
| 107 | + | ||
| 108 | + 参数: | ||
| 109 | + train_texts: 训练集文本 | ||
| 110 | + train_labels: 训练集标签 | ||
| 111 | + val_texts: 验证集文本 | ||
| 112 | + val_labels: 验证集标签 | ||
| 113 | + batch_size: 批次大小 | ||
| 114 | + learning_rate: 学习率 | ||
| 115 | + epochs: 训练轮数 | ||
| 116 | + | ||
| 117 | + 返回: | ||
| 118 | + 训练结果 | ||
| 119 | + """ | ||
| 120 | + try: | ||
| 121 | + results = lstm_model_manager.train( | ||
| 122 | + train_texts, train_labels, val_texts, val_labels, | ||
| 123 | + batch_size, learning_rate, epochs | ||
| 124 | + ) | ||
| 125 | + self.model_loaded = True | ||
| 126 | + return results | ||
| 127 | + except Exception as e: | ||
| 128 | + logger.error(f"训练模型过程中出错: {e}") | ||
| 129 | + return None | ||
| 130 | + | ||
| 131 | +# 创建全局预测器实例 | ||
| 132 | +lstm_predictor = LSTMPredictor() | ||
| 133 | + | ||
| 134 | +# 为了与现有代码兼容,提供一个与model_manager相同的predict_batch函数 | ||
| 135 | +def predict_batch(texts): | ||
| 136 | + return lstm_predictor.predict_batch(texts) | ||
| 137 | + | ||
| 138 | +# 为了与现有代码兼容,提供一个与model_manager相同的load_models函数 | ||
| 139 | +def load_models(model_save_path, bert_model_path, tokenizer_path=None): | ||
| 140 | + return lstm_predictor.load_models(model_save_path, bert_model_path, tokenizer_path) | ||
| 141 | + | ||
| 142 | +# 测试代码 | ||
| 143 | +if __name__ == "__main__": | ||
| 144 | + # 加载模型 | ||
| 145 | + load_models( | ||
| 146 | + model_save_path="model_pro/lstm_model.pt", | ||
| 147 | + bert_model_path="model_pro/bert_model" | ||
| 148 | + ) | ||
| 149 | + | ||
| 150 | + # 测试预测功能 | ||
| 151 | + test_sentences = [ | ||
| 152 | + "这件事情做得非常好", | ||
| 153 | + "服务太差了,态度恶劣", | ||
| 154 | + "这个产品质量一般,但价格便宜", | ||
| 155 | + "我对这家公司非常满意", | ||
| 156 | + ] | ||
| 157 | + | ||
| 158 | + for sentence in test_sentences: | ||
| 159 | + pred, prob = lstm_predictor.predict(sentence) | ||
| 160 | + if pred is not None: | ||
| 161 | + label = '良好' if pred == 0 else '不良' | ||
| 162 | + confidence = prob[pred] | ||
| 163 | + print(f"句子: '{sentence}' 预测结果: {label} (置信度: {confidence:.2%})") | ||
| 164 | + else: | ||
| 165 | + print(f"句子: '{sentence}' 预测失败") |
| @@ -20,6 +20,7 @@ from functools import wraps | @@ -20,6 +20,7 @@ from functools import wraps | ||
| 20 | import bleach | 20 | import bleach |
| 21 | import re | 21 | import re |
| 22 | from datetime import datetime, timedelta | 22 | from datetime import datetime, timedelta |
| 23 | +from model_pro.lstm_predict import lstm_predictor | ||
| 23 | 24 | ||
| 24 | pb = Blueprint('page', | 25 | pb = Blueprint('page', |
| 25 | __name__, | 26 | __name__, |
| @@ -75,12 +76,15 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | @@ -75,12 +76,15 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| 75 | 76 | ||
| 76 | # 设置模型路径 | 77 | # 设置模型路径 |
| 77 | model_save_path = 'model_pro/final_model.pt' | 78 | model_save_path = 'model_pro/final_model.pt' |
| 79 | +lstm_model_path = 'model_pro/lstm_model.pt' | ||
| 78 | bert_model_path = 'model_pro/bert_model' | 80 | bert_model_path = 'model_pro/bert_model' |
| 79 | ctm_tokenizer_path = 'model_pro/sentence_bert_model' | 81 | ctm_tokenizer_path = 'model_pro/sentence_bert_model' |
| 80 | 82 | ||
| 81 | # 初始化模型 | 83 | # 初始化模型 |
| 82 | try: | 84 | try: |
| 83 | model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path) | 85 | model_manager.load_models(model_save_path, bert_model_path, ctm_tokenizer_path) |
| 86 | + # 同时初始化LSTM模型 | ||
| 87 | + lstm_predictor.load_models(lstm_model_path, bert_model_path) | ||
| 84 | except Exception as e: | 88 | except Exception as e: |
| 85 | logging.error(f"模型加载失败: {e}") | 89 | logging.error(f"模型加载失败: {e}") |
| 86 | 90 | ||
| @@ -315,7 +319,7 @@ def yuqingpredict(): | @@ -315,7 +319,7 @@ def yuqingpredict(): | ||
| 315 | X, Y = getTopicCreatedAtandpredictData(defaultTopic) | 319 | X, Y = getTopicCreatedAtandpredictData(defaultTopic) |
| 316 | 320 | ||
| 317 | model_type = sanitize_input(request.args.get('model', 'pro')) | 321 | model_type = sanitize_input(request.args.get('model', 'pro')) |
| 318 | - if model_type not in ['pro', 'basic']: | 322 | + if model_type not in ['pro', 'basic', 'lstm']: |
| 319 | return abort(400, "无效的模型类型") | 323 | return abort(400, "无效的模型类型") |
| 320 | 324 | ||
| 321 | # 尝试从缓存获取预测结果 | 325 | # 尝试从缓存获取预测结果 |
| @@ -333,6 +337,14 @@ def yuqingpredict(): | @@ -333,6 +337,14 @@ def yuqingpredict(): | ||
| 333 | sentences = '正面' | 337 | sentences = '正面' |
| 334 | elif value < 0.5: | 338 | elif value < 0.5: |
| 335 | sentences = '负面' | 339 | sentences = '负面' |
| 340 | + elif model_type == 'lstm': | ||
| 341 | + predicted_label, confidence = lstm_predictor.predict(defaultTopic) | ||
| 342 | + if predicted_label is not None: | ||
| 343 | + sentences = '良好' if predicted_label == 0 else '不良' | ||
| 344 | + sentences = f"{sentences} (LSTM置信度: {confidence[predicted_label]:.2%})" | ||
| 345 | + else: | ||
| 346 | + sentences = 'LSTM预测失败,请稍后重试' | ||
| 347 | + logging.error(f"LSTM预测失败,话题: {defaultTopic}") | ||
| 336 | else: | 348 | else: |
| 337 | predicted_label, confidence = predict_sentiment(defaultTopic) | 349 | predicted_label, confidence = predict_sentiment(defaultTopic) |
| 338 | if predicted_label is not None: | 350 | if predicted_label is not None: |
| @@ -165,23 +165,10 @@ | @@ -165,23 +165,10 @@ | ||
| 165 | <div class="col-lg-12"> | 165 | <div class="col-lg-12"> |
| 166 | <div class="form-group"> | 166 | <div class="form-group"> |
| 167 | <label for="modelSelect">选择分析模型:</label> | 167 | <label for="modelSelect">选择分析模型:</label> |
| 168 | - <select class="form-control" id="modelSelect" onchange="updateModel(this.value)"> | ||
| 169 | - <optgroup label="基础模型"> | ||
| 170 | - <option value="basic" {% if model_type == 'basic' %}selected{% endif %}>SnowNLP</option> | ||
| 171 | - </optgroup> | ||
| 172 | - <optgroup label="OpenAI 模型"> | ||
| 173 | - <option value="gpt-3.5-turbo" {% if model_type == 'gpt-3.5-turbo' %}selected{% endif %}>GPT-3.5-Turbo</option> | ||
| 174 | - <option value="gpt-4" {% if model_type == 'gpt-4' %}selected{% endif %}>GPT-4</option> | ||
| 175 | - </optgroup> | ||
| 176 | - <optgroup label="Claude 模型"> | ||
| 177 | - <option value="claude-3-opus-20240229" {% if model_type == 'claude-3-opus-20240229' %}selected{% endif %}>Claude-3 Opus</option> | ||
| 178 | - <option value="claude-3-sonnet-20240229" {% if model_type == 'claude-3-sonnet-20240229' %}selected{% endif %}>Claude-3 Sonnet</option> | ||
| 179 | - <option value="claude-3-haiku-20240307" {% if model_type == 'claude-3-haiku-20240307' %}selected{% endif %}>Claude-3 Haiku</option> | ||
| 180 | - </optgroup> | ||
| 181 | - <optgroup label="DeepSeek 模型"> | ||
| 182 | - <option value="deepseek-chat" {% if model_type == 'deepseek-chat' %}selected{% endif %}>DeepSeek-V3</option> | ||
| 183 | - <option value="deepseek-reasoner" {% if model_type == 'deepseek-reasoner' %}selected{% endif %}>DeepSeek-R1</option> | ||
| 184 | - </optgroup> | 168 | + <select class="custom-select" onchange="updateModel(this.value)"> |
| 169 | + <option value="basic" {% if model_type == 'basic' %}selected{% endif %}>基础模型 (SnowNLP)</option> | ||
| 170 | + <option value="pro" {% if model_type == 'pro' %}selected{% endif %}>进阶模型 (BERT+CTM)</option> | ||
| 171 | + <option value="lstm" {% if model_type == 'lstm' %}selected{% endif %}>LSTM模型 (新增)</option> | ||
| 185 | </select> | 172 | </select> |
| 186 | </div> | 173 | </div> |
| 187 | </div> | 174 | </div> |
-
Please register or login to post a comment