Skip to content

Commit

Permalink
[cherry pick] Support optional residual add in fused ops and slice la…
Browse files Browse the repository at this point in the history
…rge tensor for cudnn_softmax (#43719)

 [cherry pick] Support optional residual add in fused ops and slice large tensor for cudnn_softmax

cherry-pick #43635 #43681 #43474
  • Loading branch information
zhangting2020 authored Jun 22, 2022
1 parent 8e6a194 commit 0660d5f
Show file tree
Hide file tree
Showing 9 changed files with 487 additions and 334 deletions.
35 changes: 25 additions & 10 deletions paddle/fluid/operators/fused/fused_attention_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -372,19 +373,22 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
"0.0 and 0.001, But received [%s].",
ln_epsilon));
});
AddAttr<bool>("add_residual", "Whether to add residual.").SetDefault(true);
AddAttr<int>(
"ring_id",
"ring id for tensor model parallel. distributed training and inference")
.SetDefault(-1);

AddComment(R"DOC(
Add fused attention op whose logic is as follows:
// @input: [batch_size, seq_len, 3, num_head, head_dim]
The fused_attention operator is the same as following pseudo codes:
// @input: [batch_size, seq_len, embed_dim]
// @final_out: [batch_size, seq_len, num_heads, head_dim]
residual = input
if (pre_layernorm)
out = layer_norm(input);
out = compute_qkv(out) + bias;
// fmha module
query = layer_norm(input);
out = compute_qkv(query) + qkv_bias;
// fmha module
{
out = transpose(out, perm=[2, 0, 3, 1, 4]);
out = q * k^t;
Expand All @@ -395,11 +399,14 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
out = transpose(out, perm=[0, 2, 1, 3]);
}
out = out_linear(out);
if (pre_layernorm)
final_out = residual + dropout(bias + out);
else
final_out = layer_norm(residual + dropout(bias + out));
// out linear
out = linear(out);
if add_residual:
out = residual + dropout(out);
else:
out = dropout(out);
if (!pre_layernorm)
out = layer_norm(out);
)DOC");
}
};
Expand Down Expand Up @@ -649,3 +656,11 @@ REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp,
ops::FusedAttentionGradOpMaker<paddle::framework::OpDesc>,
ops::FusedAttentionGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp);

REGISTER_OP_VERSION(fused_attention)
.AddCheckpoint(
R"ROC(
Add a new attribute [add_residual] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"add_residual", "A flag to indicate whether to add residual.",
true));
62 changes: 33 additions & 29 deletions paddle/fluid/operators/fused/fused_attention_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -245,26 +245,32 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
// tensor model parallel
AllReduce<T>(*out_linear_out, ring_id, ctx.cuda_device_context());

bool add_residual = ctx.Attr<bool>("add_residual");
const T *residual_ptr = add_residual ? x_data : nullptr;
if (pre_layer_norm) {
// output = (residual + dropout(input + bias))
fused_dropout_layernorm_helper.ResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data,
ctx.cuda_device_context(), out_linear_out_data, residual_ptr,
out_linear_bias_data, final_out_data, dropout_mask_out_data);
} else {
auto *ln_scale_2_data =
(ln_scale_2 == nullptr ? nullptr : ln_scale_2->data<U>());
auto *ln_bias_2_data =
(ln_bias_2 == nullptr ? nullptr : ln_bias_2->data<U>());
auto *bias_dropout_residual_out_data =
// TODO(Xreki): support post layer_norm case when add_residual is false.
PADDLE_ENFORCE_EQ(add_residual, true,
platform::errors::InvalidArgument(
"Attribute add_residual is expected to be true "
"when pre_layer_norm is false."));

const U *ln_scale_2_ptr = ln_scale_2 ? ln_scale_2->data<U>() : nullptr;
const U *ln_bias_2_ptr = ln_bias_2 ? ln_bias_2->data<U>() : nullptr;
T *bias_dropout_residual_out_ptr =
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_2_data = ln_mean_2->mutable_data<U>(ctx.GetPlace());
auto *ln_var_2_data = ln_var_2->mutable_data<U>(ctx.GetPlace());
U *ln_mean_2_ptr = ln_mean_2->mutable_data<U>(ctx.GetPlace());
U *ln_var_2_ptr = ln_var_2->mutable_data<U>(ctx.GetPlace());
// output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data,
out_linear_bias_data, ln_scale_2_data, ln_bias_2_data,
bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data,
ln_mean_2_data, ln_var_2_data);
ctx.cuda_device_context(), out_linear_out_data, residual_ptr,
out_linear_bias_data, ln_scale_2_ptr, ln_bias_2_ptr,
bias_dropout_residual_out_ptr, dropout_mask_out_data, final_out_data,
ln_mean_2_ptr, ln_var_2_ptr);
}
}
};
Expand Down Expand Up @@ -418,16 +424,17 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
int output_size = 3 * hidden_size;
int input_size = dim_embed;

bool add_residual = ctx.Attr<bool>("add_residual");
Tensor d_residual;
d_residual.Resize(input_x_dims);
T *d_residual_data = d_residual.mutable_data<T>(ctx.GetPlace());
T *d_residual_data = nullptr;
if (add_residual) {
d_residual.Resize(input_x_dims);
d_residual_data = d_residual.mutable_data<T>(ctx.GetPlace());
}

bool transA = false;
bool transB = true;
bool compute_qkv_bias = true;
if (qkv_bias == nullptr) {
compute_qkv_bias = false;
}
bool compute_qkv_bias = qkv_bias ? true : false;
auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(),
epsilon, bsz_seq, dim_embed);
auto qkv_compute =
Expand Down Expand Up @@ -536,17 +543,14 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
// tensor model parallel
AllReduce<T>(*d_x, ring_id, ctx.cuda_device_context());
}
// gradient accumulation
std::vector<const Tensor *> ins;
std::vector<Tensor *> outs;
ins.emplace_back(&d_residual);
ins.emplace_back(d_x);
outs.emplace_back(d_x);
int elewise_add_axis = -1;
paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T,
T>(
ctx.cuda_device_context(), ins, &outs, elewise_add_axis,
AddFunctor<T>());

if (add_residual) {
// gradient accumulation
std::vector<const Tensor *> ins = {&d_residual, d_x};
std::vector<Tensor *> outs = {d_x};
phi::funcs::ElementwiseKernel<T>(ctx.cuda_device_context(), ins, &outs,
phi::funcs::AddFunctor<T>());
}
}
};

Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/operators/fused/fused_dropout_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,10 @@ class FusedDropoutHelper {
LaunchResidualDropoutBiasGrad<T, uint8_t>(
d_out, mask, dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx);
auto cuda_place = ctx.GetPlace();
memory::Copy(cuda_place, d_residual, cuda_place, d_out,
rows_ * cols_ * sizeof(T), ctx.stream());
if (d_residual) {
memory::Copy(ctx.GetPlace(), d_residual, ctx.GetPlace(), d_out,
rows_ * cols_ * sizeof(T), ctx.stream());
}
}

// out = dropout(activation(src + bias))
Expand Down
39 changes: 28 additions & 11 deletions paddle/fluid/operators/fused/fused_feedforward_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,20 +193,29 @@ class FusedFeedForwardOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(false);
AddAttr<int>("dropout1_seed", "Dropout1 random seed.").SetDefault(0);
AddAttr<int>("dropout2_seed", "Dropout2 random seed.").SetDefault(0);
AddAttr<bool>("add_residual", "Whether to add residual.").SetDefault(true);
AddAttr<int>("ring_id", "ring id for tensor model parallel.")
.SetDefault(-1);
AddComment(R"DOC(
the function of fused_feedforward operator is the same as the following pseudo code:
residual = src;
ln1_out = src;
if(pre_layer_norm){
ln1_out = layer_norm(src);
}
out = linear(dropout(activation(dropout(linear(ln1_out)))));
if(!pre_layer_norm) {
out = layer_norm(out);
}
)DOC");
The fused_feedforward operator is the same as the following pseudo codes:
residual = src;
if (pre_layer_norm)
ln1_out = layer_norm(src);
else
ln1_out = src;
// linear 1
out = linear(ln1_out);
out = dropout(activation(out));
// linear 2
out = linear(out);
if (add_residual)
out = residual + dropout(out);
else
out = dropout(out);
if (!pre_layer_norm)
out = layer_norm(out);
)DOC");
}
};

Expand Down Expand Up @@ -366,3 +375,11 @@ REGISTER_OPERATOR(fused_feedforward, ops::FusedFeedForwardOp,
ops::FusedFeedForwardOpGradMaker<paddle::framework::OpDesc>,
ops::FusedFeedForwardOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_feedforward_grad, ops::FusedFeedForwardOpGrad);

REGISTER_OP_VERSION(fused_feedforward)
.AddCheckpoint(
R"ROC(
Add a new attribute [add_residual] )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"add_residual", "A flag to indicate whether to add residual.",
true));
Loading

0 comments on commit 0660d5f

Please sign in to comment.