Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Codegen output & loss #3465

Merged
merged 17 commits into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/code_generation/codegen/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def evaluate(model, data_loader, loss_fct):
model = model._layers if isinstance(model, paddle.DataParallel) else model
for batch in data_loader:
labels = batch.pop("labels")
logits, _ = model(**batch)
logits = model(**batch)[0]
loss = loss_fct(logits[:, :-1, :], labels[:, 1:])
correct = metric.compute(paddle.max(logits[:, :-1, :], axis=-1),
labels[:, 1:])
Expand Down Expand Up @@ -318,7 +318,7 @@ def do_train(args):
with paddle.amp.auto_cast(
args.use_amp,
custom_white_list=["layer_norm", "softmax", "gelu"]):
logits, _ = model(**batch)
logits = model(**batch)[0]
loss = loss_fct(logits[:, :-1, :], labels[:, 1:])
if args.use_amp:
scaled_loss = scaler.scale(loss)
Expand Down
128 changes: 108 additions & 20 deletions paddlenlp/transformers/codegen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from ..nezha.modeling import ACT2FN
from .. import PretrainedModel, register_base_model
from ..model_outputs import (BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions)

CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
"Salesforce/codegen-350M-nl",
Expand Down Expand Up @@ -154,6 +156,7 @@ def forward(
attention_mask=None,
use_cache=False,
cache=None,
output_attentions=False,
):
qkv = self.qkv_proj(hidden_states)
mp_num = 4
Expand Down Expand Up @@ -225,6 +228,8 @@ def forward(
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

if output_attentions:
return attn_output, present, attn_weights
return attn_output, present


Expand Down Expand Up @@ -265,13 +270,15 @@ def forward(
attention_mask=None,
use_cache=False,
cache=None,
output_attentions=False,
):
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(hidden_states,
attention_mask=attention_mask,
cache=cache,
use_cache=use_cache)
use_cache=use_cache,
output_attentions=output_attentions)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]

Expand All @@ -283,7 +290,7 @@ def forward(
else:
outputs = (hidden_states, ) + outputs[1:]

return outputs # hidden_states, present, (attentions)
return outputs # hidden_states, (present, attentions) outputs is a tuple


class CodeGenPreTrainedModel(PretrainedModel):
Expand Down Expand Up @@ -427,6 +434,9 @@ def forward(
token_type_ids=None,
use_cache=False,
cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False,
):
r'''
The CodeGenModel forward method, overrides the `__call__()` special method.
Expand Down Expand Up @@ -454,8 +464,22 @@ def forward(
See `TransformerDecoder.gen_cache <https://github.com/PaddlePaddle/Paddle/blob/release/2.1/python/paddle/nn/layer/transformer.py#L1060>`__ for more details.
It is only used for inference and should be None for training.
Default to `None`.
output_attentions (bool, optional):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail. Defaults to `False`.
output_hidden_states (bool, optional):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. Defaults to `False`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions` object.
If `False`, the output will be a tuple of tensors. Defaults to `False`.
Returns:
Tensor: Returns tensor `decoder_output`, which is the output at the last layer of the model.
An instance of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions` if
`return_dict=True`. Otherwise it returns a tuple of tensors corresponding
to ordered and not None (depending on the input arguments) fields of
:class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions`.
Especially, When `return_dict=output_hidden_states=output_attentions=False` and `cache=None`,
returns a tensor representing the output of :class:`UnifiedTransformerModel`.
Its data type should be float32 and has a shape of [batch_size, sequence_length, hidden_size].
Example:
.. code-block::
Expand Down Expand Up @@ -516,23 +540,48 @@ def forward(
output_shape = input_shape[:] + [hidden_states.shape[-1]]

presents = () if use_cache else None
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, (block, old_cache) in enumerate(zip(self.h, cache)):
if output_hidden_states:
all_hidden_states += (hidden_states, )
outputs = block(hidden_states,
attention_mask=attention_mask,
use_cache=use_cache,
cache=old_cache)
cache=old_cache,
output_attentions=output_attentions)

hidden_states = outputs[0]
if use_cache:
presents = presents + (outputs[1], )
if output_attentions:
all_self_attentions += (outputs[-1], )

hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.reshape(shape=output_shape)

if output_hidden_states:
all_hidden_states += (hidden_states, )

last_hidden_state = hidden_states
new_cache = presents

return last_hidden_state, new_cache
if not return_dict:
temp_list = [
last_hidden_state,
new_cache,
all_hidden_states,
all_self_attentions,
]
return tuple(v for v in temp_list if v is not None)

return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=last_hidden_state,
past_key_values=new_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=None,
)


class CodeGenForCausalLM(CodeGenPreTrainedModel):
Expand Down Expand Up @@ -615,7 +664,11 @@ def forward(self,
attention_mask=None,
token_type_ids=None,
use_cache=False,
cache=None):
cache=None,
labels=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False):
r"""
The CodeGenForCausalLM forward method, overrides the __call__() special method.
Args:
Expand All @@ -627,14 +680,24 @@ def forward(self,
See :class:`CodeGenModel`.
cache (Tensor, optional):
See :class:`CodeGenModel`.
labels: (Tensor, optional):
Labels for language modeling. Note that the labels are shifted inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., vocab_size]`
output_attentions (bool, optional):
See :class: `CodeGenModel`
output_hidden_states (bool, optional):
See :class: `CodeGenModel`
return_dict (bool, optional):
See :class: `CodeGenModel`
Returns:
Tensor or tuple: Returns Tensor `lm_logits` if `use_cache` is `False`, otherwise, returns tuple (`lm_logits`, `cache`).
With the fields:
- `lm_logits` (Tensor):
The generated sentence of the model.
Its data type should be float32 and has a shape of [batch_size, sequence_length, vocab_size].
- `cache` (Tensor):
See :class:`CodeGenModel`.
An instance of :class:`~paddlenlp.transformers.model_outputs.CausalLMOutputWithPastAndCrossAttentions` if
`return_dict=True`. Otherwise it returns a tuple of tensors corresponding
to ordered and not None (depending on the input arguments) fields of
:class:`~paddlenlp.transformers.model_outputs.CausalLMOutputWithPastAndCrossAttentions`.
Especially, When `return_dict=output_hidden_states=output_attentions=False` and `cache=labels=None`,
returns tensor `lm_logits` of shape [batch_size, sequence_length, vocab_size],

Example:
.. code-block::
import paddle
Expand All @@ -646,21 +709,46 @@ def forward(self,
outputs = model(**inputs)
"""

transformer_outputs = self.transformer(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
use_cache=use_cache,
cache=cache)
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
use_cache=use_cache,
cache=cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

hidden_states = transformer_outputs[0]

# make sure sampling in fp16 works correctly and
# compute loss in fp32 to match with mesh-tf version
# https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
lm_logits = paddle.cast(self.lm_head(hidden_states), "float32")
past_key_values = transformer_outputs[1]

return lm_logits, past_key_values
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[:, :-1, :]
shift_labels = labels[:, 1:]
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.reshape((-1, shift_logits.shape[-1])),
shift_labels.reshape((-1, )))

if not return_dict:
# if isinstance(transformer_outputs, type(input_ids)):
# return (loss, lm_logits) if loss is not None else lm_logits
outputs = (lm_logits, ) + transformer_outputs[1:]
return ((loss, ) + outputs) if loss is not None else outputs

return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

def __getattr__(self, name):
try:
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,8 @@ def update_model_kwargs_for_generation(outputs,
# method.

# update cache
if isinstance(outputs,
tuple) and not isinstance(outputs[1], paddle.Tensor):
if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(
outputs[1], paddle.Tensor):
model_kwargs["cache"] = outputs[1]

# update token_type_ids with last value
Expand Down
Loading