Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When training with MambaLMHead with Huggingface Trainer, getting unexpected attention_mask #695

Open
srijiths opened this issue Feb 15, 2025 · 0 comments

Comments

@srijiths
Copy link

Any pointers on how to resolve this ? Thank You !!!

The following columns in the evaluation set don't have a corresponding argument in `MambaLMHeadModel.forward` and have been ignored: labels, __index_level_0__, conversion, token_type_ids, user_journey. If labels, __index_level_0__, conversion, token_type_ids, user_journey are not expected by `MambaLMHeadModel.forward`,  you can safely ignore this message.

***** Running Evaluation *****
  Num examples = 5645
  Batch size = 32
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-95-3435b262f1ae> in <cell line: 1>()
----> 1 trainer.train()

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2162                 hf_hub_utils.enable_progress_bars()
   2163         else:
-> 2164             return inner_training_loop(
   2165                 args=args,
   2166                 resume_from_checkpoint=resume_from_checkpoint,

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2587                         self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
   2588                         self.control = self.callback_handler.on_step_end(args, self.state, self.control)
-> 2589                         self._maybe_log_save_evaluate(
   2590                             tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time
   2591                         )

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time)
   3045         metrics = None
   3046         if self.control.should_evaluate:
-> 3047             metrics = self._evaluate(trial, ignore_keys_for_eval)
   3048             is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
   3049 

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler)
   2999 
   3000     def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
-> 3001         metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
   3002         self._report_to_hp_search(trial, self.state.global_step, metrics)
   3003 

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   4049 
   4050         eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 4051         output = eval_loop(
   4052             eval_dataloader,
   4053             description="Evaluation",

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   4243 
   4244             # Prediction step
-> 4245             losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
   4246             main_input_name = getattr(self.model, "main_input_name", "input_ids")
   4247             inputs_decode = (

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in prediction_step(self, model, inputs, prediction_loss_only, ignore_keys)
   4469                     loss = None
   4470                     with self.compute_loss_context_manager():
-> 4471                         outputs = model(**inputs)
   4472                     if isinstance(outputs, dict):
   4473                         logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
    191                 return self.module(*inputs[0], **module_kwargs[0])
    192             replicas = self.replicate(self.module, self.device_ids[: len(inputs)])
--> 193             outputs = self.parallel_apply(replicas, inputs, module_kwargs)
    194             return self.gather(outputs, self.output_device)
    195 

/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/data_parallel.py in parallel_apply(self, replicas, inputs, kwargs)
    210         self, replicas: Sequence[T], inputs: Sequence[Any], kwargs: Any
    211     ) -> List[Any]:
--> 212         return parallel_apply(
    213             replicas, inputs, kwargs, self.device_ids[: len(replicas)]
    214         )

/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/parallel_apply.py in parallel_apply(modules, inputs, kwargs_tup, devices)
    124         output = results[i]
    125         if isinstance(output, ExceptionWrapper):
--> 126             output.reraise()
    127         outputs.append(output)
    128     return outputs

/usr/local/lib/python3.10/dist-packages/torch/_utils.py in reraise(self)
    713             # instantiate since we don't know how to
    714             raise RuntimeError(msg) from None
--> 715         raise exception
    716 
    717 

TypeError: Caught TypeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/parallel_apply.py", line 96, in _worker
    output = module(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 279, in forward
    hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 194, in forward
    hidden_states, residual = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/block.py", line 67, in forward
    hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: Mamba.forward() got an unexpected keyword argument 'attention_mask'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant