Toggle navigation
Toggle navigation
This project
Loading...
Sign in
万朱浩
/
Venue-Ops
Go to a project
Toggle navigation
Projects
Groups
Snippets
Help
Toggle navigation pinning
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Network
Create a new issue
Builds
Commits
Authored by
戒酒的李白
2024-10-06 11:54:32 +0800
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
4500b2719e5477ab6424de205665a23d15d12b50
4500b271
1 parent
f5e307d3
Divide the input into long heads
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
1 deletions
model_pro/MHA.py
model_pro/MHA.py
View file @
4500b27
...
...
@@ -15,9 +15,31 @@ class MultiHeadAttentionLayer(nn.Module):
self
.
k_linear
=
nn
.
Linear
(
embed_size
,
embed_size
)
self
.
v_linear
=
nn
.
Linear
(
embed_size
,
embed_size
)
def
forward
(
self
,
values
,
keys
,
query
):
N
=
query
.
shape
[
0
]
# batch_size
# Linear transformations for Q, K, V
Q
=
self
.
q_linear
(
query
)
# shape: (N, seq_len, embed_size)
K
=
self
.
k_linear
(
keys
)
# shape: (N, seq_len, embed_size)
V
=
self
.
v_linear
(
values
)
# shape: (N, seq_len, embed_size)
# Reshape Q, K, V into multiple heads
Q
=
Q
.
reshape
(
N
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
K
=
K
.
reshape
(
N
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
V
=
V
.
reshape
(
N
,
-
1
,
self
.
num_heads
,
self
.
head_dim
)
return
Q
,
K
,
V
if
__name__
==
"__main__"
:
embed_size
=
512
num_heads
=
8
mha_layer
=
MultiHeadAttentionLayer
(
embed_size
,
num_heads
)
print
(
"Linear layers for Q, K, V initialized."
)
# Dummy data
values
=
torch
.
randn
(
2
,
10
,
embed_size
)
keys
=
torch
.
randn
(
2
,
10
,
embed_size
)
query
=
torch
.
randn
(
2
,
10
,
embed_size
)
Q
,
K
,
V
=
mha_layer
(
values
,
keys
,
query
)
print
(
f
"Q shape: {Q.shape}, K shape: {K.shape}, V shape: {V.shape}"
)
...
...
Please
register
or
login
to post a comment