-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FasterGeneration] MBart supports dy2sta (#3356)
- Loading branch information
Showing
6 changed files
with
264 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
147 changes: 147 additions & 0 deletions
147
paddlenlp/ops/faster_transformer/sample/mbart_export_model_sample.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import argparse | ||
import paddle | ||
from pprint import pprint | ||
from paddlenlp.transformers import MBartForConditionalGeneration, MBartTokenizer | ||
from paddlenlp.ops import FasterMBART | ||
from paddlenlp.utils.log import logger | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_name_or_path", | ||
default="mbart-large-50-many-to-many-mmt", | ||
type=str, | ||
help="The model name to specify the bart to use. ") | ||
parser.add_argument("--inference_model_dir", | ||
default="./infer_model/", | ||
type=str, | ||
help="Path to save inference model of bart. ") | ||
parser.add_argument( | ||
"--topk", | ||
default=4, | ||
type=int, | ||
help="The number of candidate to procedure top_k sampling. ") | ||
parser.add_argument( | ||
"--topp", | ||
default=1.0, | ||
type=float, | ||
help="The probability threshold to procedure top_p sampling. ") | ||
parser.add_argument("--max_out_len", | ||
default=64, | ||
type=int, | ||
help="Maximum output length. ") | ||
parser.add_argument("--temperature", | ||
default=1.0, | ||
type=float, | ||
help="The temperature to set. ") | ||
parser.add_argument("--num_return_sequences", | ||
default=1, | ||
type=int, | ||
help="The number of returned sequences. ") | ||
parser.add_argument("--use_fp16_decoding", | ||
action="store_true", | ||
help="Whether to use fp16 decoding to predict. ") | ||
parser.add_argument("--decoding_strategy", | ||
default="beam_search", | ||
choices=["sampling", "beam_search"], | ||
type=str, | ||
help="The main strategy to decode. ") | ||
parser.add_argument( | ||
"--num_beams", | ||
default=5, | ||
type=int, | ||
help="The number of candidate to procedure beam search. ") | ||
parser.add_argument("--diversity_rate", | ||
default=0.0, | ||
type=float, | ||
help="The diversity rate to procedure beam search. ") | ||
parser.add_argument("--repetition_penalty", | ||
default=1.0, | ||
type=float, | ||
help="The repetition_penalty to set. ") | ||
parser.add_argument("--length_penalty", | ||
default=0.0, | ||
type=float, | ||
help="The length penalty to decode. ") | ||
parser.add_argument("--early_stopping", | ||
action="store_true", | ||
help="Whether to do early stopping. ") | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def do_predict(args): | ||
place = "gpu" | ||
place = paddle.set_device(place) | ||
|
||
model = MBartForConditionalGeneration.from_pretrained( | ||
args.model_name_or_path, src_lang="en_XX") | ||
tokenizer = MBartTokenizer.from_pretrained(args.model_name_or_path) | ||
|
||
bos_id = tokenizer.lang_code_to_id["zh_CN"] | ||
eos_id = model.mbart.config["eos_token_id"] | ||
|
||
# For opening faster_encoder | ||
model.eval() | ||
|
||
faster_mbart = FasterMBART(model=model, | ||
use_fp16_decoding=args.use_fp16_decoding) | ||
# Set evaluate mode | ||
faster_mbart.eval() | ||
|
||
# Convert dygraph model to static graph model | ||
faster_mbart = paddle.jit.to_static( | ||
faster_mbart, | ||
input_spec=[ | ||
# input_ids | ||
paddle.static.InputSpec(shape=[None, None], dtype="int32"), | ||
# encoder_output | ||
None, | ||
# seq_len | ||
None, | ||
bos_id, # forced_bos_token_id | ||
args.num_beams, # num_beams. | ||
args.topk, # top_k | ||
args.topp, # top_p | ||
args.decoding_strategy, # decode_strategy | ||
tokenizer.bos_token_id, # bos_token_id | ||
tokenizer.eos_token_id, # eos_token_id | ||
tokenizer.pad_token_id, # pad_token_id | ||
model.mbart. | ||
config["decoder_start_token_id"], # decoder_start_token_id | ||
args.max_out_len, # max_length | ||
args.diversity_rate, # diversity_rate | ||
args.length_penalty, # length_penalty | ||
args.temperature, # temperature | ||
args.num_return_sequences, # num_return_sequences | ||
args.early_stopping, # early_stopping | ||
tokenizer.eos_token_id, #forced_eos_token_id | ||
]) | ||
|
||
# Save converted static graph model | ||
paddle.jit.save(faster_mbart, os.path.join(args.inference_model_dir, | ||
"mbart")) | ||
logger.info("MBART has been saved to {}.".format(args.inference_model_dir)) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
pprint(args) | ||
|
||
do_predict(args) |
97 changes: 97 additions & 0 deletions
97
paddlenlp/ops/faster_transformer/sample/mbart_inference.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import os | ||
import numpy as np | ||
from pprint import pprint | ||
|
||
import paddle | ||
import paddle.inference as paddle_infer | ||
|
||
from paddlenlp.transformers import MBartTokenizer | ||
from paddlenlp.ops.ext_utils import load | ||
|
||
|
||
def setup_args(): | ||
"""Setup arguments.""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--inference_model_dir", | ||
default="./infer_model/", | ||
type=str, | ||
help="Path to save inference model of BART. ") | ||
|
||
args = parser.parse_args() | ||
|
||
return args | ||
|
||
|
||
def postprocess_response(tokenizer, seq, bos_idx, eos_idx): | ||
"""Post-process the decoded sequence.""" | ||
eos_pos = len(seq) - 1 | ||
for i, idx in enumerate(seq): | ||
if idx == eos_idx: | ||
eos_pos = i | ||
break | ||
seq = [ | ||
idx for idx in seq[:eos_pos + 1] if idx != bos_idx and idx != eos_idx | ||
] | ||
res = tokenizer.convert_ids_to_string(seq) | ||
return res | ||
|
||
|
||
def infer(args): | ||
model_name = "mbart-large-50-many-to-many-mmt" | ||
tokenizer = MBartTokenizer.from_pretrained(model_name) | ||
|
||
bos_id = tokenizer.lang_code_to_id["zh_CN"] | ||
eos_id = tokenizer.eos_token_id | ||
|
||
inputs = "PaddleNLP is a powerful NLP library with Awesome pre-trained models and easy-to-use interface, supporting wide-range of NLP tasks from research to industrial applications." | ||
input_ids = tokenizer(inputs)["input_ids"] | ||
input_ids = np.asarray(input_ids, dtype="int32").reshape(1, -1) | ||
|
||
# Load FasterTransformer lib. | ||
load("FasterTransformer", verbose=True) | ||
|
||
config = paddle_infer.Config( | ||
os.path.join(args.inference_model_dir, "mbart.pdmodel"), | ||
os.path.join(args.inference_model_dir, "mbart.pdiparams")) | ||
|
||
config.enable_use_gpu(100, 0) | ||
config.disable_glog_info() | ||
predictor = paddle_infer.create_predictor(config) | ||
|
||
input_names = predictor.get_input_names() | ||
input_handle = predictor.get_input_handle(input_names[0]) | ||
input_handle.copy_from_cpu(input_ids.astype("int32")) | ||
|
||
predictor.run() | ||
|
||
output_names = predictor.get_output_names() | ||
output_handle = predictor.get_output_handle(output_names[0]) | ||
output_data = output_handle.copy_to_cpu() | ||
|
||
result = postprocess_response( | ||
tokenizer, | ||
output_data.transpose([1, 2, 0]).tolist()[0][0], bos_id, eos_id) | ||
print("Model input:", inputs) | ||
print("Result:", result) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = setup_args() | ||
pprint(args) | ||
|
||
infer(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters