戒酒的李白

The final classification layer is complete

  1 +import torch
  2 +import torch.nn as nn
  3 +
  4 +class FinalClassifier(nn.Module):
  5 + def __init__(self, input_dim, num_classes, hidden_dim=512, dropout_rate=0.3):
  6 + super(FinalClassifier, self).__init__()
  7 + # 增加一个隐藏层
  8 + self.fc1 = nn.Linear(input_dim, hidden_dim) # 第一层全连接层
  9 + self.fc2 = nn.Linear(hidden_dim, num_classes) # 第二层全连接层
  10 + self.dropout = nn.Dropout(dropout_rate) # Dropout 防止过拟合
  11 + self.relu = nn.ReLU() # 激活函数
  12 +
  13 + def forward(self, x):
  14 + x = self.relu(self.fc1(x)) # 第一层全连接 + ReLU 激活
  15 + x = self.dropout(x) # Dropout
  16 + out = self.fc2(x) # 最终输出层(未应用 softmax)
  17 + return out