-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfmha.py
204 lines (159 loc) · 7.41 KB
/
fmha.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import torch
import torch.nn.functional as F
# from apex.contrib.fmha import FMHAFun
from apex import fused_dense
from collections import OrderedDict
import numpy as np
import fmhalib as mha
class FMHAFun(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training):
b = cu_seqlens.numel() - 1
if b < 4:
max_s = 512
context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, None)
else:
context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, None)
ctx.save_for_backward(qkv, S_dmask)
ctx.cu_seqlens = cu_seqlens
ctx.p_dropout = p_dropout
ctx.max_s = max_s
return context
@staticmethod
def backward(ctx, dout):
b = ctx.cu_seqlens.numel() - 1
qkv, S_dmask = ctx.saved_tensors
if b < 4:
dqkv, dp, dkv = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s)
else:
dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s)
return dqkv, None, None, None, None
class TestParam(torch.nn.Parameter):
def __init__(self, data, requires_grad=True):
super(TestParam, self).__init__()
self.data = data
self.requires_grad = requires_grad
self.tag = 'qkv'
self.counter = 0
class NoopCat(torch.autograd.Function):
@staticmethod
def forward(ctx, Wq, Wk, Wv, Bq, Bk, Bv, Wqkv, Bqkv, hidden_size):
assert not Wqkv.requires_grad and not Bqkv.requires_grad, "hye!"
Wtmp = Wqkv.view(3, hidden_size, hidden_size)
Btmp = Bqkv.view(3, hidden_size)
Wq.data = Wtmp[0, :, :]
Wk.data = Wtmp[1, :, :]
Wv.data = Wtmp[2, :, :]
Bq.data = Btmp[0, :]
Bk.data = Btmp[1, :]
Bv.data = Btmp[2, :]
Wtmp = Wqkv.new()
Wtmp.set_(Wqkv.storage(), Wqkv.storage_offset(), Wqkv.size(), Wqkv.stride())
Wtmp.requires_grad = True
Btmp = Bqkv.new()
Btmp.set_(Bqkv.storage(), Bqkv.storage_offset(), Bqkv.size(), Bqkv.stride())
Btmp.requires_grad = True
ctx.save_for_backward(Wqkv, Bqkv, Wq, Wk, Wv, Bq, Bk, Bv)
ctx.hidden_size = hidden_size
return Wtmp, Btmp
@staticmethod
def backward(ctx, dWqkv, dBqkv):
Wqkv, Bqkv, Wq, Wk, Wv, Bq, Bk, Bv = ctx.saved_tensors
Wtmp = Wqkv.view(3, ctx.hidden_size, ctx.hidden_size)
Btmp = Bqkv.view(3, ctx.hidden_size)
Wq.data = Wtmp[0, :, :]
Wk.data = Wtmp[1, :, :]
Wv.data = Wtmp[2, :, :]
Bq.data = Btmp[0, :]
Bk.data = Btmp[1, :]
Bv.data = Btmp[2, :]
dWtmp = dWqkv.view(3, ctx.hidden_size, ctx.hidden_size)
dBtmp = dBqkv.view(3, ctx.hidden_size)
return dWtmp[0, :, :], dWtmp[1, :, :], dWtmp[2, :, :], dBtmp[0, :], dBtmp[1, :], dBtmp[2, :], None, None, None
class FMHA(torch.nn.Module):
def __init__(self, config):
super(FMHA, self).__init__()
self.p_dropout = config.attention_probs_dropout_prob
self.h = config.num_attention_heads
self.hidden_size = config.hidden_size
self.d = self.hidden_size // self.h
self.fuse_bias = config.fused_bias_mha
assert self.d * self.h == self.hidden_size, "Invalid hidden size/num_heads"
self.register_buffer("Wqkv", torch.zeros(3 * config.hidden_size, config.hidden_size))
self.register_buffer("Bqkv", torch.zeros(3 * config.hidden_size))
self.Wqkv.requires_grad = False
self.Bqkv.requires_grad = False
self.Wqkv.detach()
self.Bqkv.detach()
with torch.no_grad():
params = []
Wtmp = self.Wqkv.view(3, self.hidden_size, self.hidden_size)
Btmp = self.Bqkv.view(3, self.hidden_size)
for tag, idx in zip('qkv', range(3)):
params.append(('W' + tag, torch.nn.Parameter(Wtmp[idx, :, :])))
params.append(('B' + tag, torch.nn.Parameter(Btmp[idx, :])))
self.param_views = OrderedDict(params)
self._reset_param_views()
def prep_weights(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
Wq = state_dict.pop(prefix + 'query.weight')
bq = state_dict.pop(prefix + 'query.bias')
Wk = state_dict.pop(prefix + 'key.weight')
bk = state_dict.pop(prefix + 'key.bias')
Wv = state_dict.pop(prefix + 'value.weight')
bv = state_dict.pop(prefix + 'value.bias')
weight = torch.cat([Wq.view(self.h, self.d, self.hidden_size),
Wk.view(self.h, self.d, self.hidden_size),
Wv.view(self.h, self.d, self.hidden_size)],
dim=0).reshape(config.hidden_size * 3, config.hidden_size).contiguous()
bias = torch.cat([bq.view(self.h, self.d),
bk.view(self.h, self.d),
bv.view(self.h, self.d)],
dim=0).reshape(3 * config.hidden_size).contiguous()
state_dict[prefix + 'Wqkv'] = weight
state_dict[prefix + 'Bqkv'] = bias
state_dict[prefix + 'Wq'] = Wq
state_dict[prefix + 'Wk'] = Wk
state_dict[prefix + 'Wv'] = Wv
state_dict[prefix + 'Bq'] = bq
state_dict[prefix + 'Bk'] = bk
state_dict[prefix + 'Bv'] = bv
self._register_load_state_dict_pre_hook(prep_weights)
def _reset_param_views(self):
with torch.no_grad():
Wtmp = self.Wqkv.view(3, self.hidden_size, self.hidden_size)
Btmp = self.Bqkv.view(3, self.hidden_size)
for tag, idx in zip('qkv', range(3)):
self.param_views['W' + tag].data = Wtmp[idx, :, :]
self.param_views['B' + tag].data = Btmp[idx, :]
def _apply(self, fn):
with torch.no_grad():
self.Wqkv = fn(self.Wqkv)
if self.Wqkv.grad is not None:
self.Wqkv.grad = fn(self.Wqkv.grad)
self.Bqkv = fn(self.Bqkv)
if self.Bqkv.grad is not None:
self.Bqkv.grad = fn(self.Bqkv.grad)
self._reset_param_views()
@property
def _parameters(self):
self._reset_param_views()
return self.param_views
@_parameters.setter
def _parameters(self, _):
if 'Wqkv' in self.__dict__ and self.Wqkv is not None and self.Wqkv.device == torch.device('cuda:0'):
import traceback
traceback.print_stack()
pass
def forward(self, hidden_states, cu_seqlens, seqlens, max_s, is_training=True):
Wqkv, Bqkv = NoopCat.apply(*[self.param_views[x + y] for x in 'WB' for y in 'qkv'], self.Wqkv, self.Bqkv,
self.hidden_size)
if not self.fuse_bias:
qkv = F.linear(hidden_states, Wqkv, Bqkv)
else:
qkv = fused_dense.fused_dense_function(hidden_states, Wqkv, Bqkv)
p_dropout = self.p_dropout
zero_tensors = True
# ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, seqlens, p_dropout, max_s, is_training)
# ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, p_dropout, max_s, is_training, zero_tensors)
ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, p_dropout, max_s, is_training)
return ctx.view(-1, self.hidden_size)