Skip to content

Commit

Permalink
[XPU] add seq_softmax, seq_expand, lod_reset op in xpu (#9453)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbertVan authored Sep 26, 2022
1 parent 6b04c22 commit 4bfa2d0
Show file tree
Hide file tree
Showing 16 changed files with 471 additions and 35 deletions.
3 changes: 3 additions & 0 deletions lite/kernels/xpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ add_kernel(sequence_reverse_compute_xpu XPU extra SRCS sequence_reverse_compute.
add_kernel(sequence_concat_compute_xpu XPU extra SRCS sequence_concat_compute.cc)
add_kernel(sequence_arithmetic_compute_xpu XPU extra SRCS sequence_arithmetic_compute.cc)
add_kernel(sequence_pool_compute_xpu XPU extra SRCS sequence_pool_compute.cc)
add_kernel(sequence_expand_compute_xpu XPU extra SRCS sequence_expand_compute.cc)
add_kernel(sequence_softmax_compute_xpu XPU extra SRCS sequence_softmax_compute.cc)
add_kernel(match_matrix_tensor_compute_xpu XPU extra SRCS match_matrix_tensor_compute.cc)
add_kernel(var_conv_2d_compute_xpu XPU extra SRCS var_conv_2d_compute.cc)
add_kernel(search_grnn_compute_xpu XPU extra SRCS search_grnn_compute.cc)
Expand Down Expand Up @@ -101,6 +103,7 @@ add_kernel(is_empty_compute_xpu XPU extra SRCS is_empty_compute.cc)
add_kernel(shape_compute_xpu XPU extra SRCS shape_compute.cc)
add_kernel(lod_array_length_compute_xpu XPU extra SRCS lod_array_length_compute.cc)
add_kernel(multiclass_nms_compute_xpu XPU extra SRCS multiclass_nms_compute.cc)
add_kernel(lod_reset_compute_xpu XPU extra SRCS lod_reset_compute.cc)

# extra(fused kernel)
add_kernel(__xpu__resnet50_compute_xpu XPU extra SRCS __xpu__resnet50_compute.cc)
Expand Down
13 changes: 8 additions & 5 deletions lite/kernels/xpu/gru_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,18 @@ void GRUCompute::PrepareForRun() {
paddle::lite::xpu::math::FindMaxAbs(weight_s1_ptr, weight_s1_len);
weight_s2_abs_max_ =
paddle::lite::xpu::math::FindMaxAbs(weight_s2_ptr, weight_s2_len);
std::vector<float> weight_max_vector(8);
for (int i = 0; i < 4; i++) {
auto& ctx = this->ctx_->template As<XPUContext>();
int max_ptr_size = ctx.GetRawContext()->max_ptr_size();
std::vector<float> weight_max_vector(max_ptr_size * 2);
for (int i = 0; i < max_ptr_size; i++) {
weight_max_vector[i] = weight_s1_abs_max_;
weight_max_vector[i + 4] = weight_s2_abs_max_;
weight_max_vector[i + max_ptr_size] = weight_s2_abs_max_;
}
weight_max_guard_ = TargetWrapperXPU::MallocScratchPad(8 * sizeof(float));
weight_max_guard_ =
TargetWrapperXPU::MallocScratchPad(max_ptr_size * 2 * sizeof(float));
XPU_CALL(xpu_memcpy(reinterpret_cast<float*>(weight_max_guard_->addr_),
weight_max_vector.data(),
8 * sizeof(float),
max_ptr_size * 2 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
// quant
quant_weight_guard_ =
Expand Down
22 changes: 14 additions & 8 deletions lite/kernels/xpu/gru_unit_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,19 @@ void GRUUnitCompute::PrepareForRun() {
paddle::lite::xpu::math::FindMaxAbs(weight_s1_ptr, weight_s1_len);
weight_s2_abs_max_ =
paddle::lite::xpu::math::FindMaxAbs(weight_s2_ptr, weight_s2_len);
std::vector<float> weight_max_vector(8);
for (int i = 0; i < 4; i++) {

auto& ctx = this->ctx_->template As<XPUContext>();
int max_ptr_size = ctx.GetRawContext()->max_ptr_size();
std::vector<float> weight_max_vector(max_ptr_size * 2);
for (int i = 0; i < max_ptr_size; i++) {
weight_max_vector[i] = weight_s1_abs_max_;
weight_max_vector[i + 4] = weight_s2_abs_max_;
weight_max_vector[i + max_ptr_size] = weight_s2_abs_max_;
}
weight_max_guard_ = TargetWrapperXPU::MallocScratchPad(8 * sizeof(float));
weight_max_guard_ =
TargetWrapperXPU::MallocScratchPad(max_ptr_size * 2 * sizeof(float));
XPU_CALL(xpu_memcpy(reinterpret_cast<float*>(weight_max_guard_->addr_),
weight_max_vector.data(),
8 * sizeof(float),
max_ptr_size * 2 * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE));
// quant
quant_weight_guard_ =
Expand Down Expand Up @@ -103,14 +107,14 @@ void GRUUnitCompute::Run() {
const float* bias_ptr = (bias == nullptr) ? nullptr : bias->data<float>();

float* hidden_ptr = hidden->mutable_data<float>(TARGET(kXPU));

int ret = xdnn::gru_unit<float, int16_t, float, int16_t>(
int ret = xdnn::gru_core<float, int16_t, float, int16_t>(
ctx.GetRawContext(),
input_ptr,
hidden_prev_ptr,
weight_ptr,
hidden_ptr,
batch_size,
1,
frame_size,
nullptr,
nullptr,
Expand All @@ -119,7 +123,9 @@ void GRUUnitCompute::Run() {
bias_ptr,
xdnn::Activation_t::TANH,
xdnn::Activation_t::SIGMOID,
origin_mode);
origin_mode,
false,
false);
CHECK_EQ(ret, 0) << "call xdnn::gru_unit failed!";
}

Expand Down
45 changes: 28 additions & 17 deletions lite/kernels/xpu/layer_norm_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ namespace lite {
namespace kernels {
namespace xpu {

void LayerNormCompute::Run() {
template <typename InType, PrecisionType PType>
void LayerNormCompute<InType, PType>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();

Expand All @@ -30,16 +31,17 @@ void LayerNormCompute::Run() {
auto matrix_dim = x_dims.Flatten2D(axis);
float epsilon = param.epsilon;

int r = xdnn::layer_norm(ctx.GetRawContext(), /* context */
param.X->data<float>(), /* in */
param.Y->mutable_data<float>(TARGET(kXPU)), /* out */
matrix_dim[0], /* m */
matrix_dim[1], /* n */
epsilon, /* epsilon */
param.Scale->data<float>(), /* scale */
param.Bias->data<float>(), /* bias */
nullptr,
nullptr);
int r = xdnn::layer_norm<InType>(
ctx.GetRawContext(), /* context */
param.X->template data<InType>(), /* in */
param.Y->template mutable_data<InType>(TARGET(kXPU)), /* out */
matrix_dim[0], /* m */
matrix_dim[1], /* n */
epsilon, /* epsilon */
param.Scale->template data<float>(), /* scale */
param.Bias->template data<float>(), /* bias */
nullptr,
nullptr);

CHECK_EQ(r, 0);
}
Expand All @@ -49,16 +51,25 @@ void LayerNormCompute::Run() {
} // namespace lite
} // namespace paddle

REGISTER_LITE_KERNEL(layer_norm,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::LayerNormCompute,
def)
namespace xpu = paddle::lite::kernels::xpu;

using LayerNorm_FP32 = xpu::LayerNormCompute<float, PRECISION(kFloat)>;
using LayerNorm_FP16 = xpu::LayerNormCompute<float16, PRECISION(kFP16)>;
REGISTER_LITE_KERNEL(layer_norm, kXPU, kFloat, kNCHW, LayerNorm_FP32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Mean", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Variance", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();

REGISTER_LITE_KERNEL(layer_norm, kXPU, kFP16, kNCHW, LayerNorm_FP16, fp16)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
.BindOutput("Mean", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
.BindOutput("Variance",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFP16))})
.Finalize();
3 changes: 2 additions & 1 deletion lite/kernels/xpu/layer_norm_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ namespace lite {
namespace kernels {
namespace xpu {

class LayerNormCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
template <typename InType, PrecisionType PType>
class LayerNormCompute : public KernelLite<TARGET(kXPU), PType> {
public:
using param_t = operators::LayerNormParam;

Expand Down
77 changes: 77 additions & 0 deletions lite/kernels/xpu/lod_reset_compute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/xpu/lod_reset_compute.h"
#include <algorithm>
#include <vector>
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {

void LodResetCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();

auto x = param.X;
auto output = param.Out;
output->mutable_data(TARGET(kXPU), x->memory_size());
int r = xdnn::copy<int8_t>(ctx.GetRawContext(),
x->data<int8_t>(),
reinterpret_cast<int8_t*>(output->raw_data()),
x->memory_size());
CHECK_EQ(r, 0);
auto lod = output->mutable_lod();
if (param.Y) {
if (param.Y->lod().size()) {
*lod = param.Y->lod();
} else {
const auto* y_data = param.Y->data<int>();
std::vector<int> y_cpu(param.Y->numel());
TargetWrapperXPU::MemcpySync(y_cpu.data(),
y_data,
param.Y->numel() * sizeof(int),
IoDirection::DtoH);
(*lod).resize(1);
(*lod)[0].resize(param.Y->numel());
for (int i = 0; i < param.Y->numel(); i++) {
(*lod)[0][i] = y_cpu[i];
}
}
} else {
(*lod).resize(1);
for (auto id : param.target_lod) {
(*lod)[0].push_back(id);
}
}
}

} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle

REGISTER_LITE_KERNEL(lod_reset,
kXPU,
kAny,
kNCHW,
paddle::lite::kernels::xpu::LodResetCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kAny))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kAny))})
.Finalize();
36 changes: 36 additions & 0 deletions lite/kernels/xpu/lod_reset_compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {

class LodResetCompute : public KernelLite<TARGET(kXPU), PRECISION(kAny)> {
public:
using param_t = operators::LodResetParam;

void Run() override;

virtual ~LodResetCompute() = default;
};

} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
Loading

0 comments on commit 4bfa2d0

Please sign in to comment.