Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick] Optimize sparse kernel and fix some bug #50118

Merged
merged 7 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions cmake/external/cutlass.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

include(ExternalProject)

set(CUTLASS_PREFIX_DIR ${THIRD_PARTY_PATH}/cutlass)

set(CUTLASS_REPOSITORY https://github.com/NVIDIA/cutlass.git)
set(CUTLASS_TAG v2.9.1)

include_directories("${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/")
include_directories("${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/include/")
include_directories(
"${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/util/include/")

add_definitions("-DPADDLE_WITH_CUTLASS")

ExternalProject_Add(
extern_cutlass
${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE}
GIT_REPOSITORY ${CUTLASS_REPOSITORY}
GIT_TAG "${CUTLASS_TAG}"
PREFIX ${CUTLASS_PREFIX_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND "")

add_library(cutlass INTERFACE)

add_dependencies(cutlass extern_cutlass)
10 changes: 10 additions & 0 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -492,4 +492,14 @@ if(WITH_CUSPARSELT)
list(APPEND third_party_deps extern_cusparselt)
endif()

if(WITH_GPU
AND NOT WITH_ARM
AND NOT WIN32
AND NOT APPLE)
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0)
include(external/cutlass) # download, build, install cusparselt
list(APPEND third_party_deps extern_cutlass)
endif()
endif()

add_custom_target(third_party ALL DEPENDS ${third_party_deps})
4 changes: 4 additions & 0 deletions paddle/phi/kernels/funcs/norm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ limitations under the License. */

namespace phi {
namespace funcs {
#define CUDNN_PER_ACTIVATION_THRESHOLD 10240
#define CUDNN_SPATIAL_THRESHOLD_TRAIN 880801
#define CUDNN_SPATIAL_THRESHOLD_EVAL 65535

inline void ExtractNCWHD(const phi::DDim &dims,
const DataLayout &data_layout,
int *N,
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/kernels/funcs/sparse/utils.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,19 @@ __global__ void DistanceKernel(const T* start, const T* end, T* distance) {
}
}

inline __device__ bool SetBits(const int value, int* ptr) {
const int index = value >> 5;
const int mask = 1 << (value & 31);
const int old = atomicOr(ptr + index, mask);
return (mask & old) != 0;
}

inline __device__ bool TestBits(const int value, const int* ptr) {
const int index = value >> 5;
const int mask = 1 << (value & 31);
return (mask & ptr[index]) != 0;
}

} // namespace sparse
} // namespace funcs
} // namespace phi
20 changes: 11 additions & 9 deletions paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -852,15 +852,17 @@ void BatchNormGradRawKernel(const Context &ctx,
// ctx.GetPlace()),
// epsilon, saved_mean_data, saved_var_data));
#else
// CUDNN only support small batch size
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
if (use_native_kernel) {
if (x_dims.size() == 2) {
}
// CUDNN only support small batch size
bool use_native_nhwc =
d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC)
: false;
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN));
if (use_native_nhwc || (d_x && d_scale && d_bias)) {
if (use_native_kernel || use_native_nhwc) {
if (x_dims.size() == 2 || use_native_nhwc) {
dim3 block;
dim3 grid;
const int block_size = 512;
Expand Down
92 changes: 74 additions & 18 deletions paddle/phi/kernels/gpu/batch_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,40 @@ static __global__ void BNForwardInference(const T *x,
}
}

template <typename T>
static __global__ void InverseVariance(const BatchNormParamType<T> *variance,
const double epsilon,
const int C,
BatchNormParamType<T> *inv_variance) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < C) {
inv_variance[tid] = 1 / sqrt(variance[tid] + epsilon);
}
}

template <typename T, phi::DataLayout layout>
static __global__ void BN1DForwardInference(
const T *x,
const BatchNormParamType<T> *mean,
const BatchNormParamType<T> *inv_variance,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *bias,
const int C,
const int N,
const int HxW,
const double epsilon,
T *y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int num = N * C * HxW;
for (int i = gid; i < num; i += stride) {
const int c = layout == phi::DataLayout::kNCHW ? i / HxW % C : i % C;
BatchNormParamType<T> x_sub_mean =
static_cast<BatchNormParamType<T>>(x[i]) - mean[c];
y[i] = static_cast<T>(scale[c] * x_sub_mean * inv_variance[c] + bias[c]);
}
}

template <typename T, int BlockDim, phi::DataLayout layout>
static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining(
const T *x,
Expand Down Expand Up @@ -691,9 +725,6 @@ void BatchNormKernel(const Context &ctx,

auto handle = ctx.cudnn_handle();

const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;

// Now, depending on whether we are running test or not, we have two paths.
// It is training mode when it's not reference AND not using pre-trained
// model.
Expand Down Expand Up @@ -797,8 +828,8 @@ void BatchNormKernel(const Context &ctx,
// epsilon));
#else
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
(x_dims.size() == 2 ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_EVAL));
if (use_native_kernel) {
const int block_size = 256;
const int grid_size = (N * C * H * W * D + block_size - 1) / block_size;
Expand All @@ -816,18 +847,43 @@ void BatchNormKernel(const Context &ctx,
epsilon,
transformed_y.template data<T>());
} else {
BNForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(),
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
if (x_dims.size() == 2) {
DenseTensor inv_var = phi::Empty<BatchNormParamType<T>>(ctx, {C});
auto *inv_var_ptr = inv_var.data<BatchNormParamType<T>>();
const int threads = 512 > C ? C : 512;
const int blocks = (C + 511) / 512;
InverseVariance<T><<<blocks, threads>>>(
est_var->template data<BatchNormParamType<T>>(),
epsilon,
C,
inv_var_ptr);
BN1DForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
// est_var->template data<BatchNormParamType<T>>(),
inv_var_ptr,
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
} else {
BNForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(),
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
}
}
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
Expand Down Expand Up @@ -949,7 +1005,7 @@ void BatchNormKernel(const Context &ctx,
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN));
if (use_native_kernel) {
dim3 block;
dim3 grid;
Expand Down
Loading