-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
H2D data transfer optimization for stack kernel #48899
Changes from 3 commits
a5fd9ea
945bf4b
da6ba39
441bb1e
ff39bba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,30 +18,93 @@ | |
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/backends/gpu/gpu_launch_config.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/funcs/fast_divmod.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename IntType> | ||
__global__ void StackCUDAKernel(T** input_ptrs, | ||
IntType split_size, | ||
IntType rows, | ||
IntType cols, | ||
template <typename IndexT> | ||
struct DivmodWarpper { | ||
public: | ||
__host__ void SetDivden(IndexT dividen) { | ||
divmoder = phi::funcs::FastDivMod(dividen); | ||
} | ||
__device__ inline phi::funcs::FastDivMod::DivModT div_mod(IndexT val) { | ||
return divmoder.Divmod(val); | ||
} | ||
|
||
private: | ||
phi::funcs::FastDivMod divmoder; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 感觉还是应该 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以增加一个特化,我会另启一个PR完成这项工作. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
}; | ||
|
||
template <> | ||
struct DivmodWarpper<int64_t> { | ||
public: | ||
using DivModT = phi::AlignedVector<int64_t, 2>; | ||
|
||
__host__ void SetDivden(int64_t dividen) { dividen_ = dividen; } | ||
__device__ inline DivModT div_mod(int64_t val) { | ||
DivModT data; | ||
data[0] = val / dividen_; | ||
data[1] = val - data[0] * dividen_; | ||
return data; | ||
} | ||
|
||
private: | ||
int64_t dividen_; | ||
}; | ||
|
||
constexpr int kWarpperSize = 256; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是对 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 根据建议修改. |
||
template <typename T, typename IndexT, bool IsDataWarpperd> | ||
struct DataWarpper : public DivmodWarpper<IndexT> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q: 感觉DataWarpper和DivmodWarpper之间的继承关系有点牵强啊。 Q: DataWarpper -> PointerArray Q: 256这个数也定的太大了,实际应该很少有这么多个参数,应该可以设置的小一点。 Q: 这个类型,应该也是考虑服用到concat等算子里面的吧? |
||
const T* data[kWarpperSize]; | ||
}; | ||
|
||
template <typename T, typename IndexT> | ||
struct DataWarpper<T, IndexT, false> : public DivmodWarpper<IndexT> { | ||
T** data; | ||
}; | ||
|
||
template <typename Context, typename T> | ||
T** PackDataAndTransfer(const Context& dev_ctx, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 该函数,可以定义成struct的成员函数。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 根据建议修改 |
||
const std::vector<const DenseTensor*>& x, | ||
int num) { | ||
std::vector<const T*> x_datas(num); | ||
for (int i = 0; i < num; ++i) { | ||
x_datas[i] = x[i]->data<T>(); | ||
} | ||
auto byte_len = num * sizeof(T*); | ||
auto tmp_x_data = paddle::memory::Alloc( | ||
dev_ctx.GetPlace(), | ||
byte_len, | ||
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); | ||
paddle::memory::Copy(dev_ctx.GetPlace(), | ||
tmp_x_data->ptr(), | ||
phi::CPUPlace(), | ||
reinterpret_cast<void*>(x_datas.data()), | ||
byte_len, | ||
dev_ctx.stream()); | ||
return reinterpret_cast<T**>(tmp_x_data->ptr()); | ||
} | ||
|
||
template <typename T, typename IndexT, typename WarpT> | ||
__global__ void StackCUDAKernel(WarpT input_warpper, | ||
IndexT split_size, | ||
IndexT rows, | ||
IndexT cols, | ||
T* __restrict__ output) { | ||
IntType grid_x = static_cast<IntType>(blockIdx.x) * blockDim.x + threadIdx.x; | ||
IntType grid_x_stride = static_cast<IntType>(blockDim.x) * gridDim.x; | ||
IntType grid_y_stride = static_cast<IntType>(blockDim.y) * gridDim.y; | ||
IndexT grid_x = static_cast<IndexT>(blockIdx.x) * blockDim.x + threadIdx.x; | ||
IndexT grid_x_stride = static_cast<IndexT>(blockDim.x) * gridDim.x; | ||
IndexT grid_y_stride = static_cast<IndexT>(blockDim.y) * gridDim.y; | ||
|
||
for (; grid_x < cols; grid_x += grid_x_stride) { | ||
IntType grid_y = | ||
static_cast<IntType>(blockIdx.y) * blockDim.y + threadIdx.y; | ||
IndexT grid_y = static_cast<IndexT>(blockIdx.y) * blockDim.y + threadIdx.y; | ||
|
||
IntType split = grid_x / split_size; | ||
const T* input_ptr = input_ptrs[split]; | ||
IntType col_offset = grid_x % split_size; | ||
auto divmod_rslt = input_warpper.div_mod(grid_x); | ||
const T* input_ptr = input_warpper.data[divmod_rslt[0]]; | ||
#pragma unroll | ||
for (; grid_y < rows; grid_y += grid_y_stride) { | ||
output[grid_y * cols + grid_x] = | ||
input_ptr[grid_y * split_size + col_offset]; | ||
input_ptr[grid_y * split_size + divmod_rslt[1]]; | ||
} | ||
} | ||
} | ||
|
@@ -52,24 +115,8 @@ void StackKernel(const Context& dev_ctx, | |
int axis, | ||
DenseTensor* out) { | ||
if (axis < 0) axis += (x[0]->dims().size() + 1); | ||
|
||
int n = static_cast<int>(x.size()); | ||
T* y_data = dev_ctx.template Alloc<T>(out); | ||
std::vector<const T*> x_datas(n); | ||
for (int i = 0; i < n; i++) { | ||
x_datas[i] = x[i]->data<T>(); | ||
} | ||
|
||
auto tmp_x_data = paddle::memory::Alloc( | ||
dev_ctx.GetPlace(), | ||
x_datas.size() * sizeof(T*), | ||
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); | ||
paddle::memory::Copy(dev_ctx.GetPlace(), | ||
tmp_x_data->ptr(), | ||
phi::CPUPlace(), | ||
reinterpret_cast<void*>(x_datas.data()), | ||
x_datas.size() * sizeof(T*), | ||
dev_ctx.stream()); | ||
|
||
// Split x dim from axis to matrix | ||
int64_t x_row = 1, x_col = 1; | ||
|
@@ -78,33 +125,53 @@ void StackKernel(const Context& dev_ctx, | |
} | ||
x_col = x[0]->numel() / x_row; | ||
int64_t out_col = x_col * n; | ||
|
||
auto config = | ||
phi::backends::gpu::GetGpuLaunchConfig2D(dev_ctx, out_col, x_row); | ||
|
||
#define IMPL_STACK_CUDA_KERNEL(index_t, input_warpper) \ | ||
StackCUDAKernel<T, index_t, decltype(input_warpper)> \ | ||
<<<config.block_per_grid, \ | ||
config.thread_per_block, \ | ||
0, \ | ||
dev_ctx.stream()>>>(input_warpper, \ | ||
static_cast<index_t>(x_col), \ | ||
static_cast<index_t>(x_row), \ | ||
static_cast<index_t>(out_col), \ | ||
y_data); | ||
|
||
if (out->numel() < std::numeric_limits<int32_t>::max()) { | ||
StackCUDAKernel<T, int32_t> | ||
<<<config.block_per_grid, | ||
config.thread_per_block, | ||
0, | ||
dev_ctx.stream()>>>(reinterpret_cast<T**>(tmp_x_data->ptr()), | ||
static_cast<int32_t>(x_col), | ||
static_cast<int32_t>(x_row), | ||
static_cast<int32_t>(out_col), | ||
y_data); | ||
if (n <= kWarpperSize) { | ||
DataWarpper<T, int32_t, true> data_warpper; | ||
for (auto i = 0; i < n; ++i) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 初始化方法可以定义成struct的成员函数。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 根据建议修改 |
||
data_warpper.data[i] = x[i]->data<T>(); | ||
} | ||
data_warpper.SetDivden(x_col); | ||
IMPL_STACK_CUDA_KERNEL(int32_t, data_warpper); | ||
} else { | ||
DataWarpper<T, int32_t, false> data_warpper; | ||
T** pack_ptr = PackDataAndTransfer<Context, T>(dev_ctx, x, n); | ||
data_warpper.data = pack_ptr; | ||
data_warpper.SetDivden(x_col); | ||
IMPL_STACK_CUDA_KERNEL(int32_t, data_warpper); | ||
} | ||
} else { | ||
StackCUDAKernel<T, int64_t> | ||
<<<config.block_per_grid, | ||
config.thread_per_block, | ||
0, | ||
dev_ctx.stream()>>>(reinterpret_cast<T**>(tmp_x_data->ptr()), | ||
x_col, | ||
x_row, | ||
out_col, | ||
y_data); | ||
if (n <= kWarpperSize) { | ||
DataWarpper<T, int64_t, true> data_warpper; | ||
for (auto i = 0; i < n; ++i) { | ||
data_warpper.data[i] = x[i]->data<T>(); | ||
} | ||
data_warpper.SetDivden(x_col); | ||
IMPL_STACK_CUDA_KERNEL(int64_t, data_warpper); | ||
} else { | ||
DataWarpper<T, int64_t, false> data_warpper; | ||
T** pack_ptr = PackDataAndTransfer<Context, T>(dev_ctx, x, n); | ||
data_warpper.data = pack_ptr; | ||
data_warpper.SetDivden(x_col); | ||
IMPL_STACK_CUDA_KERNEL(int64_t, data_warpper); | ||
} | ||
} | ||
#undef IMPL_STACK_CUDA_KERNEL | ||
} | ||
|
||
} // namespace phi | ||
|
||
PD_REGISTER_KERNEL(stack, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确定是
Divden
吗?另外可以不用加__host__
。这个和transpose里面的IdxHelper
区别是什么?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dividend
代表被除数,定义DivmodWarpper
的目的仅仅是为了改善IndexT
非int64_t
情况下的除法与取模计算的性能。transpose op
里面的IdxHelper目的是为了映射一个数据从srd_dims 到 dst_dims 之间的关系加入'host' 是代码比较工整对称,下个commit删除.