Skip to content

Commit

Permalink
[OpTestPy] fix expand/expand_v2, fc,flatten_contiguous_range, gather,…
Browse files Browse the repository at this point in the history
… generate_proposals_v2,greater_equal diff! (PaddlePaddle#8339)
  • Loading branch information
zhoutianzi666 committed Feb 7, 2022
1 parent eb8c523 commit f070416
Show file tree
Hide file tree
Showing 16 changed files with 113 additions and 1,076 deletions.
27 changes: 23 additions & 4 deletions lite/kernels/host/compare_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,8 @@ REGISTER_LITE_KERNEL(greater_than, kHost, kFloat, kAny, greater_than_float, def)
using greater_than_bool = paddle::lite::kernels::host::CompareCompute<
PRECISION(kFloat),
paddle::lite::kernels::host::_GreaterThanFunctor<bool>>;
REGISTER_LITE_KERNEL(greater_than, kHost, kFloat, kAny, greater_than_bool, bool)
REGISTER_LITE_KERNEL(
greater_than, kHost, kFloat, kAny, greater_than_bool, def_bool)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kBool),
Expand All @@ -602,9 +603,10 @@ REGISTER_LITE_KERNEL(greater_than, kHost, kFloat, kAny, greater_than_bool, bool)
.Finalize();

using greater_than_int32 = paddle::lite::kernels::host::CompareCompute<
PRECISION(kInt32),
PRECISION(kFloat),
paddle::lite::kernels::host::_GreaterThanFunctor<int32_t>>;
REGISTER_LITE_KERNEL(greater_than, kHost, kInt32, kAny, greater_than_int32, def)
REGISTER_LITE_KERNEL(
greater_than, kHost, kFloat, kAny, greater_than_int32, def_int32)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt32),
Expand Down Expand Up @@ -644,7 +646,7 @@ using greater_than_int64_f = paddle::lite::kernels::host::CompareCompute<
PRECISION(kFloat),
paddle::lite::kernels::host::_GreaterThanFunctor<int64_t>>;
REGISTER_LITE_KERNEL(
greater_than, kHost, kFloat, kAny, greater_than_int64_f, int64)
greater_than, kHost, kFloat, kAny, greater_than_int64_f, def_int64)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt64),
Expand Down Expand Up @@ -699,3 +701,20 @@ REGISTER_LITE_KERNEL(
DATALAYOUT(kAny))})
.BindPaddleOpVersion("greater_equal", 1)
.Finalize();

using greater_equal_int32 = paddle::lite::kernels::host::CompareCompute<
PRECISION(kFloat),
paddle::lite::kernels::host::_GreaterEqualFunctor<int32_t>>;
REGISTER_LITE_KERNEL(
greater_equal, kHost, kFloat, kAny, greater_equal_int32, def_int32)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt32), DATALAYOUT(kAny), -1)})
.BindInput("Y",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt32), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.BindPaddleOpVersion("greater_equal", 1)
.Finalize();
4 changes: 2 additions & 2 deletions lite/kernels/host/expand_v2_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ REGISTER_LITE_KERNEL(expand_v2, kHost, kFloat, kAny, expand_v2_float, def)
.Finalize();

using expand_v2_int32 =
paddle::lite::kernels::host::ExpandV2Compute<int, PRECISION(kInt32)>;
REGISTER_LITE_KERNEL(expand_v2, kHost, kInt32, kAny, expand_v2_int32, def)
paddle::lite::kernels::host::ExpandV2Compute<int, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(expand_v2, kHost, kFloat, kAny, expand_v2_int32, def_int32)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kInt32),
Expand Down
5 changes: 2 additions & 3 deletions lite/kernels/host/gather_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,14 @@ void GatherFunc(const operators::GatherParam& param) {

template <typename IndexType, typename AxisType, typename DataType>
void GatherV2Func(const operators::GatherParam& param) {
auto* axis_data = param.Axis->data<AxisType>();
auto* index_data = param.Index->data<IndexType>();
auto* input_data = param.X->data<DataType>();
auto* out_data = param.Out->mutable_data<DataType>();

int index_size = param.Index->numel();
int input_size = param.X->numel();
auto input_dim = param.X->dims();
int axis_index = axis_data[0];
int axis_index = param.Axis ? param.Axis->data<AxisType>()[0] : param.axis;
int inner_dim_size = 1;
int outer_dim_size = 1;
int input_index_dim_size = input_dim[axis_index];
Expand Down Expand Up @@ -81,7 +80,7 @@ void GatherV2Func(const operators::GatherParam& param) {
template <typename IndexType, typename AxisType>
void GatherCompute<IndexType, AxisType>::Run() {
auto& param = this->template Param<operators::GatherParam>();
if (param.Axis != nullptr) {
if (param.Axis != nullptr || param.axis != -1) {
switch (param.X->precision()) {
case PRECISION(kFloat):
GatherV2Func<IndexType, AxisType, float>(param);
Expand Down
Loading

0 comments on commit f070416

Please sign in to comment.