Skip to content

Commit

Permalink
Updated binning method + format updates
Browse files Browse the repository at this point in the history
  • Loading branch information
dylan-asmar committed Jul 24, 2024
1 parent 4c5b380 commit 3762018
Showing 1 changed file with 108 additions and 91 deletions.
199 changes: 108 additions & 91 deletions src/tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,61 @@ mutable struct PruneData
prune_threshold::Float64
end

struct BinData
bin_value::Matrix{Float64}
bin_count::Matrix{Int}
bin_error::Matrix{Float64}
end

struct BinNode
key::Tuple{Int,Int}
prev_error::Float64
end

struct BinManager
lowest_ub::Float64
num_levels::Int
num_of_bins_per_level::Vector{Int}
bin_levels_intervals::Vector{Dict{Symbol, Float64}}
bin_levels_nodes::Vector{Dict{Int, Dict{Symbol, Union{Tuple{Int, Int}, Float64}}}}
bin_levels::Vector{Dict{Symbol, Dict{Tuple{Int, Int}, Union{Float64, Int}}}}
previous_lowerbound::Dict{Int, Float64}
num_bins_per_level::Vector{Int}
bin_levels_intervals::Vector{NamedTuple{(:ub, :entropy),Tuple{Float64,Float64}}}
bin_levels_nodes::Vector{Dict{Int,BinNode}}
bin_levels::Vector{BinData}
previous_lowerbound::Dict{Int,Float64}
end

function BinManager(Vs_upper::Vector{Float64}, num_bins_per_level=[5, 10])
num_levels = length(num_bins_per_level)
lowest_ub = minimum(Vs_upper)
highest_ub = maximum(Vs_upper)

# [level][:ub|:entropy] => value
bin_levels_intervals = [Dict{Symbol, Float64}() for _ in 1:num_levels]
bin_levels_intervals = Vector{NamedTuple{(:ub, :entropy),Tuple{Float64,Float64}}}(undef, num_levels)

# [level][b_idx][:key|:prev_error] => (ub_interval_idx, entropy_interval_idx)|previous_error
bin_levels_nodes = Vector{Dict{Int, Dict{Symbol, Union{Tuple{Int, Int}, Float64}}}}(undef, num_levels)
bin_levels_nodes = Vector{Dict{Int,BinNode}}(undef, num_levels)

# [level][:bin_value|:bin_count|:bin_error][(ub_interval_idx, entropy_interval_idx)] => Float64|Int|Float64
bin_levels = Vector{Dict{Symbol, Dict{Tuple{Int, Int}, Union{Float64, Int}}}}(undef, num_levels)
bin_levels = Vector{BinData}(undef, num_levels)

num_states = length(Vs_upper)
max_e = max_entropy(num_states)
for level_i in 1:num_levels
num_bins = num_bins_per_level[level_i]

bin_levels_intervals[level_i][:ub] = (highest_ub - lowest_ub) / num_bins
bin_levels_intervals[level_i][:entropy] = max_e / num_bins

bin_levels_nodes[level_i] = Dict{Int, Dict{Symbol, Union{Tuple{Int, Int}, Float64}}}()

level = Dict{Symbol, Dict{Tuple{Int, Int}, Union{Float64, Int}}}()
level[:bin_value] = Dict{Tuple{Int, Int}, Float64}()
level[:bin_count] = Dict{Tuple{Int, Int}, Int}()
level[:bin_error] = Dict{Tuple{Int, Int}, Float64}()
bin_levels[level_i] = level

ub = (highest_ub - lowest_ub) / num_bins
ent = max_e / num_bins
bin_levels_intervals[level_i] = (ub=ub, entropy=ent)

bin_levels_nodes[level_i] = Dict{Int,BinNode}()

bin_levels[level_i] = BinData(
zeros(Float64, num_bins, num_bins), # bin_value
zeros(Int, num_bins, num_bins), # bin_count
zeros(Float64, num_bins, num_bins) # bin_error
)
end
previous_lowerbound = Dict{Int, Float64}() # b_idx => lowerbound

previous_lowerbound = Dict{Int,Float64}() # b_idx => lowerbound

return BinManager(
lowest_ub,
num_levels,
Expand Down Expand Up @@ -74,7 +86,7 @@ struct SARSOPTree

_discount::Float64
is_terminal::BitVector
is_terminal_s::SparseVector{Bool, Int}
is_terminal_s::SparseVector{Bool,Int}

#do we need both b_pruned and ba_pruned? b_pruned might be enough
sampled::Vector{Int} # b_idx
Expand All @@ -86,7 +98,7 @@ struct SARSOPTree
prune_data::PruneData

Γ::Vector{AlphaVec{Int}}

use_binning::Bool
bm::BinManager
end
Expand All @@ -100,11 +112,9 @@ function SARSOPTree(solver, pomdp::POMDP; num_bins_per_level=[5, 10])
corner_values = map(maximum, zip(upper_policy.alphas...))

bin_manager = BinManager(corner_values, num_bins_per_level)

tree = SARSOPTree(
sparse_pomdp,

Vector{Float64}[],
tree = SARSOPTree(
sparse_pomdp, Vector{Float64}[],
Vector{Int}[],
corner_values, #upper_policy.util,
Float64[],
Expand All @@ -122,7 +132,7 @@ function SARSOPTree(solver, pomdp::POMDP; num_bins_per_level=[5, 10])
Vector{Int}(),
BitVector(),
cache,
PruneData(0,0,solver.prunethresh),
PruneData(0, 0, solver.prunethresh),
AlphaVec{Int}[],
solver.use_binning,
bin_manager
Expand Down Expand Up @@ -154,7 +164,7 @@ function insert_root!(solver, tree::SARSOPTree, b)
pomdp = tree.pomdp

Γ_lower = solve(solver.init_lower, pomdp)
for (α,a) alphapairs(Γ_lower)
for (α, a) alphapairs(Γ_lower)
new_val = dot(α, b)
push!(tree.Γ, AlphaVec(α, a))
end
Expand All @@ -179,7 +189,7 @@ function update(tree::SARSOPTree, b_idx::Int, a, o)
ba_idx = tree.b_children[b_idx][a]
bp_idx = tree.ba_children[ba_idx][o]
V̲, V̄ = if tree.is_terminal[bp_idx]
0.,0.
0.0, 0.0
else
lower_value(tree, tree.b[bp_idx]), upper_value(tree, tree.b[bp_idx])
end
Expand All @@ -200,7 +210,7 @@ function add_belief!(tree::SARSOPTree, b, ba_idx::Int, o)
push!(tree.is_terminal, terminal)

V̲, V̄ = if terminal
0., 0.
0.0, 0.0
else
lower_value(tree, b), upper_value(tree, b)
end
Expand Down Expand Up @@ -250,8 +260,8 @@ function fill_populated!(tree::SARSOPTree, b_idx::Int)
bp_idx, V̲, V̄ = update(tree, b_idx, a, o)
b′ = tree.b[bp_idx]
po = tree.poba[ba_idx][o]
+= γ*po*
+= γ*po*
+= γ * po *
+= γ * po *
end

Qa_upper[a] =
Expand Down Expand Up @@ -283,7 +293,7 @@ function fill_unpopulated!(tree::SARSOPTree, b_idx::Int)
tree.ba_children[ba_idx] = ba_children

n_b += N_OBS
pred = dropzeros!(mul!(tree.cache.pred, pomdp.T[a],b))
pred = dropzeros!(mul!(tree.cache.pred, pomdp.T[a], b))
poba = zeros(Float64, N_OBS)
Rba = belief_reward(tree, b, a)

Expand All @@ -294,15 +304,15 @@ function fill_unpopulated!(tree::SARSOPTree, b_idx::Int)
# belief update
bp = corrector(pomdp, pred, a, o)
po = sum(bp)
if po > 0.
if po > 0.0
bp.nzval ./= po
poba[o] = po
end

bp_idx, V̲, V̄ = add_belief!(tree, bp, ba_idx, o)

+= γ*po*
+= γ*po*
+= γ * po *
+= γ * po *
end
Qa_upper[a] =
Qa_lower[a] =
Expand All @@ -318,100 +328,107 @@ function initialize_bin_node!(tree::SARSOPTree, b_idx::Int)
lb_val = tree.V_lower[b_idx]
ub_val = tree.V_upper[b_idx]
node_entropy = entropy(tree.b[b_idx])

for level_i in 1:tree.bm.num_levels
ub_interval_idx = get_interval_idx(
ub_val, tree.bm.lowest_ub, tree.bm.bin_levels_intervals[level_i][:ub],
tree.bm.num_of_bins_per_level[level_i]
ub_val, tree.bm.lowest_ub, tree.bm.bin_levels_intervals[level_i][:ub],
tree.bm.num_bins_per_level[level_i]
)

entropy_interval_idx = get_interval_idx(
node_entropy, 0.0, tree.bm.bin_levels_intervals[level_i][:entropy],
tree.bm.num_of_bins_per_level[level_i]
node_entropy, 0.0, tree.bm.bin_levels_intervals[level_i][:entropy],
tree.bm.num_bins_per_level[level_i]
)

key = (ub_interval_idx, entropy_interval_idx)

if !haskey(tree.bm.bin_levels_nodes[level_i], b_idx)
tree.bm.bin_levels_nodes[level_i] = Dict(b_idx => Dict(:key => key))
end

bin_count = get(tree.bm.bin_levels[level_i][:bin_count], key, 0)
prev_error = 0.0

bin_count = tree.bm.bin_levels[level_i].bin_count[ub_interval_idx, entropy_interval_idx]
if bin_count > 0
err = tree.bm.bin_levels[level_i][:bin_value][key] - lb_val
tree.bm.bin_levels_nodes[level_i][b_idx][:prev_error] = err * err
tree.bm.bin_levels[level_i][:bin_value][key] += err * err

value = (tree.bm.bin_levels[level_i][:bin_value][key] * bin_count + lb_val) / (bin_count + 1)
tree.bm.bin_levels[level_i][:bin_count][key] += 1
err = tree.bm.bin_levels[level_i].bin_value[ub_interval_idx, entropy_interval_idx] - lb_val
prev_error = err * err
tree.bm.bin_levels[level_i].bin_error[ub_interval_idx, entropy_interval_idx] += prev_error
value = (tree.bm.bin_levels[level_i].bin_value[ub_interval_idx, entropy_interval_idx] * bin_count + lb_val) / (bin_count + 1)
tree.bm.bin_levels[level_i].bin_count[ub_interval_idx, entropy_interval_idx] += 1
else
err = ub_val - lb_val
tree.bm.bin_levels_nodes[level_i][b_idx][:prev_error] = err * err
tree.bm.bin_levels[level_i][:bin_error][key] = err * err

prev_error = err * err
tree.bm.bin_levels[level_i].bin_error[ub_interval_idx, entropy_interval_idx] = prev_error
value = lb_val
tree.bm.bin_levels[level_i][:bin_count] = Dict(key => bin_count + 1)
tree.bm.bin_levels[level_i].bin_count[ub_interval_idx, entropy_interval_idx] = 1
end
tree.bm.bin_levels[level_i][:bin_value][key] = value
tree.bm.bin_levels[level_i].bin_value[ub_interval_idx, entropy_interval_idx] = value
tree.bm.bin_levels_nodes[level_i][b_idx] = BinNode(key, prev_error)

end
tree.bm.previous_lowerbound[b_idx] = lb_val
end

function update_bin_node!(tree::SARSOPTree, b_idx::Int)
lb_val = tree.V_lower[b_idx]
up_val = tree.V_upper[b_idx]

if !haskey(tree.bm.bin_levels_nodes[1], b_idx)
return initialize_bin_node!(tree, b_idx)
end

for level_i in 1:tree.bm.num_levels
key = tree.bm.bin_levels_nodes[level_i][b_idx][:key]

bin_count = get(tree.bm.bin_levels[level_i][:bin_count], key, 0)
node = tree.bm.bin_levels_nodes[level_i][b_idx]
key = node.key
ub_interval_idx, entropy_interval_idx = key
prev_error = 0.0

bin_count = tree.bm.bin_levels[level_i].bin_count[ub_interval_idx, entropy_interval_idx]
if bin_count == 1
err = up_val - lb_val
tree.bm.bin_levels_nodes[level_i][b_idx][:prev_error] = err * err
tree.bm.bin_levels[level_i][:bin_error][key] = err * err
prev_error = err * err
tree.bm.bin_levels[level_i].bin_error[ub_interval_idx, entropy_interval_idx] = prev_error
else
err = tree.bm.bin_levels[level_i][:bin_value][key] - lb_val
tree.bm.bin_levels[level_i][:bin_error][key] -= tree.bm.bin_levels_nodes[level_i][b_idx][:prev_error]
tree.bm.bin_levels_nodes[level_i][b_idx][:prev_error] = err * err
tree.bm.bin_levels[level_i][:bin_error][key] += err * err
err = tree.bm.bin_levels[level_i].bin_value[ub_interval_idx, entropy_interval_idx] - lb_val
tree.bm.bin_levels[level_i].bin_error[ub_interval_idx, entropy_interval_idx] -= node.prev_error
prev_error = err * err
tree.bm.bin_levels[level_i].bin_error[ub_interval_idx, entropy_interval_idx] += prev_error
end

tree.bm.bin_levels[level_i][:bin_value][key] = (tree.bm.bin_levels[level_i][:bin_value][key] * bin_count + lb_val - tree.bm.previous_lowerbound[b_idx]) / bin_count
tree.bm.bin_levels_nodes[level_i][b_idx] = BinNode(key, prev_error)
tree.bm.bin_levels[level_i].bin_value[ub_interval_idx, entropy_interval_idx] = (tree.bm.bin_levels[level_i].bin_value[ub_interval_idx, entropy_interval_idx] * bin_count + lb_val - tree.bm.previous_lowerbound[b_idx]) / bin_count
end
tree.bm.previous_lowerbound[b_idx] = lb_val
end

function get_bin_value(tree::SARSOPTree, b_idx::Int)

lb_val = tree.V_lower[b_idx]
ub_val = tree.V_upper[b_idx]

key = tree.bm.bin_levels_nodes[1][b_idx][:key]
if tree.bm.bin_levels[1][:bin_count][key] == 1

node = tree.bm.bin_levels_nodes[1][b_idx]
key = node.key
ub_interval_idx, entropy_interval_idx = key
if tree.bm.bin_levels[1].bin_count[ub_interval_idx, entropy_interval_idx] == 1
return ub_val
else
smallest_error = Inf
best_level = 0
best_key = key
for level_i in 1:tree.bm.num_levels
key = tree.bm.bin_levels_nodes[level_i][b_idx][:key]
if tree.bm.bin_levels[level_i][:bin_error][key] + 1e-10 < smallest_error
node = tree.bm.bin_levels_nodes[level_i][b_idx]
key = node.key
ub_interval_idx, entropy_interval_idx = key
if tree.bm.bin_levels[level_i].bin_error[ub_interval_idx, entropy_interval_idx] + 1e-10 < smallest_error
best_level = level_i
smallest_error = tree.bm.bin_levels[level_i][:bin_error][key]
smallest_error = tree.bm.bin_levels[level_i].bin_error[ub_interval_idx, entropy_interval_idx]
best_key = key
end
end

if tree.bm.bin_levels[best_level][:bin_value][best_key] > ub_val + 1e-10

best_ub_interval_idx, best_entropy_interval_idx = best_key
best_value = tree.bm.bin_levels[best_level].bin_value[best_ub_interval_idx, best_entropy_interval_idx]
if best_value > ub_val + 1e-10
return ub_val
elseif tree.bm.bin_levels[best_level][:bin_value][best_key] + 1e-10 < lb_val
elseif best_value + 1e-10 < lb_val
return lb_val
else
return tree.bm.bin_levels[best_level][:bin_value][best_key]
else
return best_value
end
end
end
Expand Down Expand Up @@ -441,5 +458,5 @@ function get_interval_idx(value::Float64, lower::Float64, interval::Float64, num
return 1
end
idx = Int(floor((value - lower) / interval) + 1)
return clamp(idx, 1, num_intervals)
return clamp(idx, 1, num_intervals)
end

0 comments on commit 3762018

Please sign in to comment.