Skip to content

Commit

Permalink
cherry-pick 42645 (#43205)
Browse files Browse the repository at this point in the history
删除Broadcast function中rank例化以及Elementwise调用,降低编译时间。
从develop分支中的#42645 PR修改而来,由于develop分支与release分支相差较大,无法实现cherry-pick,因此针对release2.3重新提交PR.
Broadcast中关于rank的例化会导致底层模板展开较多,造成reduce_sum_grad_kernel.cu.o文件体积过大,修改后可以降低.o体积及编译时间
  • Loading branch information
AnnaTrainingG authored Jun 6, 2022
1 parent 40a7e0a commit 835a188
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 145 deletions.
11 changes: 5 additions & 6 deletions paddle/fluid/operators/fused/attn_bias_add.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ template <typename InT, typename OutT, int ShapeSize, int VecSize,
__global__ void BroadcastKernelBinary(
const InT* __restrict__ in0, const InT* __restrict__ in1, OutT* out,
phi::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel,
phi::Array<kps::details::BroadcastConfig<ShapeSize>, MAX_INPUT_NUM>
configlists,
phi::Array<kps::details::BroadcastConfig, MAX_INPUT_NUM> configlists,
int main_tid, int tail_tid, Functor func) {
int fix = blockIdx.x * blockDim.x * VecSize;
int num = tail_tid;
Expand All @@ -65,14 +64,14 @@ __global__ void BroadcastKernelBinary(

// load in0
if (use_broadcast[0]) {
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>(
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1>(
arg0, in0, fix, configlists[0], numel);
} else {
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg0, in0 + fix, num);
}
// load in1
if (use_broadcast[1]) {
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>(
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1>(
arg1, in1, fix, configlists[1], numel);
} else {
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg1, in1 + fix, num);
Expand Down Expand Up @@ -104,7 +103,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
int main_tid = numel / (data_per_thread * vec_size * threads);
int tail_tid = numel % (data_per_thread * vec_size * threads);

phi::Array<kps::details::BroadcastConfig<2>, MAX_INPUT_NUM> configlists;
phi::Array<kps::details::BroadcastConfig, MAX_INPUT_NUM> configlists;
phi::Array<bool, MAX_INPUT_NUM> use_broadcast;

use_broadcast[0] = false;
Expand All @@ -115,7 +114,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
// Here, dims are transposed due to the logic in BroadcastConfig.
std::vector<int64_t> input1_dims = {n, 1};
std::vector<int64_t> out_dims = {n, m};
configlists[1] = kps::details::BroadcastConfig<2>(out_dims, input1_dims, 2);
configlists[1] = kps::details::BroadcastConfig(out_dims, input1_dims, 2);

auto func = AddFunctor<T>();
auto stream = ctx.stream();
Expand Down
139 changes: 40 additions & 99 deletions paddle/phi/kernels/funcs/broadcast_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,19 +185,19 @@ struct DimensionsTransform {
}
};

template <typename T, int VecSize, int Rank, bool IsBoundary = false>
template <typename T, int VecSize, bool IsBoundary = false>
__device__ __forceinline__ void LoadData(
T *dst,
const _ptr_ T *src,
uint32_t block_offset,
const kps::details::BroadcastConfig<Rank> &config,
const kps::details::BroadcastConfig &config,
int numel,
int num,
int need_broadcast) {
// numel : whole num of output
// num: how many data will be deal with in this time
if (need_broadcast) {
kps::ReadDataBc<T, VecSize, 1, 1, Rank, IsBoundary>(
kps::ReadDataBc<T, VecSize, 1, 1, IsBoundary>(
dst, src, block_offset, config, numel);
} else {
kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num);
Expand All @@ -210,14 +210,13 @@ template <typename InT,
int Arity,
int NumOuts,
int VecSize,
int Rank,
bool IsBoundary = false>
__device__ void VectorizedBroadcastKernelImpl(
const phi::Array<const _ptr_ InT *__restrict__, Arity> &ins,
phi::Array<_ptr_ OutT *, NumOuts> outs,
const phi::Array<int, Arity> &use_broadcast,
uint32_t numel,
const phi::Array<kps::details::BroadcastConfig<Rank>, Arity> &configs,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
int num,
int block_offset,
Functor func) {
Expand All @@ -227,13 +226,13 @@ __device__ void VectorizedBroadcastKernelImpl(
#pragma unroll
for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
LoadData<InT, VecSize, Rank, IsBoundary>(args[i],
ins[i],
block_offset,
configs[i],
numel,
num,
use_broadcast[i]);
LoadData<InT, VecSize, IsBoundary>(args[i],
ins[i],
block_offset,
configs[i],
numel,
num,
use_broadcast[i]);
}
constexpr bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args;
Expand All @@ -254,14 +253,13 @@ template <typename InT,
typename Functor,
int Arity,
int NumOuts,
int VecSize,
int Rank>
int VecSize>
__global__ void VectorizedBroadcastKernel(
phi::Array<const _ptr_ InT *__restrict__, Arity> ins,
phi::Array<_ptr_ OutT *, NumOuts> outs,
phi::Array<int, Arity> use_broadcast,
uint32_t numel,
phi::Array<kps::details::BroadcastConfig<Rank>, Arity> configs,
phi::Array<kps::details::BroadcastConfig, Arity> configs,
int main_offset,
int tail_tid,
Functor func) {
Expand All @@ -276,7 +274,6 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
Rank,
false>(ins,
outs,
use_broadcast,
Expand All @@ -294,7 +291,6 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
Rank,
true>(
ins, outs, use_broadcast, numel, configs, num, block_offset, func);
}
Expand All @@ -306,7 +302,6 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
Rank,
false>(ins,
outs,
use_broadcast,
Expand All @@ -322,7 +317,6 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
Rank,
true>(
ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func);
}
Expand All @@ -334,15 +328,14 @@ template <typename InT,
typename Functor,
int Arity,
int NumOuts,
int VecSize,
int Rank>
int VecSize>
void LaunchBroadcastKernel(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
Functor func,
DimensionsTransform merge_dims) {
int numel = (*outs)[0]->numel();
phi::Array<kps::details::BroadcastConfig<Rank>, Arity> configs;
phi::Array<kps::details::BroadcastConfig, Arity> configs;
phi::Array<int, Arity> use_broadcast;
phi::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
phi::Array<_ptr_ OutT *, NumOuts> outs_data;
Expand All @@ -358,7 +351,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
configs[i] = kps::details::BroadcastConfig<Rank>(
configs[i] = kps::details::BroadcastConfig(
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
}
}
Expand All @@ -374,15 +367,14 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
Functor,
Arity,
NumOuts,
VecSize,
Rank><<<blocks, threads, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
VecSize><<<blocks, threads, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
#else
const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
Expand All @@ -394,58 +386,18 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
Functor,
Arity,
NumOuts,
VecSize,
Rank><<<blocks, threads, 0, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
VecSize><<<blocks, threads, 0, stream>>>(
ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
#endif
}

template <typename InT,
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize>
void BroadcastKernelForDifferentDimSize(
const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
int axis,
Functor func) {
const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);

#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \
case rank: { \
LaunchBroadcastKernel<InT, OutT, Functor, Arity, NumOuts, VecSize, rank>( \
ctx, ins, outs, func, merge_dims); \
} break;

switch (merge_dims.dim_size) {
CALL_BROADCAST_FOR_DIM_SIZE(1);
CALL_BROADCAST_FOR_DIM_SIZE(2);
CALL_BROADCAST_FOR_DIM_SIZE(3);
CALL_BROADCAST_FOR_DIM_SIZE(4);
CALL_BROADCAST_FOR_DIM_SIZE(5);
CALL_BROADCAST_FOR_DIM_SIZE(6);
CALL_BROADCAST_FOR_DIM_SIZE(7);
CALL_BROADCAST_FOR_DIM_SIZE(8);
default: {
PADDLE_THROW(phi::errors::InvalidArgument(
"The maximum dimension of input tensor is expected to be less than "
"%d, but recieved %d.",
merge_dims.dim_size,
phi::DDim::kMaxRank));
}
}
#undef CALL_BROADCAST_FOR_DIM_SIZE
}

template <ElementwiseType ET,
typename InT,
typename OutT,
Expand Down Expand Up @@ -506,33 +458,22 @@ void BroadcastKernelForDifferentVecSize(
: in_vec_size;
}
int vec_size = std::min(out_vec_size, in_vec_size);
const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);

switch (vec_size) {
case 4: {
BroadcastKernelForDifferentDimSize<InT,
OutT,
Functor,
kArity,
NumOuts,
4>(ctx, ins, outs, axis, func);
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, 4>(
ctx, ins, outs, func, merge_dims);
break;
}
case 2: {
BroadcastKernelForDifferentDimSize<InT,
OutT,
Functor,
kArity,
NumOuts,
2>(ctx, ins, outs, axis, func);
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, 2>(
ctx, ins, outs, func, merge_dims);
break;
}
case 1: {
BroadcastKernelForDifferentDimSize<InT,
OutT,
Functor,
kArity,
NumOuts,
1>(ctx, ins, outs, axis, func);
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, 1>(
ctx, ins, outs, func, merge_dims);
break;
}
default: {
Expand Down
Loading

0 comments on commit 835a188

Please sign in to comment.