Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why does LSTMCell keep rngs in its state? #4509

Open
JoaoAparicio opened this issue Jan 28, 2025 · 3 comments
Open

Why does LSTMCell keep rngs in its state? #4509

JoaoAparicio opened this issue Jan 28, 2025 · 3 comments

Comments

@JoaoAparicio
Copy link

JoaoAparicio commented Jan 28, 2025

It seems that LSTMCell keeps rngs in its state:

self.rngs = rngs

Is this intentional? Why?

I stumbled upon this because my recipe for checkpointing breaks when my model contains an LSTM:

import orbax.checkpoint as ocp
def savemodel(model, path):
    _, state = nnx.split(model)
    checkpointer = ocp.StandardCheckpointer()
    checkpointer.save(path, state)

Calling savemodel(model, path) throws:

TypeError: JAX array with PRNGKey dtype cannot be converted to a NumPy array. Use jax.random.key_data(arr) if you wish to extract the underlying integer array.

This was surprising because I've been using that recipe before and never had a problem while using other non-LSTM modules.

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 3, 2025

Hi @JoaoAparicio, great question. Its because in general the carry initializer might use the random state. In practice its almost always zeros but currently we support the general case.

@alezana
Copy link

alezana commented Feb 21, 2025

Hi @JoaoAparicio , @cgarciae ,

I ran into the same "issue". In particular to cope with this for training I had to do the following:

  • subclass the TrainState to keep track separately of the non trainable parameters requiring a more convoluted graphdef, params, static_params = nnx.split( model, nnx.filterlib.OfType(nnx.Param), nnx.filterlib.Not(nnx.Param) )

Moreover, I have the similar issues now if I need to save the checkpoints from the training state running into: TypeError: JAX array with PRNGKey dtype cannot be converted to a NumPy array. Use jax.random.key_data(arr) if you wish to extract the underlying integer array.

This seems quite convoluted to work with recurrent modules. Is there a recommended/different way of working with such units that I am missing out? in particular in regards to saving/restoring checkpoints and keeping track of the model in the training state.

I see other related issues: #4383

@cgarciae
Copy link
Collaborator

@alezana thanks for the feedback! After thinking about this a bit I'm very inclined to remove the rngs attribute from the various recurrent cells and have the user explicitly rely on initialize_carry's rngs argument.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants