From 9d1f714fafe44b5c64f1ad406e2b419cc2b758dc Mon Sep 17 00:00:00 2001 From: MaximeBouton Date: Tue, 19 Sep 2017 14:50:39 -0700 Subject: [PATCH 1/2] added requirements --- src/TabularTDLearning.jl | 2 +- src/q_learn.jl | 20 +++++++++++++++++--- src/sarsa.jl | 13 +++++++++++++ src/sarsa_lambda.jl | 13 +++++++++++++ test/runtests.jl | 8 +++++++- 5 files changed, 51 insertions(+), 5 deletions(-) diff --git a/src/TabularTDLearning.jl b/src/TabularTDLearning.jl index 7c88aca..44913e5 100644 --- a/src/TabularTDLearning.jl +++ b/src/TabularTDLearning.jl @@ -1,7 +1,7 @@ module TabularTDLearning using POMDPs -# using GenerativeModels + using POMDPToolbox import POMDPs: Solver, solve, Policy diff --git a/src/q_learn.jl b/src/q_learn.jl index df4a026..67daa11 100644 --- a/src/q_learn.jl +++ b/src/q_learn.jl @@ -1,3 +1,4 @@ +#TODO add doc mutable struct QLearningSolver <: Solver n_episodes::Int64 max_episode_length::Int64 @@ -18,12 +19,12 @@ mutable struct QLearningSolver <: Solver end end - function create_policy(solver::QLearningSolver, mdp::Union{MDP,POMDP}) return solver.exploration_policy.val end -function solve(solver::QLearningSolver, mdp::Union{MDP,POMDP}, policy=create_policy(solver, mdp)) +#TODO add verbose +function solve(solver::QLearningSolver, mdp::Union{MDP,POMDP}, policy=create_policy(solver, mdp); verbose=true) rng = solver.exploration_policy.uni.rng Q = solver.Q_vals exploration_policy = solver.exploration_policy @@ -45,8 +46,21 @@ function solve(solver::QLearningSolver, mdp::Union{MDP,POMDP}, policy=create_pol for traj in 1:solver.n_eval_traj r_tot += simulate(sim, mdp, policy, initial_state(mdp, rng)) end - println("On Iteration $i, Returns: $(r_tot/solver.n_eval_traj)") + verbose ? println("On Iteration $i, Returns: $(r_tot/solver.n_eval_traj)") : nothing end end return policy end + +@POMDP_require solve(solver::QLearningSolver, problem::Union{MDP,POMDP}) begin + P = typeof(problem) + S = state_type(P) + A = action_type(P) + @req initial_state(::P, ::AbstractRNG) + @req generate_sr(::P, ::S, ::A, ::AbstractRNG) + @req state_index(::P, ::S) + @req n_states(::P) + @req n_actions(::P) + @req action_index(::P, ::A) + @req discount(::P) +end diff --git a/src/sarsa.jl b/src/sarsa.jl index 6d4b7f6..cc6a991 100644 --- a/src/sarsa.jl +++ b/src/sarsa.jl @@ -52,3 +52,16 @@ function solve(solver::SARSASolver, mdp::Union{MDP,POMDP}, policy=create_policy( end return policy end + +@POMDP_require solve(solver::SARSASolver, problem::Union{MDP,POMDP}) begin + P = typeof(problem) + S = state_type(P) + A = action_type(P) + @req initial_state(::P, ::AbstractRNG) + @req generate_sr(::P, ::S, ::A, ::AbstractRNG) + @req state_index(::P, ::S) + @req n_states(::P) + @req n_actions(::P) + @req action_index(::P, ::A) + @req discount(::P) +end diff --git a/src/sarsa_lambda.jl b/src/sarsa_lambda.jl index 8d9bec3..6247410 100644 --- a/src/sarsa_lambda.jl +++ b/src/sarsa_lambda.jl @@ -66,3 +66,16 @@ function solve(solver::SARSALambdaSolver, mdp::Union{MDP,POMDP}, policy=create_p end return policy end + +@POMDP_require solve(solver::SARSALambdaSolver, problem::Union{MDP,POMDP}) begin + P = typeof(problem) + S = state_type(P) + A = action_type(P) + @req initial_state(::P, ::AbstractRNG) + @req generate_sr(::P, ::S, ::A, ::AbstractRNG) + @req state_index(::P, ::S) + @req n_states(::P) + @req n_actions(::P) + @req action_index(::P, ::A) + @req discount(::P) +end diff --git a/test/runtests.jl b/test/runtests.jl index bdcfa27..50e7b75 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,9 +1,16 @@ using TabularTDLearning +using POMDPs using POMDPModels using Base.Test mdp = GridWorld() + + + solver = QLearningSolver(mdp, learning_rate=0.1, n_episodes=5000, max_episode_length=50, eval_every=50, n_eval_traj=100) +println("Test QLearning requirements: ") +@requirements_info solver mdp + policy = solve(solver, mdp) @@ -13,4 +20,3 @@ policy = solve(solver, mdp) solver = SARSALambdaSolver(mdp, learning_rate=0.1, lambda=0.9, n_episodes=5000, max_episode_length=50, eval_every=50, n_eval_traj=100) policy = solve(solver, mdp) - From cd28c07bec39de46d909d97bcdc137fd95eb7db3 Mon Sep 17 00:00:00 2001 From: MaximeBouton Date: Tue, 19 Sep 2017 15:04:20 -0700 Subject: [PATCH 2/2] added doc --- src/q_learn.jl | 28 +++++++++++++++++++++++++++- src/sarsa.jl | 27 +++++++++++++++++++++++++++ src/sarsa_lambda.jl | 30 ++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 1 deletion(-) diff --git a/src/q_learn.jl b/src/q_learn.jl index 67daa11..e221be1 100644 --- a/src/q_learn.jl +++ b/src/q_learn.jl @@ -1,4 +1,30 @@ -#TODO add doc +""" + QLearningSolver + +Vanilla Q learning implementation for tabular MDPs + +Parameters: +- `mdp::Union{MDP, POMDP}`: + Your problem framed as an MDP or POMDP + (it will use the state and not the observation if the problem is a POMDP) +- `n_episodes::Int64`: + Number of episodes to train the Q table + default: `100` +- `max_episode_length::Int64`: + Maximum number of steps before the episode is ended + default: `100` +- `learning_rate::Float64` + Learning rate + defaul: `0.001` +- `exp_policy::Policy`: + Exploration policy to select the actions + default: `EpsGreedyPolicy(mdp, 0.5)` +- `eval_every::Int64`: + Frequency at which to evaluate the trained policy + default: `10` +- `n_eval_traj::Int64`: + Number of episodes to evaluate the policy +""" mutable struct QLearningSolver <: Solver n_episodes::Int64 max_episode_length::Int64 diff --git a/src/sarsa.jl b/src/sarsa.jl index cc6a991..82c4e25 100644 --- a/src/sarsa.jl +++ b/src/sarsa.jl @@ -1,3 +1,30 @@ +""" + SARSASolver + +SARSA implementation for tabular MDPs. + +Parameters: +- `mdp::Union{MDP, POMDP}`: + Your problem framed as an MDP or POMDP + (it will use the state and not the observation if the problem is a POMDP) +- `n_episodes::Int64`: + Number of episodes to train the Q table + default: `100` +- `max_episode_length::Int64`: + Maximum number of steps before the episode is ended + default: `100` +- `learning_rate::Float64` + Learning rate + defaul: `0.001` +- `exp_policy::Policy`: + Exploration policy to select the actions + default: `EpsGreedyPolicy(mdp, 0.5)` +- `eval_every::Int64`: + Frequency at which to evaluate the trained policy + default: `10` +- `n_eval_traj::Int64`: + Number of episodes to evaluate the policy +""" mutable struct SARSASolver <: Solver n_episodes::Int64 max_episode_length::Int64 diff --git a/src/sarsa_lambda.jl b/src/sarsa_lambda.jl index 6247410..6aced8e 100644 --- a/src/sarsa_lambda.jl +++ b/src/sarsa_lambda.jl @@ -1,3 +1,33 @@ +""" + SARSALambdaSolver + +SARSA-λ implementation for tabular MDPs, assign credits using eligibility traces + +Parameters: +- `mdp::Union{MDP, POMDP}`: + Your problem framed as an MDP or POMDP + (it will use the state and not the observation if the problem is a POMDP) +- `n_episodes::Int64`: + Number of episodes to train the Q table + default: `100` +- `max_episode_length::Int64`: + Maximum number of steps before the episode is ended + default: `100` +- `learning_rate::Float64`: + Learning rate + defaul: `0.001` +- `lambda::Float64`: + Exponential decay parameter for the eligibility traces + default: `0.5` +- `exp_policy::Policy`: + Exploration policy to select the actions + default: `EpsGreedyPolicy(mdp, 0.5)` +- `eval_every::Int64`: + Frequency at which to evaluate the trained policy + default: `10` +- `n_eval_traj::Int64`: + Number of episodes to evaluate the policy +""" mutable struct SARSALambdaSolver <: Solver n_episodes::Int64 max_episode_length::Int64