Skip to content

Commit

Permalink
【Fix PIR JIT SaveLoad Unittest No.18】modify test_bert.py (PaddlePaddl…
Browse files Browse the repository at this point in the history
…e#64151)

* modify test_bert.py

* ci bug
  • Loading branch information
xiaoguoguo626807 authored and co63oc committed May 11, 2024
1 parent 06cf64e commit 378621a
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 13 deletions.
4 changes: 3 additions & 1 deletion python/paddle/static/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,9 @@ def load_vars(
"""
if in_pir_mode():
return load_vars_pir(dirname, main_program, vars, predicate, filename)
return load_vars_pir(
executor, dirname, main_program, vars, predicate, filename
)

vars_from_memory = False
if dirname is not None:
Expand Down
23 changes: 20 additions & 3 deletions python/paddle/static/pir_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ def get_pir_parameters(program):
return params, opts


def get_pir_feed_names(program):
feed_name_list = []
for op in program.global_block().ops:
if op.name() == "pd_op.data" or op.name() == "pd_op.feed":
feed_name_list.append(op.attrs()["name"])
return feed_name_list


def set_var(name, ndarray):
t = global_scope().find_var(name).get_tensor()
p = t._place()
Expand Down Expand Up @@ -352,6 +360,7 @@ def save_vars_pir(


def load_vars_pir(
executor,
dirname,
main_program=None,
vars=None,
Expand All @@ -374,6 +383,7 @@ def load_vars_pir(
use `filename` to specify it.
Args:
executor(Executor): The executor to create variables in scope.
dirname(str): The folder where to load the variables.
main_program(Program, optional): The program whose variables will be loaded.
If it is None, the default main program will
Expand All @@ -391,6 +401,7 @@ def load_vars_pir(
Returns:
None
"""
assert executor is None or isinstance(executor, Executor)

vars_from_memory = False
if dirname is not None:
Expand All @@ -407,6 +418,7 @@ def load_vars_pir(
param, opt = get_pir_parameters(main_program)
vars_list = param + opt
load_vars_pir(
executor,
dirname=dirname,
main_program=main_program,
vars=vars_list, # list(filter(predicate, vars_list)),
Expand All @@ -418,7 +430,9 @@ def load_vars_pir(

# TODO(chenzhiyang):save origin param shape, check vars
load_var_map = {}

paddle.base.libpaddle.pir.create_loaded_parameter(
vars, global_scope(), executor._default_executor
)
for v in vars:
var = global_scope().find_var(v.name)
assert isinstance(var, paddle.base.libpaddle.Variable)
Expand Down Expand Up @@ -768,6 +782,7 @@ def load_pir_inference_model(path_prefix, executor, **kwargs):
if len(params + opts) > 0:
load_vars_pir(
# load from memory, dirname is None
executor,
dirname=None,
main_program=program,
# predicate=persistable,
Expand Down Expand Up @@ -815,18 +830,20 @@ def load_pir_inference_model(path_prefix, executor, **kwargs):
# deserialize bytes to program
program = paddle.static.Program()
paddle.base.core.deserialize_pir_program(model_path, program, 1)

# load parameters
params, opts = get_pir_parameters(program)
if len(params + opts) > 0:
load_dirname = os.path.dirname(params_path)
params_filename = os.path.basename(params_path)

load_vars_pir(
executor,
dirname=load_dirname,
main_program=program,
# predicate=persistable,
filename=params_filename,
)

return [program, [], []]
feed_names = get_pir_feed_names(program)
# pir load program has fetch op, so if use exe.run to execute load program, don't need to set fetch_list
return [program, feed_names, []]
40 changes: 31 additions & 9 deletions test/dygraph_to_static/test_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from paddle.base import core
from paddle.base.framework import unique_name
from paddle.framework import use_pir_api
from paddle.jit.pir_translated_layer import PIR_INFER_MODEL_SUFFIX
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX

place = (
Expand Down Expand Up @@ -91,6 +92,7 @@ def setUp(self):
self.model_save_dir = os.path.join(self.temp_dir.name, 'inference')
self.model_save_prefix = os.path.join(self.model_save_dir, 'bert')
self.model_filename = 'bert' + INFER_MODEL_SUFFIX
self.pir_model_filename = 'bert' + PIR_INFER_MODEL_SUFFIX
self.params_filename = 'bert' + INFER_PARAMS_SUFFIX
self.dy_state_dict_save_path = os.path.join(
self.temp_dir.name, 'bert.dygraph'
Expand Down Expand Up @@ -162,9 +164,7 @@ def train(self, bert_config, data_reader, to_static):
step_idx += 1
if step_idx == STEP_NUM:
if to_static:
# TODO(pir-save-load): Fix this after we support save/load in PIR
if not use_pir_api():
paddle.jit.save(bert, self.model_save_prefix)
paddle.jit.save(bert, self.model_save_prefix)
else:
paddle.save(
bert.state_dict(),
Expand All @@ -183,6 +183,11 @@ def train_static(self, bert_config, data_reader):
def predict_static(self, data):
paddle.enable_static()
exe = base.Executor(place)
if use_pir_api():
model_filename = self.pir_model_filename
else:
model_filename = self.model_filename

# load inference model
[
inference_program,
Expand All @@ -191,7 +196,7 @@ def predict_static(self, data):
] = paddle.static.io.load_inference_model(
self.model_save_dir,
executor=exe,
model_filename=self.model_filename,
model_filename=model_filename,
params_filename=self.params_filename,
)
pred_res = exe.run(
Expand Down Expand Up @@ -304,10 +309,27 @@ def test_train_composite(self):
def verify_predict(self):
for data in self.data_reader.data_generator()():
dygraph_pred_res = self.predict_dygraph(self.bert_config, data)
# TODO(pir-save-load): Fix this after we support save/load in PIR
if not use_pir_api():
static_pred_res = self.predict_static(data)
dygraph_jit_pred_res = self.predict_dygraph_jit(data)
static_pred_res = self.predict_static(data)
dygraph_jit_pred_res = self.predict_dygraph_jit(data)
if use_pir_api():
for dy_res, st_res, dy_jit_res in zip(
dygraph_pred_res,
static_pred_res,
dygraph_jit_pred_res,
):
np.testing.assert_allclose(
st_res,
dy_res,
rtol=1e-04,
err_msg=f'dygraph_res: {dy_res[~np.isclose(st_res, dy_res)]},\n static_res: {st_res[~np.isclose(st_res, dy_res)]}',
)
np.testing.assert_allclose(
st_res,
dy_jit_res,
rtol=1e-04,
err_msg=f'dygraph_jit_res: {dy_jit_res[~np.isclose(st_res, dy_jit_res)]},\n static_res: {st_res[~np.isclose(st_res, dy_jit_res)]}',
)
else:
predictor_pred_res = self.predict_analysis_inference(data)

for dy_res, st_res, dy_jit_res, predictor_res in zip(
Expand All @@ -332,7 +354,7 @@ def verify_predict(self):
st_res,
predictor_res,
rtol=1e-05,
err_msg=f'dygraph_jit_res: {predictor_res[~np.isclose(st_res, predictor_res)]},\n static_res: {st_res[~np.isclose(st_res, predictor_res)]}',
err_msg=f'dygraph_jit_res_predictor: {predictor_res[~np.isclose(st_res, predictor_res)]},\n static_res: {st_res[~np.isclose(st_res, predictor_res)]}',
)
break

Expand Down

0 comments on commit 378621a

Please sign in to comment.