Skip to content

Commit

Permalink
Merge pull request #5 from JuliaPOMDP/requirements
Browse files Browse the repository at this point in the history
Added requirements and docs for the solvers
  • Loading branch information
rejuvyesh authored Dec 7, 2017
2 parents d4a7ec2 + cd28c07 commit aa290b7
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/TabularTDLearning.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module TabularTDLearning

using POMDPs
# using GenerativeModels

using POMDPToolbox

import POMDPs: Solver, solve, Policy
Expand Down
46 changes: 43 additions & 3 deletions src/q_learn.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,30 @@
"""
QLearningSolver
Vanilla Q learning implementation for tabular MDPs
Parameters:
- `mdp::Union{MDP, POMDP}`:
Your problem framed as an MDP or POMDP
(it will use the state and not the observation if the problem is a POMDP)
- `n_episodes::Int64`:
Number of episodes to train the Q table
default: `100`
- `max_episode_length::Int64`:
Maximum number of steps before the episode is ended
default: `100`
- `learning_rate::Float64`
Learning rate
defaul: `0.001`
- `exp_policy::Policy`:
Exploration policy to select the actions
default: `EpsGreedyPolicy(mdp, 0.5)`
- `eval_every::Int64`:
Frequency at which to evaluate the trained policy
default: `10`
- `n_eval_traj::Int64`:
Number of episodes to evaluate the policy
"""
mutable struct QLearningSolver <: Solver
n_episodes::Int64
max_episode_length::Int64
Expand All @@ -18,12 +45,12 @@ mutable struct QLearningSolver <: Solver
end
end


function create_policy(solver::QLearningSolver, mdp::Union{MDP,POMDP})
return solver.exploration_policy.val
end

function solve(solver::QLearningSolver, mdp::Union{MDP,POMDP}, policy=create_policy(solver, mdp))
#TODO add verbose
function solve(solver::QLearningSolver, mdp::Union{MDP,POMDP}, policy=create_policy(solver, mdp); verbose=true)
rng = solver.exploration_policy.uni.rng
Q = solver.Q_vals
exploration_policy = solver.exploration_policy
Expand All @@ -45,8 +72,21 @@ function solve(solver::QLearningSolver, mdp::Union{MDP,POMDP}, policy=create_pol
for traj in 1:solver.n_eval_traj
r_tot += simulate(sim, mdp, policy, initial_state(mdp, rng))
end
println("On Iteration $i, Returns: $(r_tot/solver.n_eval_traj)")
verbose ? println("On Iteration $i, Returns: $(r_tot/solver.n_eval_traj)") : nothing
end
end
return policy
end

@POMDP_require solve(solver::QLearningSolver, problem::Union{MDP,POMDP}) begin
P = typeof(problem)
S = state_type(P)
A = action_type(P)
@req initial_state(::P, ::AbstractRNG)
@req generate_sr(::P, ::S, ::A, ::AbstractRNG)
@req state_index(::P, ::S)
@req n_states(::P)
@req n_actions(::P)
@req action_index(::P, ::A)
@req discount(::P)
end
40 changes: 40 additions & 0 deletions src/sarsa.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,30 @@
"""
SARSASolver
SARSA implementation for tabular MDPs.
Parameters:
- `mdp::Union{MDP, POMDP}`:
Your problem framed as an MDP or POMDP
(it will use the state and not the observation if the problem is a POMDP)
- `n_episodes::Int64`:
Number of episodes to train the Q table
default: `100`
- `max_episode_length::Int64`:
Maximum number of steps before the episode is ended
default: `100`
- `learning_rate::Float64`
Learning rate
defaul: `0.001`
- `exp_policy::Policy`:
Exploration policy to select the actions
default: `EpsGreedyPolicy(mdp, 0.5)`
- `eval_every::Int64`:
Frequency at which to evaluate the trained policy
default: `10`
- `n_eval_traj::Int64`:
Number of episodes to evaluate the policy
"""
mutable struct SARSASolver <: Solver
n_episodes::Int64
max_episode_length::Int64
Expand Down Expand Up @@ -52,3 +79,16 @@ function solve(solver::SARSASolver, mdp::Union{MDP,POMDP}, policy=create_policy(
end
return policy
end

@POMDP_require solve(solver::SARSASolver, problem::Union{MDP,POMDP}) begin
P = typeof(problem)
S = state_type(P)
A = action_type(P)
@req initial_state(::P, ::AbstractRNG)
@req generate_sr(::P, ::S, ::A, ::AbstractRNG)
@req state_index(::P, ::S)
@req n_states(::P)
@req n_actions(::P)
@req action_index(::P, ::A)
@req discount(::P)
end
43 changes: 43 additions & 0 deletions src/sarsa_lambda.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,33 @@
"""
SARSALambdaSolver
SARSA-λ implementation for tabular MDPs, assign credits using eligibility traces
Parameters:
- `mdp::Union{MDP, POMDP}`:
Your problem framed as an MDP or POMDP
(it will use the state and not the observation if the problem is a POMDP)
- `n_episodes::Int64`:
Number of episodes to train the Q table
default: `100`
- `max_episode_length::Int64`:
Maximum number of steps before the episode is ended
default: `100`
- `learning_rate::Float64`:
Learning rate
defaul: `0.001`
- `lambda::Float64`:
Exponential decay parameter for the eligibility traces
default: `0.5`
- `exp_policy::Policy`:
Exploration policy to select the actions
default: `EpsGreedyPolicy(mdp, 0.5)`
- `eval_every::Int64`:
Frequency at which to evaluate the trained policy
default: `10`
- `n_eval_traj::Int64`:
Number of episodes to evaluate the policy
"""
mutable struct SARSALambdaSolver <: Solver
n_episodes::Int64
max_episode_length::Int64
Expand Down Expand Up @@ -66,3 +96,16 @@ function solve(solver::SARSALambdaSolver, mdp::Union{MDP,POMDP}, policy=create_p
end
return policy
end

@POMDP_require solve(solver::SARSALambdaSolver, problem::Union{MDP,POMDP}) begin
P = typeof(problem)
S = state_type(P)
A = action_type(P)
@req initial_state(::P, ::AbstractRNG)
@req generate_sr(::P, ::S, ::A, ::AbstractRNG)
@req state_index(::P, ::S)
@req n_states(::P)
@req n_actions(::P)
@req action_index(::P, ::A)
@req discount(::P)
end
8 changes: 7 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
using TabularTDLearning
using POMDPs
using POMDPModels
using Base.Test

mdp = GridWorld()



solver = QLearningSolver(mdp, learning_rate=0.1, n_episodes=5000, max_episode_length=50, eval_every=50, n_eval_traj=100)
println("Test QLearning requirements: ")
@requirements_info solver mdp

policy = solve(solver, mdp)


Expand All @@ -13,4 +20,3 @@ policy = solve(solver, mdp)

solver = SARSALambdaSolver(mdp, learning_rate=0.1, lambda=0.9, n_episodes=5000, max_episode_length=50, eval_every=50, n_eval_traj=100)
policy = solve(solver, mdp)

0 comments on commit aa290b7

Please sign in to comment.