-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMultiHeadAttention.py
104 lines (97 loc) · 4.21 KB
/
MultiHeadAttention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
# * @Author: DuTim
# * @Date: 2023-06-17 21:55:57
# * @LastEditTime: 2023-06-18 12:05:23
# * @Description: 自己实现的multihead attention +FFN 可以用于实现 self_attention cross_attention
"""
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=768, num_heads=12, att_drop_prob=0.1, state_drop_prob=0.5,use_ffn=False, return_att_score=False,device="cuda:0"):
super().__init__()
self.dim = d_model
self.num_heads = num_heads
self.att_drop_prob = att_drop_prob
self.state_drop_prob = state_drop_prob
self.use_ffn = use_ffn
self.return_att_score=return_att_score
self.device = device
self.size_per_head = self.dim // self.num_heads # 64
self.Wq = nn.Linear(self.dim, self.num_heads * self.size_per_head, bias=False)
self.Wk = nn.Linear(self.dim, self.num_heads * self.size_per_head, bias=False)
self.Wv = nn.Linear(self.dim, self.num_heads * self.size_per_head, bias=False)
self.W = nn.Linear(self.num_heads * self.size_per_head, self.dim)
self.lm = nn.LayerNorm(self.dim)
self.att_drop = nn.Dropout(self.att_drop_prob)
self.state_drop = nn.Dropout(self.state_drop_prob)
## ffn init
self.ffn1 = nn.Linear(self.dim, self.dim*4)
self.ffn2 = nn.Linear(self.dim*4, self.dim)
self.act = nn.GELU()
self.lm_ffn = nn.LayerNorm(self.dim)
def calc_mask_score(self, attention_mask, aim_shape):
"""
* @description: 计算mask_score
* @param self :
* @param attention_mask : (B,S)
* @param aim_shape : (B,H_num,S,H_dim)
* @return
"""
mask_score = torch.zeros(size=aim_shape).to(self.device)
mask_score = mask_score + attention_mask[:, None, None, :]
mask_score = (1.0 - mask_score) * -1000000.0
return mask_score
def SelfAttention(self, q, k, v, attention_mask):
"""
* @description: 注意力加残差
* @param self :
* @param q : bxLxd
* @param k : bxSxd
* @param v : bxsxd
* @param attention_mask : attention_mask: # bxS
1 normal token
0 masked token
* @return bxLxd
"""
Q_new_size = q.size()[:-1] + (self.num_heads, self.size_per_head) # b, L, h, h_dim
K_new_size = k.size()[:-1] + (self.num_heads, self.size_per_head) # b, S, h, h_dim
Q = self.Wq(q).view(*Q_new_size).permute(0, 2, 1, 3) ## b ,H , L,h_dim
K = self.Wk(k).view(*K_new_size).permute(0, 2, 1, 3) ## b ,H , S,h_dim
V = self.Wv(v).view(*K_new_size).permute(0, 2, 1, 3) ## b ,H , S,h_dim
attention_score = torch.matmul(Q, K.transpose(2, 3)) / math.sqrt(self.size_per_head) ## b ,H,L ,S
# attention mask here
attention_score = attention_score + self.calc_mask_score(attention_mask, attention_score.shape)
attention_score = nn.Softmax(dim=-1)(attention_score)
attention_score = self.att_drop(attention_score)
O = torch.matmul(attention_score, V)
O = self.W(O.permute(0, 2, 1, 3).contiguous().view(q.size(0), q.size(1), -1)) # bxLxd
O = self.state_drop(O)
O = self.lm(q + O)
mean_head_att_score= attention_score.mean(dim=1)
return O,mean_head_att_score
def FFN(self, x):
hidden = self.act(self.ffn1(x))
output = self.ffn2(hidden)
output = self.state_drop(output)
output = self.lm_ffn(x + output)
return output
def forward(self, q, k, v, attention_mask):
"""
* @description: 注意力加残差
* @param self :
* @param q : bxLxd
* @param k : bxsxd
* @param v : bxsxd
* @param attention_mask : # bxS
1 normal token
0 masked token
* @return x: bxLxd ; att_score: B * L * S
"""
x,att_score = self.SelfAttention(q, k, v, attention_mask)
if self.use_ffn:
x =self.FFN(x)
if self.return_att_score:
return x,att_score
else:
return x