From e9a45718d3bb73a912ec633e357426d54ae42c91 Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Wed, 25 Oct 2023 01:09:44 +0800 Subject: [PATCH 1/2] pir leaky_relu & swish --- python/paddle/nn/functional/activation.py | 4 ++-- test/legacy_test/test_activation_op.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index c74748793a4e9..5c73c9380e2b1 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -484,7 +484,7 @@ def leaky_relu(x, negative_slope=0.01, name=None): [-0.02000000, 0. , 1. ]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.leaky_relu(x, negative_slope) else: check_variable_and_dtype( @@ -1448,7 +1448,7 @@ def swish(x, name=None): Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, [-0.23840584, 0. , 0.73105860]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.swish(x) else: check_variable_and_dtype( diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index d1da7d941a679..e49b6eba5b91e 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -2466,12 +2466,12 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_prim=True) + self.check_output(check_prim=True, check_pir=True) def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_pir=True) class TestLeakyReluAlpha1(TestLeakyRelu): @@ -2508,6 +2508,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -2538,6 +2539,7 @@ def test_dygraph_api(self): for r in [out1, out2]: np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + @test_with_pir_api def test_errors(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -4407,6 +4409,7 @@ def test_check_grad(self): self.check_grad( ['X'], 'Out', + check_pir=True, ) @@ -4426,6 +4429,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -4449,6 +4453,7 @@ def test_dygraph_api(self): for r in [out1, out2]: np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) + @test_with_pir_api def test_base_api(self): with static_guard(): with base.program_guard(base.Program()): @@ -4459,6 +4464,7 @@ def test_base_api(self): out_ref = ref_swish(self.x_np) np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) + @test_with_pir_api def test_errors(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): From 894d54d9442403c9cfc4e0fe9ec0b9490ba0a9a2 Mon Sep 17 00:00:00 2001 From: Ryan <44900829+DrRyanHuang@users.noreply.github.com> Date: Wed, 25 Oct 2023 15:08:40 +0000 Subject: [PATCH 2/2] remove test_errors --- test/legacy_test/test_activation_op.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index e49b6eba5b91e..df45d4d651136 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -2539,7 +2539,6 @@ def test_dygraph_api(self): for r in [out1, out2]: np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) - @test_with_pir_api def test_errors(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -4464,7 +4463,6 @@ def test_base_api(self): out_ref = ref_swish(self.x_np) np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) - @test_with_pir_api def test_errors(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()):