Skip to content

Commit

Permalink
【CINN】Add utils and simplify corner case (PaddlePaddle#71278)
Browse files Browse the repository at this point in the history
* add corner case

* polish code

* polish code
  • Loading branch information
liuruyan authored and Enigmatisms committed Mar 5, 2025
1 parent 3a09fde commit 6d1580b
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 0 deletions.
38 changes: 38 additions & 0 deletions paddle/cinn/common/ir_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -812,5 +812,43 @@ std::optional<ir::IndexExpr> SimplifyComplexMod(const ir::IndexExpr &lhs,
}
return std::nullopt;
}

bool CheckPattern(const ir::IndexExpr &expr,
const ir::IndexExpr &pattern,
std::unordered_map<std::string, ir::IndexExpr> *map) {
// pattern may include Var to match any expr.
if (expr.node_type() != pattern.node_type() &&
pattern.node_type() != ir::IrNodeTy::_Var_)
return false;
switch (pattern.node_type()) {
case ir::IrNodeTy::Add:
case ir::IrNodeTy::Sub:
case ir::IrNodeTy::Mul:
case ir::IrNodeTy::Div:
case ir::IrNodeTy::Mod:
case ir::IrNodeTy::Min:
case ir::IrNodeTy::Max: {
return CheckPattern(expr.operand(0), pattern.operand(0), map) &&
CheckPattern(expr.operand(1), pattern.operand(1), map);
}
case ir::IrNodeTy::_Var_: {
auto it = map->find(pattern.As<ir::_Var_>()->name);
if (it != map->end()) {
return expr == it->second;
} else {
map->insert(std::make_pair(pattern.As<ir::_Var_>()->name, expr));
return true;
}
}
case ir::IrNodeTy::IntImm: {
return expr.As<ir::IntImm>()->value == pattern.As<ir::IntImm>()->value;
}
default:
PADDLE_THROW(::common::errors::InvalidArgument(
"Unsupported type of expr in CheckPattern which is: %s", expr));
}

return false;
}
} // namespace common
} // namespace cinn
18 changes: 18 additions & 0 deletions paddle/cinn/common/ir_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,5 +414,23 @@ std::optional<ir::IndexExpr> DivByPartMul(const ir::IndexExpr &lhs,
*/
std::optional<ir::IndexExpr> SimplifyComplexMod(const ir::IndexExpr &lhs,
const ir::IndexExpr &rhs);

/*!
* \brief Check whether the expression matches the pattern.
* \param expr The expression to be checked.
* \param pattern The pattern to be matched. which includes some variables.
* \param map return the matched variables.
* \return A boolean value indicating whether `expr` is matched.
*
* For example:
* 1. (i / S0 * S0 + i % (S0 * S1)) % S0 matched by a / b * b + a % (b * c)
* with map = {a: i, b: S0, c: S1}
* 2. S0 + 5 matched by a + 5 with map = {a: S0, b: 5}
*
* Note: a * b and b * a is two different pattern.
*/
bool CheckPattern(const ir::IndexExpr &expr,
const ir::IndexExpr &pattern,
std::unordered_map<std::string, ir::IndexExpr> *map);
} // namespace common
} // namespace cinn
34 changes: 34 additions & 0 deletions paddle/cinn/common/simplify_special_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,39 @@ std::optional<ir::IndexExpr> AddMulCornerCase(
return res;
}

// S0 / (S1 * S2) * S2 + S0 % (S1 * S2) / S1 ===> S0 / S1
std::optional<ir::IndexExpr> DivMulAddModDivCase(const ir::IndexExpr& lhs,
const ir::IndexExpr& rhs) {
ir::Var a = ir::Var("a");
ir::Var b = ir::Var("b");
ir::Var c = ir::Var("c");
ir::Var f = ir::Var("f");
std::unordered_map<std::string, ir::IndexExpr> map;

ir::IndexExpr pattern = f / c * a + f % c / b;

auto flatten = GetFlattenExprs<ir::Add>(lhs);
ir::IndexExpr res = ir::IndexExpr(rhs->type(), 0);
bool find = false;
for (const auto& expr : flatten) {
if (!find) {
ir::IndexExpr cand = ir::Add::Make(expr, rhs);
map.clear();
// Check if the pattern is matched
if (CheckPattern(cand, pattern, &map) &&
map.at("c") == map.at("a") * map.at("b")) {
ir::IndexExpr simplied = map.at("f") / map.at("b");
res = res.defined() ? res + simplied : simplied;
find = true;
continue;
}
}
res = res.defined() ? ir::Add::Make(res, expr) : expr;
}
if (find) return res;
return std::nullopt;
}

// (S0 + S1 - (S0 + S1) % S2) % S2 == 0
// (S0 + S1 - (S0 + S1) % S2) / S2 == (S0 + S1) / S2
std::optional<ir::IndexExpr> SubModCornerCase(const ir::IndexExpr& lhs,
Expand Down Expand Up @@ -322,6 +355,7 @@ std::optional<ir::IndexExpr> SimplifyAddCornerCase(const ir::IndexExpr& lhs,
const ir::IndexExpr& rhs) {
if (auto res = DivMulAddModCornerCase(lhs, rhs)) return res.value();
if (auto res = AddMulCornerCase(lhs, rhs)) return res.value();
if (auto res = DivMulAddModDivCase(lhs, rhs)) return res.value();
// Add other corner cases
return std::nullopt;
}
Expand Down
38 changes: 38 additions & 0 deletions test/cpp/pir/cinn/adt/index_expr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ TEST_F(TestIndexExpr, IndexExpr_3) {
ir::Expr q16 =
((S4 * 256 + S5) / S6 / S7 * S7 + (S4 * 256 + S5) / S6 % S7) * S6 +
(S4 * 256 + S5) % S6;
ir::Expr q17 = S4 / (S5 * S6) * S6 + S4 % (S5 * S6) / S5;
ir::Expr q18 = (S4 * 1024 + S5 * 256 + S6) / 2097152 * 32 +
(S4 * 1024 + S5 * 256 + S6) % 2097152 / 65536;

// `Div` corner cases
ir::Expr q6 = (S4 % S5 - S4) / S5;
Expand Down Expand Up @@ -174,6 +177,9 @@ TEST_F(TestIndexExpr, IndexExpr_3) {
ir::IndexExpr((S4 * 256 + S5 + S6 * 1024)) % 25088);
EXPECT_EQ(q16.as_index().Normalize(ir::IndexExpr::OptLevel::Level2),
ir::IndexExpr(S4 * 256 + S5));
EXPECT_EQ(q17.as_index().Normalize(), ir::IndexExpr(S4 / S5));
EXPECT_EQ(q18.as_index().Normalize(),
ir::IndexExpr((S4 * 1024 + S5 * 256 + S6) / 65536));
}

TEST_F(TestIndexExpr, Change_Seq_Of_Div_Mod) {
Expand Down Expand Up @@ -495,5 +501,37 @@ TEST_F(TestIndexExpr, CommonFactor) {
S2) *
(((((S5 + S9) + S21) + S17) + S13) + S1))));
}

TEST_F(TestIndexExpr, TestCheckPattern) {
ir::Var a = ir::Var("a");
ir::Var b = ir::Var("b");
ir::Var f = ir::Var("f");

ir::Var S0 = ir::Var("S0");
ir::Var S1 = ir::Var("S1");
ir::Var S2 = ir::Var("S2");
ir::Var S3 = ir::Var("S3");
ir::Var S4 = ir::Var("S4");
ir::Var S5 = ir::Var("S5");
ir::Var S6 = ir::Var("S6");
ir::Var S7 = ir::Var("S7");
ir::Var S8 = ir::Var("S8");
ir::Var S9 = ir::Var("S9");

ir::IndexExpr pattern = f / (a * b) * b + f % (a * b) / a;
ir::IndexExpr pattern1 = f / (a * b) * a + f % (a * b) / b;
ir::IndexExpr e = (S0 * (S1 + S2) + S1 * S2 + S2) / (S4 * S5) * S5 +
(S0 * (S1 + S2) + S1 * S2 + S2) % (S4 * S5) / S4;
ir::IndexExpr e1 = (S0 * (S1 + S2) + S1 * S2 + S2) / (S4 * S5) * S4 +
(S0 * (S1 + S2) + S1 * S2 + S2) % (S4 * S5) / S5;
std::unordered_map<std::string, ir::IndexExpr> map;
EXPECT_TRUE(common::CheckPattern(e, pattern, &map));
map.clear();
EXPECT_FALSE(common::CheckPattern(e, pattern1, &map));
map.clear();
EXPECT_FALSE(common::CheckPattern(e1, pattern, &map));
map.clear();
EXPECT_TRUE(common::CheckPattern(e1, pattern1, &map));
}
} // namespace common
} // namespace cinn

0 comments on commit 6d1580b

Please sign in to comment.