-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathqdax_brax_task.py
143 lines (113 loc) · 4.44 KB
/
qdax_brax_task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import jax
import brax.v1.envs
from qdax.core.neuroevolution.buffers import buffer
import qdax.environments
import qdax.tasks.brax_envs
from qdax.types import (
Genotype, Observation, Action, Fitness, Descriptor, StateDescriptor, ExtraScores
)
import numpy as np
import gymnasium
import logging
from functools import partial
from collections.abc import Callable
from typing import Optional, TypeAlias, Any, cast
from overrides import override
from .rl_task import RLTask, scoring_function_jnp_envs
from ..config.task import QDaxBraxConfig
from ..utils import RNGKey, assert_cast, jax_jit
_log = logging.getLogger(__name__)
QDaxBraxEnvState: TypeAlias = brax.v1.envs.State
class QDaxBraxTask(RLTask[QDaxBraxConfig, qdax.environments.QDEnv, QDaxBraxEnvState]):
def __init__(self, cfg: QDaxBraxConfig, batch_shape: tuple[int, ...]) -> None:
super().__init__(cfg, batch_shape)
env = qdax.environments.create(
self._cfg.name,
episode_length=self._cfg.episode_len,
)
self._env = cast(qdax.environments.QDEnv, env)
@property
@override
def behavior_descriptor_length(self) -> int:
return self.env.behavior_descriptor_length
@property
@override
def behavior_descriptor_limits(self) -> tuple[list[float], list[float]]:
return self.env.behavior_descriptor_limits
@property
@override
def obs_space(self) -> gymnasium.spaces.Box:
return gymnasium.spaces.Box(-np.inf, np.inf, shape=(self.env.observation_size,))
@property
@override
def action_space(self) -> gymnasium.spaces.Box:
return gymnasium.spaces.Box(-np.inf, np.inf, shape=(self.env.action_size,))
@property
@override
def state_descriptor_length(self) -> int:
return self.env.state_descriptor_length
@property
@override
def qd_offset(self) -> float:
return qdax.environments.reward_offset[self._cfg.name] * self._cfg.episode_len
@override
def obs(self, state: QDaxBraxEnvState) -> Observation:
return assert_cast(Observation, state.obs)
@override
def reset(
self, random_key: RNGKey, extra: Optional[jax.Array] = None
) -> QDaxBraxEnvState:
return self.env.reset(random_key)
@override
def step(self, state: QDaxBraxEnvState, action: Action) -> QDaxBraxEnvState:
return self.env.step(state, action)
@override
def get_constant(self, random_key: RNGKey) -> QDaxBraxEnvState:
init_state = jax_jit(self.reset)(random_key)
return init_state
@override
def extract_behavior_descriptor(
self, transition: buffer.QDTransition, mask: jax.Array
) -> Descriptor:
return qdax.environments.behavior_descriptor_extractor[self._cfg.name](transition, mask)
@override
def get_scoring_fn(self) -> Callable[
[Genotype, Genotype, RNGKey, Any],
tuple[Fitness, Descriptor, ExtraScores, RNGKey],
]:
def play_step_fn(
state: QDaxBraxEnvState,
representation_params: Genotype,
decision_params: Genotype,
random_key: RNGKey,
rand: jax.Array,
):
assert self._select_action_fn is not None
_log.debug('obs.shape = %s', state.obs.shape)
actions = self._select_action_fn(
representation_params, decision_params, assert_cast(Observation, state.obs)
)
_log.debug('actions.shape = %s', actions.shape)
state_desc: StateDescriptor = state.info['state_descriptor']
next_state = self.step(state, actions)
transition = buffer.QDTransition(
obs=state.obs,
next_obs=next_state.obs,
rewards=next_state.reward,
dones=next_state.done,
actions=actions,
truncations=next_state.info['truncation'],
state_desc=state_desc,
next_state_desc=next_state.info['state_descriptor'],
)
transition, rand = self.reduce_transitions(transition, rand)
return next_state, representation_params, decision_params, random_key, rand, transition
scoring_fn = partial(
scoring_function_jnp_envs,
episode_length=self._cfg.episode_len,
play_step_fn=play_step_fn,
behavior_descriptor_extractor=self.extract_behavior_descriptor,
task=self,
map_states=False,
)
return scoring_fn