Skip to content

Commit

Permalink
[FasterGeneration] MBart supports dy2sta (#3356)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrostML authored Sep 23, 2022
1 parent 85d7ac6 commit 62f55d0
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 11 deletions.
10 changes: 5 additions & 5 deletions faster_generation/samples/mbart_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
import paddle
from paddlenlp.transformers import MBartForConditionalGeneration, MBartTokenizer

model_name = "mbart-large-50-one-to-many-mmt"
model_name = "mbart-large-50-many-to-many-mmt"

tokenizer = MBartTokenizer.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name,
src_lang="en_XX")
tokenizer = MBartTokenizer.from_pretrained(model_name, src_lang="en_XX")
model = MBartForConditionalGeneration.from_pretrained(model_name)
model.eval()


Expand All @@ -41,7 +40,7 @@ def postprocess_response(seq, bos_idx, eos_idx):

inputs = "PaddleNLP is a powerful NLP library with Awesome pre-trained models and easy-to-use interface, supporting wide-range of NLP tasks from research to industrial applications."
input_ids = tokenizer(inputs)["input_ids"]
input_ids = paddle.to_tensor(input_ids, dtype='int64').unsqueeze(0)
input_ids = paddle.to_tensor(input_ids, dtype='int32').unsqueeze(0)

outputs, _ = model.generate(input_ids=input_ids,
forced_bos_token_id=bos_id,
Expand All @@ -53,5 +52,6 @@ def postprocess_response(seq, bos_idx, eos_idx):
result = postprocess_response(outputs[0].numpy().tolist(), bos_id, eos_id)

print("Model input:", inputs)

print("Result:", result)
# PaddleNLP是一个强大的NLP库,具有超乎寻常的预训练模型和易于使用的接口,支持从研究到工业应用的广泛的NLP任务。
147 changes: 147 additions & 0 deletions paddlenlp/ops/faster_transformer/sample/mbart_export_model_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import argparse
import paddle
from pprint import pprint
from paddlenlp.transformers import MBartForConditionalGeneration, MBartTokenizer
from paddlenlp.ops import FasterMBART
from paddlenlp.utils.log import logger


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path",
default="mbart-large-50-many-to-many-mmt",
type=str,
help="The model name to specify the bart to use. ")
parser.add_argument("--inference_model_dir",
default="./infer_model/",
type=str,
help="Path to save inference model of bart. ")
parser.add_argument(
"--topk",
default=4,
type=int,
help="The number of candidate to procedure top_k sampling. ")
parser.add_argument(
"--topp",
default=1.0,
type=float,
help="The probability threshold to procedure top_p sampling. ")
parser.add_argument("--max_out_len",
default=64,
type=int,
help="Maximum output length. ")
parser.add_argument("--temperature",
default=1.0,
type=float,
help="The temperature to set. ")
parser.add_argument("--num_return_sequences",
default=1,
type=int,
help="The number of returned sequences. ")
parser.add_argument("--use_fp16_decoding",
action="store_true",
help="Whether to use fp16 decoding to predict. ")
parser.add_argument("--decoding_strategy",
default="beam_search",
choices=["sampling", "beam_search"],
type=str,
help="The main strategy to decode. ")
parser.add_argument(
"--num_beams",
default=5,
type=int,
help="The number of candidate to procedure beam search. ")
parser.add_argument("--diversity_rate",
default=0.0,
type=float,
help="The diversity rate to procedure beam search. ")
parser.add_argument("--repetition_penalty",
default=1.0,
type=float,
help="The repetition_penalty to set. ")
parser.add_argument("--length_penalty",
default=0.0,
type=float,
help="The length penalty to decode. ")
parser.add_argument("--early_stopping",
action="store_true",
help="Whether to do early stopping. ")

args = parser.parse_args()
return args


def do_predict(args):
place = "gpu"
place = paddle.set_device(place)

model = MBartForConditionalGeneration.from_pretrained(
args.model_name_or_path, src_lang="en_XX")
tokenizer = MBartTokenizer.from_pretrained(args.model_name_or_path)

bos_id = tokenizer.lang_code_to_id["zh_CN"]
eos_id = model.mbart.config["eos_token_id"]

# For opening faster_encoder
model.eval()

faster_mbart = FasterMBART(model=model,
use_fp16_decoding=args.use_fp16_decoding)
# Set evaluate mode
faster_mbart.eval()

# Convert dygraph model to static graph model
faster_mbart = paddle.jit.to_static(
faster_mbart,
input_spec=[
# input_ids
paddle.static.InputSpec(shape=[None, None], dtype="int32"),
# encoder_output
None,
# seq_len
None,
bos_id, # forced_bos_token_id
args.num_beams, # num_beams.
args.topk, # top_k
args.topp, # top_p
args.decoding_strategy, # decode_strategy
tokenizer.bos_token_id, # bos_token_id
tokenizer.eos_token_id, # eos_token_id
tokenizer.pad_token_id, # pad_token_id
model.mbart.
config["decoder_start_token_id"], # decoder_start_token_id
args.max_out_len, # max_length
args.diversity_rate, # diversity_rate
args.length_penalty, # length_penalty
args.temperature, # temperature
args.num_return_sequences, # num_return_sequences
args.early_stopping, # early_stopping
tokenizer.eos_token_id, #forced_eos_token_id
])

# Save converted static graph model
paddle.jit.save(faster_mbart, os.path.join(args.inference_model_dir,
"mbart"))
logger.info("MBART has been saved to {}.".format(args.inference_model_dir))


if __name__ == "__main__":
args = parse_args()
pprint(args)

do_predict(args)
97 changes: 97 additions & 0 deletions paddlenlp/ops/faster_transformer/sample/mbart_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import numpy as np
from pprint import pprint

import paddle
import paddle.inference as paddle_infer

from paddlenlp.transformers import MBartTokenizer
from paddlenlp.ops.ext_utils import load


def setup_args():
"""Setup arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--inference_model_dir",
default="./infer_model/",
type=str,
help="Path to save inference model of BART. ")

args = parser.parse_args()

return args


def postprocess_response(tokenizer, seq, bos_idx, eos_idx):
"""Post-process the decoded sequence."""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [
idx for idx in seq[:eos_pos + 1] if idx != bos_idx and idx != eos_idx
]
res = tokenizer.convert_ids_to_string(seq)
return res


def infer(args):
model_name = "mbart-large-50-many-to-many-mmt"
tokenizer = MBartTokenizer.from_pretrained(model_name)

bos_id = tokenizer.lang_code_to_id["zh_CN"]
eos_id = tokenizer.eos_token_id

inputs = "PaddleNLP is a powerful NLP library with Awesome pre-trained models and easy-to-use interface, supporting wide-range of NLP tasks from research to industrial applications."
input_ids = tokenizer(inputs)["input_ids"]
input_ids = np.asarray(input_ids, dtype="int32").reshape(1, -1)

# Load FasterTransformer lib.
load("FasterTransformer", verbose=True)

config = paddle_infer.Config(
os.path.join(args.inference_model_dir, "mbart.pdmodel"),
os.path.join(args.inference_model_dir, "mbart.pdiparams"))

config.enable_use_gpu(100, 0)
config.disable_glog_info()
predictor = paddle_infer.create_predictor(config)

input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
input_handle.copy_from_cpu(input_ids.astype("int32"))

predictor.run()

output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
output_data = output_handle.copy_to_cpu()

result = postprocess_response(
tokenizer,
output_data.transpose([1, 2, 0]).tolist()[0][0], bos_id, eos_id)
print("Model input:", inputs)
print("Result:", result)


if __name__ == "__main__":
args = setup_args()
pprint(args)

infer(args)
3 changes: 2 additions & 1 deletion paddlenlp/ops/faster_transformer/transformer/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2515,7 +2515,8 @@ def __init__(self,
self.pos_emb = [model.decoder.decoder_embed_positions.weight]
self.word_emb = [model.decoder.embed_tokens.weight]

self.linear_weight = [model.lm_head_weight.t()]
setattr(self, "lm_head_weight_", model.lm_head_weight.t())
self.linear_weight = [getattr(self, "lm_head_weight_")]
self.linear_bias = [model.final_logits_bias]

def forward(self,
Expand Down
14 changes: 11 additions & 3 deletions paddlenlp/ops/faster_transformer/transformer/faster_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,8 +1379,13 @@ def forward(self,


class FasterMBART(MBartPretrainedModel):
enable_faster_encoder_func = enable_faster_encoder

def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
def __init__(self,
model,
decoding_lib=None,
use_fp16_decoding=False,
enable_faster_encoder=False):
super(FasterMBART, self).__init__()
self.use_fp16_decoding = use_fp16_decoding
self._model = model
Expand All @@ -1393,13 +1398,18 @@ def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
self.encoder = model.mbart.get_encoder()
self.decoder = model.mbart.get_decoder()
self.pad_token_id = model.mbart.config['pad_token_id']
self.enable_faster_encoder = enable_faster_encoder

self.decoding = InferMBartDecoding(
model=self._model,
decoding_lib=decoding_lib,
use_fp16_decoding=use_fp16_decoding,
hidden_act=model.mbart.config['activation_function'])

if self.enable_faster_encoder:
# Must use `enable_faster_encoder` in `__init__` when dygraph to static graph.
self.encoder = FasterMBART.enable_faster_encoder_func(self.encoder)

def get_encoder(self):
return self.encoder

Expand Down Expand Up @@ -1439,11 +1449,9 @@ def forward(self,

#(gongenlei) Not enable_faster_encoder temporarily
if encoder_output is None:
self.encoder = enable_faster_encoder(self.encoder)
assert input_ids is not None, "You have to specify either input_ids or encoder_output."
encoder_output = self.prepare_encoder_decoder_kwargs_for_generation(
input_ids, model_kwargs)["encoder_output"]
self.encoder = disable_faster_encoder(self.encoder)
batch_size = paddle.shape(encoder_output)[0]
if seq_len is None:
assert input_ids is not None, "You have to specify either input_ids when generating seq_len."
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/transformers/mbart/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def forward(self, input_ids_shape, past_key_values_length=0):
positions = paddle.arange(past_key_values_length,
past_key_values_length + seq_len,
dtype="int64")
return super().forward(positions + self.offset)
return Embedding.forward(self, positions + self.offset)


class MBartEncoder(MBartPretrainedModel):
Expand Down Expand Up @@ -270,7 +270,7 @@ def forward(self, input_ids=None, attention_mask=None, **kwargs):
if input_ids is None:
raise ValueError("Input_ids cannot be None.")
inputs_embeds = self.d_model**0.5 * self.embed_tokens(input_ids)
inputs_embed_pos = self.encoder_embed_positions(input_ids.shape)
inputs_embed_pos = self.encoder_embed_positions(paddle.shape(input_ids))
hidden_states = inputs_embeds + inputs_embed_pos
hidden_states = self.encoder_layernorm_embedding(hidden_states)
encoder_input = self.encoder_dropout(hidden_states)
Expand Down

0 comments on commit 62f55d0

Please sign in to comment.