You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have noticed an anomaly when trying to get gradients of different pipelines.
When differentiating the pipeline functions, only the positional and spring backend yield a number as gradient. The generalized backend only yiels NaN.
The following code reproduces the issue. It simulates a free falling ball in an otherwise empty environment. The simulation is done for 100 steps. The resulting gradients are 1.0 for spring and positional each but nan for generalized.
import jax
from jax import numpy as jp
from brax.generalized import pipeline as generalized_pipeline
from brax.positional import pipeline as positional_pipeline
from brax.spring import pipeline as spring_pipeline
from brax.io import mjcf
free_fall_xml = """
<mujoco>
<option gravity="0 0 -9.81" timestep="0.005" density="1.2" viscosity="0.00002" />
<worldbody>
<body pos="0 0 3" name="freeball">
<joint type="free" name="j"/>
<geom type="sphere" size=".2" mass="1.0" name="g" />
</body>
</worldbody>
</mujoco>
"""
# Simulate physics for 100 steps. The function takes an initial x and returns the resulting x, so it can be
# differentiated with respect to x.
def simulation(pipeline, system, init_x):
init_q = system.init_q.at[0].set(init_x)
init_qd = jp.zeros(6)
state = jax.jit(pipeline.init)(system, init_q, init_qd)
for i in range(100):
state = jax.jit(pipeline.step)(system, state, None)
return state.q[0]
if __name__ == "__main__":
system = mjcf.loads(free_fall_xml)
for pipeline in [generalized_pipeline, positional_pipeline, spring_pipeline]:
print(f"{pipeline} results:")
x = simulation(pipeline, system, 0)
grad_x = jax.grad(simulation, argnums=2)(pipeline, system, 0.0)
print(f"{x=}, {grad_x=}")
Hi @hackertyper , thanks for the bug report! If you are able to pinpoint which part of the pipeline returns a NaN with https://jax.readthedocs.io/en/latest/debugging/flags.html that would be helpful to know where the grad is hitting a snag. As of yet, we haven't paid much attention to gradients in the generalized implementation
Hi,
first of all thanks for your work on brax!
I have noticed an anomaly when trying to get gradients of different pipelines.
When differentiating the pipeline functions, only the positional and spring backend yield a number as gradient. The generalized backend only yiels NaN.
The following code reproduces the issue. It simulates a free falling ball in an otherwise empty environment. The simulation is done for 100 steps. The resulting gradients are
1.0
forspring
andpositional
each butnan
forgeneralized
.Output:
The text was updated successfully, but these errors were encountered: