Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xpu]: add matmul int8_t #9764

Merged
merged 1 commit into from
Dec 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions lite/kernels/xpu/__xpu__fc_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,15 @@ using XPUFC_Int8_FP32_FP32 =
using XPUFC_FP32_LOCAL_QUANT =
xpu::XPUFcCompute<float, float, float, float, PRECISION(kFloat)>;

using XPUFC_Int8_Int8_FP32 =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fc注册了int8和fp32的输入输出,kernel pick之前的逻辑需要林伟改下。否则注册的这些还是使用不了。

xpu::XPUFcCompute<int8_t, int8_t, int8_t, float, PRECISION(kInt8)>;

using XPUFC_Int8_Int8_Int8 =
xpu::XPUFcCompute<int8_t, int8_t, int8_t, int8_t, PRECISION(kInt8)>;

using XPUFC_Int8_Int8_FP32_Int8 =
xpu::XPUFcCompute<int8_t, int8_t, float, int8_t, PRECISION(kInt8)>;

REGISTER_LITE_KERNEL(
__xpu__fc, kXPU, kFloat, kNCHW, XPUFC_FP32, XPU_Real_kFloat)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))})
Expand Down Expand Up @@ -302,8 +311,6 @@ REGISTER_LITE_KERNEL(
.BindOutput("OutputMax", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();

using XPUFC_Int8_Int8_Int8 =
xpu::XPUFcCompute<int8_t, int8_t, int8_t, int8_t, PRECISION(kInt8)>;
REGISTER_LITE_KERNEL(
__xpu__fc, kXPU, kInt8, kNCHW, XPUFC_Int8_Int8_Int8, XPU_Int8_Int8_Int8)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt8))})
Expand All @@ -315,8 +322,6 @@ REGISTER_LITE_KERNEL(
.BindOutput("OutputMax", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();

using XPUFC_Int8_Int8_FP32 =
xpu::XPUFcCompute<int8_t, int8_t, int8_t, float, PRECISION(kInt8)>;
REGISTER_LITE_KERNEL(
__xpu__fc, kXPU, kInt8, kNCHW, XPUFC_Int8_Int8_FP32, XPU_Int8_Int8_FP32)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt8))})
Expand All @@ -341,3 +346,19 @@ REGISTER_LITE_KERNEL(__xpu__fc,
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("OutputMax", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();

REGISTER_LITE_KERNEL(__xpu__fc,
kXPU,
kInt8,
kNCHW,
XPUFC_Int8_Int8_FP32_Int8,
XPU_Int8_Int8_FP32_Int8)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFloat))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("InputMax", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Output",
{LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kInt8))})
.BindOutput("OutputMax", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
130 changes: 81 additions & 49 deletions lite/kernels/xpu/matmul_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "lite/kernels/xpu/matmul_compute.h"
#include <vector>
#include "lite/backends/xpu/math.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
Expand All @@ -24,17 +25,38 @@ namespace xpu {

namespace math = paddle::lite::xpu::math;

void MatMulCompute::Run() {
template <typename TGEMM,
typename TW,
typename DX,
typename DY,
PrecisionType PType>
void MatMulCompute<TGEMM, TW, DX, DY, PType>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<XPUContext>();

auto* x = param.X;
auto* y = param.Y;
auto* out = param.Out;

if (param.enable_int8) {
LOG(FATAL) << "xpu don't support matmul int8 outside encoder";
// max
int max_ptr_size = ctx.GetRawContext()->max_ptr_size();
XPUScratchPadGuard input_max_guard_ =
TargetWrapperXPU::MallocScratchPad(max_ptr_size * sizeof(float));
if (param.enable_int8) { // for paddle slim int8 quant
std::vector<float> cpu_input_max(max_ptr_size, 127 * param.input_scale);
lite::TargetWrapperXPU::MemcpySync(input_max_guard_->addr_,
cpu_input_max.data(),
sizeof(float) * max_ptr_size,
IoDirection::HtoD);
}

const float* x_maxptr = nullptr;
const float* w_maxptr = nullptr;
if (param.enable_int8 && x == y) {
x_maxptr = reinterpret_cast<float*>(input_max_guard_->addr_);
w_maxptr = reinterpret_cast<float*>(input_max_guard_->addr_);
}

auto& x_dims = x->dims();
auto& y_dims = y->dims();
auto mat_dim_a = math::CreateMatrixDescriptor(
Expand Down Expand Up @@ -69,46 +91,46 @@ void MatMulCompute::Run() {

int r = 0;
if (mat_dim_a.batch_size_ == 0 || mat_dim_a.batch_size_ == 1) {
r = xdnn::fc_fusion<float, float, float, int16_t>(
ctx.GetRawContext(), // ctx
x->data<float>(), // x
y->data<float>(), // w
out->mutable_data<float>(TARGET(kXPU)), // y
mat_dim_a.height_, // m
mat_dim_b.width_, // n
mat_dim_a.width_, // k
mat_dim_a.trans_, // x_trans
mat_dim_b.trans_, // w_trans
nullptr, // x_maxptr
nullptr, // w_maxptr
nullptr, // y_maxptr
lda, // ldx
ldb, // ldw
ldc, // ldy
param.alpha, // alpha
0.0f, // beta
nullptr, // bias
xdnn::Activation_t::LINEAR); // act
r = xdnn::fc_fusion<DX, TW, DY, TGEMM>(
ctx.GetRawContext(), // ctx
x->template data<DX>(), // x
y->template data<TW>(), // w
out->template mutable_data<DY>(TARGET(kXPU)), // y
mat_dim_a.height_, // m
mat_dim_b.width_, // n
mat_dim_a.width_, // k
mat_dim_a.trans_, // x_trans
mat_dim_b.trans_, // w_trans
x_maxptr, // x_maxptr
w_maxptr, // w_maxptr
nullptr, // y_maxptr
lda, // ldx
ldb, // ldw
ldc, // ldy
param.alpha, // alpha
0.0f, // beta
nullptr, // bias
xdnn::Activation_t::LINEAR); // act
} else {
// batch matmul
r = xdnn::fc_batched<float, float, float, int16_t>(
ctx.GetRawContext(), /* context */
mat_dim_a.batch_size_, /* batch_size */
mat_dim_a.trans_, /* TransA */
mat_dim_b.trans_, /* TransB */
mat_dim_a.height_, /* M */
mat_dim_b.width_, /* N */
mat_dim_a.width_, /* K */
param.alpha, /* alpha */
x->data<float>(), /* A */
mat_dim_a.stride_, /* stride_a */
y->data<float>(), /* B */
mat_dim_b.stride_, /* stride_b */
0.0f, /* beta */
out->mutable_data<float>(TARGET(kXPU)), /* C */
mat_dim_a.height_ * mat_dim_b.width_, /* stride_c */
nullptr, /* x_maxptr */
nullptr /* w_maxptr */);
r = xdnn::fc_batched<DX, TW, DY, TGEMM>(
ctx.GetRawContext(), /* context */
mat_dim_a.batch_size_, /* batch_size */
mat_dim_a.trans_, /* TransA */
mat_dim_b.trans_, /* TransB */
mat_dim_a.height_, /* M */
mat_dim_b.width_, /* N */
mat_dim_a.width_, /* K */
param.alpha, /* alpha */
x->template data<DX>(), /* A */
mat_dim_a.stride_, /* stride_a */
y->template data<TW>(), /* B */
mat_dim_b.stride_, /* stride_b */
0.0f, /* beta */
out->template mutable_data<DY>(TARGET(kXPU)), /* C */
mat_dim_a.height_ * mat_dim_b.width_, /* stride_c */
x_maxptr, /* x_maxptr */
w_maxptr); /* w_maxptr */
}
CHECK_EQ(r, 0);
}
Expand All @@ -118,18 +140,28 @@ void MatMulCompute::Run() {
} // namespace lite
} // namespace paddle

REGISTER_LITE_KERNEL(
matmul, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::MatMulCompute, def)
namespace xpu = paddle::lite::kernels::xpu;

// tgemm w x y
using XPUMATMUL_FP32 =
xpu::MatMulCompute<int16_t, float, float, float, PRECISION(kFloat)>;

using XPUMATMUL_Int8_FP32_FP32 =
xpu::MatMulCompute<int8_t, float, float, float, PRECISION(kInt8)>;

REGISTER_LITE_KERNEL(matmul, kXPU, kFloat, kNCHW, XPUMATMUL_FP32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
REGISTER_LITE_KERNEL(matmul_v2,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::MatMulCompute,
def)

REGISTER_LITE_KERNEL(matmul, kXPU, kInt8, kNCHW, XPUMATMUL_Int8_FP32_FP32, int8)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFloat))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFloat))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU), PRECISION(kFloat))})
.Finalize();

REGISTER_LITE_KERNEL(matmul_v2, kXPU, kFloat, kNCHW, XPUMATMUL_FP32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
Expand Down
8 changes: 6 additions & 2 deletions lite/kernels/xpu/matmul_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {

class MatMulCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
template <typename TGEMM,
typename TW,
typename DX,
typename DY,
PrecisionType PType>
class MatMulCompute : public KernelLite<TARGET(kXPU), PType> {
public:
using param_t = operators::MatMulParam;

Expand Down