Skip to content

Commit

Permalink
[Cherry-Pick]Support output 0D for is_empty/as_complex/inner/dot/rank…
Browse files Browse the repository at this point in the history
…/tensordot/squeeze_/static.accuracy/static.auc/metric.accuracy (#53199)

* support output 0D for is_empty/as_complex/inner/dot/rank/tensordot/squeeze_/static.accuracy/static.auc/metric.accuracy

* test_dot_py

* test_dot_py
  • Loading branch information
ROckDog22 authored Apr 27, 2023
1 parent 7b4badb commit f84ac44
Show file tree
Hide file tree
Showing 15 changed files with 414 additions and 42 deletions.
5 changes: 3 additions & 2 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1152,8 +1152,9 @@ void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
"with input tensor Y: %s",
x_dims.to_str(),
y_dims.to_str()));

x_dims[x_dims.size() - 1] = 1;
std::vector<int64_t> x_dims_vec = phi::vectorize(x_dims);
std::vector<int64_t> x_dims_vec_cut(x_dims_vec.begin(), x_dims_vec.end() - 1);
x_dims = phi::make_ddim(x_dims_vec_cut);
out->set_dims(x_dims);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ void AucInferMeta(const MetaTensor& input,
0,
phi::errors::InvalidArgument("slide_steps must be natural number"));

auc->set_dims({1});
auc->set_dims(phi::make_ddim({}));
auc->set_dtype(DataType::INT64);

if (slide_steps) {
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ void AccuracyInferMeta(const MetaTensor& out,
label_dim[0]));
}

accuracy->set_dims({1});
accuracy->set_dims(phi::make_ddim({}));
correct->set_dims(phi::make_ddim({}));
total->set_dims(phi::make_ddim({}));
accuracy->set_dtype(out.dtype());
correct->set_dims({1});
correct->set_dtype(out.dtype());
total->set_dims({1});
total->set_dtype(out.dtype());
accuracy->share_lod(out);
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1839,7 +1839,7 @@ void InverseInferMeta(const MetaTensor& x, MetaTensor* out) {
}

void IsEmptyInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dims(phi::make_ddim({1}));
out->set_dims(phi::make_ddim({}));
out->set_dtype(DataType::BOOL);
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/gpu/dot_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ void DotKernel(const Context& dev_ctx,
const DenseTensor& y,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
if (1 == out->dims().size()) {
if (out->dims().size() == 0) {
auto eigen_out = phi::EigenScalar<T>::From(*out);
auto eigen_x = phi::EigenVector<T>::Flatten(x);
auto eigen_y = phi::EigenVector<T>::Flatten(y);

auto& dev = *dev_ctx.eigen_device();
eigen_out.device(dev) = (eigen_x * eigen_y).sum();
} else {
auto eigen_out = phi::EigenMatrix<T>::From(*out);
auto eigen_out = phi::EigenVector<T>::From(*out);
auto eigen_x = phi::EigenMatrix<T>::From(x);
auto eigen_y = phi::EigenMatrix<T>::From(y);

Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/kernels/impl/dot_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct DotGradFunction<DeviceContext, T, phi::funcs::EnableComplex<T>> {
DenseTensor* tensor_dy) {
VLOG(1) << "enable route";
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
if (1 >= tensor_dout->dims().size()) {
auto dout = EigenVector<T>::Flatten(*tensor_dout);

if (tensor_dx) {
Expand Down Expand Up @@ -144,7 +144,7 @@ struct DotGradFunction<DeviceContext, T, phi::funcs::DisableComplex<T>> {
DenseTensor* tensor_dx,
DenseTensor* tensor_dy) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
if (1 >= tensor_dout->dims().size()) {
auto dout = EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
auto y = EigenVector<T>::Flatten(*tensor_y);
Expand Down Expand Up @@ -236,7 +236,7 @@ struct DotDoubleGradFunction<DeviceContext, T, phi::funcs::EnableComplex<T>> {
const DenseTensor* tensor_ddx = tensor_ddx_opt->get_ptr();
const DenseTensor* tensor_ddy = tensor_ddy_opt->get_ptr();
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
if (1 >= tensor_dout->dims().size()) {
DenseTensor tensor_dout_help;
auto& dev = *ctx.eigen_device();
if (tensor_dx || tensor_dy) {
Expand Down Expand Up @@ -431,7 +431,7 @@ struct DotDoubleGradFunction<DeviceContext, T, phi::funcs::DisableComplex<T>> {
const DenseTensor* tensor_ddx = tensor_ddx_opt->get_ptr();
const DenseTensor* tensor_ddy = tensor_ddy_opt->get_ptr();
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
if (1 >= tensor_dout->dims().size()) {
auto& dev = *ctx.eigen_device();
auto x = EigenVector<T>::Flatten(*tensor_x);
auto y = EigenVector<T>::Flatten(*tensor_y);
Expand Down Expand Up @@ -621,7 +621,7 @@ struct DotTripleGradFunction<DeviceContext, T, phi::funcs::EnableComplex<T>> {
const DenseTensor* in_tensor_d_dy = in_tensor_d_dy_opt->get_ptr();
const DenseTensor* in_tensor_d_ddout = in_tensor_d_ddout_opt->get_ptr();
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == in_tensor_dout->dims().size()) {
if (1 >= in_tensor_dout->dims().size()) {
auto& dev = *ctx.eigen_device();
DenseTensor in_tensor_x_help = Conj<T, DeviceContext>(ctx, *in_tensor_x);
DenseTensor in_tensor_y_help = Conj<T, DeviceContext>(ctx, *in_tensor_y);
Expand Down Expand Up @@ -1015,7 +1015,7 @@ struct DotTripleGradFunction<DeviceContext, T, phi::funcs::DisableComplex<T>> {
const DenseTensor* in_tensor_d_dy = in_tensor_d_dy_opt->get_ptr();
const DenseTensor* in_tensor_d_ddout = in_tensor_d_ddout_opt->get_ptr();
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == in_tensor_dout->dims().size()) {
if (1 >= in_tensor_dout->dims().size()) {
auto& dev = *ctx.eigen_device();
bool d_dout_flag = false;
bool d_ddx_flag = false;
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/check_nan_inf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def check(use_cuda):
step += 1
print(
'iter={:.0f},cost={},acc1={}'.format(
step, outs[1][0], outs[2][0]
step, outs[1][0], outs[2]
)
)

Expand Down
19 changes: 8 additions & 11 deletions python/paddle/fluid/tests/unittests/test_dot_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ def test_2d_input(self):
x = paddle.to_tensor(np.reshape(data, [0, 0]), dtype='float32')
y = paddle.to_tensor(np.reshape(data, [0, 0]), dtype='float32')
pd_out = paddle.dot(x, y)

self.assertEqual(pd_out.shape, (0, 1))
self.assertEqual(pd_out.shape, (0,))

def test_3d_input_error(self):
data = np.array([], dtype=np.float32)
Expand All @@ -127,7 +126,7 @@ def init_input_output(self):
self.y = (
np.random.uniform(1, 3, [132]).astype(self.dtype).reshape([11, 12])
)
self.out = np.sum(self.x * self.y, axis=1).reshape([11, 1])
self.out = np.sum(self.x * self.y, axis=1)

def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
Expand Down Expand Up @@ -180,7 +179,7 @@ def test_dygraph(self):
np.array([[2, 5], [6, 8]]).astype(np.float32)
)
np.testing.assert_array_equal(
paddle.dot(x1, y1).numpy(), np.array([[17], [58]])
paddle.dot(x1, y1).numpy(), np.array([17, 58])
)


Expand Down Expand Up @@ -211,7 +210,7 @@ def init_input_output(self):
self.out = np.dot(self.x, self.y)

def init_grad_input_output(self):
self.grad_out = np.ones(1, self.dtype) + 1j * np.ones(1, self.dtype)
self.grad_out = np.ones([], self.dtype) + 1j * np.ones([], self.dtype)
self.grad_x = self.grad_out * np.conj(self.y)
self.grad_y = self.grad_out * np.conj(self.x)

Expand Down Expand Up @@ -269,12 +268,10 @@ def init_input_output(self):
self.y = np.random.random((2, 100)).astype(
self.dtype
) + 1j * np.random.random((2, 100)).astype(self.dtype)
self.out = np.diag(np.dot(self.x, self.y.T)).reshape(-1, 1)
self.out = np.diag(np.dot(self.x, self.y.T)).reshape(-1)

def init_grad_input_output(self):
self.grad_out = np.ones((2, 1), self.dtype) + 1j * np.ones(
(2, 1), self.dtype
)
self.grad_out = np.ones((2), self.dtype) + 1j * np.ones((2), self.dtype)
self.grad_x = self._get_grad(self.grad_out, self.y)
self.grad_y = self._get_grad(self.grad_out, self.x)

Expand Down Expand Up @@ -381,7 +378,7 @@ def init_input_output(self):
self.y = (
np.random.uniform(1, 3, [132]).astype(self.dtype).reshape([11, 12])
)
self.out = np.sum(self.x * self.y, axis=1).reshape([11, 1])
self.out = np.sum(self.x * self.y, axis=1)


@unittest.skipIf(
Expand Down Expand Up @@ -468,7 +465,7 @@ def init_input_output(self):
self.y = (
np.random.uniform(1, 3, [132]).astype(np.float32).reshape([11, 12])
)
self.out = np.sum(self.x * self.y, axis=1).reshape([11, 1])
self.out = np.sum(self.x * self.y, axis=1)

def test_check_grad_normal(self):
if core.is_compiled_with_cuda():
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/fluid/tests/unittests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,8 +1426,9 @@ def test_accuracy(self):
exe.run(fluid.default_startup_program())
# x = np.random.rand(3, 32, 32).astype("float32")
# y = np.array([[1], [0], [1]])

static_out = exe.run(
feed={"input": x, "label": y}, fetch_list=result[0]
feed={"input": x, "label": y}, fetch_list=result
)

with self.dynamic_graph(force_to_use_cpu=True):
Expand Down
8 changes: 6 additions & 2 deletions python/paddle/fluid/tests/unittests/test_nan_inf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,15 @@ def check_nan_inf(self):
assert (out + err).find(b'There are NAN or INF') != -1

def test_nan_inf_in_static_mode(self):
self._python_interp += " check_nan_inf_base.py"
self._python_interp += (
" " + os.path.dirname(__file__) + "/check_nan_inf_base.py"
)
self.check_nan_inf()

def test_nan_inf_in_dynamic_mode(self):
self._python_interp += " check_nan_inf_base_dygraph.py"
self._python_interp += (
" " + os.path.dirname(__file__) + "/check_nan_inf_base_dygraph.py"
)
self.check_nan_inf()


Expand Down
Loading

0 comments on commit f84ac44

Please sign in to comment.