Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

H2D data transfer optimization for stack kernel #48899

Merged
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 117 additions & 50 deletions paddle/phi/kernels/gpu/stack_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,93 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/fast_divmod.h"

namespace phi {

template <typename T, typename IntType>
__global__ void StackCUDAKernel(T** input_ptrs,
IntType split_size,
IntType rows,
IntType cols,
template <typename IndexT>
struct DivmodWarpper {
public:
__host__ void SetDivden(IndexT dividen) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确定是Divden吗?另外可以不用加__host__。这个和transpose里面的IdxHelper区别是什么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dividend代表被除数,定义DivmodWarpper的目的仅仅是为了改善IndexT int64_t 情况下的除法与取模计算的性能。transpose op里面的IdxHelper目的是为了映射一个数据从srd_dims 到 dst_dims 之间的关系

加入'host' 是代码比较工整对称,下个commit删除.

divmoder = phi::funcs::FastDivMod(dividen);
}
__device__ inline phi::funcs::FastDivMod::DivModT div_mod(IndexT val) {
return divmoder.Divmod(val);
}

private:
phi::funcs::FastDivMod divmoder;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉还是应该DivMod定义时特化,支持uint32_t类型的快速除法取模、普通整数类型的除法取模?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以增加一个特化,我会另启一个PR完成这项工作.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DivMod若能支持所有类型,这一层wrapper封装就没有必要了。另外,SetDivdendividen,肯定有一个写错了吧,以及FastDivMod里面用的是divisor,都是一个意思吧。

};

template <>
struct DivmodWarpper<int64_t> {
public:
using DivModT = phi::AlignedVector<int64_t, 2>;

__host__ void SetDivden(int64_t dividen) { dividen_ = dividen; }
__device__ inline DivModT div_mod(int64_t val) {
DivModT data;
data[0] = val / dividen_;
data[1] = val - data[0] * dividen_;
return data;
}

private:
int64_t dividen_;
};

constexpr int kWarpperSize = 256;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是对struct类型的长度限制,定义在类型里面吧。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改.

template <typename T, typename IndexT, bool IsDataWarpperd>
struct DataWarpper : public DivmodWarpper<IndexT> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 感觉DataWarpperDivmodWarpper之间的继承关系有点牵强啊。
  • DataWarpper -> PointerArray
  • 256这个数也定的太大了,实际应该很少有这么多个参数,应该可以设置的小一点。
  • 这个类型,应该也是考虑服用到concat等算子里面的吧?

Copy link
Contributor Author

@JamesLim-sy JamesLim-sy Dec 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: 感觉DataWarpper和DivmodWarpper之间的继承关系有点牵强啊。
A: 目前的优化点主要有2点,第一是用结构体打包的的方式传输数据,减少GPU时间;第二是利用快速除法计算替代原先的计算模式;

Q: DataWarpper -> PointerArray
A: 根据建议修改

Q: 256这个数也定的太大了,实际应该很少有这么多个参数,应该可以设置的小一点。
A: 256数字的设置目的是为了尽可能的减少phi::Copy 的使用可能性,根据AF2模型的数据判断,目前准备缩小到64.

Q: 这个类型,应该也是考虑服用到concat等算子里面的吧?
A: 目前还没有想好concat算子内的优化规则,先在stack内完成,后续做修改,大范围推广.

const T* data[kWarpperSize];
};

template <typename T, typename IndexT>
struct DataWarpper<T, IndexT, false> : public DivmodWarpper<IndexT> {
T** data;
};

template <typename Context, typename T>
T** PackDataAndTransfer(const Context& dev_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该函数,可以定义成struct的成员函数。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改

const std::vector<const DenseTensor*>& x,
int num) {
std::vector<const T*> x_datas(num);
for (int i = 0; i < num; ++i) {
x_datas[i] = x[i]->data<T>();
}
auto byte_len = num * sizeof(T*);
auto tmp_x_data = paddle::memory::Alloc(
dev_ctx.GetPlace(),
byte_len,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
paddle::memory::Copy(dev_ctx.GetPlace(),
tmp_x_data->ptr(),
phi::CPUPlace(),
reinterpret_cast<void*>(x_datas.data()),
byte_len,
dev_ctx.stream());
return reinterpret_cast<T**>(tmp_x_data->ptr());
}

template <typename T, typename IndexT, typename WarpT>
__global__ void StackCUDAKernel(WarpT input_warpper,
IndexT split_size,
IndexT rows,
IndexT cols,
T* __restrict__ output) {
IntType grid_x = static_cast<IntType>(blockIdx.x) * blockDim.x + threadIdx.x;
IntType grid_x_stride = static_cast<IntType>(blockDim.x) * gridDim.x;
IntType grid_y_stride = static_cast<IntType>(blockDim.y) * gridDim.y;
IndexT grid_x = static_cast<IndexT>(blockIdx.x) * blockDim.x + threadIdx.x;
IndexT grid_x_stride = static_cast<IndexT>(blockDim.x) * gridDim.x;
IndexT grid_y_stride = static_cast<IndexT>(blockDim.y) * gridDim.y;

for (; grid_x < cols; grid_x += grid_x_stride) {
IntType grid_y =
static_cast<IntType>(blockIdx.y) * blockDim.y + threadIdx.y;
IndexT grid_y = static_cast<IndexT>(blockIdx.y) * blockDim.y + threadIdx.y;

IntType split = grid_x / split_size;
const T* input_ptr = input_ptrs[split];
IntType col_offset = grid_x % split_size;
auto divmod_rslt = input_warpper.div_mod(grid_x);
const T* input_ptr = input_warpper.data[divmod_rslt[0]];
#pragma unroll
for (; grid_y < rows; grid_y += grid_y_stride) {
output[grid_y * cols + grid_x] =
input_ptr[grid_y * split_size + col_offset];
input_ptr[grid_y * split_size + divmod_rslt[1]];
}
}
}
Expand All @@ -52,24 +115,8 @@ void StackKernel(const Context& dev_ctx,
int axis,
DenseTensor* out) {
if (axis < 0) axis += (x[0]->dims().size() + 1);

int n = static_cast<int>(x.size());
T* y_data = dev_ctx.template Alloc<T>(out);
std::vector<const T*> x_datas(n);
for (int i = 0; i < n; i++) {
x_datas[i] = x[i]->data<T>();
}

auto tmp_x_data = paddle::memory::Alloc(
dev_ctx.GetPlace(),
x_datas.size() * sizeof(T*),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
paddle::memory::Copy(dev_ctx.GetPlace(),
tmp_x_data->ptr(),
phi::CPUPlace(),
reinterpret_cast<void*>(x_datas.data()),
x_datas.size() * sizeof(T*),
dev_ctx.stream());

// Split x dim from axis to matrix
int64_t x_row = 1, x_col = 1;
Expand All @@ -78,33 +125,53 @@ void StackKernel(const Context& dev_ctx,
}
x_col = x[0]->numel() / x_row;
int64_t out_col = x_col * n;

auto config =
phi::backends::gpu::GetGpuLaunchConfig2D(dev_ctx, out_col, x_row);

#define IMPL_STACK_CUDA_KERNEL(index_t, input_warpper) \
StackCUDAKernel<T, index_t, decltype(input_warpper)> \
<<<config.block_per_grid, \
config.thread_per_block, \
0, \
dev_ctx.stream()>>>(input_warpper, \
static_cast<index_t>(x_col), \
static_cast<index_t>(x_row), \
static_cast<index_t>(out_col), \
y_data);

if (out->numel() < std::numeric_limits<int32_t>::max()) {
StackCUDAKernel<T, int32_t>
<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(reinterpret_cast<T**>(tmp_x_data->ptr()),
static_cast<int32_t>(x_col),
static_cast<int32_t>(x_row),
static_cast<int32_t>(out_col),
y_data);
if (n <= kWarpperSize) {
DataWarpper<T, int32_t, true> data_warpper;
for (auto i = 0; i < n; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

初始化方法可以定义成struct的成员函数。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据建议修改

data_warpper.data[i] = x[i]->data<T>();
}
data_warpper.SetDivden(x_col);
IMPL_STACK_CUDA_KERNEL(int32_t, data_warpper);
} else {
DataWarpper<T, int32_t, false> data_warpper;
T** pack_ptr = PackDataAndTransfer<Context, T>(dev_ctx, x, n);
data_warpper.data = pack_ptr;
data_warpper.SetDivden(x_col);
IMPL_STACK_CUDA_KERNEL(int32_t, data_warpper);
}
} else {
StackCUDAKernel<T, int64_t>
<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(reinterpret_cast<T**>(tmp_x_data->ptr()),
x_col,
x_row,
out_col,
y_data);
if (n <= kWarpperSize) {
DataWarpper<T, int64_t, true> data_warpper;
for (auto i = 0; i < n; ++i) {
data_warpper.data[i] = x[i]->data<T>();
}
data_warpper.SetDivden(x_col);
IMPL_STACK_CUDA_KERNEL(int64_t, data_warpper);
} else {
DataWarpper<T, int64_t, false> data_warpper;
T** pack_ptr = PackDataAndTransfer<Context, T>(dev_ctx, x, n);
data_warpper.data = pack_ptr;
data_warpper.SetDivden(x_col);
IMPL_STACK_CUDA_KERNEL(int64_t, data_warpper);
}
}
#undef IMPL_STACK_CUDA_KERNEL
}

} // namespace phi

PD_REGISTER_KERNEL(stack,
Expand Down