From b6cedc3e2dbc38faeb93ad78d59a0317f9fcbaca Mon Sep 17 00:00:00 2001 From: YibinLiu666 <2632839426@qq.com> Date: Mon, 20 Nov 2023 08:50:59 +0000 Subject: [PATCH 01/14] fix behavior of put_along_axis and take_along_axis --- .../kernels/cpu/put_along_axis_grad_kernel.cc | 4 +- .../kernels/funcs/gather_scatter_functor.cc | 100 +++++++- .../kernels/funcs/gather_scatter_functor.cu | 217 +++++++++++++++--- .../kernels/funcs/gather_scatter_functor.h | 14 ++ .../kernels/gpu/put_along_axis_grad_kernel.cu | 10 +- python/paddle/tensor/manipulation.py | 79 ++++--- test/legacy_test/test_put_along_axis_op.py | 146 ++++++++++-- test/legacy_test/test_take_along_axis_op.py | 41 +++- 8 files changed, 501 insertions(+), 110 deletions(-) diff --git a/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc b/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc index d44af05357a9a4..f0a1118ca92d7d 100644 --- a/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc @@ -60,10 +60,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, value_grad->Resize(index.dims()); dev_ctx.template Alloc(value_grad); if (index_type == DataType::INT32) { - phi::funcs::cpu_gather_kernel( + phi::funcs::cpu_scatter_value_grad_kernel( out_grad, axis, index, *value_grad, dev_ctx); } else if (index_type == DataType::INT64) { - phi::funcs::cpu_gather_kernel( + phi::funcs::cpu_scatter_value_grad_kernel( out_grad, axis, index, *value_grad, dev_ctx); } } diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.cc b/paddle/phi/kernels/funcs/gather_scatter_functor.cc index 597b8f231760bf..e4f7864cb34e1f 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.cc +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cc @@ -80,8 +80,10 @@ struct cpu_gather_scatter_functor { } int64_t select_dim_size = index_dims[dim]; // index matrix has different shape with self matrix or src matrix. - int replaced_select_dim_size = - is_scatter_like ? self_dims[dim] : src_dims[dim]; + int self_select_dim_size = self_dims[dim]; + int src_select_dim_size = src_dims[dim]; + int64_t outer_dim_size_self = 1; + int64_t outer_dim_size_src = 1; int64_t inner_dim_size = 1; int64_t outer_dim_size = 1; for (int i = 0; i < dim; ++i) { @@ -90,10 +92,10 @@ struct cpu_gather_scatter_functor { for (int i = dim + 1; i < index_dims.size(); i++) { outer_dim_size *= index_dims[i]; + outer_dim_size_self *= self_dims[i]; + outer_dim_size_src *= src_dims[i]; } int64_t index_idx = 0; - int64_t self_idx = 0, src_idx = 0; - // N layer loop squeezed into 3 layers loop for (int64_t i = 0; i < inner_dim_size; i++) { for (int64_t j = 0; j < select_dim_size; j++) { @@ -117,13 +119,21 @@ struct cpu_gather_scatter_functor { // This index might out of bound of index matrix's index, so here // multiply the replaced_select_dim_size. - int64_t replace_index = k + index * outer_dim_size + - i * outer_dim_size * replaced_select_dim_size; + int64_t replace_index_self, replace_index_src; + if (is_scatter_like) { + replace_index_self = k + index * outer_dim_size_self + + i * outer_dim_size_self * self_select_dim_size; + + replace_index_src = k + j * outer_dim_size_src + + i * outer_dim_size_src * src_select_dim_size; + } else { + replace_index_self = index_idx; - self_idx = is_scatter_like ? replace_index : index_idx; - src_idx = is_scatter_like ? index_idx : replace_index; - reduce_op((tensor_t*)(self_data + self_idx), // NOLINT - (tensor_t*)(src_data + src_idx)); // NOLINT + replace_index_src = k + index * outer_dim_size_src + + i * outer_dim_size_src * src_select_dim_size; + } + reduce_op((tensor_t*)(self_data + replace_index_self), // NOLINT + (tensor_t*)(src_data + replace_index_src)); // NOLINT index_idx++; } } @@ -193,6 +203,7 @@ void cpu_scatter_input_grad_kernel(phi::DenseTensor self UNUSED, int64_t inner_dim_size = 1; int64_t outer_dim_size = 1; + int64_t outer_dim_size_data = 1; int64_t select_dim_size = index_dims[dim]; int64_t output_select_dim_size = output_dims[dim]; for (int i = 0; i < dim; ++i) { @@ -201,6 +212,7 @@ void cpu_scatter_input_grad_kernel(phi::DenseTensor self UNUSED, for (int i = dim + 1; i < index_dims.size(); i++) { outer_dim_size *= index_dims[i]; + outer_dim_size_data *= output_dims[i]; } int64_t index_idx = 0; @@ -208,8 +220,9 @@ void cpu_scatter_input_grad_kernel(phi::DenseTensor self UNUSED, for (int64_t j = 0; j < select_dim_size; j++) { for (int64_t k = 0; k < outer_dim_size; k++) { int64_t index = index_data[index_idx]; - int64_t replace_index = k + index * outer_dim_size + - i * outer_dim_size * output_select_dim_size; + int64_t replace_index = + k + index * outer_dim_size_data + + i * outer_dim_size_data * output_select_dim_size; output_data[replace_index] = 0; index_idx++; } @@ -217,11 +230,74 @@ void cpu_scatter_input_grad_kernel(phi::DenseTensor self UNUSED, } } +template +void cpu_scatter_value_grad_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor output, + const phi::DeviceContext& ctx UNUSED) { + auto* self_data = self.data(); + auto* index_data = index.data(); + auto* output_data = output.data(); + + auto index_dims = index.dims(); + auto self_dims = self.dims(); + auto output_dims = output.dims(); + + int64_t self_size = self.numel(); + bool* is_self_grad_used = new bool[self_size]; + + for (int i = 0; i < self_size; i++) { + is_self_grad_used[i] = false; + } + + int64_t inner_dim_size = 1; + int64_t outer_dim_size = 1; + int64_t outer_dim_size_self = 1; + int64_t outer_dim_size_output = 1; + int64_t select_dim_size = index_dims[dim]; + int64_t self_select_dim_size = self_dims[dim]; + int64_t output_select_dim_size = output_dims[dim]; + for (int i = 0; i < dim; ++i) { + inner_dim_size *= index_dims[i]; + } + + for (int i = dim + 1; i < index_dims.size(); i++) { + outer_dim_size *= index_dims[i]; + outer_dim_size_self *= self_dims[i]; + outer_dim_size_output *= output_dims[i]; + } + + int64_t index_idx = index.numel() - 1; + for (int64_t i = inner_dim_size - 1; i >= 0; i--) { + for (int64_t j = select_dim_size - 1; j >= 0; j--) { + for (int64_t k = outer_dim_size - 1; k >= 0; k--) { + int64_t index = index_data[index_idx]; + int64_t replace_index_self = + k + index * outer_dim_size_self + + i * outer_dim_size_self * self_select_dim_size; + int64_t replace_index_output = + k + j * outer_dim_size_output + + i * outer_dim_size_output * output_select_dim_size; + if (!is_self_grad_used[replace_index_self]) { + output_data[replace_index_output] = self_data[replace_index_self]; + is_self_grad_used[replace_index_self] = true; + } else { + output_data[replace_index_output] = 0; + } + index_idx--; + } + } + } + delete[] is_self_grad_used; +} + Instantiate_Template_Function(cpu_gather_kernel) Instantiate_Template_Function(cpu_scatter_assign_kernel) Instantiate_Template_Function(cpu_scatter_add_kernel) Instantiate_Template_Function(cpu_scatter_mul_kernel) Instantiate_Template_Function(cpu_scatter_input_grad_kernel) + Instantiate_Template_Function(cpu_scatter_value_grad_kernel) } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.cu b/paddle/phi/kernels/funcs/gather_scatter_functor.cu index b53de3beef9aa4..b75a1888d42cec 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.cu +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cu @@ -62,14 +62,25 @@ __global__ void GatherScatterGPUKernel(tensor_t* self_data, int dim, const index_t* index_data, tensor_t* src_data, - int64_t inner_dim_size, int select_dim_size, - int replaced_select_dim_size, + int self_select_dim_size, + int src_select_dim_size, int64_t outer_dim_size, + int64_t outer_dim_size_self, + int64_t outer_dim_size_src, int64_t numel, + int64_t numel_data, const func_t& reduce_op) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; + extern __shared__ int thread_ids[]; + + if (tid == 0) { + for (int i = 0; i < numel_data; i++) { + thread_ids[i] = 0; + } + } + __syncthreads(); int64_t i, j, k; // The i, j, k here is the index of the 3 layers loop // squeezed from the N layers loop. /* tid = i * select_dim_size * outer_dim_size + j * outer_dim_size + k */ @@ -93,12 +104,27 @@ __global__ void GatherScatterGPUKernel(tensor_t* self_data, */ // index matrix has different shape with self matrix or src matrix. - int64_t replace_index = k + index * outer_dim_size + - i * outer_dim_size * replaced_select_dim_size; - int64_t self_idx = is_scatter_like ? replace_index : tid; - int64_t src_idx = is_scatter_like ? tid : replace_index; - reduce_op(static_cast(self_data + self_idx), - static_cast(src_data + src_idx)); + int64_t replace_index_self, replace_index_src; + if (is_scatter_like) { + replace_index_self = k + index * outer_dim_size_self + + i * outer_dim_size_self * self_select_dim_size; + + replace_index_src = k + j * outer_dim_size_src + + i * outer_dim_size_src * src_select_dim_size; + } else { + replace_index_self = tid; + + replace_index_src = k + index * outer_dim_size_src + + i * outer_dim_size_src * src_select_dim_size; + } + + atomicMax(thread_ids + replace_index_self, tid); + __syncthreads(); + + if (tid == thread_ids[replace_index_self]) { + reduce_op(static_cast(self_data + replace_index_self), + static_cast(src_data + replace_index_src)); + } } template (ctx).stream(); + int shared_mem_size = + is_scatter_like ? sizeof(int) * self_size : sizeof(int) * index_size; GatherScatterGPUKernel - <<>>(self_data, - dim, - index_data, - src_data, - inner_dim_size, - select_dim_size, - replaced_select_dim_size, - outer_dim_size, - index_size, - reduce_op); + <<>>(self_data, + dim, + index_data, + src_data, + select_dim_size, + self_select_dim_size, + src_select_dim_size, + outer_dim_size, + outer_dim_size_self, + outer_dim_size_src, + index_size, + self_size, + reduce_op); } }; // struct gpu_gather_scatter_functor @@ -211,22 +246,37 @@ template __global__ void ScatterInputGradGPUKernel(tensor_t* grad_data, int dim, const index_t* index_data, - int64_t inner_dim_size, int select_dim_size, int grad_select_dim_size, int64_t outer_dim_size, - int64_t numel) { + int64_t outer_dim_size_data, + int64_t numel, + int64_t numel_data) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; + extern __shared__ int thread_ids[]; + + if (tid == 0) { + for (int i = 0; i < numel_data; i++) { + thread_ids[i] = 0; + } + } + __syncthreads(); int64_t i, j, k; i = tid / (select_dim_size * outer_dim_size); int64_t remind = tid % (select_dim_size * outer_dim_size); j = remind / outer_dim_size; k = remind % outer_dim_size; index_t index = index_data[tid]; - int64_t replace_index = - k + index * outer_dim_size + i * outer_dim_size * grad_select_dim_size; - grad_data[replace_index] = 0; + int64_t replace_index = k + index * outer_dim_size_data + + i * outer_dim_size_data * grad_select_dim_size; + + atomicMax(thread_ids + replace_index, tid); + __syncthreads(); + + if (tid == thread_ids[replace_index]) { + grad_data[replace_index] = 0; + } } template void gpu_scatter_input_grad_kernel(phi::DenseTensor self, @@ -240,9 +290,11 @@ void gpu_scatter_input_grad_kernel(phi::DenseTensor self, auto index_dims = index.dims(); auto grad_dims = grad.dims(); int64_t index_size = index.numel(); + int64_t grad_size = grad.numel(); int64_t inner_dim_size = 1; int64_t outer_dim_size = 1; + int64_t outer_dim_size_data = 1; int select_dim_size = index_dims[dim]; int grad_select_dim_size = grad_dims[dim]; for (int64_t i = 0; i < dim; ++i) { @@ -251,28 +303,125 @@ void gpu_scatter_input_grad_kernel(phi::DenseTensor self, for (int i = dim + 1; i < index_dims.size(); i++) { outer_dim_size *= index_dims[i]; + outer_dim_size_data *= grad_dims[i]; } int block = 512; int64_t n = inner_dim_size * select_dim_size * outer_dim_size; int64_t grid = (n + block - 1) / block; auto stream = reinterpret_cast(ctx).stream(); - + int shared_mem_size = sizeof(int) * grad_size; ScatterInputGradGPUKernel - <<>>(grad_data, - dim, - index_data, - inner_dim_size, - select_dim_size, - grad_select_dim_size, - outer_dim_size, - index_size); + <<>>(grad_data, + dim, + index_data, + select_dim_size, + grad_select_dim_size, + outer_dim_size, + outer_dim_size_data, + index_size, + grad_size); +} + +template +__global__ void ScatterValueGradGPUKernel(tensor_t* grad_data, + int dim, + const tensor_t* self_data, + const index_t* index_data, + int select_dim_size, + int self_select_dim_size, + int grad_select_dim_size, + int64_t outer_dim_size, + int64_t outer_dim_size_self, + int64_t outer_dim_size_grad, + int64_t numel, + int64_t numel_data) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= numel) return; + extern __shared__ int thread_ids[]; + + if (tid == 0) { + for (int i = 0; i < numel_data; i++) { + thread_ids[i] = 0; + } + } + __syncthreads(); + int64_t i, j, k; + i = tid / (select_dim_size * outer_dim_size); + int64_t remind = tid % (select_dim_size * outer_dim_size); + j = remind / outer_dim_size; + k = remind % outer_dim_size; + index_t index = index_data[tid]; + int64_t replace_index_self = k + index * outer_dim_size_self + + i * outer_dim_size_self * self_select_dim_size; + + atomicMax(thread_ids + replace_index_self, tid); + __syncthreads(); + + if (tid == thread_ids[replace_index_self]) { + int64_t replace_index_grad = k + j * outer_dim_size_grad + + i * outer_dim_size_grad * grad_select_dim_size; + grad_data[replace_index_grad] = self_data[replace_index_self]; + } +} +template +void gpu_scatter_value_grad_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor grad, + const phi::DeviceContext& ctx) { + auto* self_data = self.data(); + auto* index_data = index.data(); + auto* grad_data = grad.data(); + + auto index_dims = index.dims(); + auto self_dims = self.dims(); + auto grad_dims = grad.dims(); + int64_t index_size = index.numel(); + int64_t self_size = self.numel(); + + int64_t inner_dim_size = 1; + int64_t outer_dim_size = 1; + int64_t outer_dim_size_self = 1; + int64_t outer_dim_size_grad = 1; + int select_dim_size = index_dims[dim]; + int self_select_dim_size = self_dims[dim]; + int grad_select_dim_size = grad_dims[dim]; + for (int64_t i = 0; i < dim; ++i) { + inner_dim_size *= index_dims[i]; + } + + for (int i = dim + 1; i < index_dims.size(); i++) { + outer_dim_size *= index_dims[i]; + outer_dim_size_self *= self_dims[i]; + outer_dim_size_grad *= grad_dims[i]; + } + + int block = 512; + int64_t n = inner_dim_size * select_dim_size * outer_dim_size; + int64_t grid = (n + block - 1) / block; + auto stream = reinterpret_cast(ctx).stream(); + int shared_mem_size = sizeof(int) * self_size; + ScatterValueGradGPUKernel + <<>>(grad_data, + dim, + self_data, + index_data, + select_dim_size, + self_select_dim_size, + grad_select_dim_size, + outer_dim_size, + outer_dim_size_self, + outer_dim_size_grad, + index_size, + self_size); } Instantiate_Template_Function(gpu_gather_kernel) Instantiate_Template_Function(gpu_scatter_assign_kernel) Instantiate_Template_Function(gpu_scatter_add_kernel) Instantiate_Template_Function(gpu_scatter_mul_kernel) Instantiate_Template_Function(gpu_scatter_input_grad_kernel) + Instantiate_Template_Function(gpu_scatter_value_grad_kernel) } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.h b/paddle/phi/kernels/funcs/gather_scatter_functor.h index 56068f9459ebd5..65930a1a11e1ed 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.h +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.h @@ -78,6 +78,13 @@ void cpu_scatter_input_grad_kernel(phi::DenseTensor self, phi::DenseTensor result, const phi::DeviceContext& ctx); +template +void cpu_scatter_value_grad_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor output, + const phi::DeviceContext& ctx); + template void gpu_gather_kernel(phi::DenseTensor self, int dim, @@ -113,5 +120,12 @@ void gpu_scatter_input_grad_kernel(phi::DenseTensor self, phi::DenseTensor result, const phi::DeviceContext& ctx); +template +void gpu_scatter_value_grad_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor grad, + const phi::DeviceContext& ctx); + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu b/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu index 8321bcd1aa7acf..ab380351610adf 100644 --- a/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu @@ -52,14 +52,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, value_grad->Resize(index.dims()); dev_ctx.template Alloc(value_grad); if (index_type == DataType::INT32) { - phi::funcs::gpu_gather_kernel( - out_grad, - axis, - index, - *value_grad, - dev_ctx); // the gradient of scatter is gather + phi::funcs::gpu_scatter_value_grad_kernel( + out_grad, axis, index, *value_grad, dev_ctx); } else if (index_type == DataType::INT64) { - phi::funcs::gpu_gather_kernel( + phi::funcs::gpu_scatter_value_grad_kernel( out_grad, axis, index, *value_grad, dev_ctx); } } diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 685b10276c476f..ddd906e661b0ab 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -5150,24 +5150,30 @@ def take_along_axis(arr, indices, axis): >>> axis = 0 >>> result = paddle.take_along_axis(x, index, axis) >>> print(result) - Tensor(shape=[1, 3], dtype=int64, place=Place(cpu), stop_gradient=True, - [[1, 2, 3]]) + Tensor(shape=[1, 1], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [[1]]) """ if len(arr.shape) != len(indices.shape): raise ValueError( "`indices` and `arr` must have the same number of dimensions!" ) axis = non_negative_axis(arr, axis) - broadcast_shape = infer_broadcast_shape(arr, indices, axis) - if not broadcast_shape: - # if indices matrix have larger size than arr, arr should broadcast into indices shape. - broadcast_shape = indices.shape + for i in range(len(arr.shape)): + if i != axis and arr.shape[i] < indices.shape[i]: + raise RuntimeError( + "Size does not match at dimension {} expected index {} to be smaller than self {} apart from dimension {}".format( + i, indices.shape, arr.shape, axis + ) + ) + + axis_max_size = arr.shape[axis] + if not (indices < axis_max_size).all(): + raise RuntimeError( + "one of element of indices is out of bounds for dimension {} with size {}".format( + axis, axis_max_size + ) + ) if in_dynamic_or_pir_mode(): - indices = paddle.broadcast_to(indices, broadcast_shape) - broadcast_shape_list = list(broadcast_shape) - broadcast_shape_list[axis] = list(arr.shape)[axis] - broadcast_shape = tuple(broadcast_shape_list) - arr = paddle.broadcast_to(arr, broadcast_shape) return _C_ops.take_along_axis(arr, indices, axis) else: check_variable_and_dtype( @@ -5187,11 +5193,6 @@ def take_along_axis(arr, indices, axis): check_variable_and_dtype( indices, 'index', ['int32', 'int64'], 'take_along_axis' ) - indices = paddle.broadcast_to(indices, broadcast_shape) - broadcast_shape_list = list(broadcast_shape) - broadcast_shape_list[axis] = list(arr.shape)[axis] - broadcast_shape = tuple(broadcast_shape_list) - arr = paddle.broadcast_to(arr, broadcast_shape) helper = LayerHelper('take_along_axis', **locals()) dtype = helper.input_dtype() result = helper.create_variable_for_type_inference(dtype) @@ -5229,26 +5230,45 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'): >>> axis = 0 >>> result = paddle.put_along_axis(x, index, value, axis) >>> print(result) - Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, - [[99, 99, 99], - [60, 40, 50]]) + Tensor(shape=[2, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True, + [[99, 30, 20], + [60, 40, 50]]) """ if len(arr.shape) != len(indices.shape): raise ValueError( "`indices` and `arr` must have the same number of dimensions!" ) + if isinstance(values, (paddle.Tensor, paddle.pir.OpResult)): + if len(indices.shape) != len(values.shape): + raise ValueError( + "`indices` and `values` must have the same number of dimensions!" + ) + for i in range(len(arr.shape)): + if (i != axis and arr.shape[i] < indices.shape[i]) or indices.shape[ + i + ] > values.shape[i]: + raise RuntimeError( + "Size does not match at dimension {} expected index {} to be smaller than self {} apart from dimension {} and to be smaller size than values {}".format( + i, indices.shape, arr.shape, axis, values.shape + ) + ) + else: + values = paddle.to_tensor(values).astype(arr.dtype) + elements = 1 + for num in values.shape: + elements *= num + if elements == 1: # paddle.pir.OpResult has no attribute 'size' + values = paddle.broadcast_to(values, indices.shape) axis = non_negative_axis(arr, axis) - broadcast_shape = infer_broadcast_shape(arr, indices, axis) - if in_dynamic_or_pir_mode(): - values = ( - paddle.to_tensor(values) - if not isinstance(values, (paddle.Tensor, paddle.pir.OpResult)) - else values + axis_max_size = arr.shape[axis] + if not (indices < axis_max_size).all(): + raise RuntimeError( + "one of element of indices is out of bounds for dimension {} with size {}".format( + axis, axis_max_size + ) ) - if broadcast_shape: - indices = paddle.broadcast_to(indices, broadcast_shape) - values = paddle.broadcast_to(values, indices.shape) + if in_dynamic_or_pir_mode(): return _C_ops.put_along_axis(arr, indices, values, axis, reduce) else: check_variable_and_dtype( @@ -5268,9 +5288,6 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'): check_variable_and_dtype( indices, 'index', ['int32', 'int64'], 'put_along_axis' ) - if broadcast_shape: - indices = paddle.broadcast_to(indices, broadcast_shape) - values = paddle.broadcast_to(values, indices.shape) helper = LayerHelper('put_along_axis', **locals()) dtype = helper.input_dtype() result = helper.create_variable_for_type_inference(dtype) diff --git a/test/legacy_test/test_put_along_axis_op.py b/test/legacy_test/test_put_along_axis_op.py index 83194145bb18e7..7b5edd2e5c41e0 100644 --- a/test/legacy_test/test_put_along_axis_op.py +++ b/test/legacy_test/test_put_along_axis_op.py @@ -157,8 +157,7 @@ def run(place): with paddle.static.program_guard(paddle.static.Program()): x = paddle.static.data('X', self.shape) index = paddle.static.data('Index', self.index_shape, "int64") - value = paddle.static.data('Value', self.value_shape) - out = paddle.put_along_axis(x, index, value, self.axis) + out = paddle.put_along_axis(x, index, self.value_np, self.axis) exe = paddle.static.Executor(self.place[0]) res = exe.run( feed={ @@ -168,12 +167,10 @@ def run(place): }, fetch_list=[out], ) - - np.put_along_axis( - self.x_np, self.index_np, self.value_np, self.axis - ) - # numpy put_along_axis is an inplace opearion. - out_ref = self.x_np + out_ref = copy.deepcopy(self.x_np) + for i in range(self.index_shape[0]): + for j in range(self.index_shape[1]): + out_ref[self.index_np[i, j], j] = self.value_np for out in res: np.testing.assert_allclose(out, out_ref, rtol=0.001) @@ -186,24 +183,21 @@ def run(place): paddle.disable_static(place) x_tensor = paddle.to_tensor(self.x_np) index_tensor = paddle.to_tensor(self.index_np) - value_tensor = paddle.to_tensor(self.value_np) out = paddle.put_along_axis( - x_tensor, index_tensor, value_tensor, self.axis - ) - np.array( - np.put_along_axis( - self.x_np, self.index_np, self.value_np, self.axis - ) + x_tensor, index_tensor, self.value_np, self.axis ) - out_ref = self.x_np + out_ref = copy.deepcopy(self.x_np) + for i in range(self.index_shape[0]): + for j in range(self.index_shape[1]): + out_ref[self.index_np[i, j], j] = self.value_np np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001) # for ci coverage, numpy put_along_axis did not support argument of 'reduce' paddle.put_along_axis( - x_tensor, index_tensor, value_tensor, self.axis, 'mul' + x_tensor, index_tensor, self.value_np, self.axis, 'mul' ) paddle.put_along_axis( - x_tensor, index_tensor, value_tensor, self.axis, 'add' + x_tensor, index_tensor, self.value_np, self.axis, 'add' ) paddle.enable_static() @@ -271,6 +265,122 @@ def test_inplace_dygraph(self): pass +class TestPutAlongAxisAPICase4(unittest.TestCase): + def setUp(self): + np.random.seed(0) + self.shape = [3, 5] + self.index1_shape = [1, 4] + self.index_np1 = np.array([[0, 1, 2, 0]]).astype('int64') + self.index2_shape = [2, 3] + self.index_np2 = np.array([[0, 1, 2], [0, 1, 4]]).astype('int64') + self.x_np = np.zeros((3, 5)).astype(np.float32) + self.value_shape = [2, 5] + self.value = ( + np.arange(1, 11).reshape(self.value_shape).astype(np.float32) + ) + self.place = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def test_api_dygraph(self): + def run(place): + paddle.disable_static(place) + x_tensor = paddle.to_tensor(self.x_np) + index_tensor1 = paddle.to_tensor(self.index_np1) + value_tensor = paddle.to_tensor(self.value) + out = paddle.put_along_axis( + x_tensor, index_tensor1, value_tensor, 0 + ) + out_ref = copy.deepcopy(self.x_np) + for i in range(self.index1_shape[0]): + for j in range(self.index1_shape[1]): + out_ref[self.index_np1[i, j], j] = self.value[i, j] + np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001) + + # for ci coverage, numpy put_along_axis did not support argument of 'reduce' + paddle.put_along_axis( + x_tensor, index_tensor1, value_tensor, 0, 'mul' + ) + paddle.put_along_axis( + x_tensor, index_tensor1, value_tensor, 0, 'add' + ) + + index_tensor2 = paddle.to_tensor(self.index_np2) + out = paddle.put_along_axis( + x_tensor, index_tensor2, value_tensor, 1 + ) + out_ref = copy.deepcopy(self.x_np) + for i in range(self.index2_shape[0]): + for j in range(self.index2_shape[1]): + out_ref[i, self.index_np2[i, j]] = self.value[i, j] + np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001) + + # for ci coverage, numpy put_along_axis did not support argument of 'reduce' + paddle.put_along_axis( + x_tensor, index_tensor2, value_tensor, 1, 'mul' + ) + paddle.put_along_axis( + x_tensor, index_tensor2, value_tensor, 1, 'add' + ) + + paddle.enable_static() + + for place in self.place: + run(place) + + @test_with_pir_api + def test_api_static(self): + paddle.enable_static() + + def run(place): + with paddle.static.program_guard(paddle.static.Program()): + x1 = paddle.static.data('X', self.shape) + index1 = paddle.static.data('Index', self.index1_shape, "int64") + value_tensor = paddle.to_tensor(self.value) + out1 = paddle.put_along_axis(x1, index1, value_tensor, 0) + exe = paddle.static.Executor(place) + res = exe.run( + feed={ + 'X': self.x_np, + 'Value': self.value, + 'Index': self.index_np1, + }, + fetch_list=[out1], + ) + out_ref = copy.deepcopy(self.x_np) + for i in range(self.index1_shape[0]): + for j in range(self.index1_shape[1]): + out_ref[self.index_np1[i, j], j] = self.value[i, j] + + for out in res: + np.testing.assert_allclose(out, out_ref, rtol=0.001) + + with paddle.static.program_guard(paddle.static.Program()): + x2 = paddle.static.data('X', self.shape) + index2 = paddle.static.data('Index', self.index2_shape, "int64") + value_tensor = paddle.to_tensor(self.value) + out2 = paddle.put_along_axis(x2, index2, value_tensor, 1) + exe = paddle.static.Executor(place) + res = exe.run( + feed={ + 'X': self.x_np, + 'Value': self.value, + 'Index': self.index_np2, + }, + fetch_list=[out2], + ) + out_ref = copy.deepcopy(self.x_np) + for i in range(self.index2_shape[0]): + for j in range(self.index2_shape[1]): + out_ref[i, self.index_np2[i, j]] = self.value[i, j] + + for out in res: + np.testing.assert_allclose(out, out_ref, rtol=0.001) + + for place in self.place: + run(place) + + if __name__ == "__main__": paddle.enable_static() unittest.main() diff --git a/test/legacy_test/test_take_along_axis_op.py b/test/legacy_test/test_take_along_axis_op.py index 3e687fcdf2a24e..6d6ed74fe4f224 100644 --- a/test/legacy_test/test_take_along_axis_op.py +++ b/test/legacy_test/test_take_along_axis_op.py @@ -163,9 +163,17 @@ def test_api_static(self): res = exe.run( feed={'X': self.x_np, 'Index': self.index_np}, fetch_list=[out] ) - out_ref = np.array( - np.take_along_axis(self.x_np, self.index_np, self.axis) - ) + out_ref = np.zeros_like(self.index_np, dtype=self.x_np.dtype) + if self.axis == 0: + for i in range(self.index_shape[0]): + for j in range(self.index_shape[1]): + out_ref[i, j] = self.x_np[self.index_np[i, j], j] + elif self.axis == 1: + for i in range(self.index_shape[0]): + for j in range(self.index_shape[1]): + out_ref[i, j] = self.x_np[i, self.index_np[i, j]] + else: + return for out in res: np.testing.assert_allclose(out, out_ref, rtol=0.001) @@ -174,9 +182,17 @@ def test_api_dygraph(self): x_tensor = paddle.to_tensor(self.x_np) self.index = paddle.to_tensor(self.index_np) out = paddle.take_along_axis(x_tensor, self.index, self.axis) - out_ref = np.array( - np.take_along_axis(self.x_np, self.index_np, self.axis) - ) + out_ref = np.zeros_like(self.index_np, dtype=self.x_np.dtype) + if self.axis == 0: + for i in range(self.index_shape[0]): + for j in range(self.index_shape[1]): + out_ref[i, j] = self.x_np[self.index_np[i, j], j] + elif self.axis == 1: + for i in range(self.index_shape[0]): + for j in range(self.index_shape[1]): + out_ref[i, j] = self.x_np[i, self.index_np[i, j]] + else: + return np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001) paddle.enable_static() @@ -210,6 +226,19 @@ def setUp(self): self.place.append(paddle.CUDAPlace(0)) +class TestTakeAlongAxisAPICase2(TestTakeAlongAxisAPI): + def setUp(self): + np.random.seed(0) + self.shape = [2, 2] + self.index_shape = [1, 1] + self.index_np = np.array([[1]]).astype('int64') + self.x_np = np.random.random(self.shape).astype(np.float32) + self.place = [paddle.CPUPlace()] + self.axis = 1 + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + if __name__ == "__main__": paddle.enable_static() unittest.main() From 36a2405cb0cfe2f3a5d83a194772c2e79e46a989 Mon Sep 17 00:00:00 2001 From: YibinLiu666 <2632839426@qq.com> Date: Mon, 20 Nov 2023 13:40:08 +0000 Subject: [PATCH 02/14] fix error --- .../kernels/funcs/gather_scatter_functor.cu | 552 ++++++++++-------- python/paddle/distribution/categorical.py | 25 +- 2 files changed, 325 insertions(+), 252 deletions(-) diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.cu b/paddle/phi/kernels/funcs/gather_scatter_functor.cu index b75a1888d42cec..a5f6f81bbcb5ec 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.cu +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cu @@ -58,19 +58,19 @@ template -__global__ void GatherScatterGPUKernel(tensor_t* self_data, - int dim, - const index_t* index_data, - tensor_t* src_data, - int select_dim_size, - int self_select_dim_size, - int src_select_dim_size, - int64_t outer_dim_size, - int64_t outer_dim_size_self, - int64_t outer_dim_size_src, - int64_t numel, - int64_t numel_data, - const func_t& reduce_op) { +__global__ void GatherScatterAssignGPUKernel(tensor_t* self_data, + int dim, + const index_t* index_data, + tensor_t* src_data, + int select_dim_size, + int self_select_dim_size, + int src_select_dim_size, + int64_t outer_dim_size, + int64_t outer_dim_size_self, + int64_t outer_dim_size_src, + int64_t numel, + int64_t numel_data, + const func_t& reduce_op) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; extern __shared__ int thread_ids[]; @@ -127,6 +127,67 @@ __global__ void GatherScatterGPUKernel(tensor_t* self_data, } } +template +__global__ void GatherScatterAddOrMulGPUKernel(tensor_t* self_data, + int dim, + const index_t* index_data, + tensor_t* src_data, + int select_dim_size, + int self_select_dim_size, + int src_select_dim_size, + int64_t outer_dim_size, + int64_t outer_dim_size_self, + int64_t outer_dim_size_src, + int64_t numel, + int64_t numel_data, + const func_t& reduce_op) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= numel) return; + int64_t i, j, k; // The i, j, k here is the index of the 3 layers loop + // squeezed from the N layers loop. + /* tid = i * select_dim_size * outer_dim_size + j * outer_dim_size + k */ + i = tid / (select_dim_size * outer_dim_size); + int64_t remind = tid % (select_dim_size * outer_dim_size); + j = remind / outer_dim_size; + k = remind % outer_dim_size; + index_t index = index_data[tid]; + /* + gather computation formula: + + self[i][j][k] = src[index[i][j][k]][j][k] # if dim == 0 + self[i][j][k] = src[i][index[i][j][k]][k] # if dim == 1 + self[i][j][k] = src[i][j][index[i][j][k]] # if dim == 2 + + scatter computation formula: + + self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + + */ + // index matrix has different shape with self matrix or src matrix. + int64_t replace_index_self, replace_index_src; + if (is_scatter_like) { + replace_index_self = k + index * outer_dim_size_self + + i * outer_dim_size_self * self_select_dim_size; + + replace_index_src = k + j * outer_dim_size_src + + i * outer_dim_size_src * src_select_dim_size; + } else { + replace_index_self = tid; + + replace_index_src = k + index * outer_dim_size_src + + i * outer_dim_size_src * src_select_dim_size; + } + + reduce_op(static_cast(self_data + replace_index_self), + static_cast(src_data + replace_index_src)); +} +} // namespace funcs + template @@ -174,254 +235,259 @@ struct gpu_gather_scatter_functor { int64_t n = inner_dim_size * select_dim_size * outer_dim_size; int64_t grid = (n + block - 1) / block; auto stream = reinterpret_cast(ctx).stream(); - int shared_mem_size = - is_scatter_like ? sizeof(int) * self_size : sizeof(int) * index_size; - GatherScatterGPUKernel - <<>>(self_data, - dim, - index_data, - src_data, - select_dim_size, - self_select_dim_size, - src_select_dim_size, - outer_dim_size, - outer_dim_size_self, - outer_dim_size_src, - index_size, - self_size, - reduce_op); + if (method_name == "gather_out_gpu" || + method_name == "scatter_assign_gpu") { + int shared_mem_size = + is_scatter_like ? sizeof(int) * self_size : sizeof(int) * index_size; + GatherScatterAssignGPUKernel + <<>>(self_data, + dim, + index_data, + src_data, + select_dim_size, + self_select_dim_size, + src_select_dim_size, + outer_dim_size, + outer_dim_size_self, + outer_dim_size_src, + index_size, + self_size, + reduce_op); + } else { + GatherScatterAddOrMulGPUKernel + <<>>(self_data, + dim, + index_data, + src_data, + select_dim_size, + self_select_dim_size, + src_select_dim_size, + outer_dim_size, + outer_dim_size_self, + outer_dim_size_src, + index_size, + self_size, + reduce_op); + } + } // struct gpu_gather_scatter_functor + + template + void gpu_gather_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor result, + const phi::DeviceContext& ctx) { + gpu_gather_scatter_functor()( + result, dim, index, self, "gather_out_gpu", tensor_assign, ctx); + return; } -}; // struct gpu_gather_scatter_functor - -template -void gpu_gather_kernel(phi::DenseTensor self, - int dim, - const phi::DenseTensor& index, - phi::DenseTensor result, - const phi::DeviceContext& ctx) { - gpu_gather_scatter_functor()( - result, dim, index, self, "gather_out_gpu", tensor_assign, ctx); - return; -} - -template -void gpu_scatter_assign_kernel(phi::DenseTensor self, - int dim, - const phi::DenseTensor& index, - phi::DenseTensor src, - const phi::DeviceContext& ctx) { - gpu_gather_scatter_functor()( - self, dim, index, src, "scatter_assign_gpu", tensor_assign, ctx); -} - -template -void gpu_scatter_add_kernel(phi::DenseTensor self, - int dim, - const phi::DenseTensor& index, - phi::DenseTensor src, - const phi::DeviceContext& ctx) { - gpu_gather_scatter_functor()( - self, dim, index, src, "scatter_add_gpu", reduce_add, ctx); -} -template -void gpu_scatter_mul_kernel(phi::DenseTensor self, - int dim, - const phi::DenseTensor& index, - phi::DenseTensor src, - const phi::DeviceContext& ctx) { - gpu_gather_scatter_functor()( - self, dim, index, src, "scatter_mul_gpu", reduce_mul, ctx); -} + template + void gpu_scatter_assign_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor src, + const phi::DeviceContext& ctx) { + gpu_gather_scatter_functor()( + self, dim, index, src, "scatter_assign_gpu", tensor_assign, ctx); + } -template -__global__ void ScatterInputGradGPUKernel(tensor_t* grad_data, - int dim, - const index_t* index_data, - int select_dim_size, - int grad_select_dim_size, - int64_t outer_dim_size, - int64_t outer_dim_size_data, - int64_t numel, - int64_t numel_data) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= numel) return; - extern __shared__ int thread_ids[]; + template + void gpu_scatter_add_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor src, + const phi::DeviceContext& ctx) { + gpu_gather_scatter_functor()( + self, dim, index, src, "scatter_add_gpu", reduce_add, ctx); + } - if (tid == 0) { - for (int i = 0; i < numel_data; i++) { - thread_ids[i] = 0; - } + template + void gpu_scatter_mul_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor src, + const phi::DeviceContext& ctx) { + gpu_gather_scatter_functor()( + self, dim, index, src, "scatter_mul_gpu", reduce_mul, ctx); } - __syncthreads(); - int64_t i, j, k; - i = tid / (select_dim_size * outer_dim_size); - int64_t remind = tid % (select_dim_size * outer_dim_size); - j = remind / outer_dim_size; - k = remind % outer_dim_size; - index_t index = index_data[tid]; - int64_t replace_index = k + index * outer_dim_size_data + - i * outer_dim_size_data * grad_select_dim_size; - atomicMax(thread_ids + replace_index, tid); - __syncthreads(); + template + __global__ void ScatterInputGradGPUKernel(tensor_t* grad_data, + int dim, + const index_t* index_data, + int select_dim_size, + int grad_select_dim_size, + int64_t outer_dim_size, + int64_t outer_dim_size_data, + int64_t numel, + int64_t numel_data) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= numel) return; + int64_t i, j, k; + i = tid / (select_dim_size * outer_dim_size); + int64_t remind = tid % (select_dim_size * outer_dim_size); + j = remind / outer_dim_size; + k = remind % outer_dim_size; + index_t index = index_data[tid]; + int64_t replace_index = k + index * outer_dim_size_data + + i * outer_dim_size_data * grad_select_dim_size; - if (tid == thread_ids[replace_index]) { grad_data[replace_index] = 0; } -} -template -void gpu_scatter_input_grad_kernel(phi::DenseTensor self, - int dim, - const phi::DenseTensor& index, - phi::DenseTensor grad, - const phi::DeviceContext& ctx) { - auto* index_data = index.data(); - auto* grad_data = grad.data(); - - auto index_dims = index.dims(); - auto grad_dims = grad.dims(); - int64_t index_size = index.numel(); - int64_t grad_size = grad.numel(); - - int64_t inner_dim_size = 1; - int64_t outer_dim_size = 1; - int64_t outer_dim_size_data = 1; - int select_dim_size = index_dims[dim]; - int grad_select_dim_size = grad_dims[dim]; - for (int64_t i = 0; i < dim; ++i) { - inner_dim_size *= index_dims[i]; - } - - for (int i = dim + 1; i < index_dims.size(); i++) { - outer_dim_size *= index_dims[i]; - outer_dim_size_data *= grad_dims[i]; - } - - int block = 512; - int64_t n = inner_dim_size * select_dim_size * outer_dim_size; - int64_t grid = (n + block - 1) / block; - auto stream = reinterpret_cast(ctx).stream(); - int shared_mem_size = sizeof(int) * grad_size; - ScatterInputGradGPUKernel - <<>>(grad_data, - dim, - index_data, - select_dim_size, - grad_select_dim_size, - outer_dim_size, - outer_dim_size_data, - index_size, - grad_size); -} + template + void gpu_scatter_input_grad_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor grad, + const phi::DeviceContext& ctx) { + auto* index_data = index.data(); + auto* grad_data = grad.data(); -template -__global__ void ScatterValueGradGPUKernel(tensor_t* grad_data, - int dim, - const tensor_t* self_data, - const index_t* index_data, - int select_dim_size, - int self_select_dim_size, - int grad_select_dim_size, - int64_t outer_dim_size, - int64_t outer_dim_size_self, - int64_t outer_dim_size_grad, - int64_t numel, - int64_t numel_data) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= numel) return; - extern __shared__ int thread_ids[]; + auto index_dims = index.dims(); + auto grad_dims = grad.dims(); + int64_t index_size = index.numel(); + int64_t grad_size = grad.numel(); - if (tid == 0) { - for (int i = 0; i < numel_data; i++) { - thread_ids[i] = 0; + int64_t inner_dim_size = 1; + int64_t outer_dim_size = 1; + int64_t outer_dim_size_data = 1; + int select_dim_size = index_dims[dim]; + int grad_select_dim_size = grad_dims[dim]; + for (int64_t i = 0; i < dim; ++i) { + inner_dim_size *= index_dims[i]; } - } - __syncthreads(); - int64_t i, j, k; - i = tid / (select_dim_size * outer_dim_size); - int64_t remind = tid % (select_dim_size * outer_dim_size); - j = remind / outer_dim_size; - k = remind % outer_dim_size; - index_t index = index_data[tid]; - int64_t replace_index_self = k + index * outer_dim_size_self + - i * outer_dim_size_self * self_select_dim_size; - atomicMax(thread_ids + replace_index_self, tid); - __syncthreads(); + for (int i = dim + 1; i < index_dims.size(); i++) { + outer_dim_size *= index_dims[i]; + outer_dim_size_data *= grad_dims[i]; + } - if (tid == thread_ids[replace_index_self]) { - int64_t replace_index_grad = k + j * outer_dim_size_grad + - i * outer_dim_size_grad * grad_select_dim_size; - grad_data[replace_index_grad] = self_data[replace_index_self]; - } -} -template -void gpu_scatter_value_grad_kernel(phi::DenseTensor self, - int dim, - const phi::DenseTensor& index, - phi::DenseTensor grad, - const phi::DeviceContext& ctx) { - auto* self_data = self.data(); - auto* index_data = index.data(); - auto* grad_data = grad.data(); - - auto index_dims = index.dims(); - auto self_dims = self.dims(); - auto grad_dims = grad.dims(); - int64_t index_size = index.numel(); - int64_t self_size = self.numel(); - - int64_t inner_dim_size = 1; - int64_t outer_dim_size = 1; - int64_t outer_dim_size_self = 1; - int64_t outer_dim_size_grad = 1; - int select_dim_size = index_dims[dim]; - int self_select_dim_size = self_dims[dim]; - int grad_select_dim_size = grad_dims[dim]; - for (int64_t i = 0; i < dim; ++i) { - inner_dim_size *= index_dims[i]; + int block = 512; + int64_t n = inner_dim_size * select_dim_size * outer_dim_size; + int64_t grid = (n + block - 1) / block; + auto stream = reinterpret_cast(ctx).stream(); + int shared_mem_size = sizeof(int) * grad_size; + ScatterInputGradGPUKernel + <<>>(grad_data, + dim, + index_data, + select_dim_size, + grad_select_dim_size, + outer_dim_size, + outer_dim_size_data, + index_size, + grad_size); } - for (int i = dim + 1; i < index_dims.size(); i++) { - outer_dim_size *= index_dims[i]; - outer_dim_size_self *= self_dims[i]; - outer_dim_size_grad *= grad_dims[i]; + template + __global__ void ScatterValueGradGPUKernel(tensor_t* grad_data, + int dim, + const tensor_t* self_data, + const index_t* index_data, + int select_dim_size, + int self_select_dim_size, + int grad_select_dim_size, + int64_t outer_dim_size, + int64_t outer_dim_size_self, + int64_t outer_dim_size_grad, + int64_t numel, + int64_t numel_data) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= numel) return; + extern __shared__ int thread_ids[]; + + if (tid == 0) { + for (int i = 0; i < numel_data; i++) { + thread_ids[i] = 0; + } + } + __syncthreads(); + int64_t i, j, k; + i = tid / (select_dim_size * outer_dim_size); + int64_t remind = tid % (select_dim_size * outer_dim_size); + j = remind / outer_dim_size; + k = remind % outer_dim_size; + index_t index = index_data[tid]; + int64_t replace_index_self = k + index * outer_dim_size_self + + i * outer_dim_size_self * self_select_dim_size; + + atomicMax(thread_ids + replace_index_self, tid); + __syncthreads(); + + if (tid == thread_ids[replace_index_self]) { + int64_t replace_index_grad = + k + j * outer_dim_size_grad + + i * outer_dim_size_grad * grad_select_dim_size; + grad_data[replace_index_grad] = self_data[replace_index_self]; + } } + template + void gpu_scatter_value_grad_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor grad, + const phi::DeviceContext& ctx) { + auto* self_data = self.data(); + auto* index_data = index.data(); + auto* grad_data = grad.data(); - int block = 512; - int64_t n = inner_dim_size * select_dim_size * outer_dim_size; - int64_t grid = (n + block - 1) / block; - auto stream = reinterpret_cast(ctx).stream(); - int shared_mem_size = sizeof(int) * self_size; - ScatterValueGradGPUKernel - <<>>(grad_data, - dim, - self_data, - index_data, - select_dim_size, - self_select_dim_size, - grad_select_dim_size, - outer_dim_size, - outer_dim_size_self, - outer_dim_size_grad, - index_size, - self_size); -} -Instantiate_Template_Function(gpu_gather_kernel) - Instantiate_Template_Function(gpu_scatter_assign_kernel) - Instantiate_Template_Function(gpu_scatter_add_kernel) - Instantiate_Template_Function(gpu_scatter_mul_kernel) - Instantiate_Template_Function(gpu_scatter_input_grad_kernel) - Instantiate_Template_Function(gpu_scatter_value_grad_kernel) + auto index_dims = index.dims(); + auto self_dims = self.dims(); + auto grad_dims = grad.dims(); + int64_t index_size = index.numel(); + int64_t self_size = self.numel(); + + int64_t inner_dim_size = 1; + int64_t outer_dim_size = 1; + int64_t outer_dim_size_self = 1; + int64_t outer_dim_size_grad = 1; + int select_dim_size = index_dims[dim]; + int self_select_dim_size = self_dims[dim]; + int grad_select_dim_size = grad_dims[dim]; + for (int64_t i = 0; i < dim; ++i) { + inner_dim_size *= index_dims[i]; + } + + for (int i = dim + 1; i < index_dims.size(); i++) { + outer_dim_size *= index_dims[i]; + outer_dim_size_self *= self_dims[i]; + outer_dim_size_grad *= grad_dims[i]; + } + int block = 512; + int64_t n = inner_dim_size * select_dim_size * outer_dim_size; + int64_t grid = (n + block - 1) / block; + auto stream = reinterpret_cast(ctx).stream(); + int shared_mem_size = sizeof(int) * self_size; + ScatterValueGradGPUKernel + <<>>(grad_data, + dim, + self_data, + index_data, + select_dim_size, + self_select_dim_size, + grad_select_dim_size, + outer_dim_size, + outer_dim_size_self, + outer_dim_size_grad, + index_size, + self_size); + } + Instantiate_Template_Function(gpu_gather_kernel) + Instantiate_Template_Function(gpu_scatter_assign_kernel) + Instantiate_Template_Function(gpu_scatter_add_kernel) + Instantiate_Template_Function(gpu_scatter_mul_kernel) + Instantiate_Template_Function(gpu_scatter_input_grad_kernel) + Instantiate_Template_Function( + gpu_scatter_value_grad_kernel) } // namespace funcs } // namespace phi diff --git a/python/paddle/distribution/categorical.py b/python/paddle/distribution/categorical.py index 9d5664dc28f4d3..2252bf263dcb75 100644 --- a/python/paddle/distribution/categorical.py +++ b/python/paddle/distribution/categorical.py @@ -20,6 +20,7 @@ from paddle.distribution import distribution from paddle.framework import in_dynamic_mode from paddle.tensor import multinomial +from paddle.tensor.manipulation import infer_broadcast_shape class Categorical(distribution.Distribution): @@ -310,17 +311,23 @@ def probs(self, value): ).reshape(value.shape, name=name) else: if len(value.shape) == 1: - return paddle.take_along_axis( - self._prob, - paddle.reshape( - value, - (len(self._prob.shape) - 1) * [1] + [-1], - name=name, - ), - axis=-1, + indices = paddle.reshape( + value, + (len(self._prob.shape) - 1) * [1] + [-1], + name=name, ) else: - return paddle.take_along_axis(self._prob, value, axis=-1) + indices = value + broadcast_shape = infer_broadcast_shape(self._prob, indices, -1) + if not broadcast_shape: + # if indices matrix have larger size than arr, arr should broadcast into indices shape. + broadcast_shape = indices.shape + indices = paddle.broadcast_to(indices, broadcast_shape) + broadcast_shape_list = list(broadcast_shape) + broadcast_shape_list[-1] = list(self._prob.shape)[-1] + broadcast_shape = tuple(broadcast_shape_list) + arr = paddle.broadcast_to(self._prob, broadcast_shape) + return paddle.take_along_axis(arr, indices, axis=-1) def log_prob(self, value): """Log probabilities of the given category. Refer to ``probs`` method. From 38182988e216a5e0b4280ddd7c15cc29c2dfb572 Mon Sep 17 00:00:00 2001 From: YibinLiu666 <2632839426@qq.com> Date: Mon, 20 Nov 2023 13:55:31 +0000 Subject: [PATCH 03/14] fix take_along_axis used in stat --- python/paddle/tensor/stat.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index d7bcc48c8fa451..2103265f1b2e2e 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -21,6 +21,7 @@ from ..base.data_feeder import check_type, check_variable_and_dtype from ..common_ops_import import Variable from ..framework import LayerHelper, core +from .manipulation import infer_broadcast_shape from .math import _get_reduce_axis_with_tensor from .search import where @@ -568,13 +569,32 @@ def _compute_quantile(x, q, axis=None, keepdim=False, ignore_nan=False): # TODO(chenjianye): replace the for-loop to directly take elements. for index in indices: + + def broadcast_shape(arr, indices, axis): + broadcast_shape = infer_broadcast_shape(arr, indices, axis) + if not broadcast_shape: + # if indices matrix have larger size than arr, arr should broadcast into indices shape. + broadcast_shape = indices.shape + indices = paddle.broadcast_to(indices, broadcast_shape) + broadcast_shape_list = list(broadcast_shape) + broadcast_shape_list[axis] = list(arr.shape)[axis] + broadcast_shape = tuple(broadcast_shape_list) + arr = paddle.broadcast_to(arr, broadcast_shape) + return arr, indices + indices_below = paddle.floor(index).astype(paddle.int32) indices_upper = paddle.ceil(index).astype(paddle.int32) + sorted_tensor_below, indices_below = broadcast_shape( + sorted_tensor, indices_below, axis + ) + sorted_tensor_upper, indices_upper = broadcast_shape( + sorted_tensor, indices_upper, axis + ) tensor_upper = paddle.take_along_axis( - sorted_tensor, indices_upper, axis=axis + sorted_tensor_upper, indices_upper, axis=axis ) tensor_below = paddle.take_along_axis( - sorted_tensor, indices_below, axis=axis + sorted_tensor_below, indices_below, axis=axis ) weights = index - indices_below.astype('float64') out = paddle.lerp( From 18dc8c3e3fd655bef10e8b0b90d667cf25efc017 Mon Sep 17 00:00:00 2001 From: YibinLiu666 <2632839426@qq.com> Date: Mon, 20 Nov 2023 14:59:45 +0000 Subject: [PATCH 04/14] update --- .../kernels/funcs/gather_scatter_functor.cu | 408 +++++++++--------- 1 file changed, 203 insertions(+), 205 deletions(-) diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.cu b/paddle/phi/kernels/funcs/gather_scatter_functor.cu index a5f6f81bbcb5ec..050126e4b9a5c6 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.cu +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cu @@ -186,7 +186,6 @@ __global__ void GatherScatterAddOrMulGPUKernel(tensor_t* self_data, reduce_op(static_cast(self_data + replace_index_self), static_cast(src_data + replace_index_src)); } -} // namespace funcs template - void gpu_gather_kernel(phi::DenseTensor self, - int dim, - const phi::DenseTensor& index, - phi::DenseTensor result, - const phi::DeviceContext& ctx) { - gpu_gather_scatter_functor()( - result, dim, index, self, "gather_out_gpu", tensor_assign, ctx); - return; } +} // struct gpu_gather_scatter_functor + +template +void gpu_gather_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor result, + const phi::DeviceContext& ctx) { + gpu_gather_scatter_functor()( + result, dim, index, self, "gather_out_gpu", tensor_assign, ctx); + return; +} - template - void gpu_scatter_assign_kernel(phi::DenseTensor self, - int dim, - const phi::DenseTensor& index, - phi::DenseTensor src, - const phi::DeviceContext& ctx) { - gpu_gather_scatter_functor()( - self, dim, index, src, "scatter_assign_gpu", tensor_assign, ctx); - } +template +void gpu_scatter_assign_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor src, + const phi::DeviceContext& ctx) { + gpu_gather_scatter_functor()( + self, dim, index, src, "scatter_assign_gpu", tensor_assign, ctx); +} - template - void gpu_scatter_add_kernel(phi::DenseTensor self, - int dim, - const phi::DenseTensor& index, - phi::DenseTensor src, - const phi::DeviceContext& ctx) { - gpu_gather_scatter_functor()( - self, dim, index, src, "scatter_add_gpu", reduce_add, ctx); - } +template +void gpu_scatter_add_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor src, + const phi::DeviceContext& ctx) { + gpu_gather_scatter_functor()( + self, dim, index, src, "scatter_add_gpu", reduce_add, ctx); +} - template - void gpu_scatter_mul_kernel(phi::DenseTensor self, - int dim, - const phi::DenseTensor& index, - phi::DenseTensor src, - const phi::DeviceContext& ctx) { - gpu_gather_scatter_functor()( - self, dim, index, src, "scatter_mul_gpu", reduce_mul, ctx); - } +template +void gpu_scatter_mul_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor src, + const phi::DeviceContext& ctx) { + gpu_gather_scatter_functor()( + self, dim, index, src, "scatter_mul_gpu", reduce_mul, ctx); +} - template - __global__ void ScatterInputGradGPUKernel(tensor_t* grad_data, - int dim, - const index_t* index_data, - int select_dim_size, - int grad_select_dim_size, - int64_t outer_dim_size, - int64_t outer_dim_size_data, - int64_t numel, - int64_t numel_data) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= numel) return; - int64_t i, j, k; - i = tid / (select_dim_size * outer_dim_size); - int64_t remind = tid % (select_dim_size * outer_dim_size); - j = remind / outer_dim_size; - k = remind % outer_dim_size; - index_t index = index_data[tid]; - int64_t replace_index = k + index * outer_dim_size_data + - i * outer_dim_size_data * grad_select_dim_size; - - grad_data[replace_index] = 0; - } - template - void gpu_scatter_input_grad_kernel(phi::DenseTensor self, - int dim, - const phi::DenseTensor& index, - phi::DenseTensor grad, - const phi::DeviceContext& ctx) { - auto* index_data = index.data(); - auto* grad_data = grad.data(); +template +__global__ void ScatterInputGradGPUKernel(tensor_t* grad_data, + int dim, + const index_t* index_data, + int select_dim_size, + int grad_select_dim_size, + int64_t outer_dim_size, + int64_t outer_dim_size_data, + int64_t numel, + int64_t numel_data) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= numel) return; + int64_t i, j, k; + i = tid / (select_dim_size * outer_dim_size); + int64_t remind = tid % (select_dim_size * outer_dim_size); + j = remind / outer_dim_size; + k = remind % outer_dim_size; + index_t index = index_data[tid]; + int64_t replace_index = k + index * outer_dim_size_data + + i * outer_dim_size_data * grad_select_dim_size; - auto index_dims = index.dims(); - auto grad_dims = grad.dims(); - int64_t index_size = index.numel(); - int64_t grad_size = grad.numel(); + grad_data[replace_index] = 0; +} +template +void gpu_scatter_input_grad_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor grad, + const phi::DeviceContext& ctx) { + auto* index_data = index.data(); + auto* grad_data = grad.data(); + + auto index_dims = index.dims(); + auto grad_dims = grad.dims(); + int64_t index_size = index.numel(); + int64_t grad_size = grad.numel(); + + int64_t inner_dim_size = 1; + int64_t outer_dim_size = 1; + int64_t outer_dim_size_data = 1; + int select_dim_size = index_dims[dim]; + int grad_select_dim_size = grad_dims[dim]; + for (int64_t i = 0; i < dim; ++i) { + inner_dim_size *= index_dims[i]; + } - int64_t inner_dim_size = 1; - int64_t outer_dim_size = 1; - int64_t outer_dim_size_data = 1; - int select_dim_size = index_dims[dim]; - int grad_select_dim_size = grad_dims[dim]; - for (int64_t i = 0; i < dim; ++i) { - inner_dim_size *= index_dims[i]; - } + for (int i = dim + 1; i < index_dims.size(); i++) { + outer_dim_size *= index_dims[i]; + outer_dim_size_data *= grad_dims[i]; + } - for (int i = dim + 1; i < index_dims.size(); i++) { - outer_dim_size *= index_dims[i]; - outer_dim_size_data *= grad_dims[i]; - } + int block = 512; + int64_t n = inner_dim_size * select_dim_size * outer_dim_size; + int64_t grid = (n + block - 1) / block; + auto stream = reinterpret_cast(ctx).stream(); + int shared_mem_size = sizeof(int) * grad_size; + ScatterInputGradGPUKernel + <<>>(grad_data, + dim, + index_data, + select_dim_size, + grad_select_dim_size, + outer_dim_size, + outer_dim_size_data, + index_size, + grad_size); +} - int block = 512; - int64_t n = inner_dim_size * select_dim_size * outer_dim_size; - int64_t grid = (n + block - 1) / block; - auto stream = reinterpret_cast(ctx).stream(); - int shared_mem_size = sizeof(int) * grad_size; - ScatterInputGradGPUKernel - <<>>(grad_data, - dim, - index_data, - select_dim_size, - grad_select_dim_size, - outer_dim_size, - outer_dim_size_data, - index_size, - grad_size); - } +template +__global__ void ScatterValueGradGPUKernel(tensor_t* grad_data, + int dim, + const tensor_t* self_data, + const index_t* index_data, + int select_dim_size, + int self_select_dim_size, + int grad_select_dim_size, + int64_t outer_dim_size, + int64_t outer_dim_size_self, + int64_t outer_dim_size_grad, + int64_t numel, + int64_t numel_data) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= numel) return; + extern __shared__ int thread_ids[]; - template - __global__ void ScatterValueGradGPUKernel(tensor_t* grad_data, - int dim, - const tensor_t* self_data, - const index_t* index_data, - int select_dim_size, - int self_select_dim_size, - int grad_select_dim_size, - int64_t outer_dim_size, - int64_t outer_dim_size_self, - int64_t outer_dim_size_grad, - int64_t numel, - int64_t numel_data) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= numel) return; - extern __shared__ int thread_ids[]; - - if (tid == 0) { - for (int i = 0; i < numel_data; i++) { - thread_ids[i] = 0; - } - } - __syncthreads(); - int64_t i, j, k; - i = tid / (select_dim_size * outer_dim_size); - int64_t remind = tid % (select_dim_size * outer_dim_size); - j = remind / outer_dim_size; - k = remind % outer_dim_size; - index_t index = index_data[tid]; - int64_t replace_index_self = k + index * outer_dim_size_self + - i * outer_dim_size_self * self_select_dim_size; - - atomicMax(thread_ids + replace_index_self, tid); - __syncthreads(); - - if (tid == thread_ids[replace_index_self]) { - int64_t replace_index_grad = - k + j * outer_dim_size_grad + - i * outer_dim_size_grad * grad_select_dim_size; - grad_data[replace_index_grad] = self_data[replace_index_self]; + if (tid == 0) { + for (int i = 0; i < numel_data; i++) { + thread_ids[i] = 0; } } - template - void gpu_scatter_value_grad_kernel(phi::DenseTensor self, - int dim, - const phi::DenseTensor& index, - phi::DenseTensor grad, - const phi::DeviceContext& ctx) { - auto* self_data = self.data(); - auto* index_data = index.data(); - auto* grad_data = grad.data(); - - auto index_dims = index.dims(); - auto self_dims = self.dims(); - auto grad_dims = grad.dims(); - int64_t index_size = index.numel(); - int64_t self_size = self.numel(); + __syncthreads(); + int64_t i, j, k; + i = tid / (select_dim_size * outer_dim_size); + int64_t remind = tid % (select_dim_size * outer_dim_size); + j = remind / outer_dim_size; + k = remind % outer_dim_size; + index_t index = index_data[tid]; + int64_t replace_index_self = k + index * outer_dim_size_self + + i * outer_dim_size_self * self_select_dim_size; - int64_t inner_dim_size = 1; - int64_t outer_dim_size = 1; - int64_t outer_dim_size_self = 1; - int64_t outer_dim_size_grad = 1; - int select_dim_size = index_dims[dim]; - int self_select_dim_size = self_dims[dim]; - int grad_select_dim_size = grad_dims[dim]; - for (int64_t i = 0; i < dim; ++i) { - inner_dim_size *= index_dims[i]; - } + atomicMax(thread_ids + replace_index_self, tid); + __syncthreads(); - for (int i = dim + 1; i < index_dims.size(); i++) { - outer_dim_size *= index_dims[i]; - outer_dim_size_self *= self_dims[i]; - outer_dim_size_grad *= grad_dims[i]; - } + if (tid == thread_ids[replace_index_self]) { + int64_t replace_index_grad = k + j * outer_dim_size_grad + + i * outer_dim_size_grad * grad_select_dim_size; + grad_data[replace_index_grad] = self_data[replace_index_self]; + } +} +template +void gpu_scatter_value_grad_kernel(phi::DenseTensor self, + int dim, + const phi::DenseTensor& index, + phi::DenseTensor grad, + const phi::DeviceContext& ctx) { + auto* self_data = self.data(); + auto* index_data = index.data(); + auto* grad_data = grad.data(); + + auto index_dims = index.dims(); + auto self_dims = self.dims(); + auto grad_dims = grad.dims(); + int64_t index_size = index.numel(); + int64_t self_size = self.numel(); + + int64_t inner_dim_size = 1; + int64_t outer_dim_size = 1; + int64_t outer_dim_size_self = 1; + int64_t outer_dim_size_grad = 1; + int select_dim_size = index_dims[dim]; + int self_select_dim_size = self_dims[dim]; + int grad_select_dim_size = grad_dims[dim]; + for (int64_t i = 0; i < dim; ++i) { + inner_dim_size *= index_dims[i]; + } - int block = 512; - int64_t n = inner_dim_size * select_dim_size * outer_dim_size; - int64_t grid = (n + block - 1) / block; - auto stream = reinterpret_cast(ctx).stream(); - int shared_mem_size = sizeof(int) * self_size; - ScatterValueGradGPUKernel - <<>>(grad_data, - dim, - self_data, - index_data, - select_dim_size, - self_select_dim_size, - grad_select_dim_size, - outer_dim_size, - outer_dim_size_self, - outer_dim_size_grad, - index_size, - self_size); + for (int i = dim + 1; i < index_dims.size(); i++) { + outer_dim_size *= index_dims[i]; + outer_dim_size_self *= self_dims[i]; + outer_dim_size_grad *= grad_dims[i]; } - Instantiate_Template_Function(gpu_gather_kernel) - Instantiate_Template_Function(gpu_scatter_assign_kernel) - Instantiate_Template_Function(gpu_scatter_add_kernel) - Instantiate_Template_Function(gpu_scatter_mul_kernel) - Instantiate_Template_Function(gpu_scatter_input_grad_kernel) - Instantiate_Template_Function( - gpu_scatter_value_grad_kernel) + + int block = 512; + int64_t n = inner_dim_size * select_dim_size * outer_dim_size; + int64_t grid = (n + block - 1) / block; + auto stream = reinterpret_cast(ctx).stream(); + int shared_mem_size = sizeof(int) * self_size; + ScatterValueGradGPUKernel + <<>>(grad_data, + dim, + self_data, + index_data, + select_dim_size, + self_select_dim_size, + grad_select_dim_size, + outer_dim_size, + outer_dim_size_self, + outer_dim_size_grad, + index_size, + self_size); +} +Instantiate_Template_Function(gpu_gather_kernel) + Instantiate_Template_Function(gpu_scatter_assign_kernel) + Instantiate_Template_Function(gpu_scatter_add_kernel) + Instantiate_Template_Function(gpu_scatter_mul_kernel) + Instantiate_Template_Function(gpu_scatter_input_grad_kernel) + Instantiate_Template_Function(gpu_scatter_value_grad_kernel) } // namespace funcs } // namespace phi From 61461f4d0ee62a4cb7e45090cfb2247758abaf36 Mon Sep 17 00:00:00 2001 From: YibinLiu666 <2632839426@qq.com> Date: Mon, 20 Nov 2023 15:10:33 +0000 Subject: [PATCH 05/14] fix build error --- paddle/phi/kernels/funcs/gather_scatter_functor.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.cu b/paddle/phi/kernels/funcs/gather_scatter_functor.cu index 050126e4b9a5c6..e6be85d9125d91 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.cu +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cu @@ -269,7 +269,7 @@ struct gpu_gather_scatter_functor { reduce_op); } } -} // struct gpu_gather_scatter_functor +}; // struct gpu_gather_scatter_functor template void gpu_gather_kernel(phi::DenseTensor self, From 013dfb417d0788ef6a41919103ba88a66a797dbd Mon Sep 17 00:00:00 2001 From: YibinLiu666 <2632839426@qq.com> Date: Tue, 21 Nov 2023 02:54:49 +0000 Subject: [PATCH 06/14] add test for error --- .../kernels/funcs/gather_scatter_functor.cu | 59 +++++++++---------- test/legacy_test/test_put_along_axis_op.py | 32 +++++++++- 2 files changed, 60 insertions(+), 31 deletions(-) diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.cu b/paddle/phi/kernels/funcs/gather_scatter_functor.cu index e6be85d9125d91..cbe866d4924d54 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.cu +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cu @@ -58,19 +58,19 @@ template -__global__ void GatherScatterAssignGPUKernel(tensor_t* self_data, - int dim, - const index_t* index_data, - tensor_t* src_data, - int select_dim_size, - int self_select_dim_size, - int src_select_dim_size, - int64_t outer_dim_size, - int64_t outer_dim_size_self, - int64_t outer_dim_size_src, - int64_t numel, - int64_t numel_data, - const func_t& reduce_op) { +__global__ void ScatterAssignGPUKernel(tensor_t* self_data, + int dim, + const index_t* index_data, + tensor_t* src_data, + int select_dim_size, + int self_select_dim_size, + int src_select_dim_size, + int64_t outer_dim_size, + int64_t outer_dim_size_self, + int64_t outer_dim_size_src, + int64_t numel, + int64_t numel_data, + const func_t& reduce_op) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; extern __shared__ int thread_ids[]; @@ -131,19 +131,19 @@ template -__global__ void GatherScatterAddOrMulGPUKernel(tensor_t* self_data, - int dim, - const index_t* index_data, - tensor_t* src_data, - int select_dim_size, - int self_select_dim_size, - int src_select_dim_size, - int64_t outer_dim_size, - int64_t outer_dim_size_self, - int64_t outer_dim_size_src, - int64_t numel, - int64_t numel_data, - const func_t& reduce_op) { +__global__ void GatherScatterGPUKernel(tensor_t* self_data, + int dim, + const index_t* index_data, + tensor_t* src_data, + int select_dim_size, + int self_select_dim_size, + int src_select_dim_size, + int64_t outer_dim_size, + int64_t outer_dim_size_self, + int64_t outer_dim_size_src, + int64_t numel, + int64_t numel_data, + const func_t& reduce_op) { int tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= numel) return; int64_t i, j, k; // The i, j, k here is the index of the 3 layers loop @@ -234,11 +234,10 @@ struct gpu_gather_scatter_functor { int64_t n = inner_dim_size * select_dim_size * outer_dim_size; int64_t grid = (n + block - 1) / block; auto stream = reinterpret_cast(ctx).stream(); - if (method_name == "gather_out_gpu" || - method_name == "scatter_assign_gpu") { + if (method_name == "scatter_assign_gpu") { int shared_mem_size = is_scatter_like ? sizeof(int) * self_size : sizeof(int) * index_size; - GatherScatterAssignGPUKernel + ScatterAssignGPUKernel <<>>(self_data, dim, index_data, @@ -253,7 +252,7 @@ struct gpu_gather_scatter_functor { self_size, reduce_op); } else { - GatherScatterAddOrMulGPUKernel + GatherScatterGPUKernel <<>>(self_data, dim, index_data, diff --git a/test/legacy_test/test_put_along_axis_op.py b/test/legacy_test/test_put_along_axis_op.py index 7b5edd2e5c41e0..dba5d65f991085 100644 --- a/test/legacy_test/test_put_along_axis_op.py +++ b/test/legacy_test/test_put_along_axis_op.py @@ -61,7 +61,7 @@ def init_data(self): self.x_shape = (10, 10, 10) self.value_type = "float64" self.value = np.array([99]).astype(self.value_type) - self.index_type = "int32" + self.index_type = "int64" self.index = np.array([[[0]]]).astype(self.index_type) self.axis = 1 self.axis_type = "int64" @@ -380,6 +380,36 @@ def run(place): for place in self.place: run(place) + def test_error(self): + tensorx = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]).astype("float32") + indices = paddle.to_tensor([1]).astype("int32") + values = paddle.to_tensor([2]) + # len(arr.shape) != len(indices.shape) + try: + res = paddle.put_along_axis(tensorx, indices, 1.0, 0) + except Exception as error: + self.assertIsInstance(error, ValueError) + indices = paddle.to_tensor([[1]]).astype("int32") + # len(values.shape) != len(indices.shape) + try: + res = paddle.put_along_axis(tensorx, indices, values, 0) + except Exception as error: + self.assertIsInstance(error, ValueError) + indices = paddle.to_tensor( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]] + ).astype("int32") + # indices too large + try: + res = paddle.put_along_axis(tensorx, indices, 1.0, 0) + except Exception as error: + self.assertIsInstance(error, RuntimeError) + indices = paddle.to_tensor([[10]]).astype("int32") + # the element of indices out of range + try: + res = paddle.put_along_axis(tensorx, indices, 1.0, 0) + except Exception as error: + self.assertIsInstance(error, RuntimeError) + if __name__ == "__main__": paddle.enable_static() From 1155d4c949314bcef3f07be357a19cd1e0adce7c Mon Sep 17 00:00:00 2001 From: YibinLiu666 <2632839426@qq.com> Date: Wed, 22 Nov 2023 14:07:51 +0000 Subject: [PATCH 07/14] add param broadcast --- python/paddle/distribution/categorical.py | 25 ++--- python/paddle/tensor/manipulation.py | 109 ++++++++++++-------- python/paddle/tensor/stat.py | 24 +---- test/legacy_test/test_put_along_axis_op.py | 68 +++++++----- test/legacy_test/test_take_along_axis_op.py | 68 +++++++----- 5 files changed, 162 insertions(+), 132 deletions(-) diff --git a/python/paddle/distribution/categorical.py b/python/paddle/distribution/categorical.py index 2252bf263dcb75..9d5664dc28f4d3 100644 --- a/python/paddle/distribution/categorical.py +++ b/python/paddle/distribution/categorical.py @@ -20,7 +20,6 @@ from paddle.distribution import distribution from paddle.framework import in_dynamic_mode from paddle.tensor import multinomial -from paddle.tensor.manipulation import infer_broadcast_shape class Categorical(distribution.Distribution): @@ -311,23 +310,17 @@ def probs(self, value): ).reshape(value.shape, name=name) else: if len(value.shape) == 1: - indices = paddle.reshape( - value, - (len(self._prob.shape) - 1) * [1] + [-1], - name=name, + return paddle.take_along_axis( + self._prob, + paddle.reshape( + value, + (len(self._prob.shape) - 1) * [1] + [-1], + name=name, + ), + axis=-1, ) else: - indices = value - broadcast_shape = infer_broadcast_shape(self._prob, indices, -1) - if not broadcast_shape: - # if indices matrix have larger size than arr, arr should broadcast into indices shape. - broadcast_shape = indices.shape - indices = paddle.broadcast_to(indices, broadcast_shape) - broadcast_shape_list = list(broadcast_shape) - broadcast_shape_list[-1] = list(self._prob.shape)[-1] - broadcast_shape = tuple(broadcast_shape_list) - arr = paddle.broadcast_to(self._prob, broadcast_shape) - return paddle.take_along_axis(arr, indices, axis=-1) + return paddle.take_along_axis(self._prob, value, axis=-1) def log_prob(self, value): """Log probabilities of the given category. Refer to ``probs`` method. diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index ddd906e661b0ab..ddd5ed7745a42f 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -5127,7 +5127,7 @@ def infer_broadcast_shape(arr, indices, axis): return broadcast_shape -def take_along_axis(arr, indices, axis): +def take_along_axis(arr, indices, axis, broadcast=True): """ Take values from the input array by given indices matrix along the designated axis. @@ -5136,6 +5136,7 @@ def take_along_axis(arr, indices, axis): indices (Tensor) : Indices to take along each 1d slice of arr. This must match the dimension of arr, and need to broadcast against arr. Supported data type are int and int64. axis (int) : The axis to take 1d slices along. + broadcast (bool, optional): whether the indices broadcast. Returns: Tensor, The indexed element, same dtype with arr @@ -5150,29 +5151,40 @@ def take_along_axis(arr, indices, axis): >>> axis = 0 >>> result = paddle.take_along_axis(x, index, axis) >>> print(result) - Tensor(shape=[1, 1], dtype=int64, place=Place(gpu:0), stop_gradient=True, - [[1]]) + Tensor(shape=[1, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + [[1, 2, 3]]) """ if len(arr.shape) != len(indices.shape): raise ValueError( "`indices` and `arr` must have the same number of dimensions!" ) axis = non_negative_axis(arr, axis) - for i in range(len(arr.shape)): - if i != axis and arr.shape[i] < indices.shape[i]: - raise RuntimeError( - "Size does not match at dimension {} expected index {} to be smaller than self {} apart from dimension {}".format( - i, indices.shape, arr.shape, axis + if broadcast: + broadcast_shape = infer_broadcast_shape(arr, indices, axis) + if not broadcast_shape: + # if indices matrix have larger size than arr, arr should broadcast into indices shape. + broadcast_shape = indices.shape + indices = paddle.broadcast_to(indices, broadcast_shape) + broadcast_shape_list = list(broadcast_shape) + broadcast_shape_list[axis] = list(arr.shape)[axis] + broadcast_shape = tuple(broadcast_shape_list) + arr = paddle.broadcast_to(arr, broadcast_shape) + else: + for i in range(len(arr.shape)): + if i != axis and arr.shape[i] < indices.shape[i]: + raise RuntimeError( + "Size does not match at dimension {} expected index {} to be smaller than self {} apart from dimension {}".format( + i, indices.shape, arr.shape, axis + ) ) - ) - axis_max_size = arr.shape[axis] - if not (indices < axis_max_size).all(): - raise RuntimeError( - "one of element of indices is out of bounds for dimension {} with size {}".format( - axis, axis_max_size + axis_max_size = arr.shape[axis] + if not (indices < axis_max_size).all(): + raise RuntimeError( + "one of element of indices is out of bounds for dimension {} with size {}".format( + axis, axis_max_size + ) ) - ) if in_dynamic_or_pir_mode(): return _C_ops.take_along_axis(arr, indices, axis) else: @@ -5205,7 +5217,7 @@ def take_along_axis(arr, indices, axis): return result -def put_along_axis(arr, indices, values, axis, reduce='assign'): +def put_along_axis(arr, indices, values, axis, reduce='assign', broadcast=True): """ Put values into the destination array by given indices matrix along the designated axis. @@ -5215,6 +5227,7 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'): and need to broadcast against arr. Supported data type are int and int64. axis (int) : The axis to put 1d slices along. reduce (str, optional): The reduce operation, default is 'assign', support 'add', 'assign', 'mul' and 'multiply'. + broadcast (bool, optional): whether the indices broadcast. Returns: Tensor, The indexed element, same dtype with arr @@ -5231,7 +5244,7 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'): >>> result = paddle.put_along_axis(x, index, value, axis) >>> print(result) Tensor(shape=[2, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True, - [[99, 30, 20], + [[99, 99, 99], [60, 40, 50]]) """ @@ -5239,35 +5252,47 @@ def put_along_axis(arr, indices, values, axis, reduce='assign'): raise ValueError( "`indices` and `arr` must have the same number of dimensions!" ) - if isinstance(values, (paddle.Tensor, paddle.pir.OpResult)): - if len(indices.shape) != len(values.shape): - raise ValueError( - "`indices` and `values` must have the same number of dimensions!" + axis = non_negative_axis(arr, axis) + if broadcast: + broadcast_shape = infer_broadcast_shape(arr, indices, axis) + if in_dynamic_or_pir_mode(): + values = ( + paddle.to_tensor(values) + if not isinstance(values, (paddle.Tensor, paddle.pir.OpResult)) + else values ) - for i in range(len(arr.shape)): - if (i != axis and arr.shape[i] < indices.shape[i]) or indices.shape[ - i - ] > values.shape[i]: - raise RuntimeError( - "Size does not match at dimension {} expected index {} to be smaller than self {} apart from dimension {} and to be smaller size than values {}".format( - i, indices.shape, arr.shape, axis, values.shape + if broadcast_shape: + indices = paddle.broadcast_to(indices, broadcast_shape) + values = paddle.broadcast_to(values, indices.shape) + else: + if isinstance(values, (paddle.Tensor, paddle.pir.OpResult)): + if len(indices.shape) != len(values.shape): + raise ValueError( + "`indices` and `values` must have the same number of dimensions!" + ) + for i in range(len(arr.shape)): + if ( + i != axis and arr.shape[i] < indices.shape[i] + ) or indices.shape[i] > values.shape[i]: + raise RuntimeError( + "Size does not match at dimension {} expected index {} to be smaller than self {} apart from dimension {} and to be smaller size than values {}".format( + i, indices.shape, arr.shape, axis, values.shape + ) ) + else: + values = paddle.to_tensor(values).astype(arr.dtype) + elements = 1 + for num in values.shape: + elements *= num + if elements == 1: # paddle.pir.OpResult has no attribute 'size' + values = paddle.broadcast_to(values, indices.shape) + axis_max_size = arr.shape[axis] + if not (indices < axis_max_size).all(): + raise RuntimeError( + "one of element of indices is out of bounds for dimension {} with size {}".format( + axis, axis_max_size ) - else: - values = paddle.to_tensor(values).astype(arr.dtype) - elements = 1 - for num in values.shape: - elements *= num - if elements == 1: # paddle.pir.OpResult has no attribute 'size' - values = paddle.broadcast_to(values, indices.shape) - axis = non_negative_axis(arr, axis) - axis_max_size = arr.shape[axis] - if not (indices < axis_max_size).all(): - raise RuntimeError( - "one of element of indices is out of bounds for dimension {} with size {}".format( - axis, axis_max_size ) - ) if in_dynamic_or_pir_mode(): return _C_ops.put_along_axis(arr, indices, values, axis, reduce) else: diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 2103265f1b2e2e..d7bcc48c8fa451 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -21,7 +21,6 @@ from ..base.data_feeder import check_type, check_variable_and_dtype from ..common_ops_import import Variable from ..framework import LayerHelper, core -from .manipulation import infer_broadcast_shape from .math import _get_reduce_axis_with_tensor from .search import where @@ -569,32 +568,13 @@ def _compute_quantile(x, q, axis=None, keepdim=False, ignore_nan=False): # TODO(chenjianye): replace the for-loop to directly take elements. for index in indices: - - def broadcast_shape(arr, indices, axis): - broadcast_shape = infer_broadcast_shape(arr, indices, axis) - if not broadcast_shape: - # if indices matrix have larger size than arr, arr should broadcast into indices shape. - broadcast_shape = indices.shape - indices = paddle.broadcast_to(indices, broadcast_shape) - broadcast_shape_list = list(broadcast_shape) - broadcast_shape_list[axis] = list(arr.shape)[axis] - broadcast_shape = tuple(broadcast_shape_list) - arr = paddle.broadcast_to(arr, broadcast_shape) - return arr, indices - indices_below = paddle.floor(index).astype(paddle.int32) indices_upper = paddle.ceil(index).astype(paddle.int32) - sorted_tensor_below, indices_below = broadcast_shape( - sorted_tensor, indices_below, axis - ) - sorted_tensor_upper, indices_upper = broadcast_shape( - sorted_tensor, indices_upper, axis - ) tensor_upper = paddle.take_along_axis( - sorted_tensor_upper, indices_upper, axis=axis + sorted_tensor, indices_upper, axis=axis ) tensor_below = paddle.take_along_axis( - sorted_tensor_below, indices_below, axis=axis + sorted_tensor, indices_below, axis=axis ) weights = index - indices_below.astype('float64') out = paddle.lerp( diff --git a/test/legacy_test/test_put_along_axis_op.py b/test/legacy_test/test_put_along_axis_op.py index dba5d65f991085..eafaf9701e55cf 100644 --- a/test/legacy_test/test_put_along_axis_op.py +++ b/test/legacy_test/test_put_along_axis_op.py @@ -61,7 +61,7 @@ def init_data(self): self.x_shape = (10, 10, 10) self.value_type = "float64" self.value = np.array([99]).astype(self.value_type) - self.index_type = "int64" + self.index_type = "int32" self.index = np.array([[[0]]]).astype(self.index_type) self.axis = 1 self.axis_type = "int64" @@ -157,7 +157,8 @@ def run(place): with paddle.static.program_guard(paddle.static.Program()): x = paddle.static.data('X', self.shape) index = paddle.static.data('Index', self.index_shape, "int64") - out = paddle.put_along_axis(x, index, self.value_np, self.axis) + value = paddle.static.data('Value', self.value_shape) + out = paddle.put_along_axis(x, index, value, self.axis) exe = paddle.static.Executor(self.place[0]) res = exe.run( feed={ @@ -167,10 +168,12 @@ def run(place): }, fetch_list=[out], ) - out_ref = copy.deepcopy(self.x_np) - for i in range(self.index_shape[0]): - for j in range(self.index_shape[1]): - out_ref[self.index_np[i, j], j] = self.value_np + + np.put_along_axis( + self.x_np, self.index_np, self.value_np, self.axis + ) + # numpy put_along_axis is an inplace opearion. + out_ref = self.x_np for out in res: np.testing.assert_allclose(out, out_ref, rtol=0.001) @@ -183,21 +186,24 @@ def run(place): paddle.disable_static(place) x_tensor = paddle.to_tensor(self.x_np) index_tensor = paddle.to_tensor(self.index_np) + value_tensor = paddle.to_tensor(self.value_np) out = paddle.put_along_axis( - x_tensor, index_tensor, self.value_np, self.axis + x_tensor, index_tensor, value_tensor, self.axis ) - out_ref = copy.deepcopy(self.x_np) - for i in range(self.index_shape[0]): - for j in range(self.index_shape[1]): - out_ref[self.index_np[i, j], j] = self.value_np + np.array( + np.put_along_axis( + self.x_np, self.index_np, self.value_np, self.axis + ) + ) + out_ref = self.x_np np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001) # for ci coverage, numpy put_along_axis did not support argument of 'reduce' paddle.put_along_axis( - x_tensor, index_tensor, self.value_np, self.axis, 'mul' + x_tensor, index_tensor, value_tensor, self.axis, 'mul' ) paddle.put_along_axis( - x_tensor, index_tensor, self.value_np, self.axis, 'add' + x_tensor, index_tensor, value_tensor, self.axis, 'add' ) paddle.enable_static() @@ -289,7 +295,7 @@ def run(place): index_tensor1 = paddle.to_tensor(self.index_np1) value_tensor = paddle.to_tensor(self.value) out = paddle.put_along_axis( - x_tensor, index_tensor1, value_tensor, 0 + x_tensor, index_tensor1, value_tensor, 0, 'assign', False ) out_ref = copy.deepcopy(self.x_np) for i in range(self.index1_shape[0]): @@ -299,15 +305,15 @@ def run(place): # for ci coverage, numpy put_along_axis did not support argument of 'reduce' paddle.put_along_axis( - x_tensor, index_tensor1, value_tensor, 0, 'mul' + x_tensor, index_tensor1, value_tensor, 0, 'mul', False ) paddle.put_along_axis( - x_tensor, index_tensor1, value_tensor, 0, 'add' + x_tensor, index_tensor1, value_tensor, 0, 'add', False ) index_tensor2 = paddle.to_tensor(self.index_np2) out = paddle.put_along_axis( - x_tensor, index_tensor2, value_tensor, 1 + x_tensor, index_tensor2, value_tensor, 1, 'assign', False ) out_ref = copy.deepcopy(self.x_np) for i in range(self.index2_shape[0]): @@ -317,10 +323,10 @@ def run(place): # for ci coverage, numpy put_along_axis did not support argument of 'reduce' paddle.put_along_axis( - x_tensor, index_tensor2, value_tensor, 1, 'mul' + x_tensor, index_tensor2, value_tensor, 1, 'mul', False ) paddle.put_along_axis( - x_tensor, index_tensor2, value_tensor, 1, 'add' + x_tensor, index_tensor2, value_tensor, 1, 'add', False ) paddle.enable_static() @@ -337,7 +343,9 @@ def run(place): x1 = paddle.static.data('X', self.shape) index1 = paddle.static.data('Index', self.index1_shape, "int64") value_tensor = paddle.to_tensor(self.value) - out1 = paddle.put_along_axis(x1, index1, value_tensor, 0) + out1 = paddle.put_along_axis( + x1, index1, value_tensor, 0, 'assign', False + ) exe = paddle.static.Executor(place) res = exe.run( feed={ @@ -359,7 +367,9 @@ def run(place): x2 = paddle.static.data('X', self.shape) index2 = paddle.static.data('Index', self.index2_shape, "int64") value_tensor = paddle.to_tensor(self.value) - out2 = paddle.put_along_axis(x2, index2, value_tensor, 1) + out2 = paddle.put_along_axis( + x2, index2, value_tensor, 1, 'assign', False + ) exe = paddle.static.Executor(place) res = exe.run( feed={ @@ -386,13 +396,17 @@ def test_error(self): values = paddle.to_tensor([2]) # len(arr.shape) != len(indices.shape) try: - res = paddle.put_along_axis(tensorx, indices, 1.0, 0) + res = paddle.put_along_axis( + tensorx, indices, 1.0, 0, 'assign', False + ) except Exception as error: self.assertIsInstance(error, ValueError) indices = paddle.to_tensor([[1]]).astype("int32") # len(values.shape) != len(indices.shape) try: - res = paddle.put_along_axis(tensorx, indices, values, 0) + res = paddle.put_along_axis( + tensorx, indices, values, 0, 'assign', False + ) except Exception as error: self.assertIsInstance(error, ValueError) indices = paddle.to_tensor( @@ -400,13 +414,17 @@ def test_error(self): ).astype("int32") # indices too large try: - res = paddle.put_along_axis(tensorx, indices, 1.0, 0) + res = paddle.put_along_axis( + tensorx, indices, 1.0, 0, 'assign', False + ) except Exception as error: self.assertIsInstance(error, RuntimeError) indices = paddle.to_tensor([[10]]).astype("int32") # the element of indices out of range try: - res = paddle.put_along_axis(tensorx, indices, 1.0, 0) + res = paddle.put_along_axis( + tensorx, indices, 1.0, 0, 'assign', False + ) except Exception as error: self.assertIsInstance(error, RuntimeError) diff --git a/test/legacy_test/test_take_along_axis_op.py b/test/legacy_test/test_take_along_axis_op.py index 6d6ed74fe4f224..1e0e173e1e3ff8 100644 --- a/test/legacy_test/test_take_along_axis_op.py +++ b/test/legacy_test/test_take_along_axis_op.py @@ -163,17 +163,9 @@ def test_api_static(self): res = exe.run( feed={'X': self.x_np, 'Index': self.index_np}, fetch_list=[out] ) - out_ref = np.zeros_like(self.index_np, dtype=self.x_np.dtype) - if self.axis == 0: - for i in range(self.index_shape[0]): - for j in range(self.index_shape[1]): - out_ref[i, j] = self.x_np[self.index_np[i, j], j] - elif self.axis == 1: - for i in range(self.index_shape[0]): - for j in range(self.index_shape[1]): - out_ref[i, j] = self.x_np[i, self.index_np[i, j]] - else: - return + out_ref = np.array( + np.take_along_axis(self.x_np, self.index_np, self.axis) + ) for out in res: np.testing.assert_allclose(out, out_ref, rtol=0.001) @@ -182,17 +174,9 @@ def test_api_dygraph(self): x_tensor = paddle.to_tensor(self.x_np) self.index = paddle.to_tensor(self.index_np) out = paddle.take_along_axis(x_tensor, self.index, self.axis) - out_ref = np.zeros_like(self.index_np, dtype=self.x_np.dtype) - if self.axis == 0: - for i in range(self.index_shape[0]): - for j in range(self.index_shape[1]): - out_ref[i, j] = self.x_np[self.index_np[i, j], j] - elif self.axis == 1: - for i in range(self.index_shape[0]): - for j in range(self.index_shape[1]): - out_ref[i, j] = self.x_np[i, self.index_np[i, j]] - else: - return + out_ref = np.array( + np.take_along_axis(self.x_np, self.index_np, self.axis) + ) np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001) paddle.enable_static() @@ -226,18 +210,48 @@ def setUp(self): self.place.append(paddle.CUDAPlace(0)) -class TestTakeAlongAxisAPICase2(TestTakeAlongAxisAPI): +class TestTakeAlongAxisAPICase2(unittest.TestCase): def setUp(self): np.random.seed(0) - self.shape = [2, 2] - self.index_shape = [1, 1] - self.index_np = np.array([[1]]).astype('int64') + self.shape = [3, 3] + self.index_shape = [1, 3] + self.index_np = np.array([[0, 1, 2]]).astype('int64') self.x_np = np.random.random(self.shape).astype(np.float32) self.place = [paddle.CPUPlace()] - self.axis = 1 + self.axis = 0 if core.is_compiled_with_cuda(): self.place.append(paddle.CUDAPlace(0)) + @test_with_pir_api + def test_api_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('X', self.shape) + index = paddle.static.data('Index', self.index_shape, "int64") + out = paddle.take_along_axis(x, index, self.axis, False) + exe = paddle.static.Executor(self.place[0]) + res = exe.run( + feed={'X': self.x_np, 'Index': self.index_np}, fetch_list=[out] + ) + out_ref = np.zeros_like(self.index_np, dtype=self.x_np.dtype) + for i in range(self.index_shape[0]): + for j in range(self.index_shape[1]): + out_ref[i, j] = self.x_np[self.index_np[i, j], j] + for out in res: + np.testing.assert_allclose(out, out_ref, rtol=0.001) + + def test_api_dygraph(self): + paddle.disable_static(self.place[0]) + x_tensor = paddle.to_tensor(self.x_np) + self.index = paddle.to_tensor(self.index_np) + out = paddle.take_along_axis(x_tensor, self.index, self.axis, False) + out_ref = np.zeros_like(self.index_np, dtype=self.x_np.dtype) + for i in range(self.index_shape[0]): + for j in range(self.index_shape[1]): + out_ref[i, j] = self.x_np[self.index_np[i, j], j] + np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001) + paddle.enable_static() + if __name__ == "__main__": paddle.enable_static() From efc488c3e626a5617db665eecf5feb71fa2fc8d8 Mon Sep 17 00:00:00 2001 From: YibinLiu666 <2632839426@qq.com> Date: Wed, 22 Nov 2023 14:16:30 +0000 Subject: [PATCH 08/14] use origin example --- python/paddle/tensor/manipulation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index ddd5ed7745a42f..430d59dad01429 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -5243,9 +5243,9 @@ def put_along_axis(arr, indices, values, axis, reduce='assign', broadcast=True): >>> axis = 0 >>> result = paddle.put_along_axis(x, index, value, axis) >>> print(result) - Tensor(shape=[2, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True, - [[99, 99, 99], - [60, 40, 50]]) + Tensor(shape=[2, 3], dtype=int64, place=Place(cpu), stop_gradient=True, + [[99, 99, 99], + [60, 40, 50]]) """ if len(arr.shape) != len(indices.shape): From 675b641fac4dec5ce8f486ebdbdb7773205e3858 Mon Sep 17 00:00:00 2001 From: YibinLiu666 <2632839426@qq.com> Date: Sat, 25 Nov 2023 03:08:34 +0000 Subject: [PATCH 09/14] add param include_self --- python/paddle/tensor/manipulation.py | 15 ++++++++-- test/legacy_test/test_put_along_axis_op.py | 32 ++++++++++++++-------- 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 430d59dad01429..81137ebb01fc58 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -5217,7 +5217,15 @@ def take_along_axis(arr, indices, axis, broadcast=True): return result -def put_along_axis(arr, indices, values, axis, reduce='assign', broadcast=True): +def put_along_axis( + arr, + indices, + values, + axis, + reduce='assign', + include_self=True, + broadcast=True, +): """ Put values into the destination array by given indices matrix along the designated axis. @@ -5227,7 +5235,8 @@ def put_along_axis(arr, indices, values, axis, reduce='assign', broadcast=True): and need to broadcast against arr. Supported data type are int and int64. axis (int) : The axis to put 1d slices along. reduce (str, optional): The reduce operation, default is 'assign', support 'add', 'assign', 'mul' and 'multiply'. - broadcast (bool, optional): whether the indices broadcast. + include_self (bool, optional): whether to reduce with the elements of arr. (Only support True now) + broadcast (bool, optional): whether to broadcast indices. Returns: Tensor, The indexed element, same dtype with arr @@ -5248,6 +5257,8 @@ def put_along_axis(arr, indices, values, axis, reduce='assign', broadcast=True): [60, 40, 50]]) """ + if not include_self: + raise ValueError("`include_self` is only support True now.") if len(arr.shape) != len(indices.shape): raise ValueError( "`indices` and `arr` must have the same number of dimensions!" diff --git a/test/legacy_test/test_put_along_axis_op.py b/test/legacy_test/test_put_along_axis_op.py index eafaf9701e55cf..bc6f60296831c0 100644 --- a/test/legacy_test/test_put_along_axis_op.py +++ b/test/legacy_test/test_put_along_axis_op.py @@ -295,7 +295,7 @@ def run(place): index_tensor1 = paddle.to_tensor(self.index_np1) value_tensor = paddle.to_tensor(self.value) out = paddle.put_along_axis( - x_tensor, index_tensor1, value_tensor, 0, 'assign', False + x_tensor, index_tensor1, value_tensor, 0, 'assign', True, False ) out_ref = copy.deepcopy(self.x_np) for i in range(self.index1_shape[0]): @@ -305,15 +305,15 @@ def run(place): # for ci coverage, numpy put_along_axis did not support argument of 'reduce' paddle.put_along_axis( - x_tensor, index_tensor1, value_tensor, 0, 'mul', False + x_tensor, index_tensor1, value_tensor, 0, 'mul', True, False ) paddle.put_along_axis( - x_tensor, index_tensor1, value_tensor, 0, 'add', False + x_tensor, index_tensor1, value_tensor, 0, 'add', True, False ) index_tensor2 = paddle.to_tensor(self.index_np2) out = paddle.put_along_axis( - x_tensor, index_tensor2, value_tensor, 1, 'assign', False + x_tensor, index_tensor2, value_tensor, 1, 'assign', True, False ) out_ref = copy.deepcopy(self.x_np) for i in range(self.index2_shape[0]): @@ -323,10 +323,10 @@ def run(place): # for ci coverage, numpy put_along_axis did not support argument of 'reduce' paddle.put_along_axis( - x_tensor, index_tensor2, value_tensor, 1, 'mul', False + x_tensor, index_tensor2, value_tensor, 1, 'mul', True, False ) paddle.put_along_axis( - x_tensor, index_tensor2, value_tensor, 1, 'add', False + x_tensor, index_tensor2, value_tensor, 1, 'add', True, False ) paddle.enable_static() @@ -344,7 +344,7 @@ def run(place): index1 = paddle.static.data('Index', self.index1_shape, "int64") value_tensor = paddle.to_tensor(self.value) out1 = paddle.put_along_axis( - x1, index1, value_tensor, 0, 'assign', False + x1, index1, value_tensor, 0, 'assign', True, False ) exe = paddle.static.Executor(place) res = exe.run( @@ -368,7 +368,7 @@ def run(place): index2 = paddle.static.data('Index', self.index2_shape, "int64") value_tensor = paddle.to_tensor(self.value) out2 = paddle.put_along_axis( - x2, index2, value_tensor, 1, 'assign', False + x2, index2, value_tensor, 1, 'assign', True, False ) exe = paddle.static.Executor(place) res = exe.run( @@ -397,7 +397,7 @@ def test_error(self): # len(arr.shape) != len(indices.shape) try: res = paddle.put_along_axis( - tensorx, indices, 1.0, 0, 'assign', False + tensorx, indices, 1.0, 0, 'assign', True, False ) except Exception as error: self.assertIsInstance(error, ValueError) @@ -405,7 +405,7 @@ def test_error(self): # len(values.shape) != len(indices.shape) try: res = paddle.put_along_axis( - tensorx, indices, values, 0, 'assign', False + tensorx, indices, values, 0, 'assign', True, False ) except Exception as error: self.assertIsInstance(error, ValueError) @@ -415,7 +415,7 @@ def test_error(self): # indices too large try: res = paddle.put_along_axis( - tensorx, indices, 1.0, 0, 'assign', False + tensorx, indices, 1.0, 0, 'assign', True, False ) except Exception as error: self.assertIsInstance(error, RuntimeError) @@ -423,11 +423,19 @@ def test_error(self): # the element of indices out of range try: res = paddle.put_along_axis( - tensorx, indices, 1.0, 0, 'assign', False + tensorx, indices, 1.0, 0, 'assign', True, False ) except Exception as error: self.assertIsInstance(error, RuntimeError) + # use includ_self=False + try: + res = paddle.put_along_axis( + tensorx, indices, 1.0, 0, 'assign', False + ) + except Exception as error: + self.assertIsInstance(error, ValueError) + if __name__ == "__main__": paddle.enable_static() From 2c968e381c66540216b85d90ee6b30e84bed166f Mon Sep 17 00:00:00 2001 From: YibinLiu666 <2632839426@qq.com> Date: Wed, 29 Nov 2023 16:03:25 +0000 Subject: [PATCH 10/14] update param name --- .../kernels/funcs/gather_scatter_functor.cc | 39 +++++++++---------- .../kernels/funcs/gather_scatter_functor.h | 6 +-- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.cc b/paddle/phi/kernels/funcs/gather_scatter_functor.cc index e4f7864cb34e1f..2d4a4e0043ba37 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.cc +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cc @@ -193,26 +193,26 @@ template void cpu_scatter_input_grad_kernel(phi::DenseTensor self UNUSED, int dim, const phi::DenseTensor& index, - phi::DenseTensor output, + phi::DenseTensor grad, const phi::DeviceContext& ctx UNUSED) { auto* index_data = index.data(); - auto* output_data = output.data(); + auto* grad_data = grad.data(); auto index_dims = index.dims(); - auto output_dims = output.dims(); + auto grad_dims = grad.dims(); int64_t inner_dim_size = 1; int64_t outer_dim_size = 1; int64_t outer_dim_size_data = 1; int64_t select_dim_size = index_dims[dim]; - int64_t output_select_dim_size = output_dims[dim]; + int64_t grad_select_dim_size = grad_dims[dim]; for (int i = 0; i < dim; ++i) { inner_dim_size *= index_dims[i]; } for (int i = dim + 1; i < index_dims.size(); i++) { outer_dim_size *= index_dims[i]; - outer_dim_size_data *= output_dims[i]; + outer_dim_size_data *= grad_dims[i]; } int64_t index_idx = 0; @@ -220,10 +220,9 @@ void cpu_scatter_input_grad_kernel(phi::DenseTensor self UNUSED, for (int64_t j = 0; j < select_dim_size; j++) { for (int64_t k = 0; k < outer_dim_size; k++) { int64_t index = index_data[index_idx]; - int64_t replace_index = - k + index * outer_dim_size_data + - i * outer_dim_size_data * output_select_dim_size; - output_data[replace_index] = 0; + int64_t replace_index = k + index * outer_dim_size_data + + i * outer_dim_size_data * grad_select_dim_size; + grad_data[replace_index] = 0; index_idx++; } } @@ -234,15 +233,15 @@ template void cpu_scatter_value_grad_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, - phi::DenseTensor output, + phi::DenseTensor grad, const phi::DeviceContext& ctx UNUSED) { auto* self_data = self.data(); auto* index_data = index.data(); - auto* output_data = output.data(); + auto* grad_data = grad.data(); auto index_dims = index.dims(); auto self_dims = self.dims(); - auto output_dims = output.dims(); + auto grad_dims = grad.dims(); int64_t self_size = self.numel(); bool* is_self_grad_used = new bool[self_size]; @@ -254,10 +253,10 @@ void cpu_scatter_value_grad_kernel(phi::DenseTensor self, int64_t inner_dim_size = 1; int64_t outer_dim_size = 1; int64_t outer_dim_size_self = 1; - int64_t outer_dim_size_output = 1; + int64_t outer_dim_size_grad = 1; int64_t select_dim_size = index_dims[dim]; int64_t self_select_dim_size = self_dims[dim]; - int64_t output_select_dim_size = output_dims[dim]; + int64_t grad_select_dim_size = grad_dims[dim]; for (int i = 0; i < dim; ++i) { inner_dim_size *= index_dims[i]; } @@ -265,7 +264,7 @@ void cpu_scatter_value_grad_kernel(phi::DenseTensor self, for (int i = dim + 1; i < index_dims.size(); i++) { outer_dim_size *= index_dims[i]; outer_dim_size_self *= self_dims[i]; - outer_dim_size_output *= output_dims[i]; + outer_dim_size_grad *= grad_dims[i]; } int64_t index_idx = index.numel() - 1; @@ -276,14 +275,14 @@ void cpu_scatter_value_grad_kernel(phi::DenseTensor self, int64_t replace_index_self = k + index * outer_dim_size_self + i * outer_dim_size_self * self_select_dim_size; - int64_t replace_index_output = - k + j * outer_dim_size_output + - i * outer_dim_size_output * output_select_dim_size; + int64_t replace_index_grad = + k + j * outer_dim_size_grad + + i * outer_dim_size_grad * grad_select_dim_size; if (!is_self_grad_used[replace_index_self]) { - output_data[replace_index_output] = self_data[replace_index_self]; + grad_data[replace_index_grad] = self_data[replace_index_self]; is_self_grad_used[replace_index_self] = true; } else { - output_data[replace_index_output] = 0; + grad_data[replace_index_grad] = 0; } index_idx--; } diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.h b/paddle/phi/kernels/funcs/gather_scatter_functor.h index 65930a1a11e1ed..054ccc196fcd00 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.h +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.h @@ -75,14 +75,14 @@ template void cpu_scatter_input_grad_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, - phi::DenseTensor result, + phi::DenseTensor grad, const phi::DeviceContext& ctx); template void cpu_scatter_value_grad_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, - phi::DenseTensor output, + phi::DenseTensor grad, const phi::DeviceContext& ctx); template @@ -117,7 +117,7 @@ template void gpu_scatter_input_grad_kernel(phi::DenseTensor self, int dim, const phi::DenseTensor& index, - phi::DenseTensor result, + phi::DenseTensor grad, const phi::DeviceContext& ctx); template From d32db6f3d32accf32b52804472d9dd99b8d74962 Mon Sep 17 00:00:00 2001 From: YibinLiu666 <2632839426@qq.com> Date: Thu, 30 Nov 2023 10:54:47 +0800 Subject: [PATCH 11/14] modify ut --- test/legacy_test/test_put_along_axis_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_put_along_axis_op.py b/test/legacy_test/test_put_along_axis_op.py index bc6f60296831c0..390c5a6ab9377f 100644 --- a/test/legacy_test/test_put_along_axis_op.py +++ b/test/legacy_test/test_put_along_axis_op.py @@ -74,7 +74,7 @@ def init_data(self): self.x_shape = (10, 10, 10) self.value_type = "float16" self.value = np.array([99]).astype(self.value_type) - self.index_type = "int32" + self.index_type = "int64" self.index = np.array([[[0]]]).astype(self.index_type) self.axis = 1 self.axis_type = "int64" From 564de930d41c82829d6ff0cbb25dbce734be4aa2 Mon Sep 17 00:00:00 2001 From: YibinLiu666 <2632839426@qq.com> Date: Thu, 30 Nov 2023 12:52:54 +0000 Subject: [PATCH 12/14] update test case --- .../kernels/cpu/put_along_axis_grad_kernel.cc | 2 +- .../kernels/funcs/gather_scatter_functor.cc | 7 ++-- .../kernels/gpu/put_along_axis_grad_kernel.cu | 2 +- test/legacy_test/prim_op_test.py | 5 ++- test/legacy_test/test_put_along_axis_op.py | 40 +++++++++++++++++++ test/legacy_test/test_take_along_axis_op.py | 28 +++++++++++++ 6 files changed, 78 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc b/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc index f0a1118ca92d7d..dd7b762849d16b 100644 --- a/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc @@ -62,7 +62,7 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, if (index_type == DataType::INT32) { phi::funcs::cpu_scatter_value_grad_kernel( out_grad, axis, index, *value_grad, dev_ctx); - } else if (index_type == DataType::INT64) { + } else { phi::funcs::cpu_scatter_value_grad_kernel( out_grad, axis, index, *value_grad, dev_ctx); } diff --git a/paddle/phi/kernels/funcs/gather_scatter_functor.cc b/paddle/phi/kernels/funcs/gather_scatter_functor.cc index 2d4a4e0043ba37..be07c68b0fd338 100644 --- a/paddle/phi/kernels/funcs/gather_scatter_functor.cc +++ b/paddle/phi/kernels/funcs/gather_scatter_functor.cc @@ -244,6 +244,7 @@ void cpu_scatter_value_grad_kernel(phi::DenseTensor self, auto grad_dims = grad.dims(); int64_t self_size = self.numel(); + int64_t grad_size = grad.numel(); bool* is_self_grad_used = new bool[self_size]; for (int i = 0; i < self_size; i++) { @@ -266,8 +267,10 @@ void cpu_scatter_value_grad_kernel(phi::DenseTensor self, outer_dim_size_self *= self_dims[i]; outer_dim_size_grad *= grad_dims[i]; } - int64_t index_idx = index.numel() - 1; + for (int i = 0; i < grad_size; i++) { + grad_data[i] = static_cast(0); + } for (int64_t i = inner_dim_size - 1; i >= 0; i--) { for (int64_t j = select_dim_size - 1; j >= 0; j--) { for (int64_t k = outer_dim_size - 1; k >= 0; k--) { @@ -281,8 +284,6 @@ void cpu_scatter_value_grad_kernel(phi::DenseTensor self, if (!is_self_grad_used[replace_index_self]) { grad_data[replace_index_grad] = self_data[replace_index_self]; is_self_grad_used[replace_index_self] = true; - } else { - grad_data[replace_index_grad] = 0; } index_idx--; } diff --git a/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu b/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu index ab380351610adf..d86e0493786ebd 100644 --- a/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu @@ -54,7 +54,7 @@ void PutAlongAxisGradKernel(const Context& dev_ctx, if (index_type == DataType::INT32) { phi::funcs::gpu_scatter_value_grad_kernel( out_grad, axis, index, *value_grad, dev_ctx); - } else if (index_type == DataType::INT64) { + } else { phi::funcs::gpu_scatter_value_grad_kernel( out_grad, axis, index, *value_grad, dev_ctx); } diff --git a/test/legacy_test/prim_op_test.py b/test/legacy_test/prim_op_test.py index 743a856058a6f0..d1680c33e6c924 100644 --- a/test/legacy_test/prim_op_test.py +++ b/test/legacy_test/prim_op_test.py @@ -192,7 +192,10 @@ def parse_attri_value(name, op_inputs, op_proto_attrs): tmp = input_arguments[idx_of_op_proto_arguments] idx_of_op_proto_arguments += 1 else: - tmp = Empty() # use the default value + # tmp = Empty() # use the default value + tmp = parse_attri_value( + arg_name, op_proto_ins, op_proto_attrs + ) if isinstance(tmp, Empty): results.append(get_default(idx, api_defaults)) diff --git a/test/legacy_test/test_put_along_axis_op.py b/test/legacy_test/test_put_along_axis_op.py index 390c5a6ab9377f..43d2e80c25e24a 100644 --- a/test/legacy_test/test_put_along_axis_op.py +++ b/test/legacy_test/test_put_along_axis_op.py @@ -80,6 +80,46 @@ def init_data(self): self.axis_type = "int64" +class TestPutAlongAxisOpCase2(TestPutAlongAxisOp): + def setUp(self): + self.init_data() + self.reduce_op = "assign" + self.op_type = "put_along_axis" + self.python_api = paddle.tensor.put_along_axis + self.xnp = np.random.random(self.x_shape).astype(self.x_type) + # numpy put_along_axis is an inplace operation. + self.target = copy.deepcopy(self.xnp) + for i in range(5): + for j in range(5): + for k in range(5): + self.target[i, self.index[i, j, k], k] = self.value[i, j, k] + self.inputs = { + 'Input': self.xnp, + 'Index': self.index, + 'Value': self.value, + } + self.attrs = { + 'Axis': self.axis, + 'Reduce': self.reduce_op, + 'include_self': True, + 'broadcast': False, + } + self.outputs = {'Result': self.target} + + def init_data(self): + self.dtype = 'float32' + self.x_type = "float32" + self.x_shape = (10, 10, 10) + self.value_type = "float32" + self.value = ( + np.arange(1, 126).reshape((5, 5, 5)).astype(self.value_type) + ) + self.index_type = "int64" + self.index = np.zeros((5, 5, 5)).astype(self.index_type) + self.axis = 1 + self.axis_type = "int64" + + @unittest.skipIf( not core.is_compiled_with_cuda() or not core.is_bfloat16_supported(core.CUDAPlace(0)), diff --git a/test/legacy_test/test_take_along_axis_op.py b/test/legacy_test/test_take_along_axis_op.py index 1e0e173e1e3ff8..1d5227ad4ab582 100644 --- a/test/legacy_test/test_take_along_axis_op.py +++ b/test/legacy_test/test_take_along_axis_op.py @@ -76,6 +76,34 @@ def init_data(self): self.axis_type = "int64" +class TestTakeAlongAxisOp(OpTest): + def setUp(self): + self.init_data() + self.op_type = "take_along_axis" + self.python_api = paddle.tensor.take_along_axis + self.check_cinn = True + self.xnp = np.random.random(self.x_shape).astype(self.x_type) + self.target = np.zeros((2, 3, 4)).astype(self.x_type) + for i in range(2): + for j in range(3): + for k in range(4): + self.target[i, j, k] = self.xnp[i, j, self.index[i, j, k]] + self.inputs = { + 'Input': self.xnp, + 'Index': self.index, + } + self.attrs = {'Axis': self.axis, 'broadcast': False} + self.outputs = {'Result': self.target} + + def init_data(self): + self.x_type = "float64" + self.x_shape = (10, 10, 10) + self.index_type = "int64" + self.index = np.random.randint(0, 10, (2, 3, 4)).astype(self.index_type) + self.axis = 2 + self.axis_type = "int64" + + @unittest.skipIf( not core.is_compiled_with_cuda() or not core.is_bfloat16_supported(core.CUDAPlace(0)), From d0c14de6dc278f62d9bba152f1de8415dc70e76a Mon Sep 17 00:00:00 2001 From: YibinLiu666 <2632839426@qq.com> Date: Fri, 1 Dec 2023 02:28:24 +0000 Subject: [PATCH 13/14] add error UT --- test/legacy_test/test_take_along_axis_op.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/legacy_test/test_take_along_axis_op.py b/test/legacy_test/test_take_along_axis_op.py index 1d5227ad4ab582..90bdf0c4c57b4c 100644 --- a/test/legacy_test/test_take_along_axis_op.py +++ b/test/legacy_test/test_take_along_axis_op.py @@ -280,6 +280,21 @@ def test_api_dygraph(self): np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001) paddle.enable_static() + def test_error(self): + tensorx = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]).astype("float32") + indices = paddle.to_tensor([1]).astype("int32") + # len(arr.shape) != len(indices.shape) + try: + res = paddle.take_along_axis(tensorx, indices, 0, False) + except Exception as error: + self.assertIsInstance(error, ValueError) + indices = paddle.to_tensor([[10]]).astype("int32") + # the element of indices out of range + try: + res = paddle.take_along_axis(tensorx, indices, 0, False) + except Exception as error: + self.assertIsInstance(error, RuntimeError) + if __name__ == "__main__": paddle.enable_static() From 28aadd6d8e2f6c0b5929a1e59acd5eafb3f67e75 Mon Sep 17 00:00:00 2001 From: YibinLiu666 <2632839426@qq.com> Date: Fri, 1 Dec 2023 05:52:57 +0000 Subject: [PATCH 14/14] update --- test/legacy_test/test_take_along_axis_op.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/test/legacy_test/test_take_along_axis_op.py b/test/legacy_test/test_take_along_axis_op.py index 90bdf0c4c57b4c..dc73ae34aeea43 100644 --- a/test/legacy_test/test_take_along_axis_op.py +++ b/test/legacy_test/test_take_along_axis_op.py @@ -281,19 +281,22 @@ def test_api_dygraph(self): paddle.enable_static() def test_error(self): + paddle.disable_static(self.place[0]) tensorx = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]).astype("float32") indices = paddle.to_tensor([1]).astype("int32") # len(arr.shape) != len(indices.shape) - try: + with self.assertRaises(ValueError): res = paddle.take_along_axis(tensorx, indices, 0, False) - except Exception as error: - self.assertIsInstance(error, ValueError) - indices = paddle.to_tensor([[10]]).astype("int32") # the element of indices out of range - try: + with self.assertRaises(RuntimeError): + indices = paddle.to_tensor([[100]]).astype("int32") + res = paddle.take_along_axis(tensorx, indices, 0, False) + # the shape of indices doesn't match + with self.assertRaises(RuntimeError): + indices = paddle.to_tensor( + [[1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0]] + ).astype("int32") res = paddle.take_along_axis(tensorx, indices, 0, False) - except Exception as error: - self.assertIsInstance(error, RuntimeError) if __name__ == "__main__":