diff --git a/paddle/cinn/ir/ir_base.h b/paddle/cinn/ir/ir_base.h index 0047100ebcfdfc..c333448d029ae0 100644 --- a/paddle/cinn/ir/ir_base.h +++ b/paddle/cinn/ir/ir_base.h @@ -110,23 +110,16 @@ class Dim; macro__(Product) \ macro__(Sum) \ macro__(PrimitiveNode) \ + macro__(IntrinsicOp) \ macro__(_BufferRange_) \ macro__(ScheduleBlock) \ macro__(ScheduleBlockRealize) \ macro__(_Dim_) \ -#define NODETY_CONTROL_OP_FOR_INTRINSIC(macro__) \ - macro__(IntrinsicOp) \ #define NODETY_FORALL(__m) \ NODETY_PRIMITIVE_TYPE_FOR_EACH(__m) \ NODETY_OP_FOR_EACH(__m) \ - NODETY_CONTROL_OP_FOR_INTRINSIC(__m) \ - NODETY_CONTROL_OP_FOR_EACH(__m) - -#define NODETY_FORALL_EXCEPT_INTRINSIC(__m) \ - NODETY_PRIMITIVE_TYPE_FOR_EACH(__m) \ - NODETY_OP_FOR_EACH(__m) \ NODETY_CONTROL_OP_FOR_EACH(__m) // clang-format on diff --git a/paddle/cinn/ir/utils/ir_nodes_collector.cc b/paddle/cinn/ir/utils/ir_nodes_collector.cc index e4ebaca653bae9..ac2f0317e9213f 100644 --- a/paddle/cinn/ir/utils/ir_nodes_collector.cc +++ b/paddle/cinn/ir/utils/ir_nodes_collector.cc @@ -15,8 +15,6 @@ #include "paddle/cinn/ir/utils/ir_nodes_collector.h" #include -#include "paddle/cinn/ir/intrinsic_ops.h" -#include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_mutator.h" #include "paddle/cinn/ir/ir_printer.h" @@ -73,71 +71,8 @@ struct IrNodesCollector : public IRVisitorRequireReImpl { } \ } - NODETY_FORALL_EXCEPT_INTRINSIC(__m) + NODETY_FORALL(__m) #undef __m - - void Visit(const ir::IntrinsicOp* op) { - switch (op->getKind()) { -#define __(x) \ - case ir::IntrinsicKind::k##x: \ - Visit(llvm::dyn_cast(op)); \ - break; - - INTRINSIC_KIND_FOR_EACH(__) -#undef __ - } - } - - void Visit(const ir::intrinsics::GetAddr* x) { - if (x->data.defined()) { - Visit(&(x->data)); - } - } - - void Visit(const ir::intrinsics::BufferGetDataHandle* x) { - if (x->buffer.defined()) { - Visit(&(x->buffer)); - } - } - - void Visit(const ir::intrinsics::BufferGetDataConstHandle* x) { - if (x->buffer.defined()) { - Visit(&(x->buffer)); - } - } - - void Visit(const ir::intrinsics::PodValueToX* x) { - if (x->pod_value_ptr.defined()) { - Visit(&(x->pod_value_ptr)); - } - } - - void Visit(const ir::intrinsics::BufferCreate* x) { - if (x->buffer.defined()) { - Visit(&(x->buffer)); - } - } - - void Visit(const ir::intrinsics::ArgsConstruct* x) { - if (x->var.defined()) { - Expr convert = Expr(x->var); - Visit(&convert); - } - for (int i = 0; i < x->args.size(); ++i) { - if (x->args[i].defined()) { - Visit(&(x->args[i])); - } - } - } - - void Visit(const ir::intrinsics::BuiltinIntrin* x) { - for (int i = 0; i < x->args.size(); ++i) { - if (x->args[i].defined()) { - Visit(&(x->args[i])); - } - } - } - std::set visited_; };