Showing
1 changed file
with
16 additions
and
3 deletions
| @@ -219,7 +219,15 @@ class LSTMModelManager: | @@ -219,7 +219,15 @@ class LSTMModelManager: | ||
| 219 | 219 | ||
| 220 | def __init__(self, bert_model_path, model_save_path=None, vocab_size=30522, | 220 | def __init__(self, bert_model_path, model_save_path=None, vocab_size=30522, |
| 221 | embedding_dim=100, hidden_dim=64, output_dim=2, n_layers=1, | 221 | embedding_dim=100, hidden_dim=64, output_dim=2, n_layers=1, |
| 222 | - bidirectional=True, dropout=0.3, word2vec_path=None): | 222 | + bidirectional=True, dropout=0.3, word2vec_path=None, random_seed=42): |
| 223 | + # 设置随机种子以确保可重现性 | ||
| 224 | + self.random_seed = random_seed | ||
| 225 | + random.seed(random_seed) | ||
| 226 | + np.random.seed(random_seed) | ||
| 227 | + torch.manual_seed(random_seed) | ||
| 228 | + if torch.cuda.is_available(): | ||
| 229 | + torch.cuda.manual_seed_all(random_seed) | ||
| 230 | + | ||
| 223 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | 231 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 224 | self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) | 232 | self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) |
| 225 | self.vocab_size = vocab_size | 233 | self.vocab_size = vocab_size |
| @@ -305,13 +313,18 @@ class LSTMModelManager: | @@ -305,13 +313,18 @@ class LSTMModelManager: | ||
| 305 | 313 | ||
| 306 | if val_texts is None: | 314 | if val_texts is None: |
| 307 | X_train, X_val, y_train, y_val = train_test_split( | 315 | X_train, X_val, y_train, y_val = train_test_split( |
| 308 | - X_train, train_labels, test_size=0.2, stratify=train_labels | 316 | + X_train, train_labels, test_size=0.2, |
| 317 | + stratify=train_labels, | ||
| 318 | + random_state=self.random_seed # 添加随机种子 | ||
| 309 | ) | 319 | ) |
| 310 | else: | 320 | else: |
| 311 | X_val = vectorizer.transform(val_texts) | 321 | X_val = vectorizer.transform(val_texts) |
| 312 | y_train, y_val = train_labels, val_labels | 322 | y_train, y_val = train_labels, val_labels |
| 313 | 323 | ||
| 314 | - lr_model = LogisticRegression(class_weight='balanced') | 324 | + lr_model = LogisticRegression( |
| 325 | + class_weight='balanced', | ||
| 326 | + random_state=self.random_seed # 添加随机种子 | ||
| 327 | + ) | ||
| 315 | lr_model.fit(X_train, y_train) | 328 | lr_model.fit(X_train, y_train) |
| 316 | 329 | ||
| 317 | val_pred = lr_model.predict(X_val) | 330 | val_pred = lr_model.predict(X_val) |
-
Please register or login to post a comment