Showing
1 changed file
with
3 additions
and
21 deletions
| @@ -219,15 +219,7 @@ class LSTMModelManager: | @@ -219,15 +219,7 @@ class LSTMModelManager: | ||
| 219 | 219 | ||
| 220 | def __init__(self, bert_model_path, model_save_path=None, vocab_size=30522, | 220 | def __init__(self, bert_model_path, model_save_path=None, vocab_size=30522, |
| 221 | embedding_dim=100, hidden_dim=64, output_dim=2, n_layers=1, | 221 | embedding_dim=100, hidden_dim=64, output_dim=2, n_layers=1, |
| 222 | - bidirectional=True, dropout=0.3, word2vec_path=None, random_seed=42): | ||
| 223 | - # 设置随机种子以确保可重现性 | ||
| 224 | - self.random_seed = random_seed | ||
| 225 | - random.seed(random_seed) | ||
| 226 | - np.random.seed(random_seed) | ||
| 227 | - torch.manual_seed(random_seed) | ||
| 228 | - if torch.cuda.is_available(): | ||
| 229 | - torch.cuda.manual_seed_all(random_seed) | ||
| 230 | - | 222 | + bidirectional=True, dropout=0.3, word2vec_path=None): |
| 231 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | 223 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 232 | self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) | 224 | self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) |
| 233 | self.vocab_size = vocab_size | 225 | self.vocab_size = vocab_size |
| @@ -308,28 +300,18 @@ class LSTMModelManager: | @@ -308,28 +300,18 @@ class LSTMModelManager: | ||
| 308 | return loss + self.alpha * loss_adv | 300 | return loss + self.alpha * loss_adv |
| 309 | 301 | ||
| 310 | def train_logistic_regression(self, train_texts, train_labels, val_texts=None, val_labels=None): | 302 | def train_logistic_regression(self, train_texts, train_labels, val_texts=None, val_labels=None): |
| 311 | - """训练逻辑回归基线模型""" | ||
| 312 | - # 设置随机种子以确保可重现性 | ||
| 313 | - np.random.seed(self.random_seed) | ||
| 314 | - | ||
| 315 | vectorizer = TfidfVectorizer(max_features=5000) | 303 | vectorizer = TfidfVectorizer(max_features=5000) |
| 316 | X_train = vectorizer.fit_transform(train_texts) | 304 | X_train = vectorizer.fit_transform(train_texts) |
| 317 | 305 | ||
| 318 | if val_texts is None: | 306 | if val_texts is None: |
| 319 | X_train, X_val, y_train, y_val = train_test_split( | 307 | X_train, X_val, y_train, y_val = train_test_split( |
| 320 | - X_train, train_labels, test_size=0.2, | ||
| 321 | - stratify=train_labels, | ||
| 322 | - random_state=self.random_seed # 添加随机种子 | 308 | + X_train, train_labels, test_size=0.2, stratify=train_labels |
| 323 | ) | 309 | ) |
| 324 | else: | 310 | else: |
| 325 | X_val = vectorizer.transform(val_texts) | 311 | X_val = vectorizer.transform(val_texts) |
| 326 | y_train, y_val = train_labels, val_labels | 312 | y_train, y_val = train_labels, val_labels |
| 327 | 313 | ||
| 328 | - lr_model = LogisticRegression( | ||
| 329 | - class_weight='balanced', | ||
| 330 | - random_state=self.random_seed, # 添加随机种子 | ||
| 331 | - max_iter=1000 # 增加最大迭代次数以确保收敛 | ||
| 332 | - ) | 314 | + lr_model = LogisticRegression(class_weight='balanced') |
| 333 | lr_model.fit(X_train, y_train) | 315 | lr_model.fit(X_train, y_train) |
| 334 | 316 | ||
| 335 | val_pred = lr_model.predict(X_val) | 317 | val_pred = lr_model.predict(X_val) |
-
Please register or login to post a comment