Skip to content

Commit

Permalink
[cherry-pick] Optimize sparse kernel and fix some bug (#50118)
Browse files Browse the repository at this point in the history
cherry-pick some PR about optimize sparse kernel and fix some bug:
#47736 #47703 #47604 #46679 #48439 #49009 #49734
  • Loading branch information
zhangkaihuo authored Feb 2, 2023
1 parent e32ff65 commit 8c5e432
Show file tree
Hide file tree
Showing 12 changed files with 1,322 additions and 197 deletions.
43 changes: 43 additions & 0 deletions cmake/external/cutlass.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

include(ExternalProject)

set(CUTLASS_PREFIX_DIR ${THIRD_PARTY_PATH}/cutlass)

set(CUTLASS_REPOSITORY https://github.com/NVIDIA/cutlass.git)
set(CUTLASS_TAG v2.9.1)

include_directories("${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/")
include_directories("${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/include/")
include_directories(
"${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/util/include/")

add_definitions("-DPADDLE_WITH_CUTLASS")

ExternalProject_Add(
extern_cutlass
${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE}
GIT_REPOSITORY ${CUTLASS_REPOSITORY}
GIT_TAG "${CUTLASS_TAG}"
PREFIX ${CUTLASS_PREFIX_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND "")

add_library(cutlass INTERFACE)

add_dependencies(cutlass extern_cutlass)
10 changes: 10 additions & 0 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -492,4 +492,14 @@ if(WITH_CUSPARSELT)
list(APPEND third_party_deps extern_cusparselt)
endif()

if(WITH_GPU
AND NOT WITH_ARM
AND NOT WIN32
AND NOT APPLE)
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0)
include(external/cutlass) # download, build, install cusparselt
list(APPEND third_party_deps extern_cutlass)
endif()
endif()

add_custom_target(third_party ALL DEPENDS ${third_party_deps})
4 changes: 4 additions & 0 deletions paddle/phi/kernels/funcs/norm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ limitations under the License. */

namespace phi {
namespace funcs {
#define CUDNN_PER_ACTIVATION_THRESHOLD 10240
#define CUDNN_SPATIAL_THRESHOLD_TRAIN 880801
#define CUDNN_SPATIAL_THRESHOLD_EVAL 65535

inline void ExtractNCWHD(const phi::DDim &dims,
const DataLayout &data_layout,
int *N,
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/kernels/funcs/sparse/utils.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,19 @@ __global__ void DistanceKernel(const T* start, const T* end, T* distance) {
}
}

inline __device__ bool SetBits(const int value, int* ptr) {
const int index = value >> 5;
const int mask = 1 << (value & 31);
const int old = atomicOr(ptr + index, mask);
return (mask & old) != 0;
}

inline __device__ bool TestBits(const int value, const int* ptr) {
const int index = value >> 5;
const int mask = 1 << (value & 31);
return (mask & ptr[index]) != 0;
}

} // namespace sparse
} // namespace funcs
} // namespace phi
20 changes: 11 additions & 9 deletions paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -852,15 +852,17 @@ void BatchNormGradRawKernel(const Context &ctx,
// ctx.GetPlace()),
// epsilon, saved_mean_data, saved_var_data));
#else
// CUDNN only support small batch size
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
if (use_native_kernel) {
if (x_dims.size() == 2) {
}
// CUDNN only support small batch size
bool use_native_nhwc =
d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC)
: false;
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN));
if (use_native_nhwc || (d_x && d_scale && d_bias)) {
if (use_native_kernel || use_native_nhwc) {
if (x_dims.size() == 2 || use_native_nhwc) {
dim3 block;
dim3 grid;
const int block_size = 512;
Expand Down
92 changes: 74 additions & 18 deletions paddle/phi/kernels/gpu/batch_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,40 @@ static __global__ void BNForwardInference(const T *x,
}
}

template <typename T>
static __global__ void InverseVariance(const BatchNormParamType<T> *variance,
const double epsilon,
const int C,
BatchNormParamType<T> *inv_variance) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < C) {
inv_variance[tid] = 1 / sqrt(variance[tid] + epsilon);
}
}

template <typename T, phi::DataLayout layout>
static __global__ void BN1DForwardInference(
const T *x,
const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *inv_variance,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *bias,
const int C,
const int N,
const int HxW,
const double epsilon,
T *y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int num = N * C * HxW;
for (int i = gid; i < num; i += stride) {
const int c = layout == phi::DataLayout::kNCHW ? i / HxW % C : i % C;
BatchNormParamType<T> x_sub_mean =
static_cast<BatchNormParamType<T>>(x[i]) - mean[c];
y[i] = static_cast<T>(scale[c] * x_sub_mean * inv_variance[c] + bias[c]);
}
}

template <typename T, int BlockDim, phi::DataLayout layout>
static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining(
const T *x,
Expand Down Expand Up @@ -691,9 +725,6 @@ void BatchNormKernel(const Context &ctx,

auto handle = ctx.cudnn_handle();

const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;

// Now, depending on whether we are running test or not, we have two paths.
// It is training mode when it's not reference AND not using pre-trained
// model.
Expand Down Expand Up @@ -797,8 +828,8 @@ void BatchNormKernel(const Context &ctx,
// epsilon));
#else
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
(x_dims.size() == 2 ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_EVAL));
if (use_native_kernel) {
const int block_size = 256;
const int grid_size = (N * C * H * W * D + block_size - 1) / block_size;
Expand All @@ -816,18 +847,43 @@ void BatchNormKernel(const Context &ctx,
epsilon,
transformed_y.template data<T>());
} else {
BNForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(),
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
if (x_dims.size() == 2) {
DenseTensor inv_var = phi::Empty<BatchNormParamType<T>>(ctx, {C});
auto *inv_var_ptr = inv_var.data<BatchNormParamType<T>>();
const int threads = 512 > C ? C : 512;
const int blocks = (C + 511) / 512;
InverseVariance<T><<<blocks, threads>>>(
est_var->template data<BatchNormParamType<T>>(),
epsilon,
C,
inv_var_ptr);
BN1DForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
// est_var->template data<BatchNormParamType<T>>(),
inv_var_ptr,
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
} else {
BNForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(),
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
}
}
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
Expand Down Expand Up @@ -949,7 +1005,7 @@ void BatchNormKernel(const Context &ctx,
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN));
if (use_native_kernel) {
dim3 block;
dim3 grid;
Expand Down
Loading

0 comments on commit 8c5e432

Please sign in to comment.