Showing
1 changed file
with
6 additions
and
1 deletions
| @@ -308,6 +308,10 @@ class LSTMModelManager: | @@ -308,6 +308,10 @@ class LSTMModelManager: | ||
| 308 | return loss + self.alpha * loss_adv | 308 | return loss + self.alpha * loss_adv |
| 309 | 309 | ||
| 310 | def train_logistic_regression(self, train_texts, train_labels, val_texts=None, val_labels=None): | 310 | def train_logistic_regression(self, train_texts, train_labels, val_texts=None, val_labels=None): |
| 311 | + """训练逻辑回归基线模型""" | ||
| 312 | + # 设置随机种子以确保可重现性 | ||
| 313 | + np.random.seed(self.random_seed) | ||
| 314 | + | ||
| 311 | vectorizer = TfidfVectorizer(max_features=5000) | 315 | vectorizer = TfidfVectorizer(max_features=5000) |
| 312 | X_train = vectorizer.fit_transform(train_texts) | 316 | X_train = vectorizer.fit_transform(train_texts) |
| 313 | 317 | ||
| @@ -323,7 +327,8 @@ class LSTMModelManager: | @@ -323,7 +327,8 @@ class LSTMModelManager: | ||
| 323 | 327 | ||
| 324 | lr_model = LogisticRegression( | 328 | lr_model = LogisticRegression( |
| 325 | class_weight='balanced', | 329 | class_weight='balanced', |
| 326 | - random_state=self.random_seed # 添加随机种子 | 330 | + random_state=self.random_seed, # 添加随机种子 |
| 331 | + max_iter=1000 # 增加最大迭代次数以确保收敛 | ||
| 327 | ) | 332 | ) |
| 328 | lr_model.fit(X_train, y_train) | 333 | lr_model.fit(X_train, y_train) |
| 329 | 334 |
-
Please register or login to post a comment