From 64be38b56a5c931cfeff710aa69bc6336472fe46 Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Mon, 20 Jun 2022 10:56:50 +0000 Subject: [PATCH] opt conv --- paddle/phi/kernels/funcs/sparse/scatter.cu.h | 34 +++++--- .../kernels/sparse/gpu/coalesced_kernel.cu | 39 ++++++--- .../phi/kernels/sparse/gpu/convolution.cu.h | 18 +++-- .../sparse/gpu/convolution_grad_kernel.cu | 75 ++++++++++++----- .../kernels/sparse/gpu/convolution_kernel.cu | 80 +++++++++++++------ 5 files changed, 172 insertions(+), 74 deletions(-) diff --git a/paddle/phi/kernels/funcs/sparse/scatter.cu.h b/paddle/phi/kernels/funcs/sparse/scatter.cu.h index b9568f1df716d..cd89c916db577 100644 --- a/paddle/phi/kernels/funcs/sparse/scatter.cu.h +++ b/paddle/phi/kernels/funcs/sparse/scatter.cu.h @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/phi/kernels/funcs/aligned_vector.h" + +#define VecBytes 16 namespace phi { namespace funcs { @@ -28,33 +31,40 @@ namespace sparse { * channels: the output channel size * out: the outputs **/ -template +template __global__ void ScatterKernel(const T* input, const int* unique_value, const int* out_index, const int non_zero_num, const int rulebook_len, const int channels, - T* out, - const bool subm = false) { + T* out) { int tid = threadIdx.x + blockIdx.x * blockDim.x; - for (int i = tid; i < non_zero_num * channels; i += gridDim.x * blockDim.x) { - int indices_i = i / channels; - int channels_i = i - indices_i * channels; + const int vec_channels = channels / VecSize; + using LoadT = phi::AlignedVector; + using StoreT = phi::AlignedVector; + for (int i = tid; i < non_zero_num * vec_channels; + i += gridDim.x * blockDim.x) { + int indices_i = i / vec_channels; + int channels_i = i - indices_i * vec_channels; int start = unique_value[indices_i]; int end = indices_i == non_zero_num - 1 ? rulebook_len : unique_value[indices_i + 1]; // max(end-start) = kernel_size - T sum = static_cast(0); - if (subm) { - sum = out[indices_i * channels + channels_i]; - } + StoreT sums = {static_cast(0)}; for (int j = start; j < end; j++) { const int out_feature_i = out_index[j]; - sum += input[out_feature_i * channels + channels_i]; + LoadT vec_in; + phi::Load( + input + out_feature_i * channels + channels_i * VecSize, &vec_in); +#pragma unroll + for (int k = 0; k < VecSize; k++) { + sums[k] += vec_in[k]; + } } - out[indices_i * channels + channels_i] = sum; + phi::Store(sums, + out + indices_i * channels + channels_i * VecSize); } } diff --git a/paddle/phi/kernels/sparse/gpu/coalesced_kernel.cu b/paddle/phi/kernels/sparse/gpu/coalesced_kernel.cu index 7d9e566916add..60d90a18d4633 100644 --- a/paddle/phi/kernels/sparse/gpu/coalesced_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/coalesced_kernel.cu @@ -132,16 +132,35 @@ void CoalescedGPUKernel(const GPUContext& dev_ctx, } // 5. scatter the values - config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, nnz * stride, 1); - phi::funcs::sparse::ScatterKernel - <<>>( - x_values_ptr, - public_indexs.data(), - values_indexs_ptr, - out_nnz, - nnz, - stride, - out_values.data()); + const int VecSize = VecBytes / sizeof(T); + if (stride % VecSize == 0) { + config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, nnz * stride / VecSize, 1); + phi::funcs::sparse::ScatterKernel + <<>>(x_values_ptr, + public_indexs.data(), + values_indexs_ptr, + out_nnz, + nnz, + stride, + out_values.data()); + } else { + config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, nnz * stride, 1); + phi::funcs::sparse::ScatterKernel + <<>>(x_values_ptr, + public_indexs.data(), + values_indexs_ptr, + out_nnz, + nnz, + stride, + out_values.data()); + } // 6. convert index to coordinate Dim const_dims; diff --git a/paddle/phi/kernels/sparse/gpu/convolution.cu.h b/paddle/phi/kernels/sparse/gpu/convolution.cu.h index 24a7387d4fe19..a08c7931bb4f4 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution.cu.h +++ b/paddle/phi/kernels/sparse/gpu/convolution.cu.h @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/index_impl.cu.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/sparse/utils.cu.h" @@ -46,18 +47,23 @@ using Dims4D = phi::funcs::sparse::Dims4D; * index_size: the size of indices * slice_size: slice size corresponding to each index, here is the channel size **/ -template +template __global__ void GatherKernel(const T* params, const IndexT* indices, T* output, size_t index_size, size_t slice_size) { - CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) { - int64_t indices_i = i / slice_size; - int64_t slice_i = i - indices_i * slice_size; // offset inside the slice + CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size / VecSize, int64_t) { + const int vec_slice_size = slice_size / VecSize; + int indices_i = i / vec_slice_size; + int slice_i = i - indices_i * vec_slice_size; // offset inside the slice IndexT gather_i = indices[indices_i]; - int64_t params_i = gather_i * slice_size + slice_i; - *(output + i) = *(params + params_i); + int64_t params_i = gather_i * slice_size + slice_i * VecSize; + using LoadT = phi::AlignedVector; + using StoreT = phi::AlignedVector; + LoadT params_vec; + phi::Load(params + params_i, ¶ms_vec); + phi::Store(params_vec, output + i * VecSize); } } diff --git a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu index d83d064418eec..d91c93fde66fb 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" +#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h" #include "paddle/phi/kernels/sparse/convolution_grad_kernel.h" #include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" @@ -137,28 +138,58 @@ void Conv3dGradGPUKernel(const GPUContext& dev_ctx, } } - auto config = phi::backends::gpu::GetGpuLaunchConfig1D( - dev_ctx, rulebook_len * in_channels, 1); - GatherKernel<<>>(x.non_zero_elements().data(), - rulebook_ptr + rulebook_len, - in_features_ptr, - rulebook_len, - in_channels); + const int VecSize = VecBytes / sizeof(T); + if (in_channels % VecSize == 0) { + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, rulebook_len * in_channels / VecSize, 1); + GatherKernel + <<>>(x.non_zero_elements().data(), + rulebook_ptr + rulebook_len, + in_features_ptr, + rulebook_len, + in_channels); + } else { + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, rulebook_len * in_channels, 1); + GatherKernel + <<>>(x.non_zero_elements().data(), + rulebook_ptr + rulebook_len, + in_features_ptr, + rulebook_len, + in_channels); + } - config = phi::backends::gpu::GetGpuLaunchConfig1D( - dev_ctx, rulebook_len * out_channels, 1); - GatherKernel - <<>>(out_grad.non_zero_elements().data(), - rulebook_ptr + rulebook_len * 2, - out_grad_features_ptr, - rulebook_len, - out_channels); + if (out_channels % VecSize == 0) { + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, rulebook_len * out_channels / VecSize, 1); + GatherKernel + <<>>(out_grad.non_zero_elements().data(), + rulebook_ptr + rulebook_len * 2, + out_grad_features_ptr, + rulebook_len, + out_channels); + } else { + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, rulebook_len * out_channels, 1); + GatherKernel + <<>>(out_grad.non_zero_elements().data(), + rulebook_ptr + rulebook_len * 2, + out_grad_features_ptr, + rulebook_len, + out_channels); + } const T* kernel_ptr = kernel.data(); for (int i = 0; i < kernel_size; i++) { @@ -203,7 +234,7 @@ void Conv3dGradGPUKernel(const GPUContext& dev_ctx, } // 4. scatter - config = phi::backends::gpu::GetGpuLaunchConfig1D( + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( dev_ctx, rulebook_len * in_channels, 1); phi::funcs::ScatterCUDAKernel<< set_zero; set_zero(dev_ctx, &out_features, static_cast(0.0f)); - auto config = - phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * in_channels, 1); - GatherKernel<<>>(x.non_zero_elements().data(), - rulebook_ptr + n, - in_features_ptr, - n, - in_channels); + const int VecSize = VecBytes / sizeof(T); + if (in_channels % VecSize == 0) { + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, n * in_channels / VecSize, 1); + GatherKernel + <<>>(x.non_zero_elements().data(), + rulebook_ptr + n, + in_features_ptr, + n, + in_channels); + } else { + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * in_channels, 1); + GatherKernel + <<>>(x.non_zero_elements().data(), + rulebook_ptr + n, + in_features_ptr, + n, + in_channels); + } // 3. call gemm for every werght auto blas = phi::funcs::GetBlas(dev_ctx); @@ -155,7 +171,7 @@ void Conv3dGPUKernel(const GPUContext& dev_ctx, // 4. scatter if (subm) { set_zero(dev_ctx, out_values, static_cast(0.0f)); - config = + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * out_channels, 1); phi::funcs::ScatterCUDAKernel <<nnz() * out_channels, 1); - phi::funcs::sparse::ScatterKernel - <<>>(out_features_ptr, - unique_value.data(), - out_index.data(), - out->nnz(), - n, - out_channels, - out_values_ptr); + if (out_channels % VecSize == 0) { + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, out->nnz() * out_channels / VecSize, 1); + phi::funcs::sparse::ScatterKernel + <<>>(out_features_ptr, + unique_value.data(), + out_index.data(), + out->nnz(), + n, + out_channels, + out_values_ptr); + } else { + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, out->nnz() * out_channels, 1); + phi::funcs::sparse::ScatterKernel + <<>>(out_features_ptr, + unique_value.data(), + out_index.data(), + out->nnz(), + n, + out_channels, + out_values_ptr); + } } } /**