Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace any by variant in infermeta context #42181

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion paddle/phi/core/infermeta_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void InferMetaContext::EmplaceBackOutput(MetaTensor output) {
outputs_.emplace_back(std::move(output));
output_range_.emplace_back(std::pair<int, int>(index, index + 1));
}
void InferMetaContext::EmplaceBackAttr(paddle::any attr) {
void InferMetaContext::EmplaceBackAttr(Attribute attr) {
attrs_.emplace_back(std::move(attr));
}

Expand Down Expand Up @@ -120,6 +120,38 @@ std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start,
return result;
}

template <typename AttrType>
const AttrType& InferMetaContext::AttrAt(size_t idx) const {
try {
return paddle::get<AttrType>(attrs_.at(idx));
} catch (paddle::bad_variant_access const& e) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Attribute cast error in InferMeta Context, the expected attribute "
"type is `%s`.",
std::type_index(typeid(AttrType)).name()));
}
}

template const bool& InferMetaContext::AttrAt(size_t idx) const;
template const int& InferMetaContext::AttrAt(size_t idx) const;
template const int64_t& InferMetaContext::AttrAt(size_t idx) const;
template const float& InferMetaContext::AttrAt(size_t idx) const;
template const double& InferMetaContext::AttrAt(size_t idx) const;
template const std::string& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<bool>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<int>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<int64_t>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<float>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<double>& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<std::string>& InferMetaContext::AttrAt(
size_t idx) const;
template const Scalar& InferMetaContext::AttrAt(size_t idx) const;
template const std::vector<Scalar>& InferMetaContext::AttrAt(size_t idx) const;
template const IntArray& InferMetaContext::AttrAt(size_t idx) const;
template const DataType& InferMetaContext::AttrAt(size_t idx) const;
template const DataLayout& InferMetaContext::AttrAt(size_t idx) const;
template const Place& InferMetaContext::AttrAt(size_t idx) const;

MetaFnFactory& MetaFnFactory::Instance() {
static MetaFnFactory g_meta_fn_map;
return g_meta_fn_map;
Expand Down
60 changes: 33 additions & 27 deletions paddle/phi/core/infermeta_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */

#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/attribute.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/macros.h"
#include "paddle/phi/core/meta_tensor.h"
Expand All @@ -41,7 +42,7 @@ class InferMetaContext {

void EmplaceBackInput(MetaTensor input);
void EmplaceBackOutput(MetaTensor output);
void EmplaceBackAttr(paddle::any attr);
void EmplaceBackAttr(Attribute attr);

void EmplaceBackInputs(
paddle::SmallVector<MetaTensor, phi::kInputSmallVectorSize> inputs);
Expand All @@ -61,17 +62,7 @@ class InferMetaContext {
size_t end);

template <typename AttrType>
AttrType AttrAt(size_t idx) {
try {
return paddle::any_cast<AttrType>(attrs_.at(idx));
} catch (paddle::bad_any_cast& e) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Attribute cast error in InferMeta Context, the expected attribute "
"type is `%s`, but actual attribute type is `%s`.",
std::type_index(typeid(AttrType)).name(),
std::type_index(attrs_.at(idx).type()).name()));
}
}
const AttrType& AttrAt(size_t idx) const;

const std::pair<int, int>& InputRangeAt(size_t idx) const;
const std::pair<int, int>& OutputRangeAt(size_t idx) const;
Expand All @@ -81,7 +72,7 @@ class InferMetaContext {
protected:
MetaConfig config_;

paddle::SmallVector<paddle::any, kAttrSmallVectorSize> attrs_;
paddle::SmallVector<Attribute, kAttrSmallVectorSize> attrs_;

paddle::SmallVector<std::pair<int, int>, phi::kInputSmallVectorSize>
input_range_;
Expand Down Expand Up @@ -111,6 +102,21 @@ class InferMetaContext {
} \
}

#define PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(attr_type) \
template <typename... Tail> \
struct InferMetaFnCallHelper<const attr_type&, Tail...> { \
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> \
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) { \
static_assert(out_idx == 0, \
"InferMeta's Attributes should appear before Outputs."); \
const attr_type& arg = ctx->AttrAt<attr_type>(attr_idx); \
InferMetaFnCallHelper< \
Tail...>::template Call<in_idx, attr_idx + 1, out_idx>(ctx, \
pargs..., \
arg); \
} \
}

template <typename T>
struct InferMetaTypeTag {};

Expand Down Expand Up @@ -201,27 +207,27 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
}
};

// TODO(chenweihang): support other attr type later
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(bool);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(int64_t);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(float);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::string&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<bool>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(
const std::vector<int64_t>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<float>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const std::vector<double>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(
const std::vector<std::string>&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataType);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(Backend);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(DataLayout);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const Scalar&);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_ATTRIBUTE(const IntArray&);

// TODO(chenweihang): support vector<MetaTensor> input later
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::string);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(Scalar);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(IntArray);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<bool>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector<int>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<int64_t>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<float>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<double>);
PD_SPECIALIZE_InferMetaFnCallHelper_FOR_CONST_ATTRIBUTE_REF(
std::vector<std::string>);

template <typename... Tail>
struct InferMetaFnCallHelper<MetaTensor*, Tail...> {
Expand Down
29 changes: 0 additions & 29 deletions paddle/phi/core/type_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,8 @@
#include <string>
#include <vector>

#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/scalar.h"

#include "paddle/utils/variant.h"

namespace phi {

class Place;

// NOTE: Add needed type in the future
using Attribute = paddle::variant<bool,
int,
int64_t,
float,
double,
std::string,
std::vector<bool>,
std::vector<int>,
std::vector<int64_t>,
std::vector<float>,
std::vector<double>,
std::vector<std::string>,
Scalar,
std::vector<Scalar>,
IntArray,
DataType,
DataLayout,
Place>;

class Kernel;
class KernelKey;
class KernelArgsDef;
Expand Down
8 changes: 0 additions & 8 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,6 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out) {
out->set_dtype(x.dtype());
}

void CopyToInferMeta(const MetaTensor& x,
Backend backend,
bool blocking,
MetaTensor* out) {
UnchangedInferMeta(x, out);
}

void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(dtype == DataType::UNDEFINED ? x.dtype() : dtype);
Expand Down Expand Up @@ -3008,6 +3001,5 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) {

} // namespace phi

PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);
PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta);
PD_REGISTER_INFER_META_FN(split, phi::SplitInferMeta);
5 changes: 0 additions & 5 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out);

void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out);

void CopyToInferMeta(const MetaTensor& x,
Backend backend,
bool blocking,
MetaTensor* out);

void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out);

void CumsumInferMeta(const MetaTensor& x,
Expand Down
26 changes: 0 additions & 26 deletions paddle/phi/tests/core/test_meta_fn_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,32 +60,6 @@ TEST(MetaFnFactory, InferMetaFnExists) {
EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]);
}

TEST(MetaFnFactory, CopyInferMetaFn) {
phi::DenseTensor dense_x;
dense_x.Resize({3, 4});

phi::MetaTensor meta_x(&dense_x);
phi::DenseTensor dense_out1;
phi::MetaTensor meta_out(&dense_out1);
phi::UnchangedInferMeta(meta_x, &meta_out);

auto shared_meat_x = phi::MetaTensor(&dense_x);
phi::DenseTensor dense_out2;
auto shared_meta_out = phi::MetaTensor(&dense_out2);

phi::InferMetaContext ctx;
ctx.EmplaceBackInput(shared_meat_x);
ctx.EmplaceBackAttr(Backend::CPU);
ctx.EmplaceBackAttr(false);
ctx.EmplaceBackOutput(shared_meta_out);
ctx.SetMetaConfig({/*is_runtime =*/true, /*is_run_mkldnn_kernel=*/false});
phi::MetaFnFactory::Instance().Get("copy_to")(&ctx);

EXPECT_EQ(dense_out1.dims().size(), dense_out2.dims().size());
EXPECT_EQ(dense_out1.dims()[0], dense_out2.dims()[0]);
EXPECT_EQ(dense_out1.dims()[1], dense_out2.dims()[1]);
}

TEST(MetaFnFactory, SplitInferMetaFn) {
phi::DenseTensor dense_x;
dense_x.Resize({4, 10});
Expand Down