Skip to content

Commit

Permalink
【PIR API adaptor No.193-196】Migrate paddle.geometric.segment_max, pad…
Browse files Browse the repository at this point in the history
…dle.geometric.segment_mean, paddle.geometric.segment_min, paddle.geometric.segment_sum into pir (PaddlePaddle#58579)
  • Loading branch information
enkilee authored and SecretXV committed Nov 28, 2023
1 parent 1c7d20a commit f408c94
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 20 deletions.
10 changes: 5 additions & 5 deletions python/paddle/geometric/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from paddle import _C_ops
from paddle.base.data_feeder import check_variable_and_dtype
from paddle.base.layer_helper import LayerHelper
from paddle.framework import in_dynamic_mode
from paddle.framework import in_dynamic_or_pir_mode

__all__ = []

Expand Down Expand Up @@ -52,7 +52,7 @@ def segment_sum(data, segment_ids, name=None):
[4. 5. 6.]]
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.segment_pool(data, segment_ids, "SUM")
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -111,7 +111,7 @@ def segment_mean(data, segment_ids, name=None):
"""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.segment_pool(data, segment_ids, "MEAN")
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -169,7 +169,7 @@ def segment_min(data, segment_ids, name=None):
"""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.segment_pool(data, segment_ids, "MIN")
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -227,7 +227,7 @@ def segment_max(data, segment_ids, name=None):
"""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.segment_pool(data, segment_ids, "MAX")
else:
check_variable_and_dtype(
Expand Down
10 changes: 5 additions & 5 deletions python/paddle/incubate/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from paddle import _C_ops
from paddle.base.data_feeder import check_variable_and_dtype
from paddle.base.layer_helper import LayerHelper
from paddle.framework import in_dynamic_mode
from paddle.framework import in_dynamic_or_pir_mode
from paddle.utils import deprecated

__all__ = []
Expand Down Expand Up @@ -66,7 +66,7 @@ def segment_sum(data, segment_ids, name=None):
[4., 5., 6.]])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.segment_pool(data, segment_ids, "SUM")
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -135,7 +135,7 @@ def segment_mean(data, segment_ids, name=None):
"""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.segment_pool(data, segment_ids, "MEAN")

check_variable_and_dtype(
Expand Down Expand Up @@ -203,7 +203,7 @@ def segment_min(data, segment_ids, name=None):
"""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.segment_pool(data, segment_ids, "MIN")

check_variable_and_dtype(
Expand Down Expand Up @@ -271,7 +271,7 @@ def segment_max(data, segment_ids, name=None):
"""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
out = _C_ops.segment_pool(data, segment_ids, "MAX")
return out

Expand Down
31 changes: 21 additions & 10 deletions test/legacy_test/test_segment_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


def compute_segment_sum(x, segment_ids):
Expand Down Expand Up @@ -123,7 +124,7 @@ def setUp(self):
self.convert_bf16()

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(["X"], "Out")
Expand Down Expand Up @@ -165,7 +166,9 @@ def prepare(self):
self.attrs = {'pooltype': "MAX"}

def test_check_grad(self):
self.check_grad(["X"], "Out", user_defined_grads=[self.gradient])
self.check_grad(
["X"], "Out", user_defined_grads=[self.gradient], check_pir=True
)


class TestSegmentMax2(TestSegmentMax):
Expand Down Expand Up @@ -220,11 +223,11 @@ def setUp(self):

def test_check_output(self):
if core.is_compiled_with_cuda():
self.check_output_with_place(core.CUDAPlace(0))
self.check_output_with_place(core.CUDAPlace(0), check_pir=True)
# due to CPU kernel not implement calculate 'SummedIds'
# so cannot check 'SummedIds'
del self.outputs['SummedIds']
self.check_output_with_place(core.CPUPlace())
self.check_output_with_place(core.CPUPlace(), check_pir=True)


class TestSegmentMean2(TestSegmentMean):
Expand Down Expand Up @@ -271,7 +274,7 @@ def prepare(self):
self.np_dtype = np.float32

def test_check_output(self):
self.check_output_with_place(self.place)
self.check_output_with_place(self.place, check_pir=True)

def test_check_grad(self):
self.check_grad_with_place(self.place, ["X"], "Out")
Expand All @@ -289,11 +292,14 @@ def prepare(self):
self.np_dtype = np.float32

def test_check_output(self):
self.check_output_with_place(self.place)
self.check_output_with_place(self.place, check_pir=True)

def test_check_grad(self):
self.check_grad_with_place(
self.place, ["X"], "Out", user_defined_grads=[self.gradient]
self.place,
["X"],
"Out",
user_defined_grads=[self.gradient],
)


Expand All @@ -309,11 +315,14 @@ def prepare(self):
self.np_dtype = np.float32

def test_check_output(self):
self.check_output_with_place(self.place)
self.check_output_with_place(self.place, check_pir=True)

def test_check_grad(self):
self.check_grad_with_place(
self.place, ["X"], "Out", user_defined_grads=[self.gradient]
self.place,
["X"],
"Out",
user_defined_grads=[self.gradient],
)


Expand All @@ -329,13 +338,14 @@ def prepare(self):
self.np_dtype = np.float32

def test_check_output(self):
self.check_output_with_place(self.place)
self.check_output_with_place(self.place, check_pir=True)

def test_check_grad(self):
self.check_grad_with_place(self.place, ["X"], "Out")


class API_SegmentOpsTest(unittest.TestCase):
@test_with_pir_api
def test_static(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(name="x", shape=[3, 3], dtype="float32")
Expand Down Expand Up @@ -389,6 +399,7 @@ def test_dygraph(self):


class API_GeometricSegmentOpsTest(unittest.TestCase):
@test_with_pir_api
def test_static(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(name="x", shape=[3, 3], dtype="float32")
Expand Down

0 comments on commit f408c94

Please sign in to comment.