Enhance LSTM model for small datasets and improve performance.
Showing
1 changed file
with
309 additions
and
112 deletions
| @@ -4,12 +4,19 @@ import torch.optim as optim | @@ -4,12 +4,19 @@ import torch.optim as optim | ||
| 4 | from torch.utils.data import Dataset, DataLoader | 4 | from torch.utils.data import Dataset, DataLoader |
| 5 | import numpy as np | 5 | import numpy as np |
| 6 | import pandas as pd | 6 | import pandas as pd |
| 7 | -from sklearn.model_selection import train_test_split | 7 | +from sklearn.model_selection import train_test_split, KFold, StratifiedKFold |
| 8 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix | 8 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix |
| 9 | +from sklearn.feature_extraction.text import TfidfVectorizer | ||
| 10 | +from sklearn.linear_model import LogisticRegression | ||
| 9 | import jieba | 11 | import jieba |
| 10 | -from transformers import BertTokenizer | 12 | +from transformers import BertTokenizer, BertModel |
| 11 | import logging | 13 | import logging |
| 12 | import os | 14 | import os |
| 15 | +import random | ||
| 16 | +from torch.optim.lr_scheduler import ReduceLROnPlateau | ||
| 17 | +from gensim.models import KeyedVectors | ||
| 18 | +import json | ||
| 19 | +import torch.nn.functional as F | ||
| 13 | 20 | ||
| 14 | # 配置日志记录 | 21 | # 配置日志记录 |
| 15 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | 22 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| @@ -49,17 +56,90 @@ class TextDataset(Dataset): | @@ -49,17 +56,90 @@ class TextDataset(Dataset): | ||
| 49 | 'label': torch.tensor(label, dtype=torch.long) | 56 | 'label': torch.tensor(label, dtype=torch.long) |
| 50 | } | 57 | } |
| 51 | 58 | ||
| 59 | +class AttentionLayer(nn.Module): | ||
| 60 | + """注意力层""" | ||
| 61 | + def __init__(self, hidden_dim): | ||
| 62 | + super().__init__() | ||
| 63 | + self.attention = nn.Linear(hidden_dim, 1) | ||
| 64 | + | ||
| 65 | + def forward(self, lstm_output): | ||
| 66 | + attention_weights = torch.softmax(self.attention(lstm_output), dim=1) | ||
| 67 | + context_vector = torch.sum(attention_weights * lstm_output, dim=1) | ||
| 68 | + return context_vector, attention_weights | ||
| 69 | + | ||
| 70 | +# 添加数据增强类 | ||
| 71 | +class TextAugmenter: | ||
| 72 | + def __init__(self, language='zh', synonyms_file=None): | ||
| 73 | + self.language = language | ||
| 74 | + self.synonyms_dict = self._load_synonyms(synonyms_file) | ||
| 75 | + | ||
| 76 | + def _load_synonyms(self, file_path): | ||
| 77 | + base_dict = { | ||
| 78 | + "很好": ["非常好", "太好了", "特别好", "相当好", "真不错"], | ||
| 79 | + "糟糕": ["差劲", "很差", "不好", "太差", "糟透了"], | ||
| 80 | + "一般": ["还行", "凑合", "普通", "马马虎虎", "中等"], | ||
| 81 | + "满意": ["很满意", "挺好", "不错", "称心如意"], | ||
| 82 | + "生气": ["愤怒", "恼火", "不爽", "气愤"], | ||
| 83 | + "失望": ["伤心", "难过", "不满意", "遗憾"], | ||
| 84 | + # 添加更多情感词汇对 | ||
| 85 | + } | ||
| 86 | + | ||
| 87 | + if file_path and os.path.exists(file_path): | ||
| 88 | + try: | ||
| 89 | + with open(file_path, 'r', encoding='utf-8') as f: | ||
| 90 | + custom_dict = json.load(f) | ||
| 91 | + base_dict.update(custom_dict) | ||
| 92 | + except Exception as e: | ||
| 93 | + logger.warning(f"加载同义词典失败: {e}") | ||
| 94 | + | ||
| 95 | + return base_dict | ||
| 96 | + | ||
| 97 | + def synonym_replacement(self, text, n=1): | ||
| 98 | + words = list(jieba.cut(text)) | ||
| 99 | + new_words = words.copy() | ||
| 100 | + num_replaced = 0 | ||
| 101 | + | ||
| 102 | + for word in list(set(words)): | ||
| 103 | + if len(word) > 1 and num_replaced < n: | ||
| 104 | + synonyms = self._get_synonyms(word) | ||
| 105 | + if synonyms: | ||
| 106 | + synonym = random.choice(synonyms) | ||
| 107 | + new_words = [synonym if w == word else w for w in new_words] | ||
| 108 | + num_replaced += 1 | ||
| 109 | + | ||
| 110 | + return ''.join(new_words) | ||
| 111 | + | ||
| 112 | + def _get_synonyms(self, word): | ||
| 113 | + return self.synonyms_dict.get(word, []) | ||
| 114 | + | ||
| 115 | + def augment(self, texts, labels, augment_ratio=0.5): | ||
| 116 | + augmented_texts = [] | ||
| 117 | + augmented_labels = [] | ||
| 118 | + | ||
| 119 | + for text, label in zip(texts, labels): | ||
| 120 | + augmented_texts.append(text) | ||
| 121 | + augmented_labels.append(label) | ||
| 122 | + | ||
| 123 | + if random.random() < augment_ratio: | ||
| 124 | + aug_text = self.synonym_replacement(text) | ||
| 125 | + augmented_texts.append(aug_text) | ||
| 126 | + augmented_labels.append(label) | ||
| 127 | + | ||
| 128 | + return np.array(augmented_texts), np.array(augmented_labels) | ||
| 129 | + | ||
| 52 | class LSTMSentimentModel(nn.Module): | 130 | class LSTMSentimentModel(nn.Module): |
| 53 | """基于LSTM的情感分析模型""" | 131 | """基于LSTM的情感分析模型""" |
| 54 | 132 | ||
| 55 | - def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers=2, | ||
| 56 | - bidirectional=True, dropout=0.5, pad_idx=0): | 133 | + def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers=1, |
| 134 | + bidirectional=True, dropout=0.3, pad_idx=0, pretrained_embeddings=None): | ||
| 57 | super().__init__() | 135 | super().__init__() |
| 58 | 136 | ||
| 59 | # 嵌入层 | 137 | # 嵌入层 |
| 60 | self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx) | 138 | self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx) |
| 139 | + if pretrained_embeddings is not None: | ||
| 140 | + self.embedding.weight.data.copy_(pretrained_embeddings) | ||
| 141 | + self.embedding.weight.requires_grad = True | ||
| 61 | 142 | ||
| 62 | - # LSTM层 | ||
| 63 | self.lstm = nn.LSTM( | 143 | self.lstm = nn.LSTM( |
| 64 | embedding_dim, | 144 | embedding_dim, |
| 65 | hidden_dim, | 145 | hidden_dim, |
| @@ -69,11 +149,17 @@ class LSTMSentimentModel(nn.Module): | @@ -69,11 +149,17 @@ class LSTMSentimentModel(nn.Module): | ||
| 69 | batch_first=True | 149 | batch_first=True |
| 70 | ) | 150 | ) |
| 71 | 151 | ||
| 72 | - # 全连接层,如果是双向LSTM,输入维度需要翻倍 | ||
| 73 | - self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim) | 152 | + # 注意力层 |
| 153 | + self.attention = AttentionLayer(hidden_dim * 2 if bidirectional else hidden_dim) | ||
| 154 | + | ||
| 155 | + # 全连接层 | ||
| 156 | + fc_dim = hidden_dim * 2 if bidirectional else hidden_dim | ||
| 157 | + self.fc1 = nn.Linear(fc_dim, fc_dim // 2) | ||
| 158 | + self.fc2 = nn.Linear(fc_dim // 2, output_dim) | ||
| 74 | 159 | ||
| 75 | - # Dropout层 | ||
| 76 | self.dropout = nn.Dropout(dropout) | 160 | self.dropout = nn.Dropout(dropout) |
| 161 | + self.bn = nn.BatchNorm1d(fc_dim // 2) | ||
| 162 | + self.relu = nn.ReLU() | ||
| 77 | 163 | ||
| 78 | def forward(self, text, attention_mask=None): | 164 | def forward(self, text, attention_mask=None): |
| 79 | # 文本通过嵌入层 [batch_size, seq_len] -> [batch_size, seq_len, embedding_dim] | 165 | # 文本通过嵌入层 [batch_size, seq_len] -> [batch_size, seq_len, embedding_dim] |
| @@ -95,27 +181,49 @@ class LSTMSentimentModel(nn.Module): | @@ -95,27 +181,49 @@ class LSTMSentimentModel(nn.Module): | ||
| 95 | else: | 181 | else: |
| 96 | output, (hidden, cell) = self.lstm(embedded) | 182 | output, (hidden, cell) = self.lstm(embedded) |
| 97 | 183 | ||
| 98 | - # 如果是双向LSTM,需要拼接最后一层的前向和后向隐藏状态 | ||
| 99 | - if self.lstm.bidirectional: | ||
| 100 | - hidden = torch.cat([hidden[-2], hidden[-1]], dim=1) | ||
| 101 | - else: | ||
| 102 | - hidden = hidden[-1] | 184 | + # 应用注意力机制 |
| 185 | + context_vector, attention_weights = self.attention(output) | ||
| 103 | 186 | ||
| 104 | - # 应用dropout | ||
| 105 | - hidden = self.dropout(hidden) | 187 | + # 应用dropout和全连接层 |
| 188 | + x = self.dropout(context_vector) | ||
| 189 | + x = self.fc1(x) | ||
| 190 | + x = self.bn(x) | ||
| 191 | + x = self.relu(x) | ||
| 192 | + x = self.dropout(x) | ||
| 193 | + x = self.fc2(x) | ||
| 106 | 194 | ||
| 107 | - # 全连接层 | ||
| 108 | - return self.fc(hidden) | 195 | + return x, attention_weights |
| 196 | + | ||
| 197 | +# 添加早停类 | ||
| 198 | +class EarlyStopping: | ||
| 199 | + def __init__(self, patience=5, min_delta=0): | ||
| 200 | + self.patience = patience | ||
| 201 | + self.min_delta = min_delta | ||
| 202 | + self.counter = 0 | ||
| 203 | + self.best_loss = None | ||
| 204 | + self.early_stop = False | ||
| 205 | + | ||
| 206 | + def __call__(self, val_loss): | ||
| 207 | + if self.best_loss is None: | ||
| 208 | + self.best_loss = val_loss | ||
| 209 | + elif val_loss > self.best_loss - self.min_delta: | ||
| 210 | + self.counter += 1 | ||
| 211 | + if self.counter >= self.patience: | ||
| 212 | + self.early_stop = True | ||
| 213 | + else: | ||
| 214 | + self.best_loss = val_loss | ||
| 215 | + self.counter = 0 | ||
| 109 | 216 | ||
| 110 | class LSTMModelManager: | 217 | class LSTMModelManager: |
| 111 | """LSTM模型管理类,用于训练、评估和预测""" | 218 | """LSTM模型管理类,用于训练、评估和预测""" |
| 112 | 219 | ||
| 113 | def __init__(self, bert_model_path, model_save_path=None, vocab_size=30522, | 220 | def __init__(self, bert_model_path, model_save_path=None, vocab_size=30522, |
| 114 | - embedding_dim=128, hidden_dim=256, output_dim=2, n_layers=2, | ||
| 115 | - bidirectional=True, dropout=0.5): | 221 | + embedding_dim=100, hidden_dim=64, output_dim=2, n_layers=1, |
| 222 | + bidirectional=True, dropout=0.3, word2vec_path=None): | ||
| 116 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | 223 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 117 | self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) | 224 | self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) |
| 118 | self.vocab_size = vocab_size | 225 | self.vocab_size = vocab_size |
| 226 | + self.embedding_dim = embedding_dim | ||
| 119 | self.model = LSTMSentimentModel( | 227 | self.model = LSTMSentimentModel( |
| 120 | vocab_size=vocab_size, | 228 | vocab_size=vocab_size, |
| 121 | embedding_dim=embedding_dim, | 229 | embedding_dim=embedding_dim, |
| @@ -131,113 +239,202 @@ class LSTMModelManager: | @@ -131,113 +239,202 @@ class LSTMModelManager: | ||
| 131 | if model_save_path and os.path.exists(model_save_path): | 239 | if model_save_path and os.path.exists(model_save_path): |
| 132 | self.model.load_state_dict(torch.load(model_save_path, map_location=self.device)) | 240 | self.model.load_state_dict(torch.load(model_save_path, map_location=self.device)) |
| 133 | logger.info(f"已从 {model_save_path} 加载模型") | 241 | logger.info(f"已从 {model_save_path} 加载模型") |
| 242 | + | ||
| 243 | + self.augmenter = TextAugmenter() | ||
| 244 | + self.early_stopping = EarlyStopping(patience=5) | ||
| 245 | + | ||
| 246 | + # 加载预训练词向量 | ||
| 247 | + self.pretrained_embeddings = None | ||
| 248 | + if word2vec_path and os.path.exists(word2vec_path): | ||
| 249 | + try: | ||
| 250 | + word_vectors = KeyedVectors.load_word2vec_format(word2vec_path, binary=True) | ||
| 251 | + self.pretrained_embeddings = self._build_embedding_matrix(word_vectors) | ||
| 252 | + logger.info("成功加载预训练词向量") | ||
| 253 | + except Exception as e: | ||
| 254 | + logger.warning(f"加载预训练词向量失败: {e}") | ||
| 255 | + | ||
| 256 | + # 初始化对抗训练参数 | ||
| 257 | + self.epsilon = 0.01 | ||
| 258 | + self.alpha = 0.001 | ||
| 259 | + | ||
| 260 | + def _build_embedding_matrix(self, word_vectors): | ||
| 261 | + embedding_matrix = torch.zeros(self.vocab_size, self.embedding_dim) | ||
| 262 | + for i in range(self.vocab_size): | ||
| 263 | + try: | ||
| 264 | + word = self.tokenizer.convert_ids_to_tokens(i) | ||
| 265 | + if word in word_vectors: | ||
| 266 | + embedding_matrix[i] = torch.tensor(word_vectors[word]) | ||
| 267 | + except: | ||
| 268 | + continue | ||
| 269 | + return embedding_matrix | ||
| 270 | + | ||
| 271 | + def adversarial_training(self, batch, criterion): | ||
| 272 | + """对抗训练步骤""" | ||
| 273 | + # 计算原始损失 | ||
| 274 | + input_ids = batch['input_ids'].to(self.device) | ||
| 275 | + attention_mask = batch['attention_mask'].to(self.device) | ||
| 276 | + labels = batch['label'].to(self.device) | ||
| 277 | + | ||
| 278 | + outputs, _ = self.model(input_ids, attention_mask) | ||
| 279 | + loss = criterion(outputs, labels) | ||
| 280 | + | ||
| 281 | + # 计算梯度 | ||
| 282 | + loss.backward(retain_graph=True) | ||
| 283 | + | ||
| 284 | + # 获取嵌入层的梯度 | ||
| 285 | + grad_embed = self.model.embedding.weight.grad.data | ||
| 286 | + | ||
| 287 | + # 生成对抗扰动 | ||
| 288 | + perturb = self.epsilon * torch.sign(grad_embed) | ||
| 289 | + | ||
| 290 | + # 应用扰动 | ||
| 291 | + self.model.embedding.weight.data.add_(perturb) | ||
| 292 | + | ||
| 293 | + # 计算对抗损失 | ||
| 294 | + outputs_adv, _ = self.model(input_ids, attention_mask) | ||
| 295 | + loss_adv = criterion(outputs_adv, labels) | ||
| 296 | + | ||
| 297 | + # 恢复原始嵌入 | ||
| 298 | + self.model.embedding.weight.data.sub_(perturb) | ||
| 299 | + | ||
| 300 | + return loss + self.alpha * loss_adv | ||
| 301 | + | ||
| 302 | + def train_logistic_regression(self, train_texts, train_labels, val_texts=None, val_labels=None): | ||
| 303 | + vectorizer = TfidfVectorizer(max_features=5000) | ||
| 304 | + X_train = vectorizer.fit_transform(train_texts) | ||
| 305 | + | ||
| 306 | + if val_texts is None: | ||
| 307 | + X_train, X_val, y_train, y_val = train_test_split( | ||
| 308 | + X_train, train_labels, test_size=0.2, stratify=train_labels | ||
| 309 | + ) | ||
| 310 | + else: | ||
| 311 | + X_val = vectorizer.transform(val_texts) | ||
| 312 | + y_train, y_val = train_labels, val_labels | ||
| 313 | + | ||
| 314 | + lr_model = LogisticRegression(class_weight='balanced') | ||
| 315 | + lr_model.fit(X_train, y_train) | ||
| 316 | + | ||
| 317 | + val_pred = lr_model.predict(X_val) | ||
| 318 | + lr_accuracy = accuracy_score(y_val, val_pred) | ||
| 319 | + lr_f1 = f1_score(y_val, val_pred, average='macro') | ||
| 320 | + | ||
| 321 | + return lr_accuracy, lr_f1 | ||
| 134 | 322 | ||
| 135 | def train(self, train_texts, train_labels, val_texts=None, val_labels=None, | 323 | def train(self, train_texts, train_labels, val_texts=None, val_labels=None, |
| 136 | - batch_size=32, learning_rate=2e-5, epochs=10, validation_split=0.2): | 324 | + batch_size=16, epochs=10, learning_rate=2e-4): |
| 137 | """训练模型""" | 325 | """训练模型""" |
| 138 | logger.info("开始训练模型...") | 326 | logger.info("开始训练模型...") |
| 139 | 327 | ||
| 140 | - # 如果没有提供验证集,从训练集中划分 | ||
| 141 | - if val_texts is None or val_labels is None: | ||
| 142 | - train_texts, val_texts, train_labels, val_labels = train_test_split( | ||
| 143 | - train_texts, train_labels, test_size=validation_split, random_state=42 | ||
| 144 | - ) | 328 | + # 首先训练逻辑回归作为基线 |
| 329 | + lr_accuracy, lr_f1 = self.train_logistic_regression(train_texts, train_labels, val_texts, val_labels) | ||
| 330 | + logger.info(f"逻辑回归基线模型 - 准确率: {lr_accuracy:.4f}, F1: {lr_f1:.4f}") | ||
| 145 | 331 | ||
| 146 | - # 创建数据集和数据加载器 | ||
| 147 | - train_dataset = TextDataset(train_texts, train_labels, self.tokenizer) | ||
| 148 | - val_dataset = TextDataset(val_texts, val_labels, self.tokenizer) | 332 | + # 如果数据量小于1000,进行数据增强 |
| 333 | + if len(train_texts) < 1000: | ||
| 334 | + train_texts, train_labels = self.augmenter.augment(train_texts, train_labels) | ||
| 335 | + logger.info(f"数据增强后的训练集大小: {len(train_texts)}") | ||
| 149 | 336 | ||
| 150 | - train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | ||
| 151 | - val_dataloader = DataLoader(val_dataset, batch_size=batch_size) | 337 | + # 创建K折交叉验证 |
| 338 | + kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) | ||
| 339 | + fold_results = [] | ||
| 152 | 340 | ||
| 153 | - # 优化器和损失函数 | ||
| 154 | - optimizer = optim.Adam(self.model.parameters(), lr=learning_rate) | ||
| 155 | - criterion = nn.CrossEntropyLoss() | ||
| 156 | - | ||
| 157 | - # 训练循环 | ||
| 158 | - best_val_loss = float('inf') | ||
| 159 | - for epoch in range(epochs): | ||
| 160 | - # 训练模式 | ||
| 161 | - self.model.train() | ||
| 162 | - train_loss = 0 | ||
| 163 | - train_preds = [] | ||
| 164 | - train_labels_list = [] | 341 | + for fold, (train_idx, val_idx) in enumerate(kf.split(train_texts, train_labels)): |
| 342 | + logger.info(f"训练第 {fold+1} 折...") | ||
| 165 | 343 | ||
| 166 | - for batch in train_dataloader: | ||
| 167 | - # 获取数据 | ||
| 168 | - input_ids = batch['input_ids'].to(self.device) | ||
| 169 | - attention_mask = batch['attention_mask'].to(self.device) | ||
| 170 | - labels = batch['label'].to(self.device) | ||
| 171 | - | ||
| 172 | - # 前向传播 | ||
| 173 | - optimizer.zero_grad() | ||
| 174 | - outputs = self.model(input_ids, attention_mask) | 344 | + # 重置模型 |
| 345 | + self.model = self._create_model() | ||
| 346 | + optimizer = optim.AdamW(self.model.parameters(), lr=learning_rate) | ||
| 347 | + scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=2) | ||
| 348 | + criterion = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 1.0]).to(self.device)) | ||
| 349 | + | ||
| 350 | + # 准备数据 | ||
| 351 | + X_train, X_val = train_texts[train_idx], train_texts[val_idx] | ||
| 352 | + y_train, y_val = train_labels[train_idx], train_labels[val_idx] | ||
| 353 | + | ||
| 354 | + train_dataset = TextDataset(X_train, y_train, self.tokenizer) | ||
| 355 | + val_dataset = TextDataset(X_val, y_val, self.tokenizer) | ||
| 356 | + | ||
| 357 | + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | ||
| 358 | + val_loader = DataLoader(val_dataset, batch_size=batch_size) | ||
| 359 | + | ||
| 360 | + best_val_loss = float('inf') | ||
| 361 | + for epoch in range(epochs): | ||
| 362 | + # 训练和验证逻辑 | ||
| 363 | + train_loss = self._train_epoch(train_loader, optimizer, criterion) | ||
| 364 | + val_loss, val_acc, val_f1 = self._validate(val_loader, criterion) | ||
| 175 | 365 | ||
| 176 | - # 计算损失 | ||
| 177 | - loss = criterion(outputs, labels) | ||
| 178 | - train_loss += loss.item() | 366 | + scheduler.step(val_loss) |
| 179 | 367 | ||
| 180 | - # 反向传播 | ||
| 181 | - loss.backward() | ||
| 182 | - optimizer.step() | 368 | + if val_loss < best_val_loss: |
| 369 | + best_val_loss = val_loss | ||
| 370 | + if self.model_save_path: | ||
| 371 | + torch.save(self.model.state_dict(), | ||
| 372 | + f"{self.model_save_path}_fold{fold}.pt") | ||
| 183 | 373 | ||
| 184 | - # 收集预测和标签 | ||
| 185 | - _, predicted = torch.max(outputs, 1) | ||
| 186 | - train_preds.extend(predicted.cpu().numpy()) | ||
| 187 | - train_labels_list.extend(labels.cpu().numpy()) | 374 | + if self.early_stopping(val_loss): |
| 375 | + break | ||
| 188 | 376 | ||
| 189 | - # 计算训练集的评估指标 | ||
| 190 | - train_accuracy = accuracy_score(train_labels_list, train_preds) | ||
| 191 | - train_f1 = f1_score(train_labels_list, train_preds, average='macro') | ||
| 192 | - | ||
| 193 | - # 验证模式 | ||
| 194 | - self.model.eval() | ||
| 195 | - val_loss = 0 | ||
| 196 | - val_preds = [] | ||
| 197 | - val_labels_list = [] | ||
| 198 | - | ||
| 199 | - with torch.no_grad(): | ||
| 200 | - for batch in val_dataloader: | ||
| 201 | - input_ids = batch['input_ids'].to(self.device) | ||
| 202 | - attention_mask = batch['attention_mask'].to(self.device) | ||
| 203 | - labels = batch['label'].to(self.device) | ||
| 204 | - | ||
| 205 | - outputs = self.model(input_ids, attention_mask) | ||
| 206 | - loss = criterion(outputs, labels) | ||
| 207 | - val_loss += loss.item() | ||
| 208 | - | ||
| 209 | - _, predicted = torch.max(outputs, 1) | ||
| 210 | - val_preds.extend(predicted.cpu().numpy()) | ||
| 211 | - val_labels_list.extend(labels.cpu().numpy()) | 377 | + fold_results.append({ |
| 378 | + 'val_loss': val_loss, | ||
| 379 | + 'val_accuracy': val_acc, | ||
| 380 | + 'val_f1': val_f1 | ||
| 381 | + }) | ||
| 382 | + | ||
| 383 | + # 计算平均结果 | ||
| 384 | + avg_val_loss = np.mean([res['val_loss'] for res in fold_results]) | ||
| 385 | + avg_val_acc = np.mean([res['val_accuracy'] for res in fold_results]) | ||
| 386 | + avg_val_f1 = np.mean([res['val_f1'] for res in fold_results]) | ||
| 387 | + | ||
| 388 | + logger.info(f"交叉验证平均结果 - 损失: {avg_val_loss:.4f}, 准确率: {avg_val_acc:.4f}, F1: {avg_val_f1:.4f}") | ||
| 389 | + | ||
| 390 | + # 如果LSTM模型效果比逻辑回归差,给出警告 | ||
| 391 | + if avg_val_acc < lr_accuracy: | ||
| 392 | + logger.warning("LSTM模型性能低于逻辑回归基线,建议使用逻辑回归模型") | ||
| 393 | + | ||
| 394 | + return avg_val_loss, avg_val_acc, avg_val_f1 | ||
| 395 | + | ||
| 396 | + def _train_epoch(self, train_loader, optimizer, criterion): | ||
| 397 | + self.model.train() | ||
| 398 | + total_loss = 0 | ||
| 399 | + for batch in train_loader: | ||
| 400 | + optimizer.zero_grad() | ||
| 212 | 401 | ||
| 213 | - # 计算验证集的评估指标 | ||
| 214 | - val_accuracy = accuracy_score(val_labels_list, val_preds) | ||
| 215 | - val_f1 = f1_score(val_labels_list, val_preds, average='macro') | 402 | + # 使用对抗训练 |
| 403 | + loss = self.adversarial_training(batch, criterion) | ||
| 216 | 404 | ||
| 217 | - # 计算平均损失 | ||
| 218 | - train_loss /= len(train_dataloader) | ||
| 219 | - val_loss /= len(val_dataloader) | 405 | + loss.backward() |
| 406 | + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) | ||
| 407 | + optimizer.step() | ||
| 220 | 408 | ||
| 221 | - logger.info(f'Epoch {epoch+1}/{epochs} | ' | ||
| 222 | - f'Train Loss: {train_loss:.4f} | ' | ||
| 223 | - f'Train Acc: {train_accuracy:.4f} | ' | ||
| 224 | - f'Train F1: {train_f1:.4f} | ' | ||
| 225 | - f'Val Loss: {val_loss:.4f} | ' | ||
| 226 | - f'Val Acc: {val_accuracy:.4f} | ' | ||
| 227 | - f'Val F1: {val_f1:.4f}') | 409 | + total_loss += loss.item() |
| 228 | 410 | ||
| 229 | - # 保存最佳模型 | ||
| 230 | - if val_loss < best_val_loss and self.model_save_path: | ||
| 231 | - best_val_loss = val_loss | ||
| 232 | - torch.save(self.model.state_dict(), self.model_save_path) | ||
| 233 | - logger.info(f"模型已保存到 {self.model_save_path}") | ||
| 234 | - | ||
| 235 | - # 如果有保存路径但没有保存过模型,保存最后一轮的模型 | ||
| 236 | - if self.model_save_path and best_val_loss == float('inf'): | ||
| 237 | - torch.save(self.model.state_dict(), self.model_save_path) | ||
| 238 | - logger.info(f"最终模型已保存到 {self.model_save_path}") | ||
| 239 | - | ||
| 240 | - return train_loss, val_loss, val_accuracy, val_f1 | 411 | + return total_loss / len(train_loader) |
| 412 | + | ||
| 413 | + def _validate(self, val_loader, criterion): | ||
| 414 | + self.model.eval() | ||
| 415 | + total_loss = 0 | ||
| 416 | + val_preds = [] | ||
| 417 | + val_labels_list = [] | ||
| 418 | + | ||
| 419 | + with torch.no_grad(): | ||
| 420 | + for batch in val_loader: | ||
| 421 | + input_ids = batch['input_ids'].to(self.device) | ||
| 422 | + attention_mask = batch['attention_mask'].to(self.device) | ||
| 423 | + labels = batch['label'].to(self.device) | ||
| 424 | + | ||
| 425 | + outputs, _ = self.model(input_ids, attention_mask) | ||
| 426 | + loss = criterion(outputs, labels) | ||
| 427 | + total_loss += loss.item() | ||
| 428 | + | ||
| 429 | + _, predicted = torch.max(outputs, 1) | ||
| 430 | + val_preds.extend(predicted.cpu().numpy()) | ||
| 431 | + val_labels_list.extend(labels.cpu().numpy()) | ||
| 432 | + | ||
| 433 | + avg_loss = total_loss / len(val_loader) | ||
| 434 | + accuracy = accuracy_score(val_labels_list, val_preds) | ||
| 435 | + f1 = f1_score(val_labels_list, val_preds, average='macro') | ||
| 436 | + | ||
| 437 | + return avg_loss, accuracy, f1 | ||
| 241 | 438 | ||
| 242 | def evaluate(self, test_texts, test_labels, batch_size=32): | 439 | def evaluate(self, test_texts, test_labels, batch_size=32): |
| 243 | """评估模型""" | 440 | """评估模型""" |
| @@ -263,7 +460,7 @@ class LSTMModelManager: | @@ -263,7 +460,7 @@ class LSTMModelManager: | ||
| 263 | attention_mask = batch['attention_mask'].to(self.device) | 460 | attention_mask = batch['attention_mask'].to(self.device) |
| 264 | labels = batch['label'].to(self.device) | 461 | labels = batch['label'].to(self.device) |
| 265 | 462 | ||
| 266 | - outputs = self.model(input_ids, attention_mask) | 463 | + outputs, _ = self.model(input_ids, attention_mask) |
| 267 | loss = criterion(outputs, labels) | 464 | loss = criterion(outputs, labels) |
| 268 | test_loss += loss.item() | 465 | test_loss += loss.item() |
| 269 | 466 | ||
| @@ -327,7 +524,7 @@ class LSTMModelManager: | @@ -327,7 +524,7 @@ class LSTMModelManager: | ||
| 327 | input_ids = batch['input_ids'].to(self.device) | 524 | input_ids = batch['input_ids'].to(self.device) |
| 328 | attention_mask = batch['attention_mask'].to(self.device) | 525 | attention_mask = batch['attention_mask'].to(self.device) |
| 329 | 526 | ||
| 330 | - outputs = self.model(input_ids, attention_mask) | 527 | + outputs, _ = self.model(input_ids, attention_mask) |
| 331 | probs = torch.softmax(outputs, dim=1) | 528 | probs = torch.softmax(outputs, dim=1) |
| 332 | _, predicted = torch.max(outputs, 1) | 529 | _, predicted = torch.max(outputs, 1) |
| 333 | 530 | ||
| @@ -388,4 +585,4 @@ if __name__ == "__main__": | @@ -388,4 +585,4 @@ if __name__ == "__main__": | ||
| 388 | pred, prob = lstm_model_manager.predict(sentence) | 585 | pred, prob = lstm_model_manager.predict(sentence) |
| 389 | label = '良好' if pred == 0 else '不良' | 586 | label = '良好' if pred == 0 else '不良' |
| 390 | confidence = prob[pred] | 587 | confidence = prob[pred] |
| 391 | - print(f"句子: '{sentence}' 预测结果: {label} (置信度: {confidence:.2%})") | ||
| 588 | + print(f"句子: '{sentence}' 预测结果: {label} (置信度: {confidence:.2%})") |
-
Please register or login to post a comment