-
Notifications
You must be signed in to change notification settings - Fork 679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why does LSTMCell keep rngs in its state? #4509
Comments
Hi @JoaoAparicio, great question. Its because in general the carry initializer might use the random state. In practice its almost always zeros but currently we support the general case. |
Hi @JoaoAparicio , @cgarciae , I ran into the same "issue". In particular to cope with this for training I had to do the following:
Moreover, I have the similar issues now if I need to save the checkpoints from the training state running into: This seems quite convoluted to work with recurrent modules. Is there a recommended/different way of working with such units that I am missing out? in particular in regards to saving/restoring checkpoints and keeping track of the model in the training state. I see other related issues: #4383 |
@alezana thanks for the feedback! After thinking about this a bit I'm very inclined to remove the |
It seems that LSTMCell keeps rngs in its state:
flax/flax/nnx/nn/recurrent.py
Line 137 in a8a192f
Is this intentional? Why?
I stumbled upon this because my recipe for checkpointing breaks when my model contains an LSTM:
Calling
savemodel(model, path)
throws:This was surprising because I've been using that recipe before and never had a problem while using other non-LSTM modules.
The text was updated successfully, but these errors were encountered: