戒酒的李白

Multi-head attention mechanism infrastructure and input dimension settings.

  1 +import torch
  2 +import torch.nn as nn
  3 +
  4 +class MultiHeadAttentionLayer(nn.Module):
  5 + def __init__(self, embed_size, num_heads):
  6 + super(MultiHeadAttentionLayer, self).__init__()
  7 + self.embed_size = embed_size
  8 + self.num_heads = num_heads
  9 + self.head_dim = embed_size // num_heads
  10 +
  11 + assert (self.head_dim * num_heads == embed_size), "Embedding size needs to be divisible by num_heads"
  12 +
  13 +
  14 +if __name__ == "__main__":
  15 + embed_size = 512
  16 + num_heads = 8
  17 + mha_layer = MultiHeadAttentionLayer(embed_size, num_heads)
  18 + print("Model initialized successfully.")