-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmin_example_actor.py
119 lines (90 loc) · 2.95 KB
/
min_example_actor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from actor import Actor
from actor import DummyCont
from actor import DummyDiscrete
import gym
seed = 3141
torch.manual_seed(seed)
num_hidden = 10
print('Testing continuous actor on MountainCarContinuous-v0!')
# init env and policy
env = gym.make('MountainCarContinuous-v0')
env.seed(seed)
policy = DummyCont(env, num_hidden)
obs_list = []
action_list = []
# reset env
done = False
obs = env.reset()
obs_list.append(obs)
obs = torch.from_numpy(obs).reshape(1,-1).float() # make sure obs has batch dimension and is cast to float32 instead of float64
# check if we can batch process states
print('Computing dists for three sets of params')
params = policy.forward(torch.cat((obs,obs,obs), dim=0))
dist = policy.get_dist(params)
print('batch shape of dist is ',dist.batch_shape)
# sample trajectory
print('Sampling a trajectory')
while not done:
# sample action
action = policy.sample_action(obs)
# take step
obs, rew, done, _ = env.step(action)
action_list.append(action.detach().numpy())
# don't include last state
if not done:
obs_list.append(obs)
obs = torch.from_numpy(obs).reshape(1, -1).float()
print('\nTesting log_probs')
# convert lists to tensor
obs = torch.Tensor(obs_list)
actions = torch.Tensor(action_list)
# get logprobs
log_probs = policy.get_log_probs(obs, actions)
print('\nLogprobs are tensor of shape {}.'.format(log_probs.shape))
#print(log_probs)
print('\nTesting KL divergence')
kl = policy.get_kl(obs)
print('Expecting KL of 0, got {}'.format(kl))
print('\n\n####################\n####################\n\n')
print('Testing discrete control on MountainCar-v0')
#TODO: Implement this
# init env and policy
env = gym.make('MountainCar-v0')
env.seed(seed)
policy = DummyDiscrete(env, num_hidden)
obs_list = []
action_list = []
# reset env
done = False
obs = env.reset()
obs_list.append(obs)
obs = torch.from_numpy(obs).reshape(1,-1).float() # make sure obs has batch dimension and is cast to float32 instead of float64
# check if we can batch process states
print('Computing dists for two sets of params')
params = policy.forward(torch.cat((obs,obs), dim=0))
dist = policy.get_dist(params)
print('batch shape of dist is ',dist.batch_shape)
# sample trajectory
print('Sampling a trajectory')
while not done:
# sample action
action = policy.sample_action(obs).item()
# take step
obs, rew, done, _ = env.step(action)
action_list.append(action)
# don't include last state
if not done:
obs_list.append(obs)
obs = torch.from_numpy(obs).reshape(1, -1).float()
print('\nTesting log_probs')
# convert lists to tensor
obs = torch.Tensor(obs_list)
actions = torch.Tensor(action_list)
# get logprobs
log_probs = policy.get_log_probs(obs, actions)
print('Logprobs are tensor of shape {}.'.format(log_probs.shape))
#print(log_probs)
print('\nTesting KL divergence')
kl = policy.get_kl(obs)
print('Expecting KL of 0, got {}'.format(kl))