Toggle navigation
Toggle navigation
This project
Loading...
Sign in
万朱浩
/
Venue-Ops
Go to a project
Toggle navigation
Projects
Groups
Snippets
Help
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
戒酒的李白
2024-10-07 09:51:29 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
9af61e2ade8e2dd5ce77590804aba65c200632e2
9af61e2a
1 parent
4500b271
Calculates the scaling dot product attention
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
8 deletions
model_pro/MHA.py
model_pro/MHA.py
View file @
9af61e2
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
MultiHeadAttentionLayer
(
nn
.
Module
):
def
__init__
(
self
,
embed_size
,
num_heads
):
...
...
@@ -19,16 +20,21 @@ class MultiHeadAttentionLayer(nn.Module):
N
=
query
.
shape
[
0
]
# batch_size
# Linear transformations for Q, K, V
Q
=
self
.
q_linear
(
query
)
# shape: (N, seq_len, embed_size)
K
=
self
.
k_linear
(
keys
)
# shape: (N, seq_len, embed_size)
V
=
self
.
v_linear
(
values
)
# shape: (N, seq_len, embed_size)
Q
=
self
.
q_linear
(
query
)
K
=
self
.
k_linear
(
keys
)
V
=
self
.
v_linear
(
values
)
# Reshape
Q, K, V
into multiple heads
# Reshape into multiple heads
Q
=
Q
.
reshape
(
N
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
K
=
K
.
reshape
(
N
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
V
=
V
.
reshape
(
N
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
return
Q
,
K
,
V
# Compute scaled dot-product attention scores
attention_scores
=
torch
.
einsum
(
"nqhd,nkhd->nhqk"
,
[
Q
,
K
])
attention_scores
=
attention_scores
/
(
self
.
head_dim
**
0.5
)
attention
=
torch
.
softmax
(
attention_scores
,
dim
=-
1
)
# Normalize
return
attention
if
__name__
==
"__main__"
:
...
...
@@ -36,10 +42,9 @@ if __name__ == "__main__":
num_heads
=
8
mha_layer
=
MultiHeadAttentionLayer
(
embed_size
,
num_heads
)
# Dummy data
values
=
torch
.
randn
(
2
,
10
,
embed_size
)
keys
=
torch
.
randn
(
2
,
10
,
embed_size
)
query
=
torch
.
randn
(
2
,
10
,
embed_size
)
Q
,
K
,
V
=
mha_layer
(
values
,
keys
,
query
)
print
(
f
"Q shape: {Q.shape}, K shape: {K.shape}, V shape: {V.shape}"
)
attention
=
mha_layer
(
values
,
keys
,
query
)
print
(
f
"Attention shape: {attention.shape}"
)
...
...
Please
register
or
login
to post a comment