Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nnx.vmap with nnx.split_rngs does not create multiple Dropout layers #4589

Open
johnnycrab opened this issue Mar 1, 2025 · 2 comments
Open

Comments

@johnnycrab
Copy link

johnnycrab commented Mar 1, 2025

Hi,
Thanks for your great work on flax nnx!
I am experiencing a ValueError when scanning over a stack of network layers that contain nnx.Dropout. The error happens when the stack was created with vmap and we split the nnx rng with @nnx.split_rngs. The error does not happen when I split the random keys with jax.random.split. To reproduce step by step, please consider the following minimal example:

Let's consider a simple layer that contains Dropout:

import jax
import jax.numpy as jnp
from flax import nnx

class SomeLayer(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        self.linear = nnx.Linear(10, 10, rngs=rngs)
        self.dropout = nnx.Dropout(0.1, rngs=rngs)

    def __call__(self, x):
        return self.linear(self.dropout(x))

And now we build a model with a stack of 3 of these layers via vmap. In the documentation, there are two ways provided how we can split the random keys for this. The "jax"-version, as described in NNX basics doc, where we split a jax.random.key and the "nnx"-version, as described in the nnx toy examples, where we use nnx.split_rngs. This is exemplified in the following code with the two versions:

class SomeModel(nnx.Module):
    def __init__(self, version: str, rngs: nnx.Rngs):

        if version == "jax":
            @nnx.vmap
            def create_layer(key):
                return SomeLayer(nnx.Rngs(key))

            keys = jax.random.split(jax.random.key(0), 3)
            self.stack = create_layer(keys)

        elif version == "nnx":
            @nnx.split_rngs(splits=3)
            @nnx.vmap(axis_size=3)
            def create_layer(r: nnx.Rngs):
                return SomeLayer(r)

            self.stack = create_layer(rngs)

    def __call__(self, x):
        @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
        def scan_over_stack(x, layer):
            return layer(x)

        return scan_over_stack(x, self.stack)

And let's simply run it like this:

arr = jnp.zeros(10)
model = SomeModel("jax", nnx.Rngs(0))
print(model.stack)
out = model(arr)

When I use the "jax"-version, everything works as expected. Indeed, when I print the layer stack, the Dropout part of the stack looks like this, where the value of count and key is of shape (3,), as expected:

SomeLayer( # RngState: 6 (36 B), Param: 330 (1.3 KB), Total: 336 (1.4 KB)
  dropout=Dropout( # RngState: 6 (36 B)
    broadcast_dims=(),
    deterministic=False,
    rate=0.1,
    rng_collection='dropout',
    rngs=Rngs( # RngState: 6 (36 B)
      default=RngStream( # RngState: 6 (36 B)
        count=RngCount( # 3 (12 B)
          value=Array(shape=(3,), dtype=dtype('uint32')),     <--- three times!
          tag="'default'"
        ),
        key=RngKey( # 3 (24 B)
          value=Array(shape=(3,), dtype=key<fry>),     <----- three times!
          tag="'default'"
        )
      )
    )
  ),

On the other hand, when I use the "nnx"-version, I get the following error raised for the model call:
ValueError: axis 0 is out of bounds for array of dimension 0.

The dropout part of the printed layer stack looks like this:

SomeLayer( # RngState: 2 (12 B), Param: 330 (1.3 KB), Total: 332 (1.3 KB)
  dropout=Dropout( # RngState: 2 (12 B)
    broadcast_dims=(),
    deterministic=False,
    rate=0.1,
    rng_collection='dropout',
    rngs=Rngs( # RngState: 2 (12 B)
      default=RngStream( # RngState: 2 (12 B)
        count=RngCount( # 1 (4 B)
          value=Array(1, dtype=uint32),     <---- only one!
          tag="'default'"
        ),
        key=RngKey( # 1 (8 B)
          value=Array((), dtype=key<fry>) overlaying:
          [0 0],
          tag="'default'"
        )
      )
    )
  ),

So it seems that we don't have Dropout three times (the Linear layers are fine and of the correct shape).

Is this expected behavior and I'm just doing something wrong?
Thank you very much for looking into this!

System information

  • OS Platform and Distribution: macOS 15.3.1, Ubuntu 22.04.5
  • Versions: flax 0.10.3, jax 0.5.0, jaxlib 0.50
  • Python version: 3.10.11
  • CPU
@cgarciae
Copy link
Collaborator

cgarciae commented Mar 4, 2025

Hey @johnnycrab, the reason is that split_rngs upon exit calls reverts the keys and counts to a (potentially) un-split / scalar representation and this will cause an error during scan since you are iterating over the layer. One solution is to call split_rngs again during __call__:

class SomeModel(nnx.Module):
  def __init__(self, rngs: nnx.Rngs):
    @nnx.split_rngs(splits=3)
    @nnx.vmap(axis_size=3)
    def create_layer(r: nnx.Rngs):
      return SomeLayer(r)

    self.stack = create_layer(rngs)

  def __call__(self, x):
    @nnx.split_rngs(splits=3)
    @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
    def scan_over_stack(x, layer):
      return layer(x)

    return scan_over_stack(x, self.stack)

arr = jnp.zeros(10)
model = SomeModel(nnx.Rngs(0))
print(model.stack)
out = model(arr)

Another interesting alternative mark the RngState as a Carry by using StateAxes to control Module substates:

class SomeModel(nnx.Module):
  def __init__(self, rngs: nnx.Rngs):
    @nnx.split_rngs(splits=3)
    @nnx.vmap(axis_size=3)
    def create_layer(r: nnx.Rngs):
      return SomeLayer(r)

    self.stack = create_layer(rngs)

  def __call__(self, x):

    state_axes = nnx.StateAxes({nnx.RngState: nnx.Carry, ...: 0})
    @nnx.scan(in_axes=(nnx.Carry, state_axes), out_axes=nnx.Carry)
    def scan_over_stack(x, layer):
      return layer(x)

    return scan_over_stack(x, self.stack)

Here you don't need to split because each scan step will increase/accumulate all the RngCounter for the next step so they all get different random values.

@johnnycrab
Copy link
Author

@cgarciae I see, thanks a lot for the explanation! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants