diff --git a/src/prune.jl b/src/prune.jl index 7f19834..395da7d 100644 --- a/src/prune.jl +++ b/src/prune.jl @@ -5,6 +5,7 @@ end function prune!(solver::SARSOPSolver, tree::SARSOPTree) prune!(tree) + prune_strictly_dominated!(tree::SARSOPTree) if should_prune_alphas(tree) prune_alpha!(tree, solver.delta) end @@ -48,60 +49,120 @@ function prune!(tree::SARSOPTree) end end -function belief_space_domination(α1, α2, B, δ) - a1_dominant = true - a2_dominant = true - for b ∈ B - !a1_dominant && !a2_dominant && return (false, false) - δV = intersection_distance(α1, α2, b) - δV ≤ δ && (a1_dominant = false) - δV ≥ -δ && (a2_dominant = false) - end - return a1_dominant, a2_dominant -end - -@inline function intersection_distance(α1, α2, b) - s = 0.0 +function intersection_distance(α1, α2, b) dot_sum = 0.0 - I,B = b.nzind, b.nzval - @inbounds for _i ∈ eachindex(I) + I, B = b.nzind, b.nzval + for _i ∈ eachindex(I) i = I[_i] - diff = α1[i] - α2[i] - s += abs2(diff) - dot_sum += diff*B[_i] + dot_sum += (α1[i] - α2[i]) * B[_i] + end + s = 0.0 + for i ∈ eachindex(α1, α2) + s += (α1[i] - α2[i])^2 end return dot_sum / sqrt(s) end -function prune_alpha!(tree::SARSOPTree, δ) +function prune_alpha!(tree::SARSOPTree, δ, eps=0.0) Γ = tree.Γ - B_valid = tree.b[map(!,tree.b_pruned)] - pruned = falses(length(Γ)) - - # checking if α_i dominates α_j - for (i,α_i) ∈ enumerate(Γ) - pruned[i] && continue - for (j,α_j) ∈ enumerate(Γ) - (j ≤ i || pruned[j]) && continue - a1_dominant,a2_dominant = belief_space_domination(α_i, α_j, B_valid, δ) - #= - NOTE: α1 and α2 shouldn't technically be able to mutually dominate - i.e. a1_dominant and a2_dominant should never both be true. - But this does happen when α1 == α2 because intersection_distance returns NaN. - Current impl prunes α2 without doing an equality check, removing - the duplicate α. Could do equality check to short-circuit - belief_space_domination which would speed things up if we have - a lot of duplicates, but the equality check can slow things down - if α's are sufficiently diverse. - =# - if a1_dominant - pruned[j] = true - elseif a2_dominant - pruned[i] = true - break + B_valid = tree.b[map(!, tree.b_pruned)] + + n_Γ = length(Γ) + n_B = length(B_valid) + + dominant_indices_bools = falses(n_Γ) + dominant_vector_indices = Vector{Int}(undef, n_B) + + # First, identify dominant alpha vectors + for b_idx in 1:n_B + max_value = -Inf + max_index = -1 + for i in 1:n_Γ + value = dot(Γ[i], B_valid[b_idx]) + if value > max_value + max_value = value + max_index = i end end + dominant_indices_bools[max_index] = true + dominant_vector_indices[b_idx] = max_index end - deleteat!(Γ, pruned) + + non_dominant_indices = findall(!, dominant_indices_bools) + n_non_dom = length(non_dominant_indices) + keep_non_dom = falses(n_non_dom) + + for b_idx in 1:n_B + dom_vec_idx = dominant_vector_indices[b_idx] + for j in 1:n_non_dom + non_dom_idx = non_dominant_indices[j] + if keep_non_dom[j] + continue + end + intx_dist = intersection_distance(Γ[dom_vec_idx], Γ[non_dom_idx], B_valid[b_idx]) + if !isnan(intx_dist) && (intx_dist + eps ≤ δ) + keep_non_dom[j] = true + end + end + end + + non_dominant_indices = non_dominant_indices[.!keep_non_dom] + deleteat!(Γ, non_dominant_indices) tree.prune_data.last_Γ_size = length(Γ) end + +function strictly_dominates(α1, α2, eps) + for ii in 1:length(α1) + if α1[ii] < α2[ii] - eps + return false + end + end + return true +end + +function prune_strictly_dominated!(tree::SARSOPTree, eps=1e-10) + Γ = tree.Γ + Γ_new_idxs = Vector{Int}(undef, length(Γ)) + keep = trues(length(Γ)) + + idx_count = 0 + for (α_try_idx, α_try) in enumerate(Γ) + dominated = false + for jj in 1:idx_count + α_in_idx = Γ_new_idxs[jj] + α_in = Γ[α_in_idx] + if strictly_dominates(α_try, α_in, eps) + keep[jj] = false + elseif strictly_dominates(α_in, α_try, eps) + dominated = true + break + end + end + if !dominated + new_idx_count = 0 + for jj in 1:idx_count + if keep[jj] + new_idx_count += 1 + Γ_new_idxs[new_idx_count] = Γ_new_idxs[jj] + end + end + new_idx_count += 1 + Γ_new_idxs[new_idx_count] = α_try_idx + idx_count = new_idx_count + fill!(keep, true) + end + end + + resize!(Γ_new_idxs, idx_count) + + to_delete = trues(length(Γ)) + for idx in Γ_new_idxs + to_delete[idx] = false + end + + for ii in length(Γ):-1:1 + if to_delete[ii] + deleteat!(Γ, ii) + end + end +end diff --git a/test/prune.jl b/test/prune.jl new file mode 100644 index 0000000..f55fb36 --- /dev/null +++ b/test/prune.jl @@ -0,0 +1,19 @@ +@testset "prune" begin + # NativeSARSOP.strictly_dominates + a1 = [1.0, 2.0, 3.0] + a2 = [1.0, 2.1, 2.9] + a3 = [0.9, 1.9, 2.9] + @test !NativeSARSOP.strictly_dominates(a1, a2, 1e-10) + @test NativeSARSOP.strictly_dominates(a1, a1, 1e-10) + @test NativeSARSOP.strictly_dominates(a1, a3, 1e-10) + + # NativeSARSOP.intersection_distance + b = SparseVector([1.0, 0.0]) + a1 = [1.0, 0.0] + a2 = [0.0, 1.0] + @test isapprox(NativeSARSOP.intersection_distance(a1, a2, b), + sqrt(0.5^2 + 0.5^2), atol=1e-10) + + b = SparseVector([0.5, 0.5]) + @test isapprox(NativeSARSOP.intersection_distance(a1, a2, b), 0.0, atol=1e-10) +end diff --git a/test/runtests.jl b/test/runtests.jl index 806fca1..044dfde 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,6 +27,8 @@ include("sample.jl") include("updater.jl") +include("prune.jl") + include("tree.jl") @testset "Tiger POMDP" begin