-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimization for StackGradCUDAKernel for last dimension stack case. #48992
Optimization for StackGradCUDAKernel for last dimension stack case. #48992
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
… optimization_for_stack_grad
… optimization_for_stack_grad
…format with pre-commit
if (is_valid) { | ||
if (out_datas[col_idx] != nullptr) { | ||
out_datas[col_idx][row_idx] = s_buf[s_idx]; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s_buf[s_idx];好像没啥作用呀,为啥不直接将data赋值给out_data,如果直接赋值那sync也需要去掉吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块想着是做访存合并,把数据全部读取到连续,存入转置的share_memory中,然后再连续读出来。
但是写的时候写挫了,没有实现连续读取.
const IndexT tile_x_num, | ||
T** out_datas) { | ||
constexpr int buffer_size = 512; | ||
__shared__ T s_buf[buffer_size]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里block的线程数最大是512?如果block大小是变的 可以考虑动态shared memory size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
设置的 block大小是可变的,最大值为512,所以在__global__ kernel
内部用的是个上限值.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
|
||
if (is_valid) { | ||
T data = in_data[row_idx * cols + col_idx]; | ||
s_buf[s_idx] = data; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以参考transpose写法,连续读,写shared memory的时候进行转置。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分的优化是顺带完成的,读取部分没有实现连续访存读取需要进一步性能优化;不过,暂时不是影响模型的核心原因,所以进一步优化,在这一阶段不会继续下去。
关于优化方案部分,采用transpose 的tile方案确实是可行的,但是这个OP在last dim上做stack时,tensor的数量一般是3, 5, 9 这类的,采用transpose tile方案会产生较多的空线程,可能优化完之后收益并不是特别大,所以后续还是在这版方案的基础上进一步做优化.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto col_divmod = phi::funcs::FastDivMod(out_num);
template <typename T, typename IndexT>
__global__ void StackGradKernelForLastDim(const T* __restrict__ in_data,
const int cols,
const IndexT rows,
const IndexT tile_x_num,
const IndexT numel,
const phi::funcs::FastDivMod divmoder,
T** out_datas) {
__shared__ T s_buf[1024];
int share_stride = blockDim.x + 1;
for (IndexT tile_x = blockIdx.x; tile_x < tile_x_num; tile_x += gridDim.x) {
int tid_in_block = threadIdx.y * blockDim.x + threadIdx.x;
auto result = divmoder.Divmod(tid_in_block);
IndexT tid = tile_x * blockDim.x * blockDim.y + tid_in_block;
if (tid < numel) {
int share_idx = result[1] * share_stride + result[0];
s_buf[share_idx] = in_data[tid];
}
IndexT row_idx = tile_x * blockDim.x + threadIdx.x;
int col_idx = blockIdx.y * blockDim.y + threadIdx.y;
__syncthreads();
if (col_idx < cols && row_idx < rows) {
int share_idx = threadIdx.y * share_stride + threadIdx.x;
if (out_datas[col_idx] != nullptr) {
out_datas[col_idx][row_idx] = s_buf[share_idx];
}
}
}
}
完成了这里的修改后,性能并没有明显的提升,反而相对合入的kernel的性能有所下降。但是这个kernel 已经实现了读写的访存连续。
目前合入的kernel中share_mem数据写入是线程连续的,现在我能想到的解释是:
- 在这种case下,由于Dev版本中的kernel内share_mem 的写入线程是连续的,产生了broadcast效应,而上述代码的share_mem数据写入则是由于每个数据之间是存在 blockDim.x + 1 的数据间隔 (stride),因此单纯从写入share_mem这个角度,Dev版本的性能较好;
- Dev版本中尽管数据读取并不连续,但是每个线程之间的数据stride多为3,tid_0读取data[0], tid_1读取data[3],这样通过cache完成数据读取的时候,读取性能并不是太差,不造成瓶颈。
两者叠加造成了目前的测试结果,当然我后面也需要深入看看ncu 把这个问题搞清楚,感谢 @zkh2016 指教.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM,先合入当前版本,后续PR再继续优化吧。
@@ -46,6 +46,9 @@ namespace phi { | |||
namespace backends { | |||
namespace gpu { | |||
|
|||
// Limitation of the setting in one dimension of cuda grid. | |||
constexpr int kMultiDimslimit = 65536; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个到底是什么的limit,变量命名上看不出来,后面再优化下,建议如kMaxGridSize
。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,这个是想表达多维线程设置情况下,每个线程的设置值的上限.
PR types
Performance optimization
PR changes
OPs
Describe
2D
grid-block config.