Multi-head attention mechanism infrastructure and input dimension settings.
Showing
1 changed file
with
18 additions
and
0 deletions
model_pro/MHA.py
0 → 100644
| 1 | +import torch | ||
| 2 | +import torch.nn as nn | ||
| 3 | + | ||
| 4 | +class MultiHeadAttentionLayer(nn.Module): | ||
| 5 | + def __init__(self, embed_size, num_heads): | ||
| 6 | + super(MultiHeadAttentionLayer, self).__init__() | ||
| 7 | + self.embed_size = embed_size | ||
| 8 | + self.num_heads = num_heads | ||
| 9 | + self.head_dim = embed_size // num_heads | ||
| 10 | + | ||
| 11 | + assert (self.head_dim * num_heads == embed_size), "Embedding size needs to be divisible by num_heads" | ||
| 12 | + | ||
| 13 | + | ||
| 14 | +if __name__ == "__main__": | ||
| 15 | + embed_size = 512 | ||
| 16 | + num_heads = 8 | ||
| 17 | + mha_layer = MultiHeadAttentionLayer(embed_size, num_heads) | ||
| 18 | + print("Model initialized successfully.") |
-
Please register or login to post a comment