diff --git a/model_zoo/bert/run_pretrain.py b/model_zoo/bert/run_pretrain.py index b6cbe3a48972..947de25e25e0 100644 --- a/model_zoo/bert/run_pretrain.py +++ b/model_zoo/bert/run_pretrain.py @@ -398,6 +398,14 @@ def do_train(args): next_sentence_labels, masked_lm_scale, ) = batch + input_ids = input_ids.cuda(blocking=False) + segment_ids = segment_ids.cuda(blocking=False) + input_mask = input_mask.cuda(blocking=False) + masked_lm_positions = masked_lm_positions.cuda(blocking=False) + masked_lm_labels = masked_lm_labels.cuda(blocking=False) + next_sentence_labels = next_sentence_labels.cuda(blocking=False) + masked_lm_scale = masked_lm_scale.cuda(blocking=False) + with paddle.amp.auto_cast( args.use_amp, custom_white_list=["layer_norm", "softmax", "gelu", "fused_attention", "fused_feedforward"],