Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient of generalized pipeline is nan #387

Closed
hackertyper opened this issue Aug 21, 2023 · 2 comments
Closed

Gradient of generalized pipeline is nan #387

hackertyper opened this issue Aug 21, 2023 · 2 comments

Comments

@hackertyper
Copy link

Hi,

first of all thanks for your work on brax!

I have noticed an anomaly when trying to get gradients of different pipelines.

When differentiating the pipeline functions, only the positional and spring backend yield a number as gradient. The generalized backend only yiels NaN.

The following code reproduces the issue. It simulates a free falling ball in an otherwise empty environment. The simulation is done for 100 steps. The resulting gradients are 1.0 for spring and positional each but nan for generalized.

import jax
from jax import numpy as jp
from brax.generalized import pipeline as generalized_pipeline
from brax.positional import pipeline as positional_pipeline
from brax.spring import pipeline as spring_pipeline
from brax.io import mjcf

free_fall_xml = """
<mujoco>
  <option gravity="0 0 -9.81" timestep="0.005" density="1.2" viscosity="0.00002" />
  <worldbody>
    <body pos="0 0 3" name="freeball">
      <joint type="free" name="j"/>
      <geom type="sphere" size=".2" mass="1.0" name="g" />
    </body>
  </worldbody>
</mujoco>
"""

# Simulate physics for 100 steps. The function takes an initial x and returns the resulting x, so it can be
# differentiated with respect to x.
def simulation(pipeline, system, init_x):
    init_q = system.init_q.at[0].set(init_x)
    init_qd = jp.zeros(6)
    state = jax.jit(pipeline.init)(system, init_q, init_qd)
    for i in range(100):
        state = jax.jit(pipeline.step)(system, state, None)
    return state.q[0]


if __name__ == "__main__":
    system = mjcf.loads(free_fall_xml)
    for pipeline in [generalized_pipeline, positional_pipeline, spring_pipeline]:
        print(f"{pipeline} results:")
        x = simulation(pipeline, system, 0)
        grad_x = jax.grad(simulation, argnums=2)(pipeline, system, 0.0)
        print(f"{x=}, {grad_x=}")

Output:

<module 'brax.generalized.pipeline'> results:
x=Array(0., dtype=float32), grad_x=Array(nan, dtype=float32, weak_type=True)
<module 'brax.positional.pipeline'> results:
x=Array(0., dtype=float32), grad_x=Array(1., dtype=float32, weak_type=True)
<module 'brax.spring.pipeline'> results:
x=Array(0., dtype=float32), grad_x=Array(1., dtype=float32, weak_type=True)
@btaba
Copy link
Collaborator

btaba commented Oct 15, 2023

Hi @hackertyper , thanks for the bug report! If you are able to pinpoint which part of the pipeline returns a NaN with https://jax.readthedocs.io/en/latest/debugging/flags.html that would be helpful to know where the grad is hitting a snag. As of yet, we haven't paid much attention to gradients in the generalized implementation

@btaba
Copy link
Collaborator

btaba commented Oct 25, 2023

this should be fixed in 1630403

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants