-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdynamic_programming_policy_iteration_frozen_lake.py
98 lines (80 loc) · 3.32 KB
/
dynamic_programming_policy_iteration_frozen_lake.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gymnasium as gym
from itertools import count
from typing import Optional
import matplotlib.pyplot as plt
import matplotlib.animation as anim
def update_scene(num, frames, patch):
patch.set_data(frames[num])
return patch,
def plot_animation(frames:list, save_path:str, title:Optional[str]=None, repeat=False, interval=500):
fig = plt.figure()
patch = plt.imshow(frames[0])
plt.axis('off')
if title is None:
title = save_path
plt.title(title, fontsize=16)
animation = anim.FuncAnimation(
fig, update_scene, fargs=(frames, patch),
frames=len(frames), repeat=repeat, interval=interval)
animation.save(save_path, writer="ffmpeg", fps=20)
return animation
def init_v_pi_vals(num_states:int):
V = {state_val:0.0 for state_val in range(num_states)}
pi = {state_val:0 for state_val in range(num_states)}
return V, pi # V(s) = v # pi(s) = a # init values arbitrarily given state s
def policy_evaluation(V:dict[int, float], pi:dict[int, int], env:gym.Env, gamma:float, theta:float):
while True:
delta = 0
for state in range(env.observation_space.n):
v = V[state]
V[state] = sum([p*(r + gamma*V[s_]) for p, s_, r, _ in env.unwrapped.P[state][pi[state]]])
delta = max(delta, abs(v - V[state]))
if delta < theta:
break
def policy_improvement(V:dict[int, float], pi:dict[int, int], env:gym.Env, gamma:float):
policy_stable = True
for state in range(env.observation_space.n):
old_action = pi[state]
# find the largest element which maximizes the returned value from the below lambda function
pi[state] = max(
range(env.action_space.n),
key=lambda action: sum([p*(r + gamma*V[s_]) for p, s_, r, _ in env.unwrapped.P[state][action]])
)
if old_action != pi[state]:
policy_stable = False
return policy_stable
def policy_iteration(env:gym.Env, gamma:float, theta:float):
V, pi = init_v_pi_vals(env.observation_space.n)
while True:
policy_evaluation(V, pi, env, gamma, theta)
if policy_improvement(V, pi, env, gamma):
break
return V, pi
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--is_slippery", type=str)
args = parser.parse_args()
SLIPPERY = True if args.is_slippery.lower()=="y" else False
env = gym.make("FrozenLake-v1", is_slippery=SLIPPERY)
print(f"Environment is {'slippery' if SLIPPERY else 'not slippery'}")
v_vals, pi_vals = policy_iteration(env, gamma=0.9, theta=1e-8)
env.close()
del env
# see the optimal policy
env = gym.make("FrozenLake-v1", is_slippery=SLIPPERY, render_mode="rgb_array")
state, info = env.reset()
frames = [env.render()]
for i in count():
action = pi_vals[state]
state, reward, done, truncated, info = env.step(action)
print("|| rewards:", reward, "||")
frames.append(env.render())
if done or truncated:
break
plot_animation(
frames, save_path=f"images/frozen{'_slippery' if SLIPPERY else ''}_lake_policy_iteration.gif",
title=f"Policy Iteration on Frozen {'STOCHASTIC' if SLIPPERY else 'DETERMINISTIC'} Lake Environment", repeat=False, interval=2000
)
print("Done!")
env.close()