Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multidiscrete preprocessor #22

Merged
merged 5 commits into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions coax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.1.10'
__version__ = '0.1.11'


# expose specific classes and functions
Expand Down Expand Up @@ -73,17 +73,17 @@

import gym

if 'ConnectFour-v0' in gym.envs.registry.env_specs:
del gym.envs.registry.env_specs['ConnectFour-v0']
if 'ConnectFour-v0' in gym.envs.registry:
del gym.envs.registry['ConnectFour-v0']

gym.envs.register(
id='ConnectFour-v0',
entry_point='coax.envs:ConnectFourEnv',
)


if 'FrozenLakeNonSlippery-v0' in gym.envs.registry.env_specs:
del gym.envs.registry.env_specs['FrozenLakeNonSlippery-v0']
if 'FrozenLakeNonSlippery-v0' in gym.envs.registry:
del gym.envs.registry['FrozenLakeNonSlippery-v0']

gym.envs.register(
id='FrozenLakeNonSlippery-v0',
Expand Down
4 changes: 2 additions & 2 deletions coax/_base/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,13 +322,13 @@ def assertArrayNotEqual(self, x, y, margin=margin):

def assertPytreeAlmostEqual(self, x, y, decimal=None):
decimal = decimal or self.decimal
jax.tree_multimap(
jax.tree_map(
lambda x, y: onp.testing.assert_array_almost_equal(
x, y, decimal=decimal), x, y)

def assertPytreeNotEqual(self, x, y, margin=None):
margin = margin or self.margin
reldiff = jax.tree_multimap(
reldiff = jax.tree_map(
lambda a, b: abs(2 * (a - b) / (a + b + 1e-16)), x, y)
maxdiff = max(jnp.max(d) for d in jax.tree_leaves(reldiff))
assert float(maxdiff) > margin
Expand Down
2 changes: 1 addition & 1 deletion coax/_core/base_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, func, observation_space, action_space=None, random_seed=None)
self._check_output(output, example_data.output)

def soft_update_func(old, new, tau):
return jax.tree_multimap(lambda a, b: (1 - tau) * a + tau * b, old, new)
return jax.tree_map(lambda a, b: (1 - tau) * a + tau * b, old, new)

self._soft_update_func = jit(soft_update_func)

Expand Down
6 changes: 3 additions & 3 deletions coax/_core/base_stochastic_func_type1.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,12 @@ def example_data(
# input: state observations
S = [safe_sample(env.observation_space, rnd) for _ in range(batch_size)]
S = [observation_preprocessor(next(rngs), s) for s in S]
S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S)
S = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *S)

# input: actions
A = [safe_sample(env.action_space, rnd) for _ in range(batch_size)]
A = [action_preprocessor(next(rngs), a) for a in A]
A = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *A)
A = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *A)

# output: type1
dist_params_type1 = jax.tree_map(
Expand Down Expand Up @@ -425,7 +425,7 @@ def _check_output(self, actual, expected):
f"found leaves of type: {bad_types}")

if not all(a.shape == b.shape for a, b in zip(actual_leaves, expected_leaves)):
shapes_tree = jax.tree_multimap(
shapes_tree = jax.tree_map(
lambda a, b: f"{a.shape} {'!=' if a.shape != b.shape else '=='} {b.shape}",
actual, expected)
raise TypeError(f"found leaves with unexpected shapes: {shapes_tree}")
Expand Down
4 changes: 2 additions & 2 deletions coax/_core/base_stochastic_func_type2.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def example_data(
# input: state observations
S = [safe_sample(env.observation_space, rnd) for _ in range(batch_size)]
S = [observation_preprocessor(next(rngs), s) for s in S]
S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S)
S = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *S)

# output
dist_params = jax.tree_map(
Expand Down Expand Up @@ -201,7 +201,7 @@ def _check_output(self, actual, expected):
f"found leaves of type: {bad_types}")

if not all(a.shape == b.shape for a, b in zip(actual_leaves, expected_leaves)):
shapes_tree = jax.tree_multimap(
shapes_tree = jax.tree_map(
lambda a, b: f"{a.shape} {'!=' if a.shape != b.shape else '=='} {b.shape}",
actual, expected)
raise TypeError(f"found leaves with unexpected shapes: {shapes_tree}")
4 changes: 2 additions & 2 deletions coax/_core/q.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,12 @@ def example_data(
# input: state observations
S = [safe_sample(env.observation_space, rnd) for _ in range(batch_size)]
S = [observation_preprocessor(next(rngs), s) for s in S]
S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S)
S = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *S)

# input: actions
A = [safe_sample(env.action_space, rnd) for _ in range(batch_size)]
A = [action_preprocessor(next(rngs), a) for a in A]
A = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *A)
A = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *A)

# output: type1
q1_data = ExampleData(
Expand Down
6 changes: 3 additions & 3 deletions coax/_core/transition_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,12 @@ def example_data(
# input: state observations
S = [safe_sample(env.observation_space, rnd) for _ in range(batch_size)]
S = [observation_preprocessor(next(rngs), s) for s in S]
S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S)
S = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *S)

# input: actions
A = [safe_sample(env.action_space, rnd) for _ in range(batch_size)]
A = [action_preprocessor(next(rngs), a) for a in A]
A = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *A)
A = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *A)

# output: type1
S_next_type1 = jax.tree_map(lambda x: jnp.asarray(rnd.randn(batch_size, *x.shape[1:])), S)
Expand Down Expand Up @@ -312,7 +312,7 @@ def _check_output(self, actual, expected):
f"found leaves of type: {bad_types}")

if not all(a.shape == b.shape for a, b in zip(actual_leaves, expected_leaves)):
shapes_tree = jax.tree_multimap(
shapes_tree = jax.tree_map(
lambda a, b: f"{a.shape} {'!=' if a.shape != b.shape else '=='} {b.shape}",
actual, expected)
raise TypeError(f"found leaves with unexpected shapes: {shapes_tree}")
2 changes: 1 addition & 1 deletion coax/_core/v.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def example_data(cls, env, observation_preprocessor=None, batch_size=1, random_s
# input: state observations
S = [safe_sample(env.observation_space, rnd) for _ in range(batch_size)]
S = [observation_preprocessor(next(rngs), s) for s in S]
S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S)
S = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *S)

return ExampleData(
inputs=Inputs(args=ArgsType2(S=S, is_training=True), static_argnums=(1,)),
Expand Down
2 changes: 1 addition & 1 deletion coax/experience_replay/_prioritized.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __iter__(self):


def _concatenate_leaves(pytrees):
return jax.tree_multimap(lambda *leaves: onp.concatenate(leaves, axis=0), *pytrees)
return jax.tree_map(lambda *leaves: onp.concatenate(leaves, axis=0), *pytrees)


@onp.vectorize
Expand Down
2 changes: 1 addition & 1 deletion coax/experience_replay/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def sample(self, batch_size=32):
random.setstate(self._random_state)
transitions = random.sample(self._storage, batch_size)
self._random_state = random.getstate()
return jax.tree_multimap(lambda *leaves: onp.concatenate(leaves, axis=0), *transitions)
return jax.tree_map(lambda *leaves: onp.concatenate(leaves, axis=0), *transitions)

def clear(self):
r""" Clear the experience replay buffer. """
Expand Down
6 changes: 6 additions & 0 deletions coax/proba_dists/_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ def preprocess_variate(self, rng, X):
if self._structure_type == StructureType.LEAF:
return self._structure.preprocess_variate(next(rngs), X)

if isinstance(self.space, (gym.spaces.MultiDiscrete, gym.spaces.MultiBinary)):
assert self._structure_type == StructureType.LIST
return [
dist.preprocess_variate(next(rngs), X[..., i])
for i, dist in enumerate(self._structure)]

if self._structure_type == StructureType.LIST:
return [
dist.preprocess_variate(next(rngs), X[i])
Expand Down
7 changes: 7 additions & 0 deletions coax/proba_dists/_composite_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,10 @@ def test_prepostprocess_variate(self):
self.assertNotIn(X_raw['multidiscrete'][0], space['multidiscrete'])
self.assertIn(X_clean['multidiscrete'][0], space['multidiscrete'])
self.assertIn(x_clean['multidiscrete'], space['multidiscrete'])
# Check if bijective.
X_clean_ = dist.postprocess_variate(
next(self.rngs), dist.preprocess_variate(next(self.rngs), X_clean), batch_mode=True)
x_clean_ = dist.postprocess_variate(
next(self.rngs), dist.preprocess_variate(next(self.rngs), x_clean), batch_mode=False)
self.assertPytreeAlmostEqual(X_clean_, X_clean)
self.assertPytreeAlmostEqual(x_clean_, x_clean)
2 changes: 1 addition & 1 deletion coax/regularizers/_nstep_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def f(s_next):
self.f.observation_preprocessor(
next(rngs), s_next), True)
n_states = transition_batch.extra_info['states']
dist_params, _ = jax.vmap(f)(jax.tree_util.tree_multimap(
dist_params, _ = jax.vmap(f)(jax.tree_util.tree_map(
lambda *t: jnp.stack(t), *n_states))
dist_params = jax.tree_util.tree_map(
lambda t: jnp.take(t, self._n, axis=0), dist_params)
Expand Down
2 changes: 1 addition & 1 deletion coax/reward_tracing/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,4 @@ def flush(self):
while self:
transitions.append(self.pop())

return jax.tree_multimap(lambda *leaves: onp.concatenate(leaves, axis=0), *transitions)
return jax.tree_map(lambda *leaves: onp.concatenate(leaves, axis=0), *transitions)
11 changes: 7 additions & 4 deletions coax/utils/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_leaves(a, b):
if jax.tree_structure(y) != jax.tree_structure(y0):
return False
try:
jax.tree_multimap(test_leaves, y, y0)
jax.tree_map(test_leaves, y, y0)
except AssertionError:
return False
return True
Expand Down Expand Up @@ -625,7 +625,7 @@ def get_transition_batch(env, batch_size=1, gamma=0.9, random_seed=None):
def batch_sample(space):
max_seed = onp.iinfo('int32').max
X = [safe_sample(space, seed=rnd.randint(max_seed)) for _ in range(batch_size)]
return jax.tree_multimap(lambda *leaves: onp.stack(leaves, axis=0), *X)
return jax.tree_map(lambda *leaves: onp.stack(leaves, axis=0), *X)

return TransitionBatch(
S=batch_sample(env.observation_space),
Expand Down Expand Up @@ -1059,16 +1059,19 @@ def _check_leaf_batch_size(pytree):

def stack_trees(*trees):
"""
Stack
Apply :func:`jnp.stack <jax.numpy.stack>` to the leaves of a pytree.

Parameters
----------
trees : sequence of pytrees with ndarray leaves
A typical example are pytrees containing the parameters and function states of
a model that should be used in a function which is vectorized by `jax.vmap`. The trees
have to have the same pytree structure.

Returns
-------
pytree : pytree with ndarray leaves
A tuple of pytrees.

"""
return jax.tree_util.tree_multimap(lambda *args: jnp.stack(args), *zip(*trees))
return jax.tree_util.tree_map(lambda *args: jnp.stack(args), *zip(*trees))
2 changes: 1 addition & 1 deletion coax/utils/_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_default_preprocessor(self):
self.assertArrayShape(default_preprocessor(dct)(next(rngs), dct.sample())['mds'][0], (1, 3))
self.assertArrayShape(default_preprocessor(dct)(next(rngs), dct.sample())['mds'][1], (1, 5))

mds_batch = jax.tree_multimap(lambda *x: jnp.stack(x), *(mds.sample() for _ in range(7)))
mds_batch = jax.tree_map(lambda *x: jnp.stack(x), *(mds.sample() for _ in range(7)))
self.assertArrayShape(default_preprocessor(mds)(next(rngs), mds_batch)[0], (7, 3))
self.assertArrayShape(default_preprocessor(mds)(next(rngs), mds_batch)[1], (7, 5))

Expand Down
Binary file modified doc/_intersphinx/haiku.inv
Binary file not shown.
Loading