Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rescale observation wrapper #1940

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gym/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from gym.wrappers.filter_observation import FilterObservation
from gym.wrappers.atari_preprocessing import AtariPreprocessing
from gym.wrappers.rescale_action import RescaleAction
from gym.wrappers.rescale_observation import RescaleObservation
from gym.wrappers.flatten_observation import FlattenObservation
from gym.wrappers.gray_scale_observation import GrayScaleObservation
from gym.wrappers.frame_stack import LazyFrames
Expand Down
138 changes: 138 additions & 0 deletions gym/wrappers/rescale_observation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import numpy as np

import gym
from gym import spaces


def rescale_values(values, old_low, old_high, new_low, new_high):
rescaled_values = new_low + (new_high - new_low) * (
(values - old_low) / (old_high - old_low))
rescaled_values = np.clip(rescaled_values, new_low, new_high)
return rescaled_values


def verify_observation_space_type(observation_space):
if not isinstance(observation_space, spaces.Box):
raise TypeError("Expected Box observation space. Got: {}"
"".format(type(observation_space)))


def verify_observation_space_bounds(observation_space):
if np.any(~np.isfinite((
observation_space.low, observation_space.high))):
raise ValueError(
"Observation space 'low' and 'high' need to be finite."
" Got: observation_space.low={}, observation_space.high={}"
"".format(observation_space.low, observation_space.high))


def rescale_box_space(observation_space, low, high):
shape = observation_space.shape
dtype = observation_space.dtype

new_low = low + np.zeros(shape, dtype=dtype)
new_high = high + np.zeros(shape, dtype=dtype)

observation_space = spaces.Box(
low=new_low, high=new_high, shape=shape, dtype=dtype)

return observation_space


class RescaleObservation(gym.ObservationWrapper):
def __init__(self, env, low, high):
r"""Rescale observation space to a range [`low`, `high`].

For `Box` spaces, `low` and `high` can be either scalar or vector, and
will be broadcasted according to numpy broadcasting rules. For `Tuple`
and `Dict` spaces, both `low` and `high` are expected to be scalar.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you specify in the docstring that elements of Tuple or Dict space are expected to be Box?


Example:
>>> RescaleObservation(env, low, high).observation_space == Box(low, high)
True
Raises:
TypeError: If `not isinstance(environment.observation_space, (Box, Tuple, Dict))`.
ValueError: If either `low` or `high` is not finite.
ValueError: If any of `observation_space.{low,high}` is not finite.
ValueError: If `high <= low`.
ValueError: If observation space is `Tuple` or `Dict` and either
`low` or `high` is not scalar.
"""
if np.any([~np.isfinite(low), ~np.isfinite(high)]):
raise ValueError(
"Arguments 'low' and 'high' need to be finite."
" Got: low={}, high={}".format(low, high))

if np.any(high <= low):
raise ValueError("Argument `low` must be smaller than `high`"
" Got: low={}, high=".format(low, high))

if (isinstance(env.observation_space, (spaces.Tuple, spaces.Dict))
and not (np.isscalar(low) and np.isscalar(high))):
raise ValueError(
"Arguments 'low' and 'high' need to be scalars for {} spaces."
" Got: low={}, high={}".format(
type(env.observation_space), low, high))

super(RescaleObservation, self).__init__(env)

if isinstance(env.observation_space, spaces.Box):
verify_observation_space_type(env.observation_space)
verify_observation_space_bounds(env.observation_space)
self.observation_space = rescale_box_space(
env.observation_space, low, high)
elif isinstance(env.observation_space, spaces.Tuple):
for observation_space in env.observation_space.spaces:
verify_observation_space_type(observation_space)
verify_observation_space_bounds(observation_space)
self.observation_space = spaces.Tuple([
rescale_box_space(observation_space, low, high)
for observation_space
in env.observation_space.spaces
])
elif isinstance(env.observation_space, spaces.Dict):
for observation_space in env.observation_space.spaces.values():
verify_observation_space_type(observation_space)
verify_observation_space_bounds(observation_space)
self.observation_space = spaces.Dict({
name: rescale_box_space(observation_space, low, high)
for name, observation_space
in env.observation_space.spaces.items()
})
else:
raise TypeError("Unsupported observation space type: {}"
"".format(type(env.observation_space)))

def observation(self, observation):
if isinstance(self.observation_space, spaces.Box):
rescaled_observation = rescale_values(
observation,
old_low=self.env.observation_space.low,
old_high=self.env.observation_space.high,
new_low=self.observation_space.low,
new_high=self.observation_space.high)
elif isinstance(self.observation_space, spaces.Tuple):
rescaled_observation = type(observation)((
rescale_values(
value,
old_low=self.env.observation_space[i].low,
old_high=self.env.observation_space[i].high,
new_low=self.observation_space[i].low,
new_high=self.observation_space[i].high)
for i, value in enumerate(observation)
))
elif isinstance(self.observation_space, spaces.Dict):
rescaled_observation = type(observation)((
(key, rescale_values(
value,
old_low=self.env.observation_space[key].low,
old_high=self.env.observation_space[key].high,
new_low=self.observation_space[key].low,
new_high=self.observation_space[key].high))
for key, value in observation.items()
))
else:
raise TypeError("Unsupported observation space type: {}"
"".format(type(self.env.observation_space)))

return rescaled_observation
251 changes: 251 additions & 0 deletions gym/wrappers/test_rescale_observation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
import pytest

import numpy as np

import gym
from gym import spaces
from gym.wrappers import RescaleObservation


UNSCALED_BOX_SPACE = spaces.Box(
shape=(2, ),
low=np.array((-1.2, -0.07)),
high=np.array((0.6, 0.07)),
dtype=np.float32)


class FakeEnvironment(gym.Env):
def __init__(self, observation_space):
"""Fake environment whose observation equals broadcasted action."""
self.observation_space = observation_space
self.action_space = self.observation_space

def reset(self):
observation = self.observation_space.sample()
return observation

def step(self, action):
observation = action
reward, terminal, info = 0.0, False, {}
return observation, reward, terminal, info


@pytest.mark.parametrize("observation_space", [
UNSCALED_BOX_SPACE,
spaces.Tuple((UNSCALED_BOX_SPACE, UNSCALED_BOX_SPACE)),
spaces.Dict({'box-1': UNSCALED_BOX_SPACE, 'box-2': UNSCALED_BOX_SPACE}),
])
def test_rescale_observation(observation_space):
new_low, new_high = -1.0, 1.0
env = FakeEnvironment(observation_space)
wrapped_env = RescaleObservation(env, new_low, new_high)

def verify_space_bounds(observation_space):
np.testing.assert_allclose(observation_space.low, new_low)
np.testing.assert_allclose(observation_space.high, new_high)

if isinstance(wrapped_env.observation_space, spaces.Box):
verify_space_bounds(wrapped_env.observation_space)
elif isinstance(wrapped_env.observation_space, spaces.Tuple):
for observation_space in wrapped_env.observation_space.spaces:
verify_space_bounds(observation_space)
elif isinstance(wrapped_env.observation_space, spaces.Dict):
for observation_space in wrapped_env.observation_space.spaces.values():
verify_space_bounds(observation_space)
else:
raise ValueError

seed = 0
env.seed(seed)
wrapped_env.seed(seed)

env.reset()
wrapped_env.reset()

if isinstance(wrapped_env.observation_space, spaces.Box):
action = env.observation_space.low
low_observation = env.step(action)[0]
wrapped_low_observation = wrapped_env.step(action)[0]

assert np.allclose(low_observation, env.observation_space.low)
assert np.allclose(
wrapped_low_observation, wrapped_env.observation_space.low)

high_observation = env.step(env.observation_space.high)[0]
wrapped_high_observation = wrapped_env.step(env.observation_space.high)[0]

assert np.allclose(high_observation, env.observation_space.high)
assert np.allclose(
wrapped_high_observation, wrapped_env.observation_space.high)

elif isinstance(wrapped_env.observation_space, spaces.Tuple):
low_action = type(env.observation_space.spaces)(
observation_space.low
for observation_space in env.observation_space.spaces)

low_observation = env.step(low_action)[0]
wrapped_low_observation = wrapped_env.step(low_action)[0]

assert np.allclose(
low_observation,
[o.low for o in env.observation_space.spaces])
assert np.allclose(
wrapped_low_observation,
[o.low for o in wrapped_env.observation_space.spaces])

high_action = type(env.observation_space.spaces)(
observation_space.high
for observation_space in env.observation_space.spaces)

high_observation = env.step(high_action)[0]
wrapped_high_observation = wrapped_env.step(high_action)[0]

assert np.allclose(
high_observation,
[o.high for o in env.observation_space.spaces])
assert np.allclose(
wrapped_high_observation,
[o.high for o in wrapped_env.observation_space.spaces])

elif isinstance(wrapped_env.observation_space, spaces.Dict):
low_action = type(env.observation_space.spaces)(
(key, observation_space.low)
for key, observation_space in env.observation_space.spaces.items())

low_observation = env.step(low_action)[0]
wrapped_low_observation = wrapped_env.step(low_action)[0]

assert (set(env.observation_space.spaces.keys())
== set(low_observation.keys()))
assert (set(wrapped_env.observation_space.spaces.keys())
== set(low_observation.keys()))
for key in env.observation_space.spaces.keys():
np.testing.assert_allclose(
low_observation[key], env.observation_space[key].low)
np.testing.assert_allclose(
wrapped_low_observation[key],
wrapped_env.observation_space[key].low)

high_action = type(env.observation_space.spaces)(
(key, observation_space.high)
for key, observation_space in env.observation_space.spaces.items())

high_observation = env.step(high_action)[0]
wrapped_high_observation = wrapped_env.step(high_action)[0]

assert (set(env.observation_space.spaces.keys())
== set(high_observation.keys()))
assert (set(wrapped_env.observation_space.spaces.keys())
== set(high_observation.keys()))
for key in env.observation_space.spaces.keys():
np.testing.assert_allclose(
high_observation[key], env.observation_space[key].high)
np.testing.assert_allclose(
wrapped_high_observation[key],
wrapped_env.observation_space[key].high)

else:
raise ValueError


@pytest.mark.parametrize("observation_space", [
spaces.Tuple((UNSCALED_BOX_SPACE, UNSCALED_BOX_SPACE)),
spaces.Dict({'box-1': UNSCALED_BOX_SPACE, 'box-2': UNSCALED_BOX_SPACE}),
])
def test_raises_non_scalar_low_high(observation_space):
env = FakeEnvironment(observation_space)
assert isinstance(
env.observation_space, (spaces.Box, spaces.Tuple, spaces.Dict))

with pytest.raises(ValueError):
RescaleObservation(env, -1.0, np.array([1.0, 1.0]))

with pytest.raises(ValueError):
RescaleObservation(env, np.array([-1.0, -1.0]), 1.0)


@pytest.mark.parametrize("observation_space", [
UNSCALED_BOX_SPACE,
spaces.Tuple((UNSCALED_BOX_SPACE, UNSCALED_BOX_SPACE)),
spaces.Dict({'box-1': UNSCALED_BOX_SPACE, 'box-2': UNSCALED_BOX_SPACE}),
])
def test_raises_on_non_finite_low(observation_space):
env = FakeEnvironment(observation_space)
assert isinstance(
env.observation_space, (spaces.Box, spaces.Tuple, spaces.Dict))

with pytest.raises(ValueError):
RescaleObservation(env, -float('inf'), 1.0)

with pytest.raises(ValueError):
RescaleObservation(env, -1.0, float('inf'))

with pytest.raises(ValueError):
RescaleObservation(env, -1.0, np.nan)


@pytest.mark.parametrize("observation_space", [
UNSCALED_BOX_SPACE,
spaces.Tuple((UNSCALED_BOX_SPACE, UNSCALED_BOX_SPACE)),
spaces.Dict({'box-1': UNSCALED_BOX_SPACE, 'box-2': UNSCALED_BOX_SPACE}),
])
def test_raises_on_high_less_than_low(observation_space):
env = FakeEnvironment(observation_space)
assert isinstance(
env.observation_space, (spaces.Box, spaces.Tuple, spaces.Dict))
with pytest.raises(ValueError):
RescaleObservation(env, 1.0, 1.0)
with pytest.raises(ValueError):
RescaleObservation(env, 1.0, -1.0)


@pytest.mark.parametrize("observation_space", [
UNSCALED_BOX_SPACE,
spaces.Tuple((UNSCALED_BOX_SPACE, UNSCALED_BOX_SPACE)),
spaces.Dict({'box-1': UNSCALED_BOX_SPACE, 'box-2': UNSCALED_BOX_SPACE}),
])
def test_raises_on_high_equals_low(observation_space):
env = FakeEnvironment(observation_space)
assert isinstance(
env.observation_space, (spaces.Box, spaces.Tuple, spaces.Dict))
with pytest.raises(ValueError):
RescaleObservation(env, 1.0, 1.0)


@pytest.mark.parametrize("observation_space", [
spaces.Discrete(10),
spaces.Tuple((spaces.Discrete(5), spaces.Discrete(10))),
spaces.Tuple((
spaces.Discrete(5),
spaces.Box(low=np.array((0.0, 0.0)), high=np.array((1.0, 1.0))))),
spaces.Dict({
'discrete-5': spaces.Discrete(5),
'discrete-10': spaces.Discrete(10),
}),
spaces.Dict({
'discrete': spaces.Discrete(5),
'box': spaces.Box(low=np.array((0.0, 0.0)), high=np.array((1.0, 1.0))),
}),
])
def test_raises_on_non_box_space(observation_space):
env = FakeEnvironment(observation_space)
with pytest.raises(TypeError):
RescaleObservation(env, -1.0, 1.0)


@pytest.mark.parametrize("observation_space", [
spaces.Box(low=np.array((0.0, 0.0)), high=np.array((1.0, float('inf')))),
spaces.Box(low=np.array((0.0, -float('inf'))), high=np.array((1.0, 1.0))),
spaces.Tuple((
spaces.Box(low=np.array((0.0, -1.0)), high=np.array((1.0, 1.0))),
spaces.Box(low=np.array((0.0, -1.0)), high=np.array((1.0, float('inf')))),
)),
spaces.Dict({
'box-1': spaces.Box(low=np.array((0.0, -1.0)), high=np.array((1.0, 1.0))),
'box-2': spaces.Box(low=np.array((0.0, -float('inf'))), high=np.array((1.0, 1.0))),
}),
])
def test_raises_on_non_finite_space(observation_space):
env = FakeEnvironment(observation_space)
with pytest.raises(ValueError):
RescaleObservation(env, -1.0, 1.0)