Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【CINN】Add utils and simplify corner case #71278

Merged
merged 3 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions paddle/cinn/common/ir_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -812,5 +812,43 @@ std::optional<ir::IndexExpr> SimplifyComplexMod(const ir::IndexExpr &lhs,
}
return std::nullopt;
}

bool CheckPattern(const ir::IndexExpr &expr,
const ir::IndexExpr &pattern,
std::unordered_map<std::string, ir::IndexExpr> *map) {
// pattern may include Var to match any expr.
if (expr.node_type() != pattern.node_type() &&
pattern.node_type() != ir::IrNodeTy::_Var_)
return false;
switch (pattern.node_type()) {
case ir::IrNodeTy::Add:
case ir::IrNodeTy::Sub:
case ir::IrNodeTy::Mul:
case ir::IrNodeTy::Div:
case ir::IrNodeTy::Mod:
case ir::IrNodeTy::Min:
case ir::IrNodeTy::Max: {
return CheckPattern(expr.operand(0), pattern.operand(0), map) &&
CheckPattern(expr.operand(1), pattern.operand(1), map);
}
case ir::IrNodeTy::_Var_: {
auto it = map->find(pattern.As<ir::_Var_>()->name);
if (it != map->end()) {
return expr == it->second;
} else {
map->insert(std::make_pair(pattern.As<ir::_Var_>()->name, expr));
return true;
}
}
case ir::IrNodeTy::IntImm: {
return expr.As<ir::IntImm>()->value == pattern.As<ir::IntImm>()->value;
}
default:
PADDLE_THROW(::common::errors::InvalidArgument(
"Unsupported type of expr in CheckPattern which is: %s", expr));
}

return false;
}
} // namespace common
} // namespace cinn
18 changes: 18 additions & 0 deletions paddle/cinn/common/ir_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,5 +414,23 @@ std::optional<ir::IndexExpr> DivByPartMul(const ir::IndexExpr &lhs,
*/
std::optional<ir::IndexExpr> SimplifyComplexMod(const ir::IndexExpr &lhs,
const ir::IndexExpr &rhs);

/*!
* \brief Check whether the expression matches the pattern.
* \param expr The expression to be checked.
* \param pattern The pattern to be matched. which includes some variables.
* \param map return the matched variables.
* \return A boolean value indicating whether `expr` is matched.
*
* For example:
* 1. (i / S0 * S0 + i % (S0 * S1)) % S0 matched by a / b * b + a % (b * c)
* with map = {a: i, b: S0, c: S1}
* 2. S0 + 5 matched by a + 5 with map = {a: S0, b: 5}
*
* Note: a * b and b * a is two different pattern.
*/
bool CheckPattern(const ir::IndexExpr &expr,
const ir::IndexExpr &pattern,
std::unordered_map<std::string, ir::IndexExpr> *map);
} // namespace common
} // namespace cinn
34 changes: 34 additions & 0 deletions paddle/cinn/common/simplify_special_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,39 @@ std::optional<ir::IndexExpr> AddMulCornerCase(
return res;
}

// S0 / (S1 * S2) * S2 + S0 % (S1 * S2) / S1 ===> S0 / S1
std::optional<ir::IndexExpr> DivMulAddModDivCase(const ir::IndexExpr& lhs,
const ir::IndexExpr& rhs) {
ir::Var a = ir::Var("a");
ir::Var b = ir::Var("b");
ir::Var c = ir::Var("c");
ir::Var f = ir::Var("f");
std::unordered_map<std::string, ir::IndexExpr> map;

ir::IndexExpr pattern = f / c * a + f % c / b;

auto flatten = GetFlattenExprs<ir::Add>(lhs);
ir::IndexExpr res = ir::IndexExpr(rhs->type(), 0);
bool find = false;
for (const auto& expr : flatten) {
if (!find) {
ir::IndexExpr cand = ir::Add::Make(expr, rhs);
map.clear();
// Check if the pattern is matched
if (CheckPattern(cand, pattern, &map) &&
map.at("c") == map.at("a") * map.at("b")) {
ir::IndexExpr simplied = map.at("f") / map.at("b");
res = res.defined() ? res + simplied : simplied;
find = true;
continue;
}
}
res = res.defined() ? ir::Add::Make(res, expr) : expr;
}
if (find) return res;
return std::nullopt;
}

// (S0 + S1 - (S0 + S1) % S2) % S2 == 0
// (S0 + S1 - (S0 + S1) % S2) / S2 == (S0 + S1) / S2
std::optional<ir::IndexExpr> SubModCornerCase(const ir::IndexExpr& lhs,
Expand Down Expand Up @@ -322,6 +355,7 @@ std::optional<ir::IndexExpr> SimplifyAddCornerCase(const ir::IndexExpr& lhs,
const ir::IndexExpr& rhs) {
if (auto res = DivMulAddModCornerCase(lhs, rhs)) return res.value();
if (auto res = AddMulCornerCase(lhs, rhs)) return res.value();
if (auto res = DivMulAddModDivCase(lhs, rhs)) return res.value();
// Add other corner cases
return std::nullopt;
}
Expand Down
38 changes: 38 additions & 0 deletions test/cpp/pir/cinn/adt/index_expr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ TEST_F(TestIndexExpr, IndexExpr_3) {
ir::Expr q16 =
((S4 * 256 + S5) / S6 / S7 * S7 + (S4 * 256 + S5) / S6 % S7) * S6 +
(S4 * 256 + S5) % S6;
ir::Expr q17 = S4 / (S5 * S6) * S6 + S4 % (S5 * S6) / S5;
ir::Expr q18 = (S4 * 1024 + S5 * 256 + S6) / 2097152 * 32 +
(S4 * 1024 + S5 * 256 + S6) % 2097152 / 65536;

// `Div` corner cases
ir::Expr q6 = (S4 % S5 - S4) / S5;
Expand Down Expand Up @@ -174,6 +177,9 @@ TEST_F(TestIndexExpr, IndexExpr_3) {
ir::IndexExpr((S4 * 256 + S5 + S6 * 1024)) % 25088);
EXPECT_EQ(q16.as_index().Normalize(ir::IndexExpr::OptLevel::Level2),
ir::IndexExpr(S4 * 256 + S5));
EXPECT_EQ(q17.as_index().Normalize(), ir::IndexExpr(S4 / S5));
EXPECT_EQ(q18.as_index().Normalize(),
ir::IndexExpr((S4 * 1024 + S5 * 256 + S6) / 65536));
}

TEST_F(TestIndexExpr, Change_Seq_Of_Div_Mod) {
Expand Down Expand Up @@ -495,5 +501,37 @@ TEST_F(TestIndexExpr, CommonFactor) {
S2) *
(((((S5 + S9) + S21) + S17) + S13) + S1))));
}

TEST_F(TestIndexExpr, TestCheckPattern) {
ir::Var a = ir::Var("a");
ir::Var b = ir::Var("b");
ir::Var f = ir::Var("f");

ir::Var S0 = ir::Var("S0");
ir::Var S1 = ir::Var("S1");
ir::Var S2 = ir::Var("S2");
ir::Var S3 = ir::Var("S3");
ir::Var S4 = ir::Var("S4");
ir::Var S5 = ir::Var("S5");
ir::Var S6 = ir::Var("S6");
ir::Var S7 = ir::Var("S7");
ir::Var S8 = ir::Var("S8");
ir::Var S9 = ir::Var("S9");

ir::IndexExpr pattern = f / (a * b) * b + f % (a * b) / a;
ir::IndexExpr pattern1 = f / (a * b) * a + f % (a * b) / b;
ir::IndexExpr e = (S0 * (S1 + S2) + S1 * S2 + S2) / (S4 * S5) * S5 +
(S0 * (S1 + S2) + S1 * S2 + S2) % (S4 * S5) / S4;
ir::IndexExpr e1 = (S0 * (S1 + S2) + S1 * S2 + S2) / (S4 * S5) * S4 +
(S0 * (S1 + S2) + S1 * S2 + S2) % (S4 * S5) / S5;
std::unordered_map<std::string, ir::IndexExpr> map;
EXPECT_TRUE(common::CheckPattern(e, pattern, &map));
map.clear();
EXPECT_FALSE(common::CheckPattern(e, pattern1, &map));
map.clear();
EXPECT_FALSE(common::CheckPattern(e1, pattern, &map));
map.clear();
EXPECT_TRUE(common::CheckPattern(e1, pattern1, &map));
}
} // namespace common
} // namespace cinn
Loading