戒酒的李白

Fix: Provide a seed for the random_state parameter.

@@ -219,7 +219,15 @@ class LSTMModelManager: @@ -219,7 +219,15 @@ class LSTMModelManager:
219 219
220 def __init__(self, bert_model_path, model_save_path=None, vocab_size=30522, 220 def __init__(self, bert_model_path, model_save_path=None, vocab_size=30522,
221 embedding_dim=100, hidden_dim=64, output_dim=2, n_layers=1, 221 embedding_dim=100, hidden_dim=64, output_dim=2, n_layers=1,
222 - bidirectional=True, dropout=0.3, word2vec_path=None): 222 + bidirectional=True, dropout=0.3, word2vec_path=None, random_seed=42):
  223 + # 设置随机种子以确保可重现性
  224 + self.random_seed = random_seed
  225 + random.seed(random_seed)
  226 + np.random.seed(random_seed)
  227 + torch.manual_seed(random_seed)
  228 + if torch.cuda.is_available():
  229 + torch.cuda.manual_seed_all(random_seed)
  230 +
223 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 231 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
224 self.tokenizer = BertTokenizer.from_pretrained(bert_model_path) 232 self.tokenizer = BertTokenizer.from_pretrained(bert_model_path)
225 self.vocab_size = vocab_size 233 self.vocab_size = vocab_size
@@ -305,13 +313,18 @@ class LSTMModelManager: @@ -305,13 +313,18 @@ class LSTMModelManager:
305 313
306 if val_texts is None: 314 if val_texts is None:
307 X_train, X_val, y_train, y_val = train_test_split( 315 X_train, X_val, y_train, y_val = train_test_split(
308 - X_train, train_labels, test_size=0.2, stratify=train_labels 316 + X_train, train_labels, test_size=0.2,
  317 + stratify=train_labels,
  318 + random_state=self.random_seed # 添加随机种子
309 ) 319 )
310 else: 320 else:
311 X_val = vectorizer.transform(val_texts) 321 X_val = vectorizer.transform(val_texts)
312 y_train, y_val = train_labels, val_labels 322 y_train, y_val = train_labels, val_labels
313 323
314 - lr_model = LogisticRegression(class_weight='balanced') 324 + lr_model = LogisticRegression(
  325 + class_weight='balanced',
  326 + random_state=self.random_seed # 添加随机种子
  327 + )
315 lr_model.fit(X_train, y_train) 328 lr_model.fit(X_train, y_train)
316 329
317 val_pred = lr_model.predict(X_val) 330 val_pred = lr_model.predict(X_val)