Skip to content

Commit

Permalink
[PIR]Migrate reindex_heter_graph to pir (#60197)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f authored Dec 21, 2023
1 parent 95d7605 commit 308cfed
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/paddle/geometric/reindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from paddle.base.data_feeder import check_variable_and_dtype
from paddle.base.framework import Variable
from paddle.base.layer_helper import LayerHelper
from paddle.framework import in_dynamic_mode
from paddle.framework import in_dynamic_mode, in_dynamic_or_pir_mode

__all__ = []

Expand Down Expand Up @@ -212,7 +212,7 @@ def reindex_heter_graph(
True if value_buffer is not None and index_buffer is not None else False
)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
neighbors = paddle.concat(neighbors, axis=0)
count = paddle.concat(count, axis=0)
reindex_src, reindex_dst, out_nodes = _C_ops.reindex_graph(
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_graph_reindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np

import paddle
from paddle.pir_utils import test_with_pir_api


class TestGraphReindex(unittest.TestCase):
Expand Down Expand Up @@ -448,6 +449,7 @@ def test_reindex_result_static(self):
)
np.testing.assert_allclose(self.out_nodes, out_nodes_2, rtol=1e-05)

@test_with_pir_api
def test_heter_reindex_result_static(self):
paddle.enable_static()
np_x = np.arange(5).astype("int64")
Expand Down

0 comments on commit 308cfed

Please sign in to comment.