From 4bfa2d0eac373a1be8608b049c61b28f65c3479c Mon Sep 17 00:00:00 2001 From: AlbertVan <17612250325@163.com> Date: Mon, 26 Sep 2022 09:46:18 +0800 Subject: [PATCH] [XPU] add seq_softmax, seq_expand, lod_reset op in xpu (#9453) --- lite/kernels/xpu/CMakeLists.txt | 3 + lite/kernels/xpu/gru_compute.cc | 13 +- lite/kernels/xpu/gru_unit_compute.cc | 22 ++-- lite/kernels/xpu/layer_norm_compute.cc | 45 ++++--- lite/kernels/xpu/layer_norm_compute.h | 3 +- lite/kernels/xpu/lod_reset_compute.cc | 77 +++++++++++ lite/kernels/xpu/lod_reset_compute.h | 36 ++++++ lite/kernels/xpu/sequence_expand_compute.cc | 121 ++++++++++++++++++ lite/kernels/xpu/sequence_expand_compute.h | 43 +++++++ lite/kernels/xpu/sequence_pool_compute.cc | 16 +++ lite/kernels/xpu/sequence_softmax_compute.cc | 71 ++++++++++ lite/kernels/xpu/sequence_softmax_compute.h | 40 ++++++ .../fill_constant_batch_size_like_op.cc | 3 +- lite/tests/kernels/expand_compute_test.cc | 4 +- .../kernels/sequence_expand_compute_test.cc | 5 +- .../kernels/sequence_softmax_compute_test.cc | 4 +- 16 files changed, 471 insertions(+), 35 deletions(-) create mode 100644 lite/kernels/xpu/lod_reset_compute.cc create mode 100644 lite/kernels/xpu/lod_reset_compute.h create mode 100644 lite/kernels/xpu/sequence_expand_compute.cc create mode 100644 lite/kernels/xpu/sequence_expand_compute.h create mode 100644 lite/kernels/xpu/sequence_softmax_compute.cc create mode 100644 lite/kernels/xpu/sequence_softmax_compute.h diff --git a/lite/kernels/xpu/CMakeLists.txt b/lite/kernels/xpu/CMakeLists.txt index 350804e61c1..61aa1871db3 100644 --- a/lite/kernels/xpu/CMakeLists.txt +++ b/lite/kernels/xpu/CMakeLists.txt @@ -67,6 +67,8 @@ add_kernel(sequence_reverse_compute_xpu XPU extra SRCS sequence_reverse_compute. add_kernel(sequence_concat_compute_xpu XPU extra SRCS sequence_concat_compute.cc) add_kernel(sequence_arithmetic_compute_xpu XPU extra SRCS sequence_arithmetic_compute.cc) add_kernel(sequence_pool_compute_xpu XPU extra SRCS sequence_pool_compute.cc) +add_kernel(sequence_expand_compute_xpu XPU extra SRCS sequence_expand_compute.cc) +add_kernel(sequence_softmax_compute_xpu XPU extra SRCS sequence_softmax_compute.cc) add_kernel(match_matrix_tensor_compute_xpu XPU extra SRCS match_matrix_tensor_compute.cc) add_kernel(var_conv_2d_compute_xpu XPU extra SRCS var_conv_2d_compute.cc) add_kernel(search_grnn_compute_xpu XPU extra SRCS search_grnn_compute.cc) @@ -101,6 +103,7 @@ add_kernel(is_empty_compute_xpu XPU extra SRCS is_empty_compute.cc) add_kernel(shape_compute_xpu XPU extra SRCS shape_compute.cc) add_kernel(lod_array_length_compute_xpu XPU extra SRCS lod_array_length_compute.cc) add_kernel(multiclass_nms_compute_xpu XPU extra SRCS multiclass_nms_compute.cc) +add_kernel(lod_reset_compute_xpu XPU extra SRCS lod_reset_compute.cc) # extra(fused kernel) add_kernel(__xpu__resnet50_compute_xpu XPU extra SRCS __xpu__resnet50_compute.cc) diff --git a/lite/kernels/xpu/gru_compute.cc b/lite/kernels/xpu/gru_compute.cc index f3edaa364d4..46ece0f1c46 100644 --- a/lite/kernels/xpu/gru_compute.cc +++ b/lite/kernels/xpu/gru_compute.cc @@ -56,15 +56,18 @@ void GRUCompute::PrepareForRun() { paddle::lite::xpu::math::FindMaxAbs(weight_s1_ptr, weight_s1_len); weight_s2_abs_max_ = paddle::lite::xpu::math::FindMaxAbs(weight_s2_ptr, weight_s2_len); - std::vector weight_max_vector(8); - for (int i = 0; i < 4; i++) { + auto& ctx = this->ctx_->template As(); + int max_ptr_size = ctx.GetRawContext()->max_ptr_size(); + std::vector weight_max_vector(max_ptr_size * 2); + for (int i = 0; i < max_ptr_size; i++) { weight_max_vector[i] = weight_s1_abs_max_; - weight_max_vector[i + 4] = weight_s2_abs_max_; + weight_max_vector[i + max_ptr_size] = weight_s2_abs_max_; } - weight_max_guard_ = TargetWrapperXPU::MallocScratchPad(8 * sizeof(float)); + weight_max_guard_ = + TargetWrapperXPU::MallocScratchPad(max_ptr_size * 2 * sizeof(float)); XPU_CALL(xpu_memcpy(reinterpret_cast(weight_max_guard_->addr_), weight_max_vector.data(), - 8 * sizeof(float), + max_ptr_size * 2 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); // quant quant_weight_guard_ = diff --git a/lite/kernels/xpu/gru_unit_compute.cc b/lite/kernels/xpu/gru_unit_compute.cc index c674ce57f7c..0b28e7d3e25 100644 --- a/lite/kernels/xpu/gru_unit_compute.cc +++ b/lite/kernels/xpu/gru_unit_compute.cc @@ -42,15 +42,19 @@ void GRUUnitCompute::PrepareForRun() { paddle::lite::xpu::math::FindMaxAbs(weight_s1_ptr, weight_s1_len); weight_s2_abs_max_ = paddle::lite::xpu::math::FindMaxAbs(weight_s2_ptr, weight_s2_len); - std::vector weight_max_vector(8); - for (int i = 0; i < 4; i++) { + + auto& ctx = this->ctx_->template As(); + int max_ptr_size = ctx.GetRawContext()->max_ptr_size(); + std::vector weight_max_vector(max_ptr_size * 2); + for (int i = 0; i < max_ptr_size; i++) { weight_max_vector[i] = weight_s1_abs_max_; - weight_max_vector[i + 4] = weight_s2_abs_max_; + weight_max_vector[i + max_ptr_size] = weight_s2_abs_max_; } - weight_max_guard_ = TargetWrapperXPU::MallocScratchPad(8 * sizeof(float)); + weight_max_guard_ = + TargetWrapperXPU::MallocScratchPad(max_ptr_size * 2 * sizeof(float)); XPU_CALL(xpu_memcpy(reinterpret_cast(weight_max_guard_->addr_), weight_max_vector.data(), - 8 * sizeof(float), + max_ptr_size * 2 * sizeof(float), XPUMemcpyKind::XPU_HOST_TO_DEVICE)); // quant quant_weight_guard_ = @@ -103,14 +107,14 @@ void GRUUnitCompute::Run() { const float* bias_ptr = (bias == nullptr) ? nullptr : bias->data(); float* hidden_ptr = hidden->mutable_data(TARGET(kXPU)); - - int ret = xdnn::gru_unit( + int ret = xdnn::gru_core( ctx.GetRawContext(), input_ptr, hidden_prev_ptr, weight_ptr, hidden_ptr, batch_size, + 1, frame_size, nullptr, nullptr, @@ -119,7 +123,9 @@ void GRUUnitCompute::Run() { bias_ptr, xdnn::Activation_t::TANH, xdnn::Activation_t::SIGMOID, - origin_mode); + origin_mode, + false, + false); CHECK_EQ(ret, 0) << "call xdnn::gru_unit failed!"; } diff --git a/lite/kernels/xpu/layer_norm_compute.cc b/lite/kernels/xpu/layer_norm_compute.cc index 4a46c7aeab2..95fdf35f6dd 100644 --- a/lite/kernels/xpu/layer_norm_compute.cc +++ b/lite/kernels/xpu/layer_norm_compute.cc @@ -21,7 +21,8 @@ namespace lite { namespace kernels { namespace xpu { -void LayerNormCompute::Run() { +template +void LayerNormCompute::Run() { auto& param = this->template Param(); auto& ctx = this->ctx_->template As(); @@ -30,16 +31,17 @@ void LayerNormCompute::Run() { auto matrix_dim = x_dims.Flatten2D(axis); float epsilon = param.epsilon; - int r = xdnn::layer_norm(ctx.GetRawContext(), /* context */ - param.X->data(), /* in */ - param.Y->mutable_data(TARGET(kXPU)), /* out */ - matrix_dim[0], /* m */ - matrix_dim[1], /* n */ - epsilon, /* epsilon */ - param.Scale->data(), /* scale */ - param.Bias->data(), /* bias */ - nullptr, - nullptr); + int r = xdnn::layer_norm( + ctx.GetRawContext(), /* context */ + param.X->template data(), /* in */ + param.Y->template mutable_data(TARGET(kXPU)), /* out */ + matrix_dim[0], /* m */ + matrix_dim[1], /* n */ + epsilon, /* epsilon */ + param.Scale->template data(), /* scale */ + param.Bias->template data(), /* bias */ + nullptr, + nullptr); CHECK_EQ(r, 0); } @@ -49,12 +51,11 @@ void LayerNormCompute::Run() { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(layer_norm, - kXPU, - kFloat, - kNCHW, - paddle::lite::kernels::xpu::LayerNormCompute, - def) +namespace xpu = paddle::lite::kernels::xpu; + +using LayerNorm_FP32 = xpu::LayerNormCompute; +using LayerNorm_FP16 = xpu::LayerNormCompute; +REGISTER_LITE_KERNEL(layer_norm, kXPU, kFloat, kNCHW, LayerNorm_FP32, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))}) @@ -62,3 +63,13 @@ REGISTER_LITE_KERNEL(layer_norm, .BindOutput("Mean", {LiteType::GetTensorTy(TARGET(kXPU))}) .BindOutput("Variance", {LiteType::GetTensorTy(TARGET(kXPU))}) .Finalize(); + +REGISTER_LITE_KERNEL(layer_norm, kXPU, kFP16, kNCHW, LayerNorm_FP16, fp16) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindOutput("Mean", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .BindOutput("Variance", + {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))}) + .Finalize(); diff --git a/lite/kernels/xpu/layer_norm_compute.h b/lite/kernels/xpu/layer_norm_compute.h index 9eeb5924c51..beb1c0ec1a7 100644 --- a/lite/kernels/xpu/layer_norm_compute.h +++ b/lite/kernels/xpu/layer_norm_compute.h @@ -21,7 +21,8 @@ namespace lite { namespace kernels { namespace xpu { -class LayerNormCompute : public KernelLite { +template +class LayerNormCompute : public KernelLite { public: using param_t = operators::LayerNormParam; diff --git a/lite/kernels/xpu/lod_reset_compute.cc b/lite/kernels/xpu/lod_reset_compute.cc new file mode 100644 index 00000000000..384e8d76da6 --- /dev/null +++ b/lite/kernels/xpu/lod_reset_compute.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/lod_reset_compute.h" +#include +#include +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +void LodResetCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + + auto x = param.X; + auto output = param.Out; + output->mutable_data(TARGET(kXPU), x->memory_size()); + int r = xdnn::copy(ctx.GetRawContext(), + x->data(), + reinterpret_cast(output->raw_data()), + x->memory_size()); + CHECK_EQ(r, 0); + auto lod = output->mutable_lod(); + if (param.Y) { + if (param.Y->lod().size()) { + *lod = param.Y->lod(); + } else { + const auto* y_data = param.Y->data(); + std::vector y_cpu(param.Y->numel()); + TargetWrapperXPU::MemcpySync(y_cpu.data(), + y_data, + param.Y->numel() * sizeof(int), + IoDirection::DtoH); + (*lod).resize(1); + (*lod)[0].resize(param.Y->numel()); + for (int i = 0; i < param.Y->numel(); i++) { + (*lod)[0][i] = y_cpu[i]; + } + } + } else { + (*lod).resize(1); + for (auto id : param.target_lod) { + (*lod)[0].push_back(id); + } + } +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(lod_reset, + kXPU, + kAny, + kNCHW, + paddle::lite::kernels::xpu::LodResetCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kAny))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kAny))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kAny))}) + .Finalize(); diff --git a/lite/kernels/xpu/lod_reset_compute.h b/lite/kernels/xpu/lod_reset_compute.h new file mode 100644 index 00000000000..675ae648696 --- /dev/null +++ b/lite/kernels/xpu/lod_reset_compute.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +class LodResetCompute : public KernelLite { + public: + using param_t = operators::LodResetParam; + + void Run() override; + + virtual ~LodResetCompute() = default; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/xpu/sequence_expand_compute.cc b/lite/kernels/xpu/sequence_expand_compute.cc new file mode 100644 index 00000000000..ee102a30df3 --- /dev/null +++ b/lite/kernels/xpu/sequence_expand_compute.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/sequence_expand_compute.h" +#include +#include +#include + +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +template +void SequenceExpandCompute::PrepareForRun() { + lodx_cpu_.reset(new int[XPU_MAX_LOD_SIZE_64]); + lody_cpu_.reset(new int[XPU_MAX_LOD_SIZE_64]); + lodref_cpu_.reset(new int[XPU_MAX_LOD_SIZE_64]); +} + +template +void SequenceExpandCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + auto* x = param.X; + auto* y = param.Y; + auto* out = param.Out; + auto x_lod = x->lod(); + auto y_lod = y->lod(); + int ref_level = param.ref_level; + if (ref_level == -1) { + ref_level = y_lod.size() - 1; + } + + auto* x_data = x->template data(); + auto* out_data = out->template mutable_data(TARGET(kXPU)); + + if (y_lod[ref_level].size() <= 1 || + (y_lod[ref_level].size() == 2 && y_lod[ref_level][1] == 1)) { + int r = xdnn::copy(ctx.GetRawContext(), + reinterpret_cast(x_data), + reinterpret_cast(out_data), + x->numel() * sizeof(T)); + CHECK_EQ(r, 0) << "seqence_expand do copy failed."; + return; + } + + int dims = x->numel() / x->dims()[0]; + std::vector ref_y_lod = y_lod[ref_level]; + // create ref_x_lod; + std::vector ref_x_lod; + if (x->lod().size() == 1) { + ref_x_lod = x->lod()[0]; + } else { + ref_x_lod.resize(x->dims()[0] + 1); + std::iota(ref_x_lod.begin(), ref_x_lod.end(), 0); + } + + std::vector ref_out_lod(ref_y_lod.size(), 0); + std::vector out_lod; + out_lod.push_back(0); + + for (size_t i = 1; i < ref_y_lod.size(); ++i) { + int repeat_num = ref_y_lod[i] - ref_y_lod[i - 1]; + int seq_len = ref_x_lod[i] - ref_x_lod[i - 1]; + for (int j = 0; j < repeat_num; ++j) { + out_lod.push_back(out_lod.back() + seq_len); + } + ref_out_lod[i] = ref_out_lod[i - 1] + seq_len * repeat_num; + } + + auto& ref_lod = *out->mutable_lod(); + ref_lod[0] = out_lod; + int lod_len = ref_y_lod.size(); + for (int i = 0; i < lod_len; ++i) { + lodx_cpu_[i] = ref_x_lod[i]; + lodref_cpu_[i] = ref_y_lod[i]; + lody_cpu_[i] = ref_out_lod[i]; + } + + int r = + xdnn::sequence_expand(ctx.GetRawContext(), + reinterpret_cast(x_data), + reinterpret_cast(out_data), + {lodx_cpu_.get(), lod_len, nullptr}, + {lody_cpu_.get(), lod_len, nullptr}, + {lodref_cpu_.get(), lod_len, nullptr}, + dims); + + CHECK_EQ(r, 0); +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +namespace xpu = paddle::lite::kernels::xpu; + +using sequence_expand_float32 = + paddle::lite::kernels::xpu::SequenceExpandCompute; +REGISTER_LITE_KERNEL( + sequence_expand, kXPU, kFloat, kNCHW, sequence_expand_float32, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/xpu/sequence_expand_compute.h b/lite/kernels/xpu/sequence_expand_compute.h new file mode 100644 index 00000000000..3ce84b8347c --- /dev/null +++ b/lite/kernels/xpu/sequence_expand_compute.h @@ -0,0 +1,43 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +template +class SequenceExpandCompute : public KernelLite { + public: + using param_t = operators::SequenceExpandParam; + + void PrepareForRun() override; + + void Run() override; + + private: + std::unique_ptr lodx_cpu_; + std::unique_ptr lody_cpu_; + std::unique_ptr lodref_cpu_; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/xpu/sequence_pool_compute.cc b/lite/kernels/xpu/sequence_pool_compute.cc index 88d07ec550b..b5d86b40104 100644 --- a/lite/kernels/xpu/sequence_pool_compute.cc +++ b/lite/kernels/xpu/sequence_pool_compute.cc @@ -14,6 +14,7 @@ #include "lite/kernels/xpu/sequence_pool_compute.h" #include +#include #include "lite/backends/xpu/xpu_header_sitter.h" #include "lite/core/op_registry.h" @@ -42,6 +43,21 @@ void XPUSequencePoolCompute::Run() { for (size_t i = 0; i < in_lod.size(); ++i) { lod_cpu[i] = in_lod[i]; } + + int batch_size = in_lod.size() - 1; + std::vector offset_new; + if (in->lod().size() == 2) { + offset_new.resize(in->lod()[0].size()); + offset_new = in->lod()[0]; + } else { + offset_new.resize(batch_size + 1); + for (int i = 0; i <= batch_size; i++) { + offset_new[i] = i; + } + } + out->mutable_lod()->clear(); + out->mutable_lod()->push_back(offset_new); + int lod_len = in_lod.size(); int r = 0; if (pool_type_str == "MAX") { diff --git a/lite/kernels/xpu/sequence_softmax_compute.cc b/lite/kernels/xpu/sequence_softmax_compute.cc new file mode 100644 index 00000000000..ff997bad51a --- /dev/null +++ b/lite/kernels/xpu/sequence_softmax_compute.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/xpu/sequence_softmax_compute.h" +#include +#include +#include +#include "lite/backends/xpu/xpu_header_sitter.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +void SequenceSoftmaxCompute::PrepareForRun() { + lod_cpu_.reset(new int[XPU_MAX_LOD_SIZE]); +} + +void SequenceSoftmaxCompute::Run() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); + auto* in = param.X; + auto* out = param.Out; + // get lod + auto seq_offset = in->lod()[0]; + for (size_t i = 0; i < seq_offset.size(); ++i) { + lod_cpu_[i] = seq_offset[i]; + } + // get shape + auto input_dims = in->dims(); + std::vector xshape; + for (size_t i = 0; i < input_dims.size(); i++) { + xshape.push_back(input_dims[i]); + } + int seq_num = seq_offset.size(); + int r = 0; + r = xdnn::sequence_softmax(ctx.GetRawContext(), + in->data(), + out->mutable_data(TARGET(kXPU)), + xshape, + 0, + {lod_cpu_.get(), seq_num, nullptr}); + CHECK_EQ(r, 0); +} + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(sequence_softmax, + kXPU, + kFloat, + kNCHW, + paddle::lite::kernels::xpu::SequenceSoftmaxCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))}) + .Finalize(); diff --git a/lite/kernels/xpu/sequence_softmax_compute.h b/lite/kernels/xpu/sequence_softmax_compute.h new file mode 100644 index 00000000000..8ed2b9753a4 --- /dev/null +++ b/lite/kernels/xpu/sequence_softmax_compute.h @@ -0,0 +1,40 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace xpu { + +class SequenceSoftmaxCompute + : public KernelLite { + public: + using param_t = operators::SequenceSoftmaxParam; + + void PrepareForRun() override; + void Run() override; + + private: + std::unique_ptr lod_cpu_; +}; + +} // namespace xpu +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/fill_constant_batch_size_like_op.cc b/lite/operators/fill_constant_batch_size_like_op.cc index b14d8c59a4e..ac67ac4e2f6 100644 --- a/lite/operators/fill_constant_batch_size_like_op.cc +++ b/lite/operators/fill_constant_batch_size_like_op.cc @@ -30,7 +30,8 @@ bool FillConstantBatchSizeLikeOp::CheckShape() const { bool FillConstantBatchSizeLikeOp::InferShapeImpl() const { std::vector output_dim{param_.shape.begin(), param_.shape.end()}; - if (param_.input_dim_idx == 0 && !param_.input->lod().empty()) { + if (param_.input_dim_idx == 0 && !param_.input->lod().empty() && + param_.input->lod().back().size() > 1) { output_dim[param_.output_dim_idx] = param_.input->lod().back().size() - 1; } else { output_dim[param_.output_dim_idx] = diff --git a/lite/tests/kernels/expand_compute_test.cc b/lite/tests/kernels/expand_compute_test.cc index cda6ea3e101..f006cf2aa22 100644 --- a/lite/tests/kernels/expand_compute_test.cc +++ b/lite/tests/kernels/expand_compute_test.cc @@ -169,6 +169,8 @@ TEST(Expand, precision) { #if defined(LITE_WITH_NPU) place = TARGET(kNPU); abs_error = 1e-2; // Using fp16 in NPU +#elif defined(LITE_WITH_XPU) + place = TARGET(kXPU); #elif defined(LITE_WITH_ARM) || defined(LITE_WITH_X86) place = Place(TARGET(kHost), PRECISION(kAny)); #else @@ -177,7 +179,7 @@ TEST(Expand, precision) { test_expand_3dim(place, abs_error); test_expand_4dim(place, abs_error); -#ifndef LITE_WITH_NPU +#if !defined(LITE_WITH_NPU) && !defined(LITE_WITH_XPU) test_expand_3dim(place, abs_error); test_expand_4dim(place, abs_error); test_expand_4dim(place, abs_error); diff --git a/lite/tests/kernels/sequence_expand_compute_test.cc b/lite/tests/kernels/sequence_expand_compute_test.cc index dc10964586e..6f480ccf113 100644 --- a/lite/tests/kernels/sequence_expand_compute_test.cc +++ b/lite/tests/kernels/sequence_expand_compute_test.cc @@ -174,7 +174,10 @@ void test_sequence_expand(Place place) { TEST(SequenceExpand, precision) { Place place; -#if defined(LITE_WITH_ARM) || defined(LITE_WITH_X86) +// only supported xpu2 now, xpu1 will be supported later. +#if 0 && defined(LITE_WITH_XPU) + place = TARGET(kXPU); +#elif defined(LITE_WITH_ARM) || defined(LITE_WITH_X86) place = TARGET(kHost); #else return; diff --git a/lite/tests/kernels/sequence_softmax_compute_test.cc b/lite/tests/kernels/sequence_softmax_compute_test.cc index 7e8069c177c..3d7a8e98928 100644 --- a/lite/tests/kernels/sequence_softmax_compute_test.cc +++ b/lite/tests/kernels/sequence_softmax_compute_test.cc @@ -112,7 +112,9 @@ void test_sequence_softmax(Place place) { TEST(SequenceSoftmax, precision) { Place place; -#if defined(LITE_WITH_ARM) || defined(LITE_WITH_X86) +#if defined(LITE_WITH_XPU) + place = TARGET(kXPU); +#elif defined(LITE_WITH_ARM) || defined(LITE_WITH_X86) place = TARGET(kHost); #else return;