Skip to content

Commit

Permalink
Fix multidiscrete preprocessor (#22)
Browse files Browse the repository at this point in the history
* fix preprocessor for MultiDiscrete spaces

* fix deprecation warnings

* upgrade requirements

* v0.1.11

* add --quiet flag to colab install cell
  • Loading branch information
KristianHolsheimer authored Jul 22, 2022
1 parent 9005f70 commit b6a3a48
Show file tree
Hide file tree
Showing 85 changed files with 2,813 additions and 2,862 deletions.
10 changes: 5 additions & 5 deletions coax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.1.10'
__version__ = '0.1.11'


# expose specific classes and functions
Expand Down Expand Up @@ -73,17 +73,17 @@

import gym

if 'ConnectFour-v0' in gym.envs.registry.env_specs:
del gym.envs.registry.env_specs['ConnectFour-v0']
if 'ConnectFour-v0' in gym.envs.registry:
del gym.envs.registry['ConnectFour-v0']

gym.envs.register(
id='ConnectFour-v0',
entry_point='coax.envs:ConnectFourEnv',
)


if 'FrozenLakeNonSlippery-v0' in gym.envs.registry.env_specs:
del gym.envs.registry.env_specs['FrozenLakeNonSlippery-v0']
if 'FrozenLakeNonSlippery-v0' in gym.envs.registry:
del gym.envs.registry['FrozenLakeNonSlippery-v0']

gym.envs.register(
id='FrozenLakeNonSlippery-v0',
Expand Down
4 changes: 2 additions & 2 deletions coax/_base/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,13 +322,13 @@ def assertArrayNotEqual(self, x, y, margin=margin):

def assertPytreeAlmostEqual(self, x, y, decimal=None):
decimal = decimal or self.decimal
jax.tree_multimap(
jax.tree_map(
lambda x, y: onp.testing.assert_array_almost_equal(
x, y, decimal=decimal), x, y)

def assertPytreeNotEqual(self, x, y, margin=None):
margin = margin or self.margin
reldiff = jax.tree_multimap(
reldiff = jax.tree_map(
lambda a, b: abs(2 * (a - b) / (a + b + 1e-16)), x, y)
maxdiff = max(jnp.max(d) for d in jax.tree_leaves(reldiff))
assert float(maxdiff) > margin
Expand Down
2 changes: 1 addition & 1 deletion coax/_core/base_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, func, observation_space, action_space=None, random_seed=None)
self._check_output(output, example_data.output)

def soft_update_func(old, new, tau):
return jax.tree_multimap(lambda a, b: (1 - tau) * a + tau * b, old, new)
return jax.tree_map(lambda a, b: (1 - tau) * a + tau * b, old, new)

self._soft_update_func = jit(soft_update_func)

Expand Down
6 changes: 3 additions & 3 deletions coax/_core/base_stochastic_func_type1.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,12 @@ def example_data(
# input: state observations
S = [safe_sample(env.observation_space, rnd) for _ in range(batch_size)]
S = [observation_preprocessor(next(rngs), s) for s in S]
S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S)
S = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *S)

# input: actions
A = [safe_sample(env.action_space, rnd) for _ in range(batch_size)]
A = [action_preprocessor(next(rngs), a) for a in A]
A = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *A)
A = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *A)

# output: type1
dist_params_type1 = jax.tree_map(
Expand Down Expand Up @@ -425,7 +425,7 @@ def _check_output(self, actual, expected):
f"found leaves of type: {bad_types}")

if not all(a.shape == b.shape for a, b in zip(actual_leaves, expected_leaves)):
shapes_tree = jax.tree_multimap(
shapes_tree = jax.tree_map(
lambda a, b: f"{a.shape} {'!=' if a.shape != b.shape else '=='} {b.shape}",
actual, expected)
raise TypeError(f"found leaves with unexpected shapes: {shapes_tree}")
Expand Down
4 changes: 2 additions & 2 deletions coax/_core/base_stochastic_func_type2.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def example_data(
# input: state observations
S = [safe_sample(env.observation_space, rnd) for _ in range(batch_size)]
S = [observation_preprocessor(next(rngs), s) for s in S]
S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S)
S = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *S)

# output
dist_params = jax.tree_map(
Expand Down Expand Up @@ -201,7 +201,7 @@ def _check_output(self, actual, expected):
f"found leaves of type: {bad_types}")

if not all(a.shape == b.shape for a, b in zip(actual_leaves, expected_leaves)):
shapes_tree = jax.tree_multimap(
shapes_tree = jax.tree_map(
lambda a, b: f"{a.shape} {'!=' if a.shape != b.shape else '=='} {b.shape}",
actual, expected)
raise TypeError(f"found leaves with unexpected shapes: {shapes_tree}")
4 changes: 2 additions & 2 deletions coax/_core/q.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,12 @@ def example_data(
# input: state observations
S = [safe_sample(env.observation_space, rnd) for _ in range(batch_size)]
S = [observation_preprocessor(next(rngs), s) for s in S]
S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S)
S = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *S)

# input: actions
A = [safe_sample(env.action_space, rnd) for _ in range(batch_size)]
A = [action_preprocessor(next(rngs), a) for a in A]
A = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *A)
A = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *A)

# output: type1
q1_data = ExampleData(
Expand Down
6 changes: 3 additions & 3 deletions coax/_core/transition_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,12 @@ def example_data(
# input: state observations
S = [safe_sample(env.observation_space, rnd) for _ in range(batch_size)]
S = [observation_preprocessor(next(rngs), s) for s in S]
S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S)
S = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *S)

# input: actions
A = [safe_sample(env.action_space, rnd) for _ in range(batch_size)]
A = [action_preprocessor(next(rngs), a) for a in A]
A = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *A)
A = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *A)

# output: type1
S_next_type1 = jax.tree_map(lambda x: jnp.asarray(rnd.randn(batch_size, *x.shape[1:])), S)
Expand Down Expand Up @@ -312,7 +312,7 @@ def _check_output(self, actual, expected):
f"found leaves of type: {bad_types}")

if not all(a.shape == b.shape for a, b in zip(actual_leaves, expected_leaves)):
shapes_tree = jax.tree_multimap(
shapes_tree = jax.tree_map(
lambda a, b: f"{a.shape} {'!=' if a.shape != b.shape else '=='} {b.shape}",
actual, expected)
raise TypeError(f"found leaves with unexpected shapes: {shapes_tree}")
2 changes: 1 addition & 1 deletion coax/_core/v.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def example_data(cls, env, observation_preprocessor=None, batch_size=1, random_s
# input: state observations
S = [safe_sample(env.observation_space, rnd) for _ in range(batch_size)]
S = [observation_preprocessor(next(rngs), s) for s in S]
S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S)
S = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *S)

return ExampleData(
inputs=Inputs(args=ArgsType2(S=S, is_training=True), static_argnums=(1,)),
Expand Down
2 changes: 1 addition & 1 deletion coax/experience_replay/_prioritized.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __iter__(self):


def _concatenate_leaves(pytrees):
return jax.tree_multimap(lambda *leaves: onp.concatenate(leaves, axis=0), *pytrees)
return jax.tree_map(lambda *leaves: onp.concatenate(leaves, axis=0), *pytrees)


@onp.vectorize
Expand Down
2 changes: 1 addition & 1 deletion coax/experience_replay/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def sample(self, batch_size=32):
random.setstate(self._random_state)
transitions = random.sample(self._storage, batch_size)
self._random_state = random.getstate()
return jax.tree_multimap(lambda *leaves: onp.concatenate(leaves, axis=0), *transitions)
return jax.tree_map(lambda *leaves: onp.concatenate(leaves, axis=0), *transitions)

def clear(self):
r""" Clear the experience replay buffer. """
Expand Down
6 changes: 6 additions & 0 deletions coax/proba_dists/_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ def preprocess_variate(self, rng, X):
if self._structure_type == StructureType.LEAF:
return self._structure.preprocess_variate(next(rngs), X)

if isinstance(self.space, (gym.spaces.MultiDiscrete, gym.spaces.MultiBinary)):
assert self._structure_type == StructureType.LIST
return [
dist.preprocess_variate(next(rngs), X[..., i])
for i, dist in enumerate(self._structure)]

if self._structure_type == StructureType.LIST:
return [
dist.preprocess_variate(next(rngs), X[i])
Expand Down
7 changes: 7 additions & 0 deletions coax/proba_dists/_composite_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,10 @@ def test_prepostprocess_variate(self):
self.assertNotIn(X_raw['multidiscrete'][0], space['multidiscrete'])
self.assertIn(X_clean['multidiscrete'][0], space['multidiscrete'])
self.assertIn(x_clean['multidiscrete'], space['multidiscrete'])
# Check if bijective.
X_clean_ = dist.postprocess_variate(
next(self.rngs), dist.preprocess_variate(next(self.rngs), X_clean), batch_mode=True)
x_clean_ = dist.postprocess_variate(
next(self.rngs), dist.preprocess_variate(next(self.rngs), x_clean), batch_mode=False)
self.assertPytreeAlmostEqual(X_clean_, X_clean)
self.assertPytreeAlmostEqual(x_clean_, x_clean)
2 changes: 1 addition & 1 deletion coax/regularizers/_nstep_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def f(s_next):
self.f.observation_preprocessor(
next(rngs), s_next), True)
n_states = transition_batch.extra_info['states']
dist_params, _ = jax.vmap(f)(jax.tree_util.tree_multimap(
dist_params, _ = jax.vmap(f)(jax.tree_util.tree_map(
lambda *t: jnp.stack(t), *n_states))
dist_params = jax.tree_util.tree_map(
lambda t: jnp.take(t, self._n, axis=0), dist_params)
Expand Down
2 changes: 1 addition & 1 deletion coax/reward_tracing/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,4 @@ def flush(self):
while self:
transitions.append(self.pop())

return jax.tree_multimap(lambda *leaves: onp.concatenate(leaves, axis=0), *transitions)
return jax.tree_map(lambda *leaves: onp.concatenate(leaves, axis=0), *transitions)
11 changes: 7 additions & 4 deletions coax/utils/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_leaves(a, b):
if jax.tree_structure(y) != jax.tree_structure(y0):
return False
try:
jax.tree_multimap(test_leaves, y, y0)
jax.tree_map(test_leaves, y, y0)
except AssertionError:
return False
return True
Expand Down Expand Up @@ -625,7 +625,7 @@ def get_transition_batch(env, batch_size=1, gamma=0.9, random_seed=None):
def batch_sample(space):
max_seed = onp.iinfo('int32').max
X = [safe_sample(space, seed=rnd.randint(max_seed)) for _ in range(batch_size)]
return jax.tree_multimap(lambda *leaves: onp.stack(leaves, axis=0), *X)
return jax.tree_map(lambda *leaves: onp.stack(leaves, axis=0), *X)

return TransitionBatch(
S=batch_sample(env.observation_space),
Expand Down Expand Up @@ -1059,16 +1059,19 @@ def _check_leaf_batch_size(pytree):

def stack_trees(*trees):
"""
Stack
Apply :func:`jnp.stack <jax.numpy.stack>` to the leaves of a pytree.
Parameters
----------
trees : sequence of pytrees with ndarray leaves
A typical example are pytrees containing the parameters and function states of
a model that should be used in a function which is vectorized by `jax.vmap`. The trees
have to have the same pytree structure.
Returns
-------
pytree : pytree with ndarray leaves
A tuple of pytrees.
"""
return jax.tree_util.tree_multimap(lambda *args: jnp.stack(args), *zip(*trees))
return jax.tree_util.tree_map(lambda *args: jnp.stack(args), *zip(*trees))
2 changes: 1 addition & 1 deletion coax/utils/_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_default_preprocessor(self):
self.assertArrayShape(default_preprocessor(dct)(next(rngs), dct.sample())['mds'][0], (1, 3))
self.assertArrayShape(default_preprocessor(dct)(next(rngs), dct.sample())['mds'][1], (1, 5))

mds_batch = jax.tree_multimap(lambda *x: jnp.stack(x), *(mds.sample() for _ in range(7)))
mds_batch = jax.tree_map(lambda *x: jnp.stack(x), *(mds.sample() for _ in range(7)))
self.assertArrayShape(default_preprocessor(mds)(next(rngs), mds_batch)[0], (7, 3))
self.assertArrayShape(default_preprocessor(mds)(next(rngs), mds_batch)[1], (7, 5))

Expand Down
Binary file modified doc/_intersphinx/haiku.inv
Binary file not shown.
Loading

0 comments on commit b6a3a48

Please sign in to comment.