Skip to content

Commit

Permalink
add unit test case
Browse files Browse the repository at this point in the history
  • Loading branch information
Glencsa committed Feb 25, 2025
1 parent eb43f4b commit 88db8e9
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 3 deletions.
23 changes: 23 additions & 0 deletions test/cpp/auto_parallel/spmd_rule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,29 @@ TEST(ConcatRule, Ctor) {
}
check_dim_mapping(inferred_dist_attrs.second[0], {1, -1, 0});
check_partial_dims(inferred_dist_attrs.second[0], {});

// test 3,special case: concat one dimensional tensor
shapes = {{16}, {32}, {64}};
dim_mappings = {{0}, {1}, {-1}};
partial_status = {{}, {}, {1}};
inputs = build_inputs();
inferred_dist_attrs = phi::distributed::ConcatInferSpmd(inputs, 0);
// list of tensor => single tensor
EXPECT_EQ(inferred_dist_attrs.first.size(), static_cast<size_t>(1));
EXPECT_EQ(inferred_dist_attrs.second.size(), static_cast<size_t>(1));
EXPECT_TRUE(
paddle::holds_alternative<std::vector<phi::distributed::TensorDistAttr>>(
inferred_dist_attrs.first[0]));
EXPECT_TRUE(paddle::holds_alternative<phi::distributed::TensorDistAttr>(
inferred_dist_attrs.second[0]));
auto& inputs_infer3 = PADDLE_GET_CONST(std::vector<TensorDistAttr>,
inferred_dist_attrs.first[0]);
for (auto e : inputs_infer3) {
check_dim_mapping(e, {-1});
check_partial_dims(e, {});
}
check_dim_mapping(inferred_dist_attrs.second[0], {-1});
check_partial_dims(inferred_dist_attrs.second[0], {});
}

TEST(StackRule, Ctor) {
Expand Down
43 changes: 40 additions & 3 deletions test/legacy_test/test_put_along_axis_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
paddle.enable_static()


def put_along_axis_net(arr):
def put_along_axis_net(arr, axis=-1):
indices = paddle.to_tensor([[[[2]]]], dtype='int32', stop_gradient=False)
return paddle.tensor.put_along_axis(
arr, indices=indices, values=-4.0, axis=-2, reduce='add'
arr, indices=indices, values=-4.0, axis=axis, reduce='add'
)


Expand Down Expand Up @@ -1346,6 +1346,7 @@ def setUp(self):
self.enable_cinn = False
self.tol = 1e-6
self.dtype = "float32"
self.axis = -2
self.input_specs = [
InputSpec(
shape=(-1, -1, -1, -1),
Expand All @@ -1370,7 +1371,7 @@ def train(self, to_static):
else:
net = self.net

res = net(arr)
res = net(arr, self.axis)
res.backward()
arr_grad = arr.gradient()
return res, arr_grad
Expand All @@ -1389,6 +1390,42 @@ def test_dynamic_static(self):
np.testing.assert_allclose(dr, d, rtol=self.tol, atol=self.tol)


class TestPutAlongAxisDynamicShape1(TestPutAlongAxisDynamicShape):
def setUp(self):
np.random.seed(2024)
self.net = put_along_axis_net
self.enable_cinn = False
self.tol = 1e-6
self.dtype = "float32"
self.axis = 0
self.input_specs = [
InputSpec(
shape=(-1, -1, -1, -1),
dtype=self.dtype,
stop_gradient=False,
)
]
self.arr = np.random.random([16, 16, 16, 16]).astype(self.dtype)


class TestPutAlongAxisDynamicShape2(TestPutAlongAxisDynamicShape):
def setUp(self):
np.random.seed(2024)
self.net = put_along_axis_net
self.enable_cinn = False
self.tol = 1e-6
self.dtype = "float32"
self.axis = -1
self.input_specs = [
InputSpec(
shape=(-1, -1, -1, -1),
dtype=self.dtype,
stop_gradient=False,
)
]
self.arr = np.random.random([20, 20, 20, 20]).astype(self.dtype)


if __name__ == "__main__":
paddle.enable_static()
unittest.main()

0 comments on commit 88db8e9

Please sign in to comment.