Skip to content

Commit

Permalink
Support for stateful policies (HumanCompatibleAI#709)
Browse files Browse the repository at this point in the history
* Try adding state and mask to policies

* fix test

* fix test

* fix types

* dead comment

* add test

* pytype

* Update tutorial to have longer horizon

* Throw for exploration wrapper for stateful policies

* add test for throwing

* comments

* comments and lint

* comments

* comment
  • Loading branch information
timbauman authored May 11, 2023
1 parent ebdb42b commit 8d1bc3f
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 24 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/6_train_mce.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"from imitation.data import rollout\n",
"from imitation.rewards import reward_nets\n",
"\n",
"env_creator = partial(CliffWorldEnv, height=4, horizon=8, width=7, use_xy_obs=True)\n",
"env_creator = partial(CliffWorldEnv, height=4, horizon=40, width=7, use_xy_obs=True)\n",
"env_single = env_creator()\n",
"\n",
"state_env_creator = lambda: base_envs.ExposePOMDPStateWrapper(env_creator())\n",
Expand Down
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[mypy]
ignore_missing_imports = true
exclude = output
41 changes: 30 additions & 11 deletions src/imitation/data/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Mapping,
Optional,
Sequence,
Tuple,
Union,
)

Expand Down Expand Up @@ -262,9 +263,13 @@ def sample_until(trajs: Sequence[types.TrajectoryWithRew]) -> bool:
return sample_until


# A PolicyCallable is a function that takes an array of observations
# and returns an array of corresponding actions.
PolicyCallable = Callable[[np.ndarray], np.ndarray]
# A PolicyCallable is a function that takes an array of observations, an optional
# array of states, and an optional array of episode starts and returns an array of
# corresponding actions.
PolicyCallable = Callable[
[np.ndarray, Optional[Tuple[np.ndarray, ...]], Optional[np.ndarray]],
Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]],
]
AnyPolicy = Union[BaseAlgorithm, BasePolicy, PolicyCallable, None]


Expand All @@ -274,26 +279,38 @@ def policy_to_callable(
deterministic_policy: bool = False,
) -> PolicyCallable:
"""Converts any policy-like object into a function from observations to actions."""
get_actions: PolicyCallable
if policy is None:

def get_actions(states):
acts = [venv.action_space.sample() for _ in range(len(states))]
return np.stack(acts, axis=0)
def get_actions(
observations: np.ndarray,
states: Optional[Tuple[np.ndarray, ...]],
episode_starts: Optional[np.ndarray],
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
acts = [venv.action_space.sample() for _ in range(len(observations))]
return np.stack(acts, axis=0), None

elif isinstance(policy, (BaseAlgorithm, BasePolicy)):
# There's an important subtlety here: BaseAlgorithm and BasePolicy
# are themselves Callable (which we check next). But in their case,
# we want to use the .predict() method, rather than __call__()
# (which would call .forward()). So this elif clause must come first!

def get_actions(states):
def get_actions(
observations: np.ndarray,
states: Optional[Tuple[np.ndarray, ...]],
episode_starts: Optional[np.ndarray],
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
assert isinstance(policy, (BaseAlgorithm, BasePolicy))
# pytype doesn't seem to understand that policy is a BaseAlgorithm
# or BasePolicy here, rather than a Callable
acts, _ = policy.predict( # pytype: disable=attribute-error
states,
(acts, states) = policy.predict( # pytype: disable=attribute-error
observations,
state=states,
episode_start=episode_starts,
deterministic=deterministic_policy,
)
return acts
return acts, states

elif callable(policy):
# When a policy callable is passed, by default we will use it directly.
Expand Down Expand Up @@ -402,8 +419,10 @@ def generate_trajectories(
# To start with, all environments are active.
active = np.ones(venv.num_envs, dtype=bool)
assert isinstance(obs, np.ndarray), "Dict/tuple observations are not supported."
state = None
dones = np.zeros(venv.num_envs, dtype=bool)
while np.any(active):
acts = get_actions(obs)
acts, state = get_actions(obs, state, dones)
obs, rews, dones, infos = venv.step(acts)
assert isinstance(obs, np.ndarray)

Expand Down
34 changes: 29 additions & 5 deletions src/imitation/policies/exploration_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Wrapper to turn a policy into a more exploratory version."""

from typing import Optional, Tuple

import numpy as np
from stable_baselines3.common import vec_env

Expand Down Expand Up @@ -53,9 +55,15 @@ def __init__(
# Choose the initial policy at random
self._switch()

def _random_policy(self, obs: np.ndarray) -> np.ndarray:
def _random_policy(
self,
obs: np.ndarray,
state: Optional[Tuple[np.ndarray, ...]],
episode_start: Optional[np.ndarray],
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
del state, episode_start # Unused
acts = [self.venv.action_space.sample() for _ in range(len(obs))]
return np.stack(acts, axis=0)
return np.stack(acts, axis=0), None

def _switch(self) -> None:
"""Pick a new policy at random."""
Expand All @@ -64,8 +72,24 @@ def _switch(self) -> None:
else:
self.current_policy = self.wrapped_policy

def __call__(self, obs: np.ndarray) -> np.ndarray:
acts = self.current_policy(obs)
def __call__(
self,
observation: np.ndarray,
input_state: Optional[Tuple[np.ndarray, ...]],
episode_start: Optional[np.ndarray],
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
del episode_start # Unused

if input_state is not None:
# This checks that we aren't passed a state
raise ValueError("Exploration wrapper does not support stateful policies.")

acts, output_state = self.current_policy(observation, None, None)

if output_state is not None:
# This checks that the policy doesn't return a state
raise ValueError("Exploration wrapper does not support stateful policies.")

if self.rng.random() < self.switch_prob:
self._switch()
return acts
return acts, None
40 changes: 40 additions & 0 deletions tests/algorithms/test_mce_irl.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,46 @@ def test_tabular_policy(rng):
np.testing.assert_equal(timesteps[0], 2 - mask.astype(int))


def test_tabular_policy_rollouts(rng):
"""Tests that rolling out a tabular policy that varies at each timestep works."""
state_space = gym.spaces.Discrete(5)
action_space = gym.spaces.Discrete(3)
mdp = ReasonablePOMDP()
state_env = base_envs.ExposePOMDPStateWrapper(mdp)
state_venv = vec_env.DummyVecEnv([lambda: state_env])

# alternate actions every step
subpolicy = np.stack([np.eye(action_space.n)] * state_space.n, axis=1)

# repeat 7 times for a total of 21 (greater than 20)
pi = np.repeat(
subpolicy,
((mdp.horizon + action_space.n - 1) // action_space.n),
axis=0,
)

tabular = TabularPolicy(
state_space=state_space,
action_space=action_space,
pi=pi,
rng=rng,
)

trajs = rollout.generate_trajectories(
tabular,
state_venv,
sample_until=rollout.make_min_episodes(1),
rng=rng,
)

# pi[t,s,a] is the same for every state, so drop that dimension
exposed_actions_onehot = pi[:, 0, :]
exposed_actions = exposed_actions_onehot.nonzero()[1]

# check that the trajectory chooses the same actions as the policy
assert (trajs[0].acts == exposed_actions[: len(trajs[0].acts)]).all()


def test_tabular_policy_randomness(rng):
state_space = gym.spaces.Discrete(2)
action_space = gym.spaces.Discrete(2)
Expand Down
4 changes: 2 additions & 2 deletions tests/data/test_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def _sample_fixed_length_trajectories(

# Simple way to get a valid callable: just use a policies .predict() method
# (still tests another code path inside generate_trajectories)
def policy(x):
return random_policy.predict(x)[0]
def policy(x, state, mask):
return random_policy.predict(x, state, mask)

elif policy_type == "random":
policy = None
Expand Down
46 changes: 41 additions & 5 deletions tests/policies/test_exploration_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
"""Tests ExplorationWrapper."""

import numpy as np
import pytest
import seals # noqa: F401

from imitation.policies import exploration_wrapper
from imitation.util import util


def constant_policy(obs):
return np.zeros(len(obs), dtype=int)
def constant_policy(obs, state, mask):
del state, mask # Unused
return np.zeros(len(obs), dtype=int), None


def fake_stateful_policy(obs, state, mask):
del state, mask # Unused
return np.zeros(len(obs), dtype=int), (np.zeros(1),)


def make_wrapper(random_prob, switch_prob, rng):
Expand Down Expand Up @@ -92,7 +99,7 @@ def test_switch_prob(rng):
policy = wrapper.current_policy

obs = np.random.rand(100, 2)
for action in wrapper(obs):
for action in wrapper(obs, None, None)[0]:
assert venv.action_space.contains(action)
assert wrapper.current_policy == policy

Expand All @@ -102,7 +109,7 @@ def _always_switch(random_prob, num_steps):
num_constant = 0
for _ in range(num_steps):
obs = np.random.rand(1, 2)
wrapper(obs)
wrapper(obs, None, None)
if wrapper.current_policy == wrapper._random_policy:
num_random += 1
elif wrapper.current_policy == constant_policy:
Expand Down Expand Up @@ -137,5 +144,34 @@ def test_valid_output(rng):
wrapper, venv = make_wrapper(random_prob=random_prob, switch_prob=0.5, rng=rng)
np.random.seed(0)
obs = np.random.rand(100, 2)
for action in wrapper(obs):
for action in wrapper(obs, None, None)[0]:
assert venv.action_space.contains(action)


def test_throws_for_stateful_policy(rng):
venv = util.make_vec_env(
"seals/CartPole-v0",
n_envs=1,
rng=rng,
)
wrapper = exploration_wrapper.ExplorationWrapper(
policy=fake_stateful_policy,
venv=venv,
random_prob=0,
switch_prob=0,
rng=rng,
)

np.random.seed(0)
obs = np.random.rand(100, 2)
with pytest.raises(
ValueError,
match="Exploration wrapper does not support stateful policies.",
):
wrapper(obs, (np.ones_like(obs)), None)

with pytest.raises(
ValueError,
match="Exploration wrapper does not support stateful policies.",
):
wrapper(obs, None, None)

0 comments on commit 8d1bc3f

Please sign in to comment.