You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi,
Thanks for your great work on flax nnx!
I am experiencing a ValueError when scanning over a stack of network layers that contain nnx.Dropout. The error happens when the stack was created with vmap and we split the nnx rng with @nnx.split_rngs. The error does not happen when I split the random keys with jax.random.split. To reproduce step by step, please consider the following minimal example:
Let's consider a simple layer that contains Dropout:
And now we build a model with a stack of 3 of these layers via vmap. In the documentation, there are two ways provided how we can split the random keys for this. The "jax"-version, as described in NNX basics doc, where we split a jax.random.key and the "nnx"-version, as described in the nnx toy examples, where we use nnx.split_rngs. This is exemplified in the following code with the two versions:
When I use the "jax"-version, everything works as expected. Indeed, when I print the layer stack, the Dropout part of the stack looks like this, where the value of count and key is of shape (3,), as expected:
SomeLayer( # RngState: 6 (36 B), Param: 330 (1.3 KB), Total: 336 (1.4 KB)
dropout=Dropout( # RngState: 6 (36 B)
broadcast_dims=(),
deterministic=False,
rate=0.1,
rng_collection='dropout',
rngs=Rngs( # RngState: 6 (36 B)
default=RngStream( # RngState: 6 (36 B)
count=RngCount( # 3 (12 B)
value=Array(shape=(3,), dtype=dtype('uint32')), <--- three times!
tag="'default'"
),
key=RngKey( # 3 (24 B)
value=Array(shape=(3,), dtype=key<fry>), <----- three times!
tag="'default'"
)
)
)
),
On the other hand, when I use the "nnx"-version, I get the following error raised for the model call: ValueError: axis 0 is out of bounds for array of dimension 0.
The dropout part of the printed layer stack looks like this:
Hey @johnnycrab, the reason is that split_rngs upon exit calls reverts the keys and counts to a (potentially) un-split / scalar representation and this will cause an error during scan since you are iterating over the layer. One solution is to call split_rngs again during __call__:
Here you don't need to split because each scan step will increase/accumulate all the RngCounter for the next step so they all get different random values.
Hi,
Thanks for your great work on flax nnx!
I am experiencing a
ValueError
when scanning over a stack of network layers that containnnx.Dropout
. The error happens when the stack was created withvmap
and we split the nnx rng with@nnx.split_rngs
. The error does not happen when I split the random keys withjax.random.split
. To reproduce step by step, please consider the following minimal example:Let's consider a simple layer that contains
Dropout
:And now we build a model with a stack of 3 of these layers via
vmap
. In the documentation, there are two ways provided how we can split the random keys for this. The "jax"-version, as described in NNX basics doc, where we split ajax.random.key
and the "nnx"-version, as described in the nnx toy examples, where we usennx.split_rngs
. This is exemplified in the following code with the two versions:And let's simply run it like this:
When I use the "jax"-version, everything works as expected. Indeed, when I print the layer stack, the
Dropout
part of the stack looks like this, where thevalue
ofcount
andkey
is of shape(3,)
, as expected:On the other hand, when I use the "nnx"-version, I get the following error raised for the model call:
ValueError: axis 0 is out of bounds for array of dimension 0
.The dropout part of the printed layer stack looks like this:
So it seems that we don't have Dropout three times (the Linear layers are fine and of the correct shape).
Is this expected behavior and I'm just doing something wrong?
Thank you very much for looking into this!
System information
The text was updated successfully, but these errors were encountered: