Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Apr 25, 2024
2 parents 746a260 + 0455cd9 commit 9ab9aec
Show file tree
Hide file tree
Showing 196 changed files with 5,483 additions and 2,127 deletions.
3 changes: 3 additions & 0 deletions cmake/simd.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ int main()
return 0;
}"
AVX512F_FOUND)
if(AVX512F_FOUND)
add_definitions(-DPADDLE_WITH_AVX512F)
endif()

set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_RETAINED})
mark_as_advanced(MMX_FOUND SSE2_FOUND SSE3_FOUND AVX_FOUND AVX2_FOUND
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h"
Expand Down Expand Up @@ -260,7 +261,8 @@ class BlockDimExprsAsserter {
input_tensors, output_dim_expr_attrs, symbol_bindings)
.out();
return builder_
.Build<paddle::dialect::CastOp>(out_shape_value, phi::DataType::INT32)
.Build<cinn::dialect::GenerateShapeOp>(
input_tensors, output_dim_expr_attrs, symbol_bindings)
.out();
}

Expand All @@ -275,7 +277,18 @@ class BlockDimExprsAsserter {
auto opt_shape_tensor_from_dim_exprs =
BuildShapeTensorFromDataDimExprs(inputs, output, OpDimExprs4Value);
if (!opt_shape_tensor_from_dim_exprs.has_value()) return;
AddAssertEqual(op, opt_shape_tensor_from_dim_exprs.value(), output);
pir::Value flatten_output = [&] {
const auto& output_dims =
output.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
if (output_dims.size() > 1) {
return builder_
.Build<paddle::dialect::FlattenOp>(
output, 0, output_dims.size() - 1)
.out();
}
return output;
}();
AddAssertEqual(op, opt_shape_tensor_from_dim_exprs.value(), flatten_output);
}

size_t GetNumel(pir::Value value) {
Expand All @@ -302,12 +315,28 @@ class BlockDimExprsAsserter {
"received lhs's numel is [%d], rhs's numel is [%d]",
lhs_numel,
rhs_numel));

pir::Value rhs_value = [&] {
const auto& lhs_dtype =
lhs.type().dyn_cast<paddle::dialect::DenseTensorType>().dtype();
const auto& rhs_dtype =
rhs.type().dyn_cast<paddle::dialect::DenseTensorType>().dtype();
if (lhs_dtype != rhs_dtype) {
return builder_
.Build<paddle::dialect::CastOp>(
rhs, paddle::dialect::TransToPhiDataType(lhs_dtype))
.out();
}
return rhs;
}();

pir::Value lhs_eq_rhs =
builder_.Build<paddle::dialect::EqualOp>(lhs, rhs).out();
builder_.Build<paddle::dialect::EqualOp>(lhs, rhs_value).out();
pir::Value all_eq =
builder_.Build<paddle::dialect::AllOp>(lhs_eq_rhs).out();
pir::Value assert_data =
builder_.Build<pir::CombineOp>(std::vector<pir::Value>{lhs, rhs}).out();
builder_.Build<pir::CombineOp>(std::vector<pir::Value>{lhs, rhs_value})
.out();
auto assert_op = builder_.Build<paddle::dialect::AssertOp>(
all_eq, assert_data, lhs_numel);
const std::string error_msg = "Check [" + op->name() + "_" +
Expand Down
245 changes: 131 additions & 114 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,95 +470,122 @@ class ElementwisePowOpPattern
}
};

static void ReplaceSliceOp(const cinn::dialect::SplitOp &cinn_split,
pir::Operation *slice_op,
pir::PatternRewriter &rewriter) { // NOLINT
const int index = slice_op->dyn_cast<::pir::SliceOp>()
.attribute("index")
.dyn_cast<::pir::Int32Attribute>()
.data();
rewriter.ReplaceAllUsesWith(slice_op->result(0), cinn_split.result(index));
rewriter.EraseOp(slice_op);
}

static void ReplaceSplitOp(const cinn::dialect::SplitOp &cinn_split,
pir::Operation *split_op,
pir::PatternRewriter &rewriter) { // NOLINT
const size_t num_results = cinn_split.num_results();
CHECK(split_op->num_results() == num_results);
for (size_t i = 0; i < num_results; ++i) {
rewriter.ReplaceAllUsesWith(split_op->result(i), cinn_split.result(i));
}
rewriter.EraseOp(split_op);
}

class SplitOpPattern : public pir::OpRewritePattern<paddle::dialect::SplitOp> {
public:
using pir::OpRewritePattern<paddle::dialect::SplitOp>::OpRewritePattern;

bool Match(paddle::dialect::SplitOp op) const override {
const bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation());
auto sections_gen_op = op->operand_source(1)
.defining_op()
->dyn_cast<paddle::dialect::FullIntArrayOp>();
auto axis_gen_op = op->operand_source(2)
.defining_op()
->dyn_cast<paddle::dialect::FullOp>();
return !is_denied && sections_gen_op && axis_gen_op;
return sections_gen_op && axis_gen_op;
}

void Rewrite(paddle::dialect::SplitOp op,
pir::PatternRewriter &rewriter) const override {
const std::vector<int> sections = [&]() -> std::vector<int> {
std::vector<int> result;
auto sections_gen_op = op->operand_source(1)
.defining_op()
->dyn_cast<paddle::dialect::FullIntArrayOp>();
auto section_attr = sections_gen_op.attribute("value")
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
if (section_attr.size() > 0) {
for (size_t i = 0; i < section_attr.size(); ++i) {
result.push_back(
section_attr[i].dyn_cast<::pir::Int64Attribute>().data());
}
}
return result;
}();

const int axis = [&]() -> int {
auto axis_gen_op = op->operand_source(2)
.defining_op()
->dyn_cast<paddle::dialect::FullOp>();
int axis = static_cast<int>(axis_gen_op.attribute("value")
.dyn_cast<::pir::FloatAttribute>()
.data());
auto input_ele = op->operand_source(0)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>();
if (axis < 0) {
axis += input_ele.dims().size();
}
return axis;
}();

auto cinn_split = rewriter.Build<cinn::dialect::SplitOp>(
op->operand_source(0), sections, axis);

auto orig_out = op.result(0);
for (auto it = orig_out.use_begin(); it != orig_out.use_end();) {
for (auto it = op.out().use_begin(); it != op.out().use_end();) {
auto downstream_op = (it++)->owner();
if (downstream_op->isa<::pir::SliceOp>()) {
ReplaceSliceOp(cinn_split, downstream_op, rewriter);
ReplaceSplitSliceBySlice(
op, downstream_op->dyn_cast<::pir::SliceOp>(), rewriter);
} else if (downstream_op->isa<::pir::SplitOp>()) {
ReplaceSplitOp(cinn_split, downstream_op, rewriter);
ReplaceSplitSplitBySlice(
op, downstream_op->dyn_cast<::pir::SplitOp>(), rewriter);
} else {
CHECK(false) << "Currently only support pir::slice/split as downstream "
"op, but got: "
<< downstream_op->name();
}
}
rewriter.EraseOp(op);
}

private:
int GetAxis(paddle::dialect::SplitOp op) const {
auto axis_gen_op = op->operand_source(2).defining_op();
auto full_op = axis_gen_op->dyn_cast<paddle::dialect::FullOp>();
int axis = static_cast<int>(
full_op.attribute("value").dyn_cast<::pir::FloatAttribute>().data());
if (axis < 0) {
axis += op.x()
.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims()
.size();
}
return axis;
}

std::vector<int64_t> GetSections(paddle::dialect::SplitOp op) const {
std::vector<int64_t> result;
auto sections_gen_op = op->operand_source(1)
.defining_op()
->dyn_cast<paddle::dialect::FullIntArrayOp>();
auto section_attr =
sections_gen_op.attribute<pir::ArrayAttribute>("value").AsVector();
if (section_attr.size() > 0) {
for (size_t i = 0; i < section_attr.size(); ++i) {
result.push_back(
section_attr[i].dyn_cast<::pir::Int64Attribute>().data());
}
}
return result;
}

void ReplaceSplitSliceBySlice(
paddle::dialect::SplitOp split,
::pir::SliceOp slice,
pir::PatternRewriter &rewriter) const { // NOLINT
const int axis = GetAxis(split);
const std::vector<int64_t> &sections = GetSections(split);
for (auto section : sections) {
VLOG(0) << " " << section;
}
const int index = slice->attribute<::pir::Int32Attribute>("index").data();
int64_t start =
std::accumulate(sections.begin(), sections.begin() + index, 0);
int64_t end = start + sections[index];
auto paddle_slice =
rewriter.Build<paddle::dialect::SliceOp>(split.x(),
std::vector<int64_t>({axis}),
std::vector<int64_t>({start}),
std::vector<int64_t>({end}),
std::vector<int64_t>({}),
std::vector<int64_t>({}));

rewriter.ReplaceAllUsesWith(slice->result(0), paddle_slice.result(0));
rewriter.EraseOp(slice);
if (split->use_empty()) {
rewriter.EraseOp(split);
}
}

void ReplaceSplitSplitBySlice(
paddle::dialect::SplitOp split,
::pir::SplitOp pir_split,
pir::PatternRewriter &rewriter) const { // NOLINT
const int axis = GetAxis(split);
const std::vector<int64_t> &sections = GetSections(split);
int64_t start = 0, end = 0;
for (size_t i = 0; i < pir_split->num_results(); ++i) {
start = end;
end += sections.at(i);
auto paddle_slice = rewriter.Build<paddle::dialect::SliceOp>(
split.x(),
std::vector<int64_t>({axis}),
std::vector<int64_t>({start}),
std::vector<int64_t>({end}),
std::vector<int64_t>({}),
std::vector<int64_t>({}));
rewriter.ReplaceAllUsesWith(pir_split->result(i),
paddle_slice->result(0));
}
rewriter.EraseOp(pir_split);
if (split->use_empty()) {
rewriter.EraseOp(split);
}
}
};

Expand All @@ -569,62 +596,51 @@ class SplitWithNumOpPattern
paddle::dialect::SplitWithNumOp>::OpRewritePattern;

bool Match(paddle::dialect::SplitWithNumOp op) const override {
const bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation());
auto axis_gen_op = op->operand_source(1).defining_op();
auto full_op = axis_gen_op->dyn_cast<paddle::dialect::FullOp>();
return !is_denied && full_op;
return axis_gen_op->isa<paddle::dialect::FullOp>();
}

void Rewrite(paddle::dialect::SplitWithNumOp op,
pir::PatternRewriter &rewriter) const override {
const auto input_ele = op->operand_source(0)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>();

const int axis = [&]() -> int {
auto axis_gen_op = op->operand_source(1).defining_op();
auto full_op = axis_gen_op->dyn_cast<paddle::dialect::FullOp>();
int axis = static_cast<int>(
full_op.attribute("value").dyn_cast<::pir::FloatAttribute>().data());
if (axis < 0) {
axis += input_ele.dims().size();
}
return axis;
}();

const auto sections = [&]() -> std::vector<int> {
std::vector<int> result;
auto split_dim = input_ele.dims()[axis];
auto split_num =
op->attribute("num").dyn_cast<::pir::Int32Attribute>().data();
auto part_ele = (split_dim + split_num - 1) / split_num;
int total_split_num = 0;
for (int i = 0; i < split_num - 1; ++i) {
result.push_back(part_ele);
total_split_num += part_ele;
}

result.push_back(split_dim - total_split_num);
return result;
}();
const int axis = GetAxis(op);
const std::vector<int64_t> &sections = GetSections(op, axis);
auto split_op =
rewriter.Build<paddle::dialect::SplitOp>(op.x(), sections, axis);
rewriter.ReplaceAllUsesWith(op.out(), split_op.out());
rewriter.EraseOp(op);
}

auto cinn_split = rewriter.Build<cinn::dialect::SplitOp>(
op->operand_source(0), sections, axis);
protected:
int GetAxis(paddle::dialect::SplitWithNumOp op) const {
auto axis_gen_op = op->operand_source(1).defining_op();
auto full_op = axis_gen_op->dyn_cast<paddle::dialect::FullOp>();
int axis = static_cast<int>(
full_op.attribute<::pir::FloatAttribute>("value").data());
if (axis < 0) {
axis += op.x()
.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims()
.size();
}
return axis;
}

auto orig_out = op.result(0);
for (auto it = orig_out.use_begin(); it != orig_out.use_end();) {
auto downstream_op = (it++)->owner();
if (downstream_op->isa<::pir::SliceOp>()) {
ReplaceSliceOp(cinn_split, downstream_op, rewriter);
} else if (downstream_op->isa<::pir::SplitOp>()) {
ReplaceSplitOp(cinn_split, downstream_op, rewriter);
} else {
CHECK(false) << "Currently only support pir::slice/split as downstream "
"op, but got: "
<< downstream_op->name();
}
std::vector<int64_t> GetSections(paddle::dialect::SplitWithNumOp op,
int axis) const {
std::vector<int64_t> result;
auto split_dim =
op.x().type().dyn_cast<paddle::dialect::DenseTensorType>().dims()[axis];
auto split_num = op->attribute<::pir::Int32Attribute>("num").data();
auto part_ele = (split_dim + split_num - 1) / split_num;
int total_split_num = 0;
for (int i = 0; i < split_num - 1; ++i) {
result.push_back(part_ele);
total_split_num += part_ele;
}
rewriter.EraseOp(op);

result.push_back(split_dim - total_split_num);
return result;
}
};

Expand Down Expand Up @@ -985,7 +1001,8 @@ pir::RewritePatternSet PdOpToCinnOpPass::InitializePatterns(
ps.Add<ConcatOpPattern>(context);
ps.Add<SliceOpPattern>(context);
ps.Add<AddNOpPattern>(context);
// ps.Add<SplitWithNumOpPattern>(context);
ps.Add<SplitWithNumOpPattern>(context);
ps.Add<SplitOpPattern>(context);
ps.Add<ExpandOpPattern>(context);
ps.Add<IsCloseOpPattern>(context);
ps.Add<ElementwisePowOpPattern>(context);
Expand Down
Loading

0 comments on commit 9ab9aec

Please sign in to comment.