Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Fix PIR Unittest No.6】Fix some test case in PIR #66211

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 60 additions & 53 deletions test/deprecated/legacy_test/test_data_feeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,61 +44,68 @@ def test_lod_level_0_converter(self):
self.assertTrue(True)

def test_lod_level_1_converter(self):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
# lod_level = 1
# each sentence has a different number of words
sentences = paddle.static.data(
name='sentences', shape=[-1, 1], dtype='int64', lod_level=1
)
label = paddle.static.data(
name='label', shape=[-1, 1], dtype='int64'
)
feeder = base.DataFeeder([sentences, label], base.CPUPlace())

# lod = [[0, 3, 5, 9]]
# data = [[1, 2, 3], [4, 5], [6, 7, 8, 9]]
# label = [1] * len(data)
result = feeder.feed(
[([1, 2, 3], [1]), ([4, 5], [1]), ([6, 7, 8, 9], [1])]
)

self.assertEqual(result['sentences'].shape(), [9, 1])
self.assertEqual(result['label'].shape(), [3, 1])
self.assertEqual(
result['sentences'].recursive_sequence_lengths(), [[3, 2, 4]]
)
self.assertEqual(result['label'].recursive_sequence_lengths(), [])
with paddle.pir_utils.OldIrGuard():
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
# lod_level = 1
# each sentence has a different number of words
sentences = paddle.static.data(
name='sentences', shape=[-1, 1], dtype='int64', lod_level=1
)
label = paddle.static.data(
name='label', shape=[-1, 1], dtype='int64'
)
feeder = base.DataFeeder([sentences, label], base.CPUPlace())

# lod = [[0, 3, 5, 9]]
# data = [[1, 2, 3], [4, 5], [6, 7, 8, 9]]
# label = [1] * len(data)
result = feeder.feed(
[([1, 2, 3], [1]), ([4, 5], [1]), ([6, 7, 8, 9], [1])]
)

self.assertEqual(result['sentences'].shape(), [9, 1])
self.assertEqual(result['label'].shape(), [3, 1])
self.assertEqual(
result['sentences'].recursive_sequence_lengths(),
[[3, 2, 4]],
)
self.assertEqual(
result['label'].recursive_sequence_lengths(), []
)

def test_lod_level_2_converter(self):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
# lod_level = 2
# paragraphs -> sentences -> words
paragraphs = paddle.static.data(
name='paragraphs', shape=[-1, 1], dtype='int64', lod_level=2
)
label = paddle.static.data(
name='label', shape=[-1, 1], dtype='int64'
)
feeder = base.DataFeeder([paragraphs, label], base.CPUPlace())

# lod = [[0, 2, 3], [0, 3, 5, 9]]
# data = [[[1, 2, 3], [4, 5]], [[6, 7, 8, 9]]]
# label = [1] * len(data)
result = feeder.feed(
[([[1, 2, 3], [4, 5]], [1]), ([[6, 7, 8, 9]], [1])]
)

self.assertEqual(result['paragraphs'].shape(), [9, 1])
self.assertEqual(result['label'].shape(), [2, 1])
self.assertEqual(
result['paragraphs'].recursive_sequence_lengths(),
[[2, 1], [3, 2, 4]],
)
self.assertEqual(result['label'].recursive_sequence_lengths(), [])
with paddle.pir_utils.OldIrGuard():
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
# lod_level = 2
# paragraphs -> sentences -> words
paragraphs = paddle.static.data(
name='paragraphs', shape=[-1, 1], dtype='int64', lod_level=2
)
label = paddle.static.data(
name='label', shape=[-1, 1], dtype='int64'
)
feeder = base.DataFeeder([paragraphs, label], base.CPUPlace())

# lod = [[0, 2, 3], [0, 3, 5, 9]]
# data = [[[1, 2, 3], [4, 5]], [[6, 7, 8, 9]]]
# label = [1] * len(data)
result = feeder.feed(
[([[1, 2, 3], [4, 5]], [1]), ([[6, 7, 8, 9]], [1])]
)

self.assertEqual(result['paragraphs'].shape(), [9, 1])
self.assertEqual(result['label'].shape(), [2, 1])
self.assertEqual(
result['paragraphs'].recursive_sequence_lengths(),
[[2, 1], [3, 2, 4]],
)
self.assertEqual(
result['label'].recursive_sequence_lengths(), []
)

def test_errors(self):
def pir_mode_not_supported_str_feed():
Expand Down
6 changes: 4 additions & 2 deletions test/deprecated/legacy_test/test_dataset_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,12 +207,14 @@ def get_all_places(self):
def test_batch_number_with_same_length_files(self):
for p in self.get_all_places():
with base.scope_guard(base.Scope()):
self.check_batch_number(place=p, randomize_batch_num=False)
with paddle.pir_utils.OldIrGuard():
self.check_batch_number(place=p, randomize_batch_num=False)

def test_batch_number_with_different_length_files(self):
for p in self.get_all_places():
with base.scope_guard(base.Scope()):
self.check_batch_number(place=p, randomize_batch_num=True)
with paddle.pir_utils.OldIrGuard():
self.check_batch_number(place=p, randomize_batch_num=True)


class QueueDatasetTestWithoutDropLast(DatasetLoaderTestBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def test_simple_net(self):
if not core.is_compiled_with_rocm():
dtype_list.append("float64")
for dtype in dtype_list:
self.simple_net_float(is_sparse, dtype)
with paddle.pir_utils.OldIrGuard():
self.simple_net_float(is_sparse, dtype)

def simple_net_float(self, is_sparse, dtype):
places = [base.CPUPlace()]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def test_main(self):
with base.program_guard(base.Program(), base.Program()):
with base.unique_name.guard():
with base.scope_guard(base.Scope()):
self.main_impl(p)
with paddle.pir_utils.OldIrGuard():
self.main_impl(p)


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def test_main(self):
for iterable in [False, True]:
try:
with base.scope_guard(base.Scope()):
self.main_impl(p, iterable)
with paddle.pir_utils.OldIrGuard():
self.main_impl(p, iterable)

self.assertTrue(not self.raise_exception)
except ReaderException:
Expand Down
21 changes: 11 additions & 10 deletions test/deprecated/legacy_test/test_pass_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,18 @@ def test_parallel_testing_with_new_strategy(self):
pass_builder.remove_pass(len(pass_builder.all_passes()) - 1)
self.assertEqual(origin_len + 1, len(pass_builder.all_passes()))
with tempfile.TemporaryDirectory(prefix="dot_path_") as tmpdir:
graph_viz_path = os.path.join(tmpdir, 'test_viz_pass.dot')
viz_pass.set("graph_viz_path", graph_viz_path)
with paddle.pir_utils.OldIrGuard():
graph_viz_path = os.path.join(tmpdir, 'test_viz_pass.dot')
viz_pass.set("graph_viz_path", graph_viz_path)

self.check_network_convergence(
use_cuda=core.is_compiled_with_cuda(),
build_strategy=build_strategy,
)
try:
os.stat(graph_viz_path)
except OSError:
self.assertFalse(True)
self.check_network_convergence(
use_cuda=core.is_compiled_with_cuda(),
build_strategy=build_strategy,
)
try:
os.stat(graph_viz_path)
except OSError:
self.assertFalse(True)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions test/deprecated/legacy_test/test_psroi_pool_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def test_function_in_static(self):
out = paddle.vision.ops.psroi_pool(
self.x_placeholder,
self.boxes_placeholder,
self.boxes_num,
paddle.to_tensor(self.boxes_num),
output_size,
)
expect_out = calc_psroi_pool(
Expand All @@ -392,7 +392,7 @@ def test_function_in_static(self):
(out_res,) = exe.run(
paddle.static.default_main_program(),
feed={'x': self.x, 'boxes': boxes_lod_data},
fetch_list=[out.name],
fetch_list=[out],
)
np.testing.assert_allclose(out_res, expect_out, rtol=1e-05)

Expand Down
45 changes: 25 additions & 20 deletions test/deprecated/legacy_test/test_slice_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,29 +170,38 @@ def init_dtype(self):

class TestSliceScatterApiInt16(TestSliceScatterApi):
def init_dtype(self):
# old ir `set_value` not support this dtype
if paddle.framework.in_dynamic_or_pir_mode():
self.dtype = 'int16'
else:
self.dtype = 'float64'
self.dtype = 'int16'

@test_with_pir_api
def test_api_static(self):
pass

def test_api_dygraph(self):
super().test_api_dygraph()


class TestSliceScatterApiInt8(TestSliceScatterApi):
def init_dtype(self):
# old ir `set_value` not support this dtype
if paddle.framework.in_dynamic_or_pir_mode():
self.dtype = 'int8'
else:
self.dtype = 'float64'
self.dtype = 'int8'

@test_with_pir_api
def test_api_static(self):
pass

def test_api_dygraph(self):
super().test_api_dygraph()


class TestSliceScatterApiUint8(TestSliceScatterApi):
def init_dtype(self):
# old ir `set_value` not support this dtype
if paddle.framework.in_dynamic_or_pir_mode():
self.dtype = 'uint8'
else:
self.dtype = 'float64'
self.dtype = 'uint8'

@test_with_pir_api
def test_api_static(self):
pass

def test_api_dygraph(self):
super().test_api_dygraph()


class TestSliceScatterApiBool(TestSliceScatterApi):
Expand All @@ -202,11 +211,7 @@ def init_dtype(self):

class TestSliceScatterApiBfloat16(TestSliceScatterApi):
def init_dtype(self):
# old ir `set_value` not support this dtype
if paddle.framework.in_dynamic_or_pir_mode():
self.dtype = 'bfloat16'
else:
self.dtype = 'float64'
self.dtype = 'bfloat16'


class TestSliceScatterApiFloat16(TestSliceScatterApi):
Expand Down
77 changes: 39 additions & 38 deletions test/deprecated/legacy_test/test_tdm_child_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,44 +149,45 @@ def config(self):
class TestTDMChildShape(unittest.TestCase):
def test_shape(self):
with paddle_static_guard():
x = paddle.static.data(
name='x', shape=[-1, 1], dtype='int32', lod_level=1
)
tdm_tree_info = create_tdm_tree()
tree_info_np = np.array(tdm_tree_info).astype('int32')

child, leaf_mask = tdm_child(
x=x,
node_nums=26,
child_nums=2,
param_attr=base.ParamAttr(
initializer=paddle.nn.initializer.Assign(tree_info_np)
),
)

place = base.CPUPlace()
exe = base.Executor(place=place)
exe.run(base.default_startup_program())

feed = {
'x': np.array(
[
[1],
[2],
[3],
[4],
[5],
[6],
[7],
[8],
[9],
[10],
[11],
[12],
]
).astype('int32')
}
exe.run(feed=feed)
with paddle.pir_utils.OldIrGuard():
x = paddle.static.data(
name='x', shape=[-1, 1], dtype='int32', lod_level=1
)
tdm_tree_info = create_tdm_tree()
tree_info_np = np.array(tdm_tree_info).astype('int32')

child, leaf_mask = tdm_child(
x=x,
node_nums=26,
child_nums=2,
param_attr=base.ParamAttr(
initializer=paddle.nn.initializer.Assign(tree_info_np)
),
)

place = base.CPUPlace()
exe = base.Executor(place=place)
exe.run(base.default_startup_program())

feed = {
'x': np.array(
[
[1],
[2],
[3],
[4],
[5],
[6],
[7],
[8],
[9],
[10],
[11],
[12],
]
).astype('int32')
}
exe.run(feed=feed)


if __name__ == "__main__":
Expand Down