JAX 0.5.0: unsupported operand type(s) for *: 'tuple' and 'DynamicJaxprTracer' #26809
-
Hello, import jax
import jax.numpy as jnp
h0=0.01
def body(val):
t,y,sigma = val
jax.debug.print("t {}",t)
h = h0 * (t/(1.0 + (h0*(t-1.0))) )
y *= (1.0+h)
sigma -= 0.1
t += 1.0
return t,y,sigma
def cond(val):
t,y,sigma = val
return (sigma>0.1) & (sigma<2.0)
val=(1.,jnp.ones((2,2)),1.0)
jax.lax.while_loop(cond,body,val) Now, if I perform independantly of the while-loop use case the computation of "h" by a jitted function @jax.jit
def f(t, h0):
print(t,h0)
return h0 * (t/(1.0 + (h0*(t-1.0))) )
f(0.1,0.1) No crash occurs and the print gives indication of the use of DynamicJaxprTrace
So, it seems that the line This is bizarre because I used to do such things before 0.5.0, am I wrong? how I can do such simple computation in the while-loop? Thanks |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
I can't reproduce this error on any versions of JAX that I've tried (including 0.5.0 and 0.5.1) on colab or on my mac. I think that what you've written should work, and it does for me! |
Beta Was this translation helpful? Give feedback.
oh! I find my mistake :)
I had stupidly a comma
h0=0.01,
in my code, that I have not included in the snippet .... sorry.Now the question on JAX 0.5.x version on Colab would be helpful.