You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
Hey Kris, love your framework! Working with a custom environment, and your discrete action unit test works perfect locally. Don't spend much time investigating this yet, just creating this incase something jumps out at you as the problem. I plan on continuing to debug this issue.
During the first PPOClip update with the custom gym, the model weights get changed to +/-inf despite a non-infinite grad.
Expected behavior
...
adv = np.random.rand(32)
grads, function_state, metrics = ppo_clip.grads_and_metrics(transition_batch, Adv=adv)
print("grads", grads)
print(ppo_clip._pi.params)
metrics_pi = ppo_clip.update(transition_batch, Adv=adv) # This is the problem
print(ppo_clip._pi.params)
Here is the full repro script taken from the Pong PPO example and slightly modified, but it won't work because of the custom environment. This is a dummy-example, not the actual policy and value networks that would be used:
import os
from luxai2021.env.lux_env import LuxEnvironment, LuxEnvironmentTeam
from luxai2021.game.game import Game
from luxai2021.game.actions import *
from luxai2021.game.constants import LuxMatchConfigs_Default
from luxai2021.env.agent import Agent, AgentWithTeamModel
import numpy as np
from agent import TeamAgent
# set some env vars
os.environ.setdefault('JAX_PLATFORM_NAME', 'cpu') # tell JAX to use GPU
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.1' # don't use all gpu mem
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # tell XLA to be quiet
import gym
import jax
import coax
import haiku as hk
import jax.numpy as jnp
from optax import adam
# the name of this script
name = 'ppo'
configs = LuxMatchConfigs_Default
player = TeamAgent(mode="train")
opponent = Agent()
env = LuxEnvironment(configs=configs,
learning_agent=player,
opponent_agent=opponent)
env = coax.wrappers.TrainMonitor(env, name=name, tensorboard_dir=f"./data/tensorboard/{name}")
def func_pi(S, is_training):
n_actions = 4
out = {'logits': hk.Linear(n_actions)(hk.Flatten()(S)) }
return out
def func_v(S, is_training):
h = jnp.ravel(hk.Linear(1)(hk.Flatten()(S)))
return h
'''
def func_pi(S, is_training):
#print(env.action_space.shape)
n_filters = 5
n_actions = 4
n_layers = 3
h = hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(S)
for layer in range(n_layers):
h = jax.nn.relu(h + hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(h))
print('h', type(h), h.shape)
h_head = (h * S[:,:1]).reshape(h.shape[0], h.shape[1], -1).sum(-1) # torch.Size([1, N_LAYERS])
h_head_actions = hk.Linear(n_actions)(h_head)
print('h_head_actions', type(h_head_actions), h_head_actions.shape)
#print(h_head_actions)
out = {'logits': h_head_actions}
return out
def func_v(S, is_training):
n_filters = 5
n_layers = 3
h = hk.Conv2D(n_filters, kernel_shape=3, stride=1, padding="SAME", data_format='NCHW')(S)
for layer in range(n_layers):
h = jax.nn.relu(hk.Conv2D(n_filters, kernel_shape=3, stride=2, data_format='NCHW')(h))
h = hk.Flatten()(h)
h = jax.nn.relu(hk.Linear(64)(h))
h = jnp.ravel(hk.Linear(1, w_init=jnp.zeros)(h))
return h
'''
# function approximators
pi = coax.Policy(func_pi, env)
v = coax.V(func_v, env)
# target networks
pi_behavior = pi.copy()
v_targ = v.copy()
# policy regularizer (avoid premature exploitation)
entropy = coax.regularizers.EntropyRegularizer(pi, beta=0.001)
# updaters
simpletd = coax.td_learning.SimpleTD(v, v_targ, optimizer=adam(3e-4))
ppo_clip = coax.policy_objectives.PPOClip(pi, regularizer=entropy, optimizer=adam(3e-4))
# reward tracer and replay buffer
tracer = coax.reward_tracing.NStep(n=5, gamma=0.99)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=256)
# run episodes
max_episode_steps = 400
while env.T < 3000000:
s = env.reset()
for t in range(max_episode_steps):
print(t)
a, logp = pi_behavior(s, return_logp=True)
s_next, r, done, info = env.step(a)
# trace rewards and add transition to replay buffer
tracer.add(s, a, r, done, logp)
while tracer:
buffer.add(tracer.pop())
# learn
if len(buffer) >= buffer.capacity:
num_batches = int(4 * buffer.capacity / 32) # 4 epochs per round
for i in range(num_batches):
transition_batch = buffer.sample(32)
grads, function_state, metrics, td_error = simpletd.grads_and_metrics(transition_batch)
metrics_v, td_error = simpletd.update(transition_batch, return_td_error=True)
adv = np.random.rand(32)
grads, function_state, metrics = ppo_clip.grads_and_metrics(transition_batch, Adv=adv)
print("grads", grads)
print(ppo_clip._pi.params)
metrics_pi = ppo_clip.update(transition_batch, Adv=adv) # This is the problem
print(ppo_clip._pi.params)
exit()
env.record_metrics(metrics_pi)
env.record_metrics(metrics_v)
buffer.clear()
# sync target networks
pi_behavior.soft_update(pi, tau=0.1)
v_targ.soft_update(v, tau=0.1)
if done:
break
s = s_next
# generate an animated GIF to see what's going on
if env.period(name='generate_gif', T_period=10000) and env.T > 50000:
T = env.T - env.T % 10000 # round to 10000s
coax.utils.generate_gif(
env=env, policy=pi, resize_to=(320, 420),
filepath=f"./data/gifs/{name}/T{T:08d}.gif")
The text was updated successfully, but these errors were encountered:
To add to @glmcdona, I'm getting the exact same issue but with a Box action space (if that makes any difference). After the update with the first minibatch the networks are filled with nans.
I will try to replicate with a classic gym env (by the way the pendulum-v0 from the examples is deprecated I think).
This error only occurs with the optax adam optimizer. Workaround is to use sgd optimizer. Error does not reproduce with TestPPOClip->test_update_discrete() or the example pong PPO with adam optimizer. Maybe close this issue unless a reliable repro can be created?
It's very surprising that replacing optax.adam by optax.sgd seems to help. Perhaps the adam accumulators are contaminated by one a non-finite gradient somewhere?
Describe the bug
Hey Kris, love your framework! Working with a custom environment, and your discrete action unit test works perfect locally. Don't spend much time investigating this yet, just creating this incase something jumps out at you as the problem. I plan on continuing to debug this issue.
During the first PPOClip update with the custom gym, the model weights get changed to
+/-inf
despite a non-infinite grad.Expected behavior
Results in:
Here is the full repro script taken from the Pong PPO example and slightly modified, but it won't work because of the custom environment. This is a dummy-example, not the actual policy and value networks that would be used:
The text was updated successfully, but these errors were encountered: