Showing
1 changed file
with
17 additions
and
0 deletions
model_pro/classifier.py
0 → 100644
| 1 | +import torch | ||
| 2 | +import torch.nn as nn | ||
| 3 | + | ||
| 4 | +class FinalClassifier(nn.Module): | ||
| 5 | + def __init__(self, input_dim, num_classes, hidden_dim=512, dropout_rate=0.3): | ||
| 6 | + super(FinalClassifier, self).__init__() | ||
| 7 | + # 增加一个隐藏层 | ||
| 8 | + self.fc1 = nn.Linear(input_dim, hidden_dim) # 第一层全连接层 | ||
| 9 | + self.fc2 = nn.Linear(hidden_dim, num_classes) # 第二层全连接层 | ||
| 10 | + self.dropout = nn.Dropout(dropout_rate) # Dropout 防止过拟合 | ||
| 11 | + self.relu = nn.ReLU() # 激活函数 | ||
| 12 | + | ||
| 13 | + def forward(self, x): | ||
| 14 | + x = self.relu(self.fc1(x)) # 第一层全连接 + ReLU 激活 | ||
| 15 | + x = self.dropout(x) # Dropout | ||
| 16 | + out = self.fc2(x) # 最终输出层(未应用 softmax) | ||
| 17 | + return out |
-
Please register or login to post a comment