Toggle navigation
Toggle navigation
This project
Loading...
Sign in
万朱浩
/
Venue-Ops
Go to a project
Toggle navigation
Projects
Groups
Snippets
Help
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
戒酒的李白
2025-04-02 20:35:22 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
296d8e4b1e44d0c54ccfa708c521acc22641e281
296d8e4b
1 parent
0ce3c011
Fix: Provide a seed for the random_state parameter.
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
3 deletions
model_pro/LSTM_model.py
model_pro/LSTM_model.py
View file @
296d8e4
...
...
@@ -219,7 +219,15 @@ class LSTMModelManager:
def
__init__
(
self
,
bert_model_path
,
model_save_path
=
None
,
vocab_size
=
30522
,
embedding_dim
=
100
,
hidden_dim
=
64
,
output_dim
=
2
,
n_layers
=
1
,
bidirectional
=
True
,
dropout
=
0.3
,
word2vec_path
=
None
):
bidirectional
=
True
,
dropout
=
0.3
,
word2vec_path
=
None
,
random_seed
=
42
):
# 设置随机种子以确保可重现性
self
.
random_seed
=
random_seed
random
.
seed
(
random_seed
)
np
.
random
.
seed
(
random_seed
)
torch
.
manual_seed
(
random_seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
random_seed
)
self
.
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
self
.
tokenizer
=
BertTokenizer
.
from_pretrained
(
bert_model_path
)
self
.
vocab_size
=
vocab_size
...
...
@@ -305,13 +313,18 @@ class LSTMModelManager:
if
val_texts
is
None
:
X_train
,
X_val
,
y_train
,
y_val
=
train_test_split
(
X_train
,
train_labels
,
test_size
=
0.2
,
stratify
=
train_labels
X_train
,
train_labels
,
test_size
=
0.2
,
stratify
=
train_labels
,
random_state
=
self
.
random_seed
# 添加随机种子
)
else
:
X_val
=
vectorizer
.
transform
(
val_texts
)
y_train
,
y_val
=
train_labels
,
val_labels
lr_model
=
LogisticRegression
(
class_weight
=
'balanced'
)
lr_model
=
LogisticRegression
(
class_weight
=
'balanced'
,
random_state
=
self
.
random_seed
# 添加随机种子
)
lr_model
.
fit
(
X_train
,
y_train
)
val_pred
=
lr_model
.
predict
(
X_val
)
...
...
Please
register
or
login
to post a comment