Skip to content

Commit

Permalink
Add fused adaln scale residual xpu pass
Browse files Browse the repository at this point in the history
  • Loading branch information
sungenglab committed Feb 25, 2025
1 parent 6a6f084 commit fac8844
Show file tree
Hide file tree
Showing 11 changed files with 394 additions and 4 deletions.
8 changes: 5 additions & 3 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,8 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"map_op_to_another_pass",
// "quant_dequant_xpu_pass", open this pass when use old int8 model
"delete_quant_dequant_linear_op_pass",
"adaptive_layernorm_xpu_fuse_pass",
"fused_adaLN_scale_residual_xpu_pass",
"delete_weight_dequant_linear_op_pass",
"delete_assign_op_pass",
"delete_dropout_op_pass",
Expand Down Expand Up @@ -560,7 +562,6 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"add_layernorm_xpu_fuse_pass",
"layer_norm_act_xpu_fuse_pass",
"fast_layernorm_xpu_fuse_pass",
"adaptive_layernorm_xpu_fuse_pass",
"bn_act_xpu_fuse_pass",
"yolo_box_xpu_fuse_pass",
"fast_where_xpu_fuse_pass",
Expand Down Expand Up @@ -626,8 +627,9 @@ const std::vector<std::string> kPirXpuPasses{
"conv2d_bn_xpu_fuse_pass",
"conv2d_add_xpu_fuse_pass",
"group_norm_silu_fuse_pass",
"fc_xpu_fuse_pass",
"adaptive_layernorm_xpu_fuse_pass"};
"fused_adaLN_scale_residual_xpu_pass",
"adaptive_layernorm_xpu_fuse_pass",
"fc_xpu_fuse_pass"};

const std::vector<std::string> kPirMkldnnPasses {
"add_shadow_output_after_dead_parameter_pass",
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ USE_PIR_PASS(elementwise_mul_add_xpu_fuse_pass);
USE_PIR_PASS(conv2d_bn_xpu_fuse_pass);
USE_PIR_PASS(conv2d_add_xpu_fuse_pass);
USE_PIR_PASS(fc_xpu_fuse_pass);
USE_PIR_PASS(fused_adaLN_scale_residual_xpu_pass);
USE_PIR_PASS(adaptive_layernorm_xpu_fuse_pass);
#endif

Expand Down
173 changes: 173 additions & 0 deletions paddle/fluid/pir/transforms/xpu/fused_adaLN_scale_residual_xpu_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/transforms/xpu/fused_adaLN_scale_residual_xpu_pass.h"
#include <optional>

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"
#include "paddle/fluid/pir/utils/general_functions.h"

#include "paddle/fluid/framework/scope.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/pir/include/pass/pass.h"
#include "paddle/pir/include/pass/pass_registry.h"
#include "paddle/pir/include/pattern_rewrite/pattern_match.h"

/*
fuse adaln + scale_residual in to xpu_adalan_scale_residual op
For example:
graph:
ele_x
| ele_y
| /
elementwise_mul
|
| ele_z
| /
elementwise_add
|
| ele_u
| /
adaptive_layer_norm
|
|
out_Out
------------------------------------------------------
After the pass is applied:
ele_x
| ele_y
ele_u | /
\ | /
adaln_scale_residual_xpu_kernel ---- ele_z
|
|
|
Output
*/

namespace {

template <int act_type>
class FusedAdalnScaleResidualPattern : public paddle::drr::DrrPatternBase {
public:
std::string name() const override { return "FusedAdalnScaleResidualPattern"; }

// rewrite pattern operator()
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
paddle::drr::SourcePattern pat = ctx->SourcePattern();

// Patterns
const auto &multiply = pat.Op(paddle::dialect::MultiplyOp::name());
const auto &add = pat.Op(paddle::dialect::AddOp::name());
const auto &layernorm =
pat.Op(paddle::dialect::LayerNormOp::name(),
{{"epsilon", pat.Attr("epsilon")},
{"begin_norm_axis", pat.Attr("begin_norm_axis")}});
const auto &scale =
pat.Op(paddle::dialect::ScaleOp::name(),
{{"bias", pat.Attr("scale_bias")},
{"bias_after_scale", pat.Attr("bias_after_scale")}});
const auto &full = pat.Op(paddle::dialect::FullOp::name(),
{{"shape", pat.Attr("full_shape")},
{"value", pat.Attr("full_value")},
{"dtype", pat.Attr("full_dtype")},
{"place", pat.Attr("full_place")}});
// calling pattern
multiply({&pat.Tensor("x1"), &pat.Tensor("unsqueezed1")},
{&pat.Tensor("final_output")});
add({&pat.Tensor("final_output"), &pat.Tensor("x2")},
{&pat.Tensor("final_output")});
layernorm({&pat.Tensor("final_output"), &pat.Tensor("w"), &pat.Tensor("b")},
{&pat.Tensor("final_output"),
&pat.Tensor("mean_out_0"),
&pat.Tensor("variance_out_0")});
scale({&pat.Tensor("unsqueezed2"), &full()}, {&pat.Tensor("scale_out")});
multiply({&pat.Tensor("final_output"), &pat.Tensor("scale_out")},
{&pat.Tensor("final_output")});
add({&pat.Tensor("final_output"), &pat.Tensor("unsqueezed3")},
{&pat.Tensor("final_output")});

// Constraints
pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) {
auto x_shape = pir::GetShapeFromValue(match_ctx.Tensor("x1"));
auto add_in_shape = pir::GetShapeFromValue(match_ctx.Tensor("x2"));
auto unsqueezed1_shape =
pir::GetShapeFromValue(match_ctx.Tensor("unsqueezed1"));
auto unsqueezed2_shape =
pir::GetShapeFromValue(match_ctx.Tensor("unsqueezed2"));
auto unsqueezed3_shape =
pir::GetShapeFromValue(match_ctx.Tensor("unsqueezed3"));

// if ((x_shape.size() == scale_in_shape.size()) &&
// (scale_in_shape.size() == add_in_shape.size())) {
// return true;
// }else {
// return false;
// }
return true;
});

// result pattern
paddle::drr::ResultPattern res = pat.ResultPattern();
const auto &scale_weight = res.ComputeAttr(
[](const paddle::drr::MatchContext &match_ctx) -> float {
return match_ctx.Attr<double>("full_value");
});

const auto &fused_adaLN_scale_residual_xpu_kernel =
res.Op(paddle::dialect::FusedAdalnScaleResidualXpuKernelOp::name(),
{{
{"begin_norm_axis", pat.Attr("begin_norm_axis")},
{"epsilon", pat.Attr("epsilon")},
{"scale_weight", scale_weight},
{"scale_bias", pat.Attr("scale_bias")},
{"bias_after_scale", pat.Attr("bias_after_scale")},
}});
fused_adaLN_scale_residual_xpu_kernel(
{
&res.Tensor("x1"),
&res.Tensor("x2"),
&res.Tensor("unsqueeze1"),
&res.Tensor("unsqueeze2"),
&res.Tensor("unsqueeze3"),
&res.Tensor("w"),
&res.Tensor("b"),
},
{&res.Tensor("final_output")});
}
};

class FusedAdalnScaleResidualXpuPass : public pir::PatternRewritePass {
public:
FusedAdalnScaleResidualXpuPass()
: pir::PatternRewritePass("fused_adaLN_scale_residual_xpu_pass", 2) {}

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
ps.Add(paddle::drr::Create<FusedAdalnScaleResidualPattern>(context));
return ps;
}
};
} // namespace

namespace pir {
std::unique_ptr<Pass> CreateFusedAdalnScaleResidualXpuPass() {
return std::make_unique<FusedAdalnScaleResidualXpuPass>();
}
} // namespace pir

REGISTER_IR_PASS(fused_adaLN_scale_residual_xpu_pass,
FusedAdalnScaleResidualXpuPass);
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include "paddle/pir/include/core/dll_decl.h"
namespace pir {

class Pass;

IR_API std::unique_ptr<Pass> CreateFusedAdalnScaleResidualXpuPass();

} // namespace pir
3 changes: 2 additions & 1 deletion paddle/phi/backends/xpu/xpu1_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ XPUOpMap& get_kl1_ops() {
XPUKernelSet({phi::DataType::INT32,
phi::DataType::INT64,
phi::DataType::FLOAT32})},
{"fast_where_xpu",
{"fast_where_xpu", XPUKernelSet({phi::DataType::FLOAT32})},
{"fused_adaLN_scale_residual_xpu_kernel",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64})},
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1340,6 +1340,10 @@ XPUOpMap& get_kl2_ops() {
{"block_multihead_attention_xpu", XPUKernelSet({phi::DataType::FLOAT16})},
{"blha_get_max_len",
XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
{"fused_adaLN_scale_residual_xpu_kernel",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
};

return s_xpu2_kernels;
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1616,6 +1616,10 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::INT64})},
{"blha_get_max_len",
XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
{"fused_adaLN_scale_residual_xpu_kernel",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
};

return s_xpu3_kernels;
Expand Down
17 changes: 17 additions & 0 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6061,4 +6061,21 @@ void ResnetBasicBlockGradInferMeta(const MetaTensor& x,
}
}

void FusedAdalnScaleResidualInferMeta(const MetaTensor& input1,
const MetaTensor& input2,
const MetaTensor& unsqueeze1,
const MetaTensor& unsqueeze2,
const MetaTensor& unsqueeze3,
const MetaTensor& ln_weight,
const MetaTensor& ln_bias,
const int begin_norm_axis,
const float epsilon,
const float scale_op_weight,
const float scale_op_bias,
const bool bias_after_scale,
MetaTensor* out) {
out->set_dims(input1.dims());
out->set_dtype(input1.dtype());
out->set_layout(input1.layout());
}
} // namespace phi
13 changes: 13 additions & 0 deletions paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -1471,4 +1471,17 @@ void ResnetBasicBlockGradInferMeta(const MetaTensor& x,
MetaTensor* bias3_grad,
MetaConfig config = MetaConfig());

void FusedAdalnScaleResidualInferMeta(const MetaTensor& input1,
const MetaTensor& input2,
const MetaTensor& unsqueeze1,
const MetaTensor& unsqueeze2,
const MetaTensor& unsqueeze3,
const MetaTensor& ln_weight,
const MetaTensor& ln_bias,
const int begin_norm_axis,
const float epsilon,
const float scale_op_weight,
const float scale_op_bias,
const bool bias_after_scale,
MetaTensor* out);
} // namespace phi
Loading

0 comments on commit fac8844

Please sign in to comment.