Skip to content

Commit

Permalink
update the location of if not paddle.framework.use_pir_api()
Browse files Browse the repository at this point in the history
  • Loading branch information
Fripping committed Jul 23, 2024
1 parent efe01cd commit 041fdce
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from op_test import OpTest

import paddle
from paddle import base
from paddle.base import core
from paddle.nn import functional
from paddle.pir_utils import test_with_pir_api


class TestOneHotOp(OpTest):
def setUp(self):
self.op_type = 'one_hot_v2'
depth = 10
depth_np = np.array(10).astype('int32')
dimension = 12
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0])])

out = np.zeros(shape=(np.prod(x.shape), depth)).astype('float32')

for i in range(np.prod(x.shape)):
out[i, x[i]] = 1.0

self.inputs = {'X': (x, x_lod), 'depth_tensor': depth_np}
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': (out, x_lod)}

def test_check_output(self):
self.check_output(check_dygraph=False)


class TestOneHotOp_attr(OpTest):
def setUp(self):
self.op_type = 'one_hot_v2'
depth = 10
dimension = 12
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])

out = np.zeros(shape=(np.prod(x.shape[:-1]), 1, depth)).astype(
'float32'
)

for i in range(np.prod(x.shape)):
out[i, 0, x[i]] = 1.0

self.inputs = {'X': (x, x_lod)}
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32), 'depth': depth}
self.outputs = {'Out': (out, x_lod)}

def test_check_output(self):
self.check_output(check_dygraph=False)


class TestOneHotOp_default_dtype(OpTest):
def setUp(self):
self.op_type = 'one_hot_v2'
depth = 10
depth_np = np.array(10).astype('int32')
dimension = 12
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0])])

out = np.zeros(shape=(np.prod(x.shape), depth)).astype('float32')

for i in range(np.prod(x.shape)):
out[i, x[i]] = 1.0

self.inputs = {'X': (x, x_lod), 'depth_tensor': depth_np}
self.attrs = {}
self.outputs = {'Out': (out, x_lod)}

def test_check_output(self):
self.check_output(check_dygraph=False)


class TestOneHotOp_default_dtype_attr(OpTest):
def setUp(self):
self.op_type = 'one_hot_v2'
depth = 10
dimension = 12
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])

out = np.zeros(shape=(np.prod(x.shape[:-1]), 1, depth)).astype(
'float32'
)

for i in range(np.prod(x.shape)):
out[i, 0, x[i]] = 1.0

self.inputs = {'X': (x, x_lod)}
self.attrs = {'depth': depth}
self.outputs = {'Out': (out, x_lod)}

def test_check_output(self):
self.check_output(check_dygraph=False)


class TestOneHotOpApi(unittest.TestCase):
@test_with_pir_api
def test_api(self):
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
num_classes = 10
label = paddle.static.data(
name="label", shape=[-1, 1], dtype="int64"
)
one_hot_label = functional.one_hot(x=label, num_classes=num_classes)

place = base.CPUPlace()
label_data = np.array(
[np.random.randint(0, 10 - 1) for i in range(6)]
).reshape([6, 1])
label_data = label_data.astype('int64')

exe = base.Executor(place)
exe.run(startup)
ret = exe.run(
feed={
'label': label_data,
},
fetch_list=[one_hot_label],
return_numpy=False,
)

@test_with_pir_api
def test_api_with_depthTensor(self):
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
num_classes = paddle.assign(np.array([10], dtype=np.int32))
label = paddle.static.data(
name="label", shape=[-1, 1], dtype="int64"
)
one_hot_label = functional.one_hot(x=label, num_classes=num_classes)

place = base.CPUPlace()
label_data = np.array(
[np.random.randint(0, 10 - 1) for i in range(6)]
).reshape([6, 1])
label_data = label_data.astype('int64')

exe = base.Executor(place)
exe.run(startup)
ret = exe.run(
feed={
'label': label_data,
},
fetch_list=[one_hot_label],
return_numpy=False,
)

def test_api_with_dygraph(self):
num_classes = 10
label = np.array(
[np.random.randint(0, num_classes - 1) for i in range(6)]
).reshape([6, 1])
with base.dygraph.guard():
one_hot_label = functional.one_hot(
x=paddle.to_tensor(label), num_classes=num_classes
)


class BadInputTestOnehotV2(unittest.TestCase):
def test_error(self):
with base.program_guard(base.Program()):

def test_bad_x():
label = paddle.static.data(
name="label",
shape=[4],
dtype="float32",
)

if not paddle.framework.use_pir_api():
label.desc.set_need_check_feed(False)
one_hot_label = functional.one_hot(x=label, num_classes=4)

self.assertRaises(TypeError, test_bad_x)


if __name__ == '__main__':
paddle.enable_static()
unittest.main()
20 changes: 10 additions & 10 deletions test/legacy_test/test_nn_functional_hot_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,20 +186,20 @@ def test_api_with_dygraph(self):

class BadInputTestOnehotV2(unittest.TestCase):
def test_error(self):
if not paddle.framework.use_pir_api():
with base.program_guard(base.Program()):
with base.program_guard(base.Program()):

def test_bad_x():
label = paddle.static.data(
name="label",
shape=[4],
dtype="float32",
)
def test_bad_x():
label = paddle.static.data(
name="label",
shape=[4],
dtype="float32",
)

if not paddle.framework.use_pir_api():
label.desc.set_need_check_feed(False)
one_hot_label = functional.one_hot(x=label, num_classes=4)
one_hot_label = functional.one_hot(x=label, num_classes=4)

self.assertRaises(TypeError, test_bad_x)
self.assertRaises(TypeError, test_bad_x)


if __name__ == '__main__':
Expand Down

0 comments on commit 041fdce

Please sign in to comment.