戒酒的李白

Enhance LSTM model for small datasets and improve performance.

@@ -4,12 +4,19 @@ import torch.optim as optim @@ -4,12 +4,19 @@ import torch.optim as optim
4 from torch.utils.data import Dataset, DataLoader 4 from torch.utils.data import Dataset, DataLoader
5 import numpy as np 5 import numpy as np
6 import pandas as pd 6 import pandas as pd
7 -from sklearn.model_selection import train_test_split 7 +from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
8 from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix 8 from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
  9 +from sklearn.feature_extraction.text import TfidfVectorizer
  10 +from sklearn.linear_model import LogisticRegression
9 import jieba 11 import jieba
10 -from transformers import BertTokenizer 12 +from transformers import BertTokenizer, BertModel
11 import logging 13 import logging
12 import os 14 import os
  15 +import random
  16 +from torch.optim.lr_scheduler import ReduceLROnPlateau
  17 +from gensim.models import KeyedVectors
  18 +import json
  19 +import torch.nn.functional as F
13 20
14 # 配置日志记录 21 # 配置日志记录
15 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 22 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -49,17 +56,90 @@ class TextDataset(Dataset): @@ -49,17 +56,90 @@ class TextDataset(Dataset):
49 'label': torch.tensor(label, dtype=torch.long) 56 'label': torch.tensor(label, dtype=torch.long)
50 } 57 }
51 58
  59 +class AttentionLayer(nn.Module):
  60 + """注意力层"""
  61 + def __init__(self, hidden_dim):
  62 + super().__init__()
  63 + self.attention = nn.Linear(hidden_dim, 1)
  64 +
  65 + def forward(self, lstm_output):
  66 + attention_weights = torch.softmax(self.attention(lstm_output), dim=1)
  67 + context_vector = torch.sum(attention_weights * lstm_output, dim=1)
  68 + return context_vector, attention_weights
  69 +
  70 +# 添加数据增强类
  71 +class TextAugmenter:
  72 + def __init__(self, language='zh', synonyms_file=None):
  73 + self.language = language
  74 + self.synonyms_dict = self._load_synonyms(synonyms_file)
  75 +
  76 + def _load_synonyms(self, file_path):
  77 + base_dict = {
  78 + "很好": ["非常好", "太好了", "特别好", "相当好", "真不错"],
  79 + "糟糕": ["差劲", "很差", "不好", "太差", "糟透了"],
  80 + "一般": ["还行", "凑合", "普通", "马马虎虎", "中等"],
  81 + "满意": ["很满意", "挺好", "不错", "称心如意"],
  82 + "生气": ["愤怒", "恼火", "不爽", "气愤"],
  83 + "失望": ["伤心", "难过", "不满意", "遗憾"],
  84 + # 添加更多情感词汇对
  85 + }
  86 +
  87 + if file_path and os.path.exists(file_path):
  88 + try:
  89 + with open(file_path, 'r', encoding='utf-8') as f:
  90 + custom_dict = json.load(f)
  91 + base_dict.update(custom_dict)
  92 + except Exception as e:
  93 + logger.warning(f"加载同义词典失败: {e}")
  94 +
  95 + return base_dict
  96 +
  97 + def synonym_replacement(self, text, n=1):
  98 + words = list(jieba.cut(text))
  99 + new_words = words.copy()
  100 + num_replaced = 0
  101 +
  102 + for word in list(set(words)):
  103 + if len(word) > 1 and num_replaced < n:
  104 + synonyms = self._get_synonyms(word)
  105 + if synonyms:
  106 + synonym = random.choice(synonyms)
  107 + new_words = [synonym if w == word else w for w in new_words]
  108 + num_replaced += 1
  109 +
  110 + return ''.join(new_words)
  111 +
  112 + def _get_synonyms(self, word):
  113 + return self.synonyms_dict.get(word, [])
  114 +
  115 + def augment(self, texts, labels, augment_ratio=0.5):
  116 + augmented_texts = []
  117 + augmented_labels = []
  118 +
  119 + for text, label in zip(texts, labels):
  120 + augmented_texts.append(text)
  121 + augmented_labels.append(label)
  122 +
  123 + if random.random() < augment_ratio:
  124 + aug_text = self.synonym_replacement(text)
  125 + augmented_texts.append(aug_text)
  126 + augmented_labels.append(label)
  127 +
  128 + return np.array(augmented_texts), np.array(augmented_labels)
  129 +
52 class LSTMSentimentModel(nn.Module): 130 class LSTMSentimentModel(nn.Module):
53 """基于LSTM的情感分析模型""" 131 """基于LSTM的情感分析模型"""
54 132
55 - def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers=2,  
56 - bidirectional=True, dropout=0.5, pad_idx=0): 133 + def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers=1,
  134 + bidirectional=True, dropout=0.3, pad_idx=0, pretrained_embeddings=None):
57 super().__init__() 135 super().__init__()
58 136
59 # 嵌入层 137 # 嵌入层
60 self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx) 138 self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
  139 + if pretrained_embeddings is not None:
  140 + self.embedding.weight.data.copy_(pretrained_embeddings)
  141 + self.embedding.weight.requires_grad = True
61 142
62 - # LSTM层  
63 self.lstm = nn.LSTM( 143 self.lstm = nn.LSTM(
64 embedding_dim, 144 embedding_dim,
65 hidden_dim, 145 hidden_dim,
@@ -69,11 +149,17 @@ class LSTMSentimentModel(nn.Module): @@ -69,11 +149,17 @@ class LSTMSentimentModel(nn.Module):
69 batch_first=True 149 batch_first=True
70 ) 150 )
71 151
72 - # 全连接层,如果是双向LSTM,输入维度需要翻倍  
73 - self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim) 152 + # 注意力层
  153 + self.attention = AttentionLayer(hidden_dim * 2 if bidirectional else hidden_dim)
  154 +
  155 + # 全连接层
  156 + fc_dim = hidden_dim * 2 if bidirectional else hidden_dim
  157 + self.fc1 = nn.Linear(fc_dim, fc_dim // 2)
  158 + self.fc2 = nn.Linear(fc_dim // 2, output_dim)
74 159
75 - # Dropout层  
76 self.dropout = nn.Dropout(dropout) 160 self.dropout = nn.Dropout(dropout)
  161 + self.bn = nn.BatchNorm1d(fc_dim // 2)
  162 + self.relu = nn.ReLU()
77 163
78 def forward(self, text, attention_mask=None): 164 def forward(self, text, attention_mask=None):
79 # 文本通过嵌入层 [batch_size, seq_len] -> [batch_size, seq_len, embedding_dim] 165 # 文本通过嵌入层 [batch_size, seq_len] -> [batch_size, seq_len, embedding_dim]
@@ -95,27 +181,49 @@ class LSTMSentimentModel(nn.Module): @@ -95,27 +181,49 @@ class LSTMSentimentModel(nn.Module):
95 else: 181 else:
96 output, (hidden, cell) = self.lstm(embedded) 182 output, (hidden, cell) = self.lstm(embedded)
97 183
98 - # 如果是双向LSTM,需要拼接最后一层的前向和后向隐藏状态  
99 - if self.lstm.bidirectional:  
100 - hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)  
101 - else:  
102 - hidden = hidden[-1] 184 + # 应用注意力机制
  185 + context_vector, attention_weights = self.attention(output)
103 186
104 - # 应用dropout  
105 - hidden = self.dropout(hidden) 187 + # 应用dropout和全连接层
  188 + x = self.dropout(context_vector)
  189 + x = self.fc1(x)
  190 + x = self.bn(x)
  191 + x = self.relu(x)
  192 + x = self.dropout(x)
  193 + x = self.fc2(x)
106 194
107 - # 全连接层  
108 - return self.fc(hidden) 195 + return x, attention_weights
  196 +
  197 +# 添加早停类
  198 +class EarlyStopping:
  199 + def __init__(self, patience=5, min_delta=0):
  200 + self.patience = patience
  201 + self.min_delta = min_delta
  202 + self.counter = 0
  203 + self.best_loss = None
  204 + self.early_stop = False
  205 +
  206 + def __call__(self, val_loss):
  207 + if self.best_loss is None:
  208 + self.best_loss = val_loss
  209 + elif val_loss > self.best_loss - self.min_delta:
  210 + self.counter += 1
  211 + if self.counter >= self.patience:
  212 + self.early_stop = True
  213 + else:
  214 + self.best_loss = val_loss
  215 + self.counter = 0
109 216
110 class LSTMModelManager: 217 class LSTMModelManager:
111 """LSTM模型管理类,用于训练、评估和预测""" 218 """LSTM模型管理类,用于训练、评估和预测"""
112 219
113 def __init__(self, bert_model_path, model_save_path=None, vocab_size=30522, 220 def __init__(self, bert_model_path, model_save_path=None, vocab_size=30522,
114 - embedding_dim=128, hidden_dim=256, output_dim=2, n_layers=2,  
115 - bidirectional=True, dropout=0.5): 221 + embedding_dim=100, hidden_dim=64, output_dim=2, n_layers=1,
  222 + bidirectional=True, dropout=0.3, word2vec_path=None):
116 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 223 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
117 self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) 224 self.tokenizer = BertTokenizer.from_pretrained(bert_model_path)
118 self.vocab_size = vocab_size 225 self.vocab_size = vocab_size
  226 + self.embedding_dim = embedding_dim
119 self.model = LSTMSentimentModel( 227 self.model = LSTMSentimentModel(
120 vocab_size=vocab_size, 228 vocab_size=vocab_size,
121 embedding_dim=embedding_dim, 229 embedding_dim=embedding_dim,
@@ -131,113 +239,202 @@ class LSTMModelManager: @@ -131,113 +239,202 @@ class LSTMModelManager:
131 if model_save_path and os.path.exists(model_save_path): 239 if model_save_path and os.path.exists(model_save_path):
132 self.model.load_state_dict(torch.load(model_save_path, map_location=self.device)) 240 self.model.load_state_dict(torch.load(model_save_path, map_location=self.device))
133 logger.info(f"已从 {model_save_path} 加载模型") 241 logger.info(f"已从 {model_save_path} 加载模型")
  242 +
  243 + self.augmenter = TextAugmenter()
  244 + self.early_stopping = EarlyStopping(patience=5)
  245 +
  246 + # 加载预训练词向量
  247 + self.pretrained_embeddings = None
  248 + if word2vec_path and os.path.exists(word2vec_path):
  249 + try:
  250 + word_vectors = KeyedVectors.load_word2vec_format(word2vec_path, binary=True)
  251 + self.pretrained_embeddings = self._build_embedding_matrix(word_vectors)
  252 + logger.info("成功加载预训练词向量")
  253 + except Exception as e:
  254 + logger.warning(f"加载预训练词向量失败: {e}")
  255 +
  256 + # 初始化对抗训练参数
  257 + self.epsilon = 0.01
  258 + self.alpha = 0.001
  259 +
  260 + def _build_embedding_matrix(self, word_vectors):
  261 + embedding_matrix = torch.zeros(self.vocab_size, self.embedding_dim)
  262 + for i in range(self.vocab_size):
  263 + try:
  264 + word = self.tokenizer.convert_ids_to_tokens(i)
  265 + if word in word_vectors:
  266 + embedding_matrix[i] = torch.tensor(word_vectors[word])
  267 + except:
  268 + continue
  269 + return embedding_matrix
  270 +
  271 + def adversarial_training(self, batch, criterion):
  272 + """对抗训练步骤"""
  273 + # 计算原始损失
  274 + input_ids = batch['input_ids'].to(self.device)
  275 + attention_mask = batch['attention_mask'].to(self.device)
  276 + labels = batch['label'].to(self.device)
  277 +
  278 + outputs, _ = self.model(input_ids, attention_mask)
  279 + loss = criterion(outputs, labels)
  280 +
  281 + # 计算梯度
  282 + loss.backward(retain_graph=True)
  283 +
  284 + # 获取嵌入层的梯度
  285 + grad_embed = self.model.embedding.weight.grad.data
  286 +
  287 + # 生成对抗扰动
  288 + perturb = self.epsilon * torch.sign(grad_embed)
  289 +
  290 + # 应用扰动
  291 + self.model.embedding.weight.data.add_(perturb)
  292 +
  293 + # 计算对抗损失
  294 + outputs_adv, _ = self.model(input_ids, attention_mask)
  295 + loss_adv = criterion(outputs_adv, labels)
  296 +
  297 + # 恢复原始嵌入
  298 + self.model.embedding.weight.data.sub_(perturb)
  299 +
  300 + return loss + self.alpha * loss_adv
  301 +
  302 + def train_logistic_regression(self, train_texts, train_labels, val_texts=None, val_labels=None):
  303 + vectorizer = TfidfVectorizer(max_features=5000)
  304 + X_train = vectorizer.fit_transform(train_texts)
  305 +
  306 + if val_texts is None:
  307 + X_train, X_val, y_train, y_val = train_test_split(
  308 + X_train, train_labels, test_size=0.2, stratify=train_labels
  309 + )
  310 + else:
  311 + X_val = vectorizer.transform(val_texts)
  312 + y_train, y_val = train_labels, val_labels
  313 +
  314 + lr_model = LogisticRegression(class_weight='balanced')
  315 + lr_model.fit(X_train, y_train)
  316 +
  317 + val_pred = lr_model.predict(X_val)
  318 + lr_accuracy = accuracy_score(y_val, val_pred)
  319 + lr_f1 = f1_score(y_val, val_pred, average='macro')
  320 +
  321 + return lr_accuracy, lr_f1
134 322
135 def train(self, train_texts, train_labels, val_texts=None, val_labels=None, 323 def train(self, train_texts, train_labels, val_texts=None, val_labels=None,
136 - batch_size=32, learning_rate=2e-5, epochs=10, validation_split=0.2): 324 + batch_size=16, epochs=10, learning_rate=2e-4):
137 """训练模型""" 325 """训练模型"""
138 logger.info("开始训练模型...") 326 logger.info("开始训练模型...")
139 327
140 - # 如果没有提供验证集,从训练集中划分  
141 - if val_texts is None or val_labels is None:  
142 - train_texts, val_texts, train_labels, val_labels = train_test_split(  
143 - train_texts, train_labels, test_size=validation_split, random_state=42  
144 - ) 328 + # 首先训练逻辑回归作为基线
  329 + lr_accuracy, lr_f1 = self.train_logistic_regression(train_texts, train_labels, val_texts, val_labels)
  330 + logger.info(f"逻辑回归基线模型 - 准确率: {lr_accuracy:.4f}, F1: {lr_f1:.4f}")
145 331
146 - # 创建数据集和数据加载器  
147 - train_dataset = TextDataset(train_texts, train_labels, self.tokenizer)  
148 - val_dataset = TextDataset(val_texts, val_labels, self.tokenizer) 332 + # 如果数据量小于1000,进行数据增强
  333 + if len(train_texts) < 1000:
  334 + train_texts, train_labels = self.augmenter.augment(train_texts, train_labels)
  335 + logger.info(f"数据增强后的训练集大小: {len(train_texts)}")
149 336
150 - train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)  
151 - val_dataloader = DataLoader(val_dataset, batch_size=batch_size) 337 + # 创建K折交叉验证
  338 + kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
  339 + fold_results = []
152 340
153 - # 优化器和损失函数  
154 - optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)  
155 - criterion = nn.CrossEntropyLoss()  
156 -  
157 - # 训练循环  
158 - best_val_loss = float('inf')  
159 - for epoch in range(epochs):  
160 - # 训练模式  
161 - self.model.train()  
162 - train_loss = 0  
163 - train_preds = []  
164 - train_labels_list = [] 341 + for fold, (train_idx, val_idx) in enumerate(kf.split(train_texts, train_labels)):
  342 + logger.info(f"训练第 {fold+1} 折...")
165 343
166 - for batch in train_dataloader:  
167 - # 获取数据  
168 - input_ids = batch['input_ids'].to(self.device)  
169 - attention_mask = batch['attention_mask'].to(self.device)  
170 - labels = batch['label'].to(self.device)  
171 -  
172 - # 前向传播  
173 - optimizer.zero_grad()  
174 - outputs = self.model(input_ids, attention_mask) 344 + # 重置模型
  345 + self.model = self._create_model()
  346 + optimizer = optim.AdamW(self.model.parameters(), lr=learning_rate)
  347 + scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=2)
  348 + criterion = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 1.0]).to(self.device))
  349 +
  350 + # 准备数据
  351 + X_train, X_val = train_texts[train_idx], train_texts[val_idx]
  352 + y_train, y_val = train_labels[train_idx], train_labels[val_idx]
  353 +
  354 + train_dataset = TextDataset(X_train, y_train, self.tokenizer)
  355 + val_dataset = TextDataset(X_val, y_val, self.tokenizer)
  356 +
  357 + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  358 + val_loader = DataLoader(val_dataset, batch_size=batch_size)
  359 +
  360 + best_val_loss = float('inf')
  361 + for epoch in range(epochs):
  362 + # 训练和验证逻辑
  363 + train_loss = self._train_epoch(train_loader, optimizer, criterion)
  364 + val_loss, val_acc, val_f1 = self._validate(val_loader, criterion)
175 365
176 - # 计算损失  
177 - loss = criterion(outputs, labels)  
178 - train_loss += loss.item() 366 + scheduler.step(val_loss)
179 367
180 - # 反向传播  
181 - loss.backward()  
182 - optimizer.step() 368 + if val_loss < best_val_loss:
  369 + best_val_loss = val_loss
  370 + if self.model_save_path:
  371 + torch.save(self.model.state_dict(),
  372 + f"{self.model_save_path}_fold{fold}.pt")
183 373
184 - # 收集预测和标签  
185 - _, predicted = torch.max(outputs, 1)  
186 - train_preds.extend(predicted.cpu().numpy())  
187 - train_labels_list.extend(labels.cpu().numpy()) 374 + if self.early_stopping(val_loss):
  375 + break
188 376
189 - # 计算训练集的评估指标  
190 - train_accuracy = accuracy_score(train_labels_list, train_preds)  
191 - train_f1 = f1_score(train_labels_list, train_preds, average='macro')  
192 -  
193 - # 验证模式  
194 - self.model.eval()  
195 - val_loss = 0  
196 - val_preds = []  
197 - val_labels_list = []  
198 -  
199 - with torch.no_grad():  
200 - for batch in val_dataloader:  
201 - input_ids = batch['input_ids'].to(self.device)  
202 - attention_mask = batch['attention_mask'].to(self.device)  
203 - labels = batch['label'].to(self.device)  
204 -  
205 - outputs = self.model(input_ids, attention_mask)  
206 - loss = criterion(outputs, labels)  
207 - val_loss += loss.item()  
208 -  
209 - _, predicted = torch.max(outputs, 1)  
210 - val_preds.extend(predicted.cpu().numpy())  
211 - val_labels_list.extend(labels.cpu().numpy()) 377 + fold_results.append({
  378 + 'val_loss': val_loss,
  379 + 'val_accuracy': val_acc,
  380 + 'val_f1': val_f1
  381 + })
  382 +
  383 + # 计算平均结果
  384 + avg_val_loss = np.mean([res['val_loss'] for res in fold_results])
  385 + avg_val_acc = np.mean([res['val_accuracy'] for res in fold_results])
  386 + avg_val_f1 = np.mean([res['val_f1'] for res in fold_results])
  387 +
  388 + logger.info(f"交叉验证平均结果 - 损失: {avg_val_loss:.4f}, 准确率: {avg_val_acc:.4f}, F1: {avg_val_f1:.4f}")
  389 +
  390 + # 如果LSTM模型效果比逻辑回归差,给出警告
  391 + if avg_val_acc < lr_accuracy:
  392 + logger.warning("LSTM模型性能低于逻辑回归基线,建议使用逻辑回归模型")
  393 +
  394 + return avg_val_loss, avg_val_acc, avg_val_f1
  395 +
  396 + def _train_epoch(self, train_loader, optimizer, criterion):
  397 + self.model.train()
  398 + total_loss = 0
  399 + for batch in train_loader:
  400 + optimizer.zero_grad()
212 401
213 - # 计算验证集的评估指标  
214 - val_accuracy = accuracy_score(val_labels_list, val_preds)  
215 - val_f1 = f1_score(val_labels_list, val_preds, average='macro') 402 + # 使用对抗训练
  403 + loss = self.adversarial_training(batch, criterion)
216 404
217 - # 计算平均损失  
218 - train_loss /= len(train_dataloader)  
219 - val_loss /= len(val_dataloader) 405 + loss.backward()
  406 + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
  407 + optimizer.step()
220 408
221 - logger.info(f'Epoch {epoch+1}/{epochs} | '  
222 - f'Train Loss: {train_loss:.4f} | '  
223 - f'Train Acc: {train_accuracy:.4f} | '  
224 - f'Train F1: {train_f1:.4f} | '  
225 - f'Val Loss: {val_loss:.4f} | '  
226 - f'Val Acc: {val_accuracy:.4f} | '  
227 - f'Val F1: {val_f1:.4f}') 409 + total_loss += loss.item()
228 410
229 - # 保存最佳模型  
230 - if val_loss < best_val_loss and self.model_save_path:  
231 - best_val_loss = val_loss  
232 - torch.save(self.model.state_dict(), self.model_save_path)  
233 - logger.info(f"模型已保存到 {self.model_save_path}")  
234 -  
235 - # 如果有保存路径但没有保存过模型,保存最后一轮的模型  
236 - if self.model_save_path and best_val_loss == float('inf'):  
237 - torch.save(self.model.state_dict(), self.model_save_path)  
238 - logger.info(f"最终模型已保存到 {self.model_save_path}")  
239 -  
240 - return train_loss, val_loss, val_accuracy, val_f1 411 + return total_loss / len(train_loader)
  412 +
  413 + def _validate(self, val_loader, criterion):
  414 + self.model.eval()
  415 + total_loss = 0
  416 + val_preds = []
  417 + val_labels_list = []
  418 +
  419 + with torch.no_grad():
  420 + for batch in val_loader:
  421 + input_ids = batch['input_ids'].to(self.device)
  422 + attention_mask = batch['attention_mask'].to(self.device)
  423 + labels = batch['label'].to(self.device)
  424 +
  425 + outputs, _ = self.model(input_ids, attention_mask)
  426 + loss = criterion(outputs, labels)
  427 + total_loss += loss.item()
  428 +
  429 + _, predicted = torch.max(outputs, 1)
  430 + val_preds.extend(predicted.cpu().numpy())
  431 + val_labels_list.extend(labels.cpu().numpy())
  432 +
  433 + avg_loss = total_loss / len(val_loader)
  434 + accuracy = accuracy_score(val_labels_list, val_preds)
  435 + f1 = f1_score(val_labels_list, val_preds, average='macro')
  436 +
  437 + return avg_loss, accuracy, f1
241 438
242 def evaluate(self, test_texts, test_labels, batch_size=32): 439 def evaluate(self, test_texts, test_labels, batch_size=32):
243 """评估模型""" 440 """评估模型"""
@@ -263,7 +460,7 @@ class LSTMModelManager: @@ -263,7 +460,7 @@ class LSTMModelManager:
263 attention_mask = batch['attention_mask'].to(self.device) 460 attention_mask = batch['attention_mask'].to(self.device)
264 labels = batch['label'].to(self.device) 461 labels = batch['label'].to(self.device)
265 462
266 - outputs = self.model(input_ids, attention_mask) 463 + outputs, _ = self.model(input_ids, attention_mask)
267 loss = criterion(outputs, labels) 464 loss = criterion(outputs, labels)
268 test_loss += loss.item() 465 test_loss += loss.item()
269 466
@@ -327,7 +524,7 @@ class LSTMModelManager: @@ -327,7 +524,7 @@ class LSTMModelManager:
327 input_ids = batch['input_ids'].to(self.device) 524 input_ids = batch['input_ids'].to(self.device)
328 attention_mask = batch['attention_mask'].to(self.device) 525 attention_mask = batch['attention_mask'].to(self.device)
329 526
330 - outputs = self.model(input_ids, attention_mask) 527 + outputs, _ = self.model(input_ids, attention_mask)
331 probs = torch.softmax(outputs, dim=1) 528 probs = torch.softmax(outputs, dim=1)
332 _, predicted = torch.max(outputs, 1) 529 _, predicted = torch.max(outputs, 1)
333 530
@@ -388,4 +585,4 @@ if __name__ == "__main__": @@ -388,4 +585,4 @@ if __name__ == "__main__":
388 pred, prob = lstm_model_manager.predict(sentence) 585 pred, prob = lstm_model_manager.predict(sentence)
389 label = '良好' if pred == 0 else '不良' 586 label = '良好' if pred == 0 else '不良'
390 confidence = prob[pred] 587 confidence = prob[pred]
391 - print(f"句子: '{sentence}' 预测结果: {label} (置信度: {confidence:.2%})")  
  588 + print(f"句子: '{sentence}' 预测结果: {label} (置信度: {confidence:.2%})")