diff --git a/src/alpha.jl b/src/alpha.jl index 5364ac7..ce3bf0c 100644 --- a/src/alpha.jl +++ b/src/alpha.jl @@ -1,6 +1,7 @@ struct AlphaVec{A} <: AbstractVector{Float64} alpha::Vector{Float64} action::A + witnesses::Set{Int} end @inline Base.length(v::AlphaVec) = length(v.alpha) diff --git a/src/backup.jl b/src/backup.jl index ada1f62..b54d8b7 100644 --- a/src/backup.jl +++ b/src/backup.jl @@ -76,9 +76,13 @@ function backup!(tree, b_idx) best_action = a end end - - α = AlphaVec(best_α, best_action) - push!(Γ, α) + α_idx = findfirst(x->x == best_α, Γ) + if α_idx === nothing + α = AlphaVec(best_α, best_action, Set(b_idx)) + push!(Γ, α) + else + union!(Γ[α_idx].witnesses, b_idx) + end tree.V_lower[b_idx] = V end diff --git a/src/prune.jl b/src/prune.jl index 7f19834..2e1691c 100644 --- a/src/prune.jl +++ b/src/prune.jl @@ -48,16 +48,27 @@ function prune!(tree::SARSOPTree) end end -function belief_space_domination(α1, α2, B, δ) - a1_dominant = true - a2_dominant = true - for b ∈ B - !a1_dominant && !a2_dominant && return (false, false) - δV = intersection_distance(α1, α2, b) - δV ≤ δ && (a1_dominant = false) - δV ≥ -δ && (a2_dominant = false) +function recertify_witnesses!(tree, α1, α2, δ) + + if α1 == α2 + union!(α2.witnesses, α1.witnesses) + empty!(α1.witnesses) + return + end + + for b_idx in α1.witnesses + if tree.b_pruned[b_idx] + delete!(α1.witnesses, b_idx) + continue + end + + δV = intersection_distance(α2, α1, tree.b[b_idx]) + + if δV > δ + delete!(α1.witnesses, b_idx) + push!(α2.witnesses, b_idx) + end end - return a1_dominant, a2_dominant end @inline function intersection_distance(α1, α2, b) @@ -75,33 +86,21 @@ end function prune_alpha!(tree::SARSOPTree, δ) Γ = tree.Γ - B_valid = tree.b[map(!,tree.b_pruned)] pruned = falses(length(Γ)) - # checking if α_i dominates α_j - for (i,α_i) ∈ enumerate(Γ) + for (i, α_i) ∈ enumerate(Γ) pruned[i] && continue - for (j,α_j) ∈ enumerate(Γ) - (j ≤ i || pruned[j]) && continue - a1_dominant,a2_dominant = belief_space_domination(α_i, α_j, B_valid, δ) - #= - NOTE: α1 and α2 shouldn't technically be able to mutually dominate - i.e. a1_dominant and a2_dominant should never both be true. - But this does happen when α1 == α2 because intersection_distance returns NaN. - Current impl prunes α2 without doing an equality check, removing - the duplicate α. Could do equality check to short-circuit - belief_space_domination which would speed things up if we have - a lot of duplicates, but the equality check can slow things down - if α's are sufficiently diverse. - =# - if a1_dominant - pruned[j] = true - elseif a2_dominant + for (j, α_j) ∈ enumerate(Γ) + (pruned[j] || j == i) && continue + recertify_witnesses!(tree, α_i, α_j, δ) + if isempty(α_i.witnesses) pruned[i] = true break + elseif isempty(α_j.witnesses) + pruned[j] = true end end end deleteat!(Γ, pruned) tree.prune_data.last_Γ_size = length(Γ) -end +end \ No newline at end of file diff --git a/src/tree.jl b/src/tree.jl index 5443eb2..cef6fc7 100644 --- a/src/tree.jl +++ b/src/tree.jl @@ -95,7 +95,7 @@ function insert_root!(solver, tree::SARSOPTree, b) Γ_lower = solve(solver.init_lower, pomdp) for (α,a) ∈ alphapairs(Γ_lower) new_val = dot(α, b) - push!(tree.Γ, AlphaVec(α, a)) + push!(tree.Γ, AlphaVec(α, a, Set(1))) end tree.prune_data.last_Γ_size = length(tree.Γ)