-
Notifications
You must be signed in to change notification settings - Fork 679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to unroll a model end to end? #4507
Comments
Hi @JoaoAparicio, layers like class Model(nnx.Module):
def __init__(self, rngs):
self.rnn = nnx.RNN(...)
self.linear = nnx.Linear(..., rngs=rngs)
def __call__(self, x):
# x: [batch, time, features]
x = self.rnn(x) # applies recurrently to all timesteps
x = self.linear(x) # applies in parallel to all timesteps
return x |
Hey, thank you for taking the time :-) Couple of follow up questions! In the code that you presented you have the comment that And following from the above, is the intended design when writing modules, that modules should be written in a way to understand the minimum number of features dimensions they require, and assume that any additional outter dimensions are to be parallelized over? |
Oh and quick question: in the timesteps dimension, which direction does time go? This isn't in the
should I infer that for input with dimensions |
Hello
What's the correct way to unroll a model that contains an LSTM?
e.g. Suppose my model has 3 blocks from top to bottom:
I know how to unroll the LSTM N times, there's a module for that:
But then, how do I do the same for the rest of the non-recurrent parts of my model?
It ocurred to me that
nnx.RNN
is perhaps generic enough to work with the full model end to end? The fact that it's annotated as taking annx.RNNCellBase
suggests this is probably not the case.Instead I tried using
nnx.RNN
to unroll the LSTM but then do the rest manually.For example I managed to make this work (schematically):
This works. But because it has to be done by hand on a model by model bases, it scales badly for more complex models, and it's error prone.
Is there a way of unrolling the full model as easily as the LSTM component?
The text was updated successfully, but these errors were encountered: