Skip to content

Commit

Permalink
opt conv
Browse files Browse the repository at this point in the history
  • Loading branch information
zkh2016 committed Jun 20, 2022
1 parent 97493af commit 64be38b
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 74 deletions.
34 changes: 22 additions & 12 deletions paddle/phi/kernels/funcs/sparse/scatter.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/phi/kernels/funcs/aligned_vector.h"

#define VecBytes 16

namespace phi {
namespace funcs {
Expand All @@ -28,33 +31,40 @@ namespace sparse {
* channels: the output channel size
* out: the outputs
**/
template <typename T>
template <typename T, int VecSize>
__global__ void ScatterKernel(const T* input,
const int* unique_value,
const int* out_index,
const int non_zero_num,
const int rulebook_len,
const int channels,
T* out,
const bool subm = false) {
T* out) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < non_zero_num * channels; i += gridDim.x * blockDim.x) {
int indices_i = i / channels;
int channels_i = i - indices_i * channels;
const int vec_channels = channels / VecSize;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
for (int i = tid; i < non_zero_num * vec_channels;
i += gridDim.x * blockDim.x) {
int indices_i = i / vec_channels;
int channels_i = i - indices_i * vec_channels;

int start = unique_value[indices_i];
int end = indices_i == non_zero_num - 1 ? rulebook_len
: unique_value[indices_i + 1];
// max(end-start) = kernel_size
T sum = static_cast<T>(0);
if (subm) {
sum = out[indices_i * channels + channels_i];
}
StoreT sums = {static_cast<T>(0)};
for (int j = start; j < end; j++) {
const int out_feature_i = out_index[j];
sum += input[out_feature_i * channels + channels_i];
LoadT vec_in;
phi::Load<T, VecSize>(
input + out_feature_i * channels + channels_i * VecSize, &vec_in);
#pragma unroll
for (int k = 0; k < VecSize; k++) {
sums[k] += vec_in[k];
}
}
out[indices_i * channels + channels_i] = sum;
phi::Store<T, VecSize>(sums,
out + indices_i * channels + channels_i * VecSize);
}
}

Expand Down
39 changes: 29 additions & 10 deletions paddle/phi/kernels/sparse/gpu/coalesced_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,35 @@ void CoalescedGPUKernel(const GPUContext& dev_ctx,
}

// 5. scatter the values
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, nnz * stride, 1);
phi::funcs::sparse::ScatterKernel<T>
<<<config.block_per_grid, config.thread_per_block, 0, dev_ctx.stream()>>>(
x_values_ptr,
public_indexs.data<int>(),
values_indexs_ptr,
out_nnz,
nnz,
stride,
out_values.data<T>());
const int VecSize = VecBytes / sizeof(T);
if (stride % VecSize == 0) {
config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, nnz * stride / VecSize, 1);
phi::funcs::sparse::ScatterKernel<T, VecSize>
<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(x_values_ptr,
public_indexs.data<int>(),
values_indexs_ptr,
out_nnz,
nnz,
stride,
out_values.data<T>());
} else {
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, nnz * stride, 1);
phi::funcs::sparse::ScatterKernel<T, 1>
<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(x_values_ptr,
public_indexs.data<int>(),
values_indexs_ptr,
out_nnz,
nnz,
stride,
out_values.data<T>());
}

// 6. convert index to coordinate
Dim<DDim::kMaxRank> const_dims;
Expand Down
18 changes: 12 additions & 6 deletions paddle/phi/kernels/sparse/gpu/convolution.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/utils.cu.h"
Expand All @@ -46,18 +47,23 @@ using Dims4D = phi::funcs::sparse::Dims4D;
* index_size: the size of indices
* slice_size: slice size corresponding to each index, here is the channel size
**/
template <typename T, typename IndexT = int>
template <typename T, typename IndexT = int, int VecSize>
__global__ void GatherKernel(const T* params,
const IndexT* indices,
T* output,
size_t index_size,
size_t slice_size) {
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
int64_t indices_i = i / slice_size;
int64_t slice_i = i - indices_i * slice_size; // offset inside the slice
CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size / VecSize, int64_t) {
const int vec_slice_size = slice_size / VecSize;
int indices_i = i / vec_slice_size;
int slice_i = i - indices_i * vec_slice_size; // offset inside the slice
IndexT gather_i = indices[indices_i];
int64_t params_i = gather_i * slice_size + slice_i;
*(output + i) = *(params + params_i);
int64_t params_i = gather_i * slice_size + slice_i * VecSize;
using LoadT = phi::AlignedVector<T, VecSize>;
using StoreT = phi::AlignedVector<T, VecSize>;
LoadT params_vec;
phi::Load<T, VecSize>(params + params_i, &params_vec);
phi::Store<T, VecSize>(params_vec, output + i * VecSize);
}
}

Expand Down
75 changes: 53 additions & 22 deletions paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
#include "paddle/phi/kernels/sparse/convolution_grad_kernel.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h"

Expand Down Expand Up @@ -137,28 +138,58 @@ void Conv3dGradGPUKernel(const GPUContext& dev_ctx,
}
}

auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * in_channels, 1);
GatherKernel<T, IntT><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(x.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len,
in_features_ptr,
rulebook_len,
in_channels);
const int VecSize = VecBytes / sizeof(T);
if (in_channels % VecSize == 0) {
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * in_channels / VecSize, 1);
GatherKernel<T, IntT, VecSize>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(x.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len,
in_features_ptr,
rulebook_len,
in_channels);
} else {
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * in_channels, 1);
GatherKernel<T, IntT, 1>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(x.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len,
in_features_ptr,
rulebook_len,
in_channels);
}

config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * out_channels, 1);
GatherKernel<T, IntT>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_grad.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len * 2,
out_grad_features_ptr,
rulebook_len,
out_channels);
if (out_channels % VecSize == 0) {
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * out_channels / VecSize, 1);
GatherKernel<T, IntT, VecSize>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_grad.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len * 2,
out_grad_features_ptr,
rulebook_len,
out_channels);
} else {
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * out_channels, 1);
GatherKernel<T, IntT, 1>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_grad.non_zero_elements().data<T>(),
rulebook_ptr + rulebook_len * 2,
out_grad_features_ptr,
rulebook_len,
out_channels);
}

const T* kernel_ptr = kernel.data<T>();
for (int i = 0; i < kernel_size; i++) {
Expand Down Expand Up @@ -203,7 +234,7 @@ void Conv3dGradGPUKernel(const GPUContext& dev_ctx,
}

// 4. scatter
config = phi::backends::gpu::GetGpuLaunchConfig1D(
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, rulebook_len * in_channels, 1);

phi::funcs::ScatterCUDAKernel<<<config.block_per_grid,
Expand Down
80 changes: 56 additions & 24 deletions paddle/phi/kernels/sparse/gpu/convolution_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,32 @@ void Conv3dGPUKernel(const GPUContext& dev_ctx,
phi::funcs::SetConstant<GPUContext, T> set_zero;
set_zero(dev_ctx, &out_features, static_cast<T>(0.0f));

auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * in_channels, 1);
GatherKernel<T, IntT><<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(x.non_zero_elements().data<T>(),
rulebook_ptr + n,
in_features_ptr,
n,
in_channels);
const int VecSize = VecBytes / sizeof(T);
if (in_channels % VecSize == 0) {
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, n * in_channels / VecSize, 1);
GatherKernel<T, IntT, VecSize>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(x.non_zero_elements().data<T>(),
rulebook_ptr + n,
in_features_ptr,
n,
in_channels);
} else {
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * in_channels, 1);
GatherKernel<T, IntT, 1>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(x.non_zero_elements().data<T>(),
rulebook_ptr + n,
in_features_ptr,
n,
in_channels);
}

// 3. call gemm for every werght
auto blas = phi::funcs::GetBlas<GPUContext, T>(dev_ctx);
Expand Down Expand Up @@ -155,7 +171,7 @@ void Conv3dGPUKernel(const GPUContext& dev_ctx,
// 4. scatter
if (subm) {
set_zero(dev_ctx, out_values, static_cast<T>(0.0f));
config =
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * out_channels, 1);
phi::funcs::ScatterCUDAKernel<T, IntT>
<<<config.block_per_grid,
Expand All @@ -168,19 +184,35 @@ void Conv3dGPUKernel(const GPUContext& dev_ctx,
out_channels,
false);
} else {
config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, out->nnz() * out_channels, 1);
phi::funcs::sparse::ScatterKernel<T>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_features_ptr,
unique_value.data<int>(),
out_index.data<int>(),
out->nnz(),
n,
out_channels,
out_values_ptr);
if (out_channels % VecSize == 0) {
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, out->nnz() * out_channels / VecSize, 1);
phi::funcs::sparse::ScatterKernel<T, VecSize>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_features_ptr,
unique_value.data<int>(),
out_index.data<int>(),
out->nnz(),
n,
out_channels,
out_values_ptr);
} else {
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, out->nnz() * out_channels, 1);
phi::funcs::sparse::ScatterKernel<T, 1>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
dev_ctx.stream()>>>(out_features_ptr,
unique_value.data<int>(),
out_index.data<int>(),
out->nnz(),
n,
out_channels,
out_values_ptr);
}
}
}
/**
Expand Down

0 comments on commit 64be38b

Please sign in to comment.