Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Add OptimizeReductionTactic #60661

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/tile_tactic.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
Expand All @@ -27,8 +28,9 @@ namespace ir {
void DynamicShapeGroupScheduler::Init() {
InitBuckets();
tactics_.emplace_back(new AlignIterSpaceTactic());
tactics_.emplace_back(new TileTactic());
tactics_.emplace_back(new ComputeInlineTactic());
tactics_.emplace_back(new TileTactic());
tactics_.emplace_back(new OptimizeReductionTactic());
tactics_.emplace_back(new BindCudaTactic());
tactics_.emplace_back(new ArrangeStorageTactic());
}
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ core_gather_headers()
gather_srcs(cinnapi_src SRCS align_iter_space_tactic.cc)
gather_srcs(cinnapi_src SRCS tile_tactic.cc)
gather_srcs(cinnapi_src SRCS compute_inline_tactic.cc)
gather_srcs(cinnapi_src SRCS optimize_reduction_tactic.cc)
gather_srcs(cinnapi_src SRCS bind_cuda_tactic.cc)
gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc)
126 changes: 126 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"

namespace cinn {
namespace ir {

void OptimizeReductionTactic::Init(ScheduleContext* context) {
context_ = context;
}

bool CanApply(const std::string& block_name, ir::IRSchedule* sch) {
ir::Expr block_expr = sch->GetBlock(block_name);
ir::ScheduleBlockRealize* block_realize =
block_expr.As<ir::ScheduleBlockRealize>();
CHECK_NOTNULL(block_realize);
ir::ScheduleBlock* sch_block =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK_NOTNULL(sch_block);
analyzer::AnalyzeScheduleBlockReadWriteBuffer(sch_block);

// 1. The block must have write buffer
if (sch_block->write_buffers.empty()) {
return false;
}

// 2. The block must have at least one reduce axis
const std::vector<ir::Var>& iter_vars = sch_block->iter_vars;
bool find_reduce_axis = false;
for (int i = 0; i < iter_vars.size(); ++i) {
if (iter_vars[i]->is_reduce_axis) {
find_reduce_axis = true;
break;
}
}
if (!find_reduce_axis) {
return false;
}

// 3. Each loop's body only contains one sub loop or block, except reduce_init
// block
std::vector<ir::Expr> loops = sch->GetLoops(block_name);
for (const ir::Expr& loop : loops) {
const ir::Expr& body = loop.As<ir::For>()->body;
if (body.As<ir::Block>()) {
if (body.As<ir::Block>()->stmts.size() == 1) {
if (body.As<ir::Block>()->stmts[0].As<ir::For>() == nullptr &&
body.As<ir::Block>()->stmts[0].As<ir::ScheduleBlockRealize>() ==
nullptr) {
return false;
}
} else if (body.As<ir::Block>()->stmts.size() == 2) {
if (body.As<ir::Block>()->stmts[0].As<ir::ScheduleBlockRealize>() ==
nullptr ||
!ir::IsReduceInitTensorName(
analyzer::GetBlockName(body.As<ir::Block>()->stmts[0]))) {
return false;
}
if (body.As<ir::Block>()->stmts[1].As<ir::For>() == nullptr &&
body.As<ir::Block>()->stmts[1].As<ir::ScheduleBlockRealize>() ==
nullptr) {
return false;
}
} else {
return false;
}
} else if (body.As<ir::For>() || body.As<ir::ScheduleBlockRealize>()) {
continue;
} else {
return false;
}
}

return true;
}

void OptimizeReductionTactic::Apply(ir::IRSchedule* sch,
const std::string& block_id) {
if (!CanApply(block_id, sch)) return;

std::vector<ir::Expr> loops = sch->GetLoops(block_id);
int first_reduce_loop_idx = context_->iter_space_info.sp_space.size();
CHECK_LT(first_reduce_loop_idx, loops.size())
<< "first_reduce_loop_idx shoud be less than number of loop.";
// Apply FactorizeReduction
VLOG(6) << "before FactorizeReduction: " << sch->GetModule().GetExprs()[0];
sch->FactorizeReduction(loops[first_reduce_loop_idx], first_reduce_loop_idx);
VLOG(6) << "after FactorizeReduction: " << sch->GetModule().GetExprs()[0];

// Loop fusion and cross thread reduction
std::vector<ir::Expr> rb_loops = sch->GetLoops(block_id);
std::string rf_block_id = block_id + "_rf";
ir::Expr rf_block = sch->GetBlock(rf_block_id);
sch->SimpleComputeAt(rf_block, rb_loops.back());

rb_loops = sch->GetLoops(block_id);
ir::Expr rf_init_block =
sch->GetBlock(ir::GenReduceInitTensorNameOf(rf_block_id));
sch->SimpleComputeAt(rf_init_block, rb_loops.back());

if (context_->target == cinn::common::DefaultNVGPUTarget()) {
rb_loops = sch->GetLoops(block_id);
rf_block = sch->GetBlock(rf_block_id);
sch->Bind(rb_loops.back(), "threadIdx.x");
sch->SetBuffer(rf_block, "shared");
}
VLOG(6) << "Loop fusion and cross thread reduction: "
<< sch->GetModule().GetExprs()[0];
}

} // namespace ir
} // namespace cinn
36 changes: 36 additions & 0 deletions paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h"

namespace cinn {
namespace ir {

class OptimizeReductionTactic final : public ScheduleTactic {
public:
void Init(ScheduleContext* context) override;

void Apply(ir::IRSchedule* sch, const std::string& block_id) override;

std::string TacticName() const override { return "OptimizeReductionTactic"; }

private:
ScheduleContext* context_;
};

} // namespace ir
} // namespace cinn
70 changes: 70 additions & 0 deletions paddle/cinn/ir/ir_analyzer/ir_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include <utility>
#include <vector>

#include "paddle/cinn/common/context.h"
#include "paddle/cinn/common/integer_set.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
Expand Down Expand Up @@ -440,6 +442,74 @@ bool IsBroadcastSBlock(ir::Expr block) {
return load->indices.size() < store->indices.size();
}

std::vector<ir::Var> IndicesToVars(const std::vector<ir::Expr>& indices) {
std::vector<ir::Var> result;
for (const ir::Expr& e : indices) {
if (e.is_constant()) {
std::string var_name =
cinn::UniqName("constant" + static_cast<int>(e.get_constant()));
result.emplace_back(e, e, var_name, /* is_reduce = */ false);
} else if (e.As<ir::_Var_>() != nullptr) {
ir::Expr copy_e = ir::ir_utils::IRCopy(e);
ir::_Var_* var_ref = copy_e.As<ir::_Var_>();
result.emplace_back(ir::Var(var_ref));
} else {
std::string var_name = cinn::UniqName("expr");
common::cas_intervals_t var_intervals;
bool is_reduce = false;
ir::ir_utils::CollectIRNodes(e, [&](const ir::Expr* x) {
if (x->As<ir::_Var_>() != nullptr) {
ir::Var var = x->as_var_ref();
var_intervals.insert(
{var->name,
common::CasInterval{var->lower_bound, var->upper_bound}});
if (var->is_reduce_axis) is_reduce = true;
}
return false;
});
common::SymbolicExprAnalyzer analyzer(var_intervals);
result.emplace_back(
analyzer.LowerBound(e), analyzer.UpperBound(e), var_name, is_reduce);
}
}
return result;
}

void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block) {
if (!sche_block->read_buffers.empty() || !sche_block->write_buffers.empty()) {
return;
}

ir::ir_utils::CollectIRNodesWithoutTensor(
sche_block->body, [&](const Expr* x) {
const ir::Load* load_expr = x->As<ir::Load>();
if (load_expr != nullptr) {
const ir::Tensor t = load_expr->tensor.as_tensor_ref();
sche_block->read_buffers.emplace_back(
ir::BufferRange(t->buffer, IndicesToVars(load_expr->indices)));
return false;
}
const ir::Store* store_expr = x->As<ir::Store>();
if (store_expr != nullptr) {
const ir::Tensor t = store_expr->tensor.as_tensor_ref();
sche_block->write_buffers.emplace_back(
ir::BufferRange(t->buffer, IndicesToVars(store_expr->indices)));
return false;
}
return false;
});
}

std::string GetBlockName(const ir::Expr block) {
const ir::ScheduleBlockRealize* block_realize =
block.As<ir::ScheduleBlockRealize>();
CHECK_NOTNULL(block_realize);
const ir::ScheduleBlock* block_node =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK_NOTNULL(block_node);
return block_node->name;
}

} // namespace analyzer
} // namespace ir
} // namespace cinn
6 changes: 6 additions & 0 deletions paddle/cinn/ir/ir_analyzer/ir_analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ bool IsReductionSBlock(ir::Expr block);

bool IsBroadcastSBlock(ir::Expr block);

std::vector<ir::Var> IndicesToVars(const std::vector<ir::Expr>& indices);

void AnalyzeScheduleBlockReadWriteBuffer(ir::ScheduleBlock* sche_block);

std::string GetBlockName(const ir::Expr block);

} // namespace analyzer
} // namespace ir
} // namespace cinn