-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathenv_wrappers.py
107 lines (92 loc) · 3.09 KB
/
env_wrappers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import time
import numpy as np
import torch
from gym import spaces
from inspect import getargspec
class GymWrapper(object):
'''
for multi-agent
'''
def __init__(self, env):
self.env = env
@property
def observation_dim(self):
'''
for multi-agent, this is the obs per agent
'''
# tuple space
if hasattr(self.env.observation_space, 'spaces'):
total_obs_dim = 0
for space in self.env.observation_space.spaces:
if hasattr(self.env.action_space, 'shape'):
total_obs_dim += int(np.prod(space.shape))
else: # Discrete
total_obs_dim += 1
return total_obs_dim
else:
return int(np.prod(self.env.observation_space.shape))
@property
def num_actions(self):
if hasattr(self.env.action_space, 'nvec'):
# MultiDiscrete
return int(self.env.action_space.nvec[0])
elif hasattr(self.env.action_space, 'n'):
# Discrete
return self.env.action_space.n
@property
def dim_actions(self):
# for multi-agent, this is the number of action per agent
if hasattr(self.env.action_space, 'nvec'):
# MultiDiscrete
return self.env.action_space.shape[0]
# return len(self.env.action_space.shape)
elif hasattr(self.env.action_space, 'n'):
# Discrete => only 1 action takes place at a time.
return 1
@property
def action_space(self):
return self.env.action_space
def reset(self, epoch):
reset_args = getargspec(self.env.reset).args
if 'epoch' in reset_args:
obs = self.env.reset(epoch)
else:
obs = self.env.reset()
obs = self._flatten_obs(obs)
return obs
def display(self):
self.env.render()
time.sleep(0.5)
def end_display(self):
self.env.exit_render()
def step(self, action):
# TODO: Modify all environments to take list of action
# instead of doing this
if self.dim_actions == 1:
action = action[0]
obs, r, done, info = self.env.step(action)
obs = self._flatten_obs(obs)
return (obs, r, done, info)
def reward_terminal(self):
if hasattr(self.env, 'reward_terminal'):
return self.env.reward_terminal()
else:
return np.zeros(1)
def _flatten_obs(self, obs):
if isinstance(obs, tuple):
_obs=[]
for agent in obs: #list/tuple of observations.
ag_obs = []
for obs_kind in agent:
ag_obs.append(np.array(obs_kind).flatten())
_obs.append(np.concatenate(ag_obs))
obs = np.stack(_obs)
obs = obs.reshape(1, -1, self.observation_dim)
obs = torch.from_numpy(obs).double()
return obs
def get_stat(self):
if hasattr(self.env, 'stat'):
self.env.stat.pop('steps_taken', None)
return self.env.stat
else:
return dict()