Skip to content

Commit

Permalink
[PIR] remove pir::Value:;GetDefinitionOp interface
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang committed Sep 18, 2023
1 parent d9cd979 commit ebba76d
Show file tree
Hide file tree
Showing 17 changed files with 57 additions and 72 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/op_generator/op_interface_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE = """
pir::CombineOp combine_op_obj =
op_obj.{input_name}().GetDefiningOp()->dyn_cast<pir::CombineOp>();
op_obj.{input_name}().dyn_cast<pir::OpResult>().owner()->dyn_cast<pir::CombineOp>();
std::vector<Tensor> {input_name};
for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) {{
{input_name}.emplace_back(
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/phi_kernel_adaptor/phi_kernel_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ paddle::framework::Variable* CreateVar(
variable_2_var_name,
std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) {
Operation* def_op = value.GetDefiningOp();
Operation* def_op = value.dyn_cast<OpResult>().owner();
bool is_persisable = false;
if (def_op->isa<::pir::SetParameterOp>()) {
is_persisable = true;
Expand Down
15 changes: 8 additions & 7 deletions paddle/fluid/pir/transforms/inplace_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ static bool CanBeDeleted(pir::Value value) {
!value.type().isa<paddle::dialect::AllocatedSelectedRowsType>()) {
return false;
}
if (value.GetDefiningOp()->HasAttribute(kAttrIsPersisable)) {
return !(value.GetDefiningOp()
->attribute(kAttrIsPersisable)
.dyn_cast<pir::ArrayAttribute>()
.AsVector()[value.dyn_cast<pir::OpResult>().index()]
.dyn_cast<pir::BoolAttribute>()
.data());
if (auto op_result = value.dyn_cast<pir::OpResult>()) {
auto def_op = op_result.owner();
if (def_op->HasAttribute(kAttrIsPersisable)) {
return !(def_op->attribute<pir::ArrayAttribute>(kAttrIsPersisable)
.AsVector()[op_result.index()]
.dyn_cast<pir::BoolAttribute>()
.data());
}
}
return true;
}
Expand Down
48 changes: 25 additions & 23 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,9 @@ pir::OpResult AddPlaceTransferOp(pir::OpResult in,
pir::Operation* op =
pir::Operation::Create({in}, op_attribute, {out_type}, op_info);

if (in.GetDefiningOp()->HasAttribute(kAttrIsPersisable)) {
if (in.owner()->HasAttribute(kAttrIsPersisable)) {
op->set_attribute(kAttrIsPersisable,
in.GetDefiningOp()->attribute(kAttrIsPersisable));
in.owner()->attribute(kAttrIsPersisable));
}
block->push_back(op);

Expand Down Expand Up @@ -527,11 +527,10 @@ phi::KernelKey GetKernelKey(
if (op->isa<paddle::dialect::UniformOp>()) {
// try to process uniform, use shape to determin backend
// TODO(phlrain): shuold support other initilize op
auto define_op = op->operand_source(0).GetDefiningOp();
auto define_op =
op->operand_source(0).dyn_cast<pir::OpResult>().owner();
if (define_op->isa<paddle::dialect::FullIntArrayOp>()) {
auto shape = define_op->attributes()
.at("value")
.dyn_cast<dialect::IntArrayAttribute>()
auto shape = define_op->attribute<dialect::IntArrayAttribute>("value")
.data()
.GetData();

Expand Down Expand Up @@ -577,13 +576,12 @@ phi::KernelKey GetKernelKey(
// uses data op outout as inputs. So, we need set kernel backend
// manually.
if (op->operand_source(i)
.GetDefiningOp()
.dyn_cast<pir::OpResult>()
.owner()
->isa<paddle::dialect::DataOp>()) {
auto data_op = op->operand_source(i).GetDefiningOp();
auto data_place = data_op->attributes()
.at("place")
.dyn_cast<dialect::PlaceAttribute>()
.data();
auto data_op = op->operand_source(i).dyn_cast<pir::OpResult>().owner();
auto data_place =
data_op->attribute<dialect::PlaceAttribute>("place").data();

auto data_op_backend = paddle::experimental::ParseBackend(data_place);
if (data_op_backend == phi::Backend::UNDEFINED) {
Expand All @@ -592,17 +590,21 @@ phi::KernelKey GetKernelKey(
kernel_key_parser.key_set.backend_set =
kernel_key_parser.key_set.backend_set |
paddle::experimental::BackendSet(data_op_backend);
} else if (op->operand_source(i).GetDefiningOp()->name() ==
"builtin.combine") {
auto combine_op = op->operand_source(i).GetDefiningOp();
} else if (op->operand_source(i)
.dyn_cast<pir::OpResult>()
.owner()
->isa<pir::CombineOp>()) {
auto combine_op =
op->operand_source(i).dyn_cast<pir::OpResult>().owner();
for (size_t j = 0; j < combine_op->num_operands(); ++j) {
if (combine_op->operand_source(j).GetDefiningOp()->name() ==
"pd_op.data") {
auto data_op = combine_op->operand_source(j).GetDefiningOp();
auto data_place = data_op->attributes()
.at("place")
.dyn_cast<dialect::PlaceAttribute>()
.data();
if (combine_op->operand_source(j)
.dyn_cast<pir::OpResult>()
.owner()
->isa<DataOp>()) {
auto data_op =
combine_op->operand_source(j).dyn_cast<pir::OpResult>().owner();
auto data_place =
data_op->attribute<PlaceAttribute>("place").data();

auto data_op_backend =
paddle::experimental::ParseBackend(data_place);
Expand Down Expand Up @@ -981,7 +983,7 @@ std::vector<pir::Value> BuildOpInputList(
} else if (new_in_type.isa<pir::VectorType>()) {
// [ todo need update here, support combine data transfomer]
// deal with pre combine op
auto pre_define_op = cur_in.GetDefiningOp();
auto pre_define_op = cur_in.dyn_cast<pir::OpResult>().owner();

if (pre_define_op->isa<::pir::CombineOp>()) {
std::vector<pir::Value> inner_inputs;
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/transforms/transform_general_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace pir {
std::pair<std::string, pir::Parameter*> GetParameterFromValue(
pir::Value value) {
pir::GetParameterOp op =
value.GetDefiningOp()->dyn_cast<pir::GetParameterOp>();
value.dyn_cast<OpResult>().owner()->dyn_cast<pir::GetParameterOp>();
PADDLE_ENFORCE_NOT_NULL(
op,
phi::errors::InvalidArgument(
Expand Down Expand Up @@ -66,7 +66,7 @@ Operation* GetDefiningOpForInput(Operation* op, uint32_t index) {
index < op->num_operands(),
true,
phi::errors::InvalidArgument("Intput operand's index must be valid."));
return op->operand_source(index).GetDefiningOp();
return op->operand_source(index).dyn_cast<OpResult>().owner();
}

Operation* GetFirstUseOperationForOutput(Operation* op, uint32_t index) {
Expand Down
6 changes: 3 additions & 3 deletions paddle/pir/core/attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class IR_API Attribute {

bool operator!() const { return storage_ == nullptr; }

operator const void *() const { return storage_; }

///
/// \brief Some Attribute attribute acquisition interfaces.
///
Expand Down Expand Up @@ -85,8 +87,6 @@ class IR_API Attribute {
return pir::dyn_cast<U>(*this);
}

friend struct std::hash<Attribute>;

protected:
const Storage *storage_{nullptr};
};
Expand All @@ -98,7 +98,7 @@ namespace std {
template <>
struct hash<pir::Attribute> {
std::size_t operator()(const pir::Attribute &obj) const {
return std::hash<const pir::Attribute::Storage *>()(obj.storage_);
return std::hash<const void *>()(obj);
}
};
} // namespace std
2 changes: 1 addition & 1 deletion paddle/pir/core/builtin_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ void SplitOp::PassStopGradients(OperationArgument &argument) {
for (uint32_t i = 0; i < defining_op->num_operands(); ++i) {
auto value = defining_op->operand_source(i);
if (!value) continue;
auto *oprand_defining_op = value.GetDefiningOp();
auto *oprand_defining_op = value.dyn_cast<OpResult>().owner();
if (oprand_defining_op->HasAttribute(kStopGradientAttrName)) {
auto attrs = oprand_defining_op->attribute(kStopGradientAttrName)
.dyn_cast<pir::ArrayAttribute>()
Expand Down
4 changes: 2 additions & 2 deletions paddle/pir/core/op_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ class IR_API OpInfo {
template <typename InterfaceT>
typename InterfaceT::Concept *GetInterfaceImpl() const;

operator const void *() const { return impl_; }
void *AsOpaquePointer() const { return impl_; }
static OpInfo RecoverFromOpaquePointer(void *pointer) {
return OpInfo(static_cast<OpInfoImpl *>(pointer));
}

friend class OpInfoImpl;
friend struct std::hash<OpInfo>;

private:
explicit OpInfo(OpInfoImpl *impl) : impl_(impl) {}
Expand Down Expand Up @@ -105,7 +105,7 @@ namespace std {
template <>
struct hash<pir::OpInfo> {
std::size_t operator()(const pir::OpInfo &obj) const {
return std::hash<const pir::OpInfoImpl *>()(obj.impl_);
return std::hash<const void *>()(obj);
}
};
} // namespace std
14 changes: 3 additions & 11 deletions paddle/pir/core/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class IR_API Type {
Type() = default;

Type(const Storage *storage) // NOLINT
: storage_(const_cast<Storage *>(storage)) {}
: storage_(storage) {}

Type(const Type &other) = default;

Expand All @@ -74,10 +74,7 @@ class IR_API Type {
///
/// \brief Support PointerLikeTypeTraits.
///
///
const void *AsOpaquePointer() const {
return static_cast<const void *>(storage_);
}
operator const void *() const { return storage_; }
static Type RecoverFromOpaquePointer(const void *pointer) {
return Type(reinterpret_cast<Storage *>(const_cast<void *>(pointer)));
}
Expand Down Expand Up @@ -120,11 +117,6 @@ class IR_API Type {

static Type Parse(std::istream &is, IrContext *ctx);

///
/// \brief Enable hashing Type.
///
friend struct std::hash<Type>;

template <typename U>
U cast() const {
return pir::cast<U>(*this);
Expand Down Expand Up @@ -185,7 +177,7 @@ namespace std {
template <>
struct hash<pir::Type> {
std::size_t operator()(const pir::Type &obj) const {
return std::hash<const pir::Type::Storage *>()(obj.storage_);
return std::hash<const void *>()(obj);
}
};
} // namespace std
8 changes: 2 additions & 6 deletions paddle/pir/core/type_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class TypeId {
///
/// \brief Support PointerLikeTypeTraits.
///
operator const void *() const { return storage_; }
void *AsOpaquePointer() const { return storage_; }
static TypeId RecoverFromOpaquePointer(void *pointer) {
return TypeId(static_cast<Storage *>(pointer));
Expand All @@ -71,11 +72,6 @@ class TypeId {
return storage_ < other.storage_;
}

///
/// \brief Enable hashing TypeId instances.
///
friend struct std::hash<TypeId>;

private:
///
/// \brief Construct a TypeId and initialize storage.
Expand Down Expand Up @@ -150,7 +146,7 @@ namespace std {
template <>
struct hash<pir::TypeId> {
std::size_t operator()(const pir::TypeId &obj) const {
return std::hash<const pir::TypeId::Storage *>()(obj.storage_);
return std::hash<const void *>()(obj);
}
};
} // namespace std
5 changes: 0 additions & 5 deletions paddle/pir/core/value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@ void Value::set_type(pir::Type type) {
impl_->set_type(type);
}

Operation *Value::GetDefiningOp() const {
if (auto result = dyn_cast<OpResult>()) return result.owner();
return nullptr;
}

std::string Value::PrintUdChain() {
CHECK_VALUE_NULL_IMPL(PrintUdChain);
return impl()->PrintUdChain();
Expand Down
2 changes: 0 additions & 2 deletions paddle/pir/core/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ class IR_API Value {

void set_type(Type type);

Operation *GetDefiningOp() const;

std::string PrintUdChain();

///
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/dialect/shape/utils/shape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ bool SymbolicDimMgr::LoadShapeConstraintGraph() {
auto build_sym_product = [&](std::vector<Value> range,
SymbolicDimProduct& product) {
for (Value v : range) {
auto definingOp = v.GetDefiningOp();
auto definingOp = v.dyn_cast<OpResult>().owner();
if (auto constOp = definingOp->dyn_cast<ConstantOp>()) {
product.factor *= constOp.value().dyn_cast<Int32Attribute>().data();
continue;
Expand Down
3 changes: 2 additions & 1 deletion paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ class GreedyPatternRewriteDriver : public pir::PatternRewriter {
// that single use values often have more canonicalization opportunities.
if (!operand || (!operand.use_empty() && !operand.HasOneUse())) return;

if (auto* def_op = operand.GetDefiningOp()) AddToWorklist(def_op);
if (auto* def_op = operand.dyn_cast<pir::OpResult>().owner())
AddToWorklist(def_op);
}

void AddOperandsToWorklist(const std::vector<pir::Value> operands) {
Expand Down
8 changes: 4 additions & 4 deletions test/cpp/pir/core/ir_value_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ TEST(value_test, value_test) {
op4->Print(std::cout);

// Test 1:
EXPECT_EQ(op1->result(0).GetDefiningOp(), op1);
EXPECT_EQ(op2->result(0).GetDefiningOp(), op2);
EXPECT_EQ(op3->result(0).GetDefiningOp(), op3);
EXPECT_EQ(op4->result(6).GetDefiningOp(), op4);
EXPECT_EQ(op1->result(0).owner(), op1);
EXPECT_EQ(op2->result(0).owner(), op2);
EXPECT_EQ(op3->result(0).owner(), op3);
EXPECT_EQ(op4->result(6).owner(), op4);

// Test 2: op1_first_output -> op4_first_input
pir::OpResult op1_first_output = op1->result(0);
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class Conv2dBnFusePattern
pir::Value conv2d_filter = conv2d_op.filter();

// pir::GetParameterOp filter_parameter_op =
// conv2d_filter.GetDefiningOp()->dyn_cast<pir::GetParameterOp>();
// conv2d_filter.dyn_cast<pir::OpResult>().owner()->dyn_cast<pir::GetParameterOp>();
// if (!filter_parameter_op) return false;

pir::OpResult conv2d_filter_result =
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/pir/shape_dialect/symbolic_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ TEST(shape_op, dim) {
EXPECT_EQ(dimOp.getName(), "S0");
dimOp.setName("S1");
EXPECT_EQ(dimOp.getName(), "S1");
EXPECT_EQ(res.GetDefiningOp(), dimOp.operation());
EXPECT_EQ(res.owner(), dimOp.operation());
EXPECT_EQ(res.type(), pir::IndexType::get(ctx));
}

Expand Down

0 comments on commit ebba76d

Please sign in to comment.