-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathenv.py
81 lines (66 loc) · 2.85 KB
/
env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
import gym
import collections
import cv2
class RepeatActionAndMaxFrame(gym.Wrapper):
def __init__(self, env=None, repeat=4):
super(RepeatActionAndMaxFrame, self).__init__(env)
self.repeat = repeat
self.shape = env.observation_space.low.shape
self.frame_buffer = np.zeros_like((2,self.shape), dtype=object)
def step(self, action):
t_reward = 0.0
done = False
for i in range(self.repeat):
obs, reward, done, info = self.env.step(action)
t_reward += reward
idx = i % 2
self.frame_buffer[idx] = obs
if done:
break
max_frame = np.maximum(self.frame_buffer[0], self.frame_buffer[1])
return max_frame, t_reward, done, info
def reset(self):
obs = self.env.reset()
self.frame_buffer = np.zeros_like((2,self.shape), dtype=object)
self.frame_buffer[0] = obs
return obs
class PreprocessFrame(gym.ObservationWrapper):
def __init__(self, shape, env=None):
super(PreprocessFrame, self).__init__(env)
self.shape=(shape[2], shape[0], shape[1])
self.observation_space = gym.spaces.Box(low=np.float32(0), high=np.float32(1.0),
shape=self.shape, dtype=np.float32)
def observation(self, obs):
new_frame = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
resized_screen = cv2.resize(new_frame, self.shape[1:],
interpolation=cv2.INTER_AREA)
new_obs = np.array(resized_screen, dtype=np.uint8).reshape(self.shape)
new_obs = np.swapaxes(new_obs, 2,0)
new_obs = new_obs / 255.0
return new_obs
class StackFrames(gym.ObservationWrapper):
def __init__(self, env, n_steps):
super(StackFrames, self).__init__(env)
self.observation_space = gym.spaces.Box(
np.float32(env.observation_space.low.repeat(n_steps, axis=0)),
np.float32(env.observation_space.high.repeat(n_steps, axis=0)),
dtype=np.float32)
self.stack = collections.deque(maxlen=n_steps)
def reset(self):
self.stack.clear()
observation = self.env.reset()
for _ in range(self.stack.maxlen):
self.stack.append(observation)
return np.array(self.stack).reshape(self.observation_space.low.shape)
def observation(self, observation):
self.stack.append(observation)
obs = np.array(self.stack).reshape(self.observation_space.low.shape)
return obs
def make_env(env_name, shape=(84,84,1), skip=4):
np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
env = gym.make(env_name)
env = RepeatActionAndMaxFrame(env, skip)
env = PreprocessFrame(shape, env)
env = StackFrames(env, skip)
return env