From 3d697fc7b7891a9d35dde56719facb7c38fbdd93 Mon Sep 17 00:00:00 2001 From: Dylan Asmar Date: Fri, 28 Jun 2024 13:41:30 -0600 Subject: [PATCH 1/9] Added pruning of alpha vectors for strictly dominated vectors at each call to `prune` --- src/prune.jl | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/prune.jl b/src/prune.jl index 7f19834..e38339f 100644 --- a/src/prune.jl +++ b/src/prune.jl @@ -5,6 +5,7 @@ end function prune!(solver::SARSOPSolver, tree::SARSOPTree) prune!(tree) + prune_strictly_dominated!(tree::SARSOPTree) if should_prune_alphas(tree) prune_alpha!(tree, solver.delta) end @@ -105,3 +106,39 @@ function prune_alpha!(tree::SARSOPTree, δ) deleteat!(Γ, pruned) tree.prune_data.last_Γ_size = length(Γ) end + +function strictly_dominates(α1, α2, eps) + for ii in 1:length(α1) + if α1[ii] < α2[ii] - eps + return false + end + end + return true +end + +function prune_strictly_dominated!(tree::SARSOPTree, eps=1e-10) + Γ = tree.Γ + Γ_new_idxs = [] + + for (α_try_idx, α_try) in enumerate(Γ) + marked_for_deletion = falses(length(Γ_new_idxs)) + dominated = false + for (jj, α_in_idx) in enumerate(Γ_new_idxs) + α_in = Γ[α_in_idx] + if strictly_dominates(α_try, α_in, eps) + marked_for_deletion[jj] = true + elseif strictly_dominates(α_in, α_try, eps) + dominated = true + break + end + end + if !dominated + Γ_new_idxs = Γ_new_idxs[.!marked_for_deletion] + push!(Γ_new_idxs, α_try_idx) + end + end + + Γ_idxs_to_delete = setdiff(1:length(Γ), Γ_new_idxs) + deleteat!(Γ, Γ_idxs_to_delete) + tree.prune_data.last_Γ_size = length(Γ) +end From 17df2134a7fbc03790cd5df3ae1047d11e60f64f Mon Sep 17 00:00:00 2001 From: Dylan Asmar Date: Fri, 28 Jun 2024 15:16:22 -0600 Subject: [PATCH 2/9] Chnage to not update Gamma size --- src/prune.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/prune.jl b/src/prune.jl index e38339f..2a96a63 100644 --- a/src/prune.jl +++ b/src/prune.jl @@ -140,5 +140,4 @@ function prune_strictly_dominated!(tree::SARSOPTree, eps=1e-10) Γ_idxs_to_delete = setdiff(1:length(Γ), Γ_new_idxs) deleteat!(Γ, Γ_idxs_to_delete) - tree.prune_data.last_Γ_size = length(Γ) end From 10c99ecb00ea3ccdd6a301b65c2a7cc6d72a7ee3 Mon Sep 17 00:00:00 2001 From: Dylan Asmar Date: Fri, 28 Jun 2024 16:55:59 -0600 Subject: [PATCH 3/9] Update to `prune_alpha!` --- src/prune.jl | 93 +++++++++++++++++++++++++--------------------------- 1 file changed, 45 insertions(+), 48 deletions(-) diff --git a/src/prune.jl b/src/prune.jl index 2a96a63..6427a7f 100644 --- a/src/prune.jl +++ b/src/prune.jl @@ -49,61 +49,58 @@ function prune!(tree::SARSOPTree) end end -function belief_space_domination(α1, α2, B, δ) - a1_dominant = true - a2_dominant = true - for b ∈ B - !a1_dominant && !a2_dominant && return (false, false) - δV = intersection_distance(α1, α2, b) - δV ≤ δ && (a1_dominant = false) - δV ≥ -δ && (a2_dominant = false) - end - return a1_dominant, a2_dominant -end - -@inline function intersection_distance(α1, α2, b) - s = 0.0 - dot_sum = 0.0 - I,B = b.nzind, b.nzval - @inbounds for _i ∈ eachindex(I) - i = I[_i] - diff = α1[i] - α2[i] - s += abs2(diff) - dot_sum += diff*B[_i] - end - return dot_sum / sqrt(s) +@inline function intersection_distance_new(α1, α2, b) + diff = α1 - α2 + dot_b = dot(diff, b) + dot_diff = dot(diff, diff) + d = dot_b / sqrt(dot_diff) + return d end - -function prune_alpha!(tree::SARSOPTree, δ) +function prune_alpha!(tree::SARSOPTree, δ, eps=0.0) Γ = tree.Γ B_valid = tree.b[map(!,tree.b_pruned)] - pruned = falses(length(Γ)) + + n_Γ = length(Γ) + n_B = length(B_valid) + + dominant_indices_bools = falses(n_Γ) + dominant_vector_indices = Vector{Int}(undef, n_B) + + # First, identify dominant alpha vectors + for b_idx in 1:n_B + max_value = -Inf + max_index = -1 + for i in 1:n_Γ + value = dot(Γ[i], B_valid[b_idx]) + if value > max_value + max_value = value + max_index = i + end + end + dominant_indices_bools[max_index] = true + dominant_vector_indices[b_idx] = max_index + end - # checking if α_i dominates α_j - for (i,α_i) ∈ enumerate(Γ) - pruned[i] && continue - for (j,α_j) ∈ enumerate(Γ) - (j ≤ i || pruned[j]) && continue - a1_dominant,a2_dominant = belief_space_domination(α_i, α_j, B_valid, δ) - #= - NOTE: α1 and α2 shouldn't technically be able to mutually dominate - i.e. a1_dominant and a2_dominant should never both be true. - But this does happen when α1 == α2 because intersection_distance returns NaN. - Current impl prunes α2 without doing an equality check, removing - the duplicate α. Could do equality check to short-circuit - belief_space_domination which would speed things up if we have - a lot of duplicates, but the equality check can slow things down - if α's are sufficiently diverse. - =# - if a1_dominant - pruned[j] = true - elseif a2_dominant - pruned[i] = true - break + non_dominant_indices = findall(!, dominant_indices_bools) + n_non_dom = length(non_dominant_indices) + keep_non_dom = falses(n_non_dom) + + for b_idx in 1:n_B + dom_vec_idx = dominant_vector_indices[b_idx] + for j in 1:n_non_dom + non_dom_idx = non_dominant_indices[j] + if keep_non_dom[j] + continue + end + intx_dist = intersection_distance_new(Γ[dom_vec_idx], Γ[non_dom_idx], B_valid[b_idx]) + if !isnan(intx_dist) && (intx_dist + eps ≤ δ) + keep_non_dom[j] = true end end end - deleteat!(Γ, pruned) + + non_dominant_indices = non_dominant_indices[.!keep_non_dom] + deleteat!(Γ, non_dominant_indices) tree.prune_data.last_Γ_size = length(Γ) end From ff3e8ccc961063c122d034ecad21a2b9817375d8 Mon Sep 17 00:00:00 2001 From: Dylan Asmar Date: Fri, 28 Jun 2024 16:57:47 -0600 Subject: [PATCH 4/9] function name update --- src/prune.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/prune.jl b/src/prune.jl index 6427a7f..01c1b8e 100644 --- a/src/prune.jl +++ b/src/prune.jl @@ -49,7 +49,7 @@ function prune!(tree::SARSOPTree) end end -@inline function intersection_distance_new(α1, α2, b) +@inline function intersection_distance(α1, α2, b) diff = α1 - α2 dot_b = dot(diff, b) dot_diff = dot(diff, diff) @@ -92,7 +92,7 @@ function prune_alpha!(tree::SARSOPTree, δ, eps=0.0) if keep_non_dom[j] continue end - intx_dist = intersection_distance_new(Γ[dom_vec_idx], Γ[non_dom_idx], B_valid[b_idx]) + intx_dist = intersection_distance(Γ[dom_vec_idx], Γ[non_dom_idx], B_valid[b_idx]) if !isnan(intx_dist) && (intx_dist + eps ≤ δ) keep_non_dom[j] = true end From c94bdf7bf87715951ac32779f431701b177e75ca Mon Sep 17 00:00:00 2001 From: Dylan Asmar Date: Fri, 28 Jun 2024 17:13:56 -0600 Subject: [PATCH 5/9] formatting update --- src/prune.jl | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/prune.jl b/src/prune.jl index 01c1b8e..667fa4b 100644 --- a/src/prune.jl +++ b/src/prune.jl @@ -56,16 +56,17 @@ end d = dot_b / sqrt(dot_diff) return d end + function prune_alpha!(tree::SARSOPTree, δ, eps=0.0) Γ = tree.Γ - B_valid = tree.b[map(!,tree.b_pruned)] - + B_valid = tree.b[map(!, tree.b_pruned)] + n_Γ = length(Γ) n_B = length(B_valid) - + dominant_indices_bools = falses(n_Γ) dominant_vector_indices = Vector{Int}(undef, n_B) - + # First, identify dominant alpha vectors for b_idx in 1:n_B max_value = -Inf @@ -84,7 +85,7 @@ function prune_alpha!(tree::SARSOPTree, δ, eps=0.0) non_dominant_indices = findall(!, dominant_indices_bools) n_non_dom = length(non_dominant_indices) keep_non_dom = falses(n_non_dom) - + for b_idx in 1:n_B dom_vec_idx = dominant_vector_indices[b_idx] for j in 1:n_non_dom @@ -98,7 +99,7 @@ function prune_alpha!(tree::SARSOPTree, δ, eps=0.0) end end end - + non_dominant_indices = non_dominant_indices[.!keep_non_dom] deleteat!(Γ, non_dominant_indices) tree.prune_data.last_Γ_size = length(Γ) @@ -116,7 +117,7 @@ end function prune_strictly_dominated!(tree::SARSOPTree, eps=1e-10) Γ = tree.Γ Γ_new_idxs = [] - + for (α_try_idx, α_try) in enumerate(Γ) marked_for_deletion = falses(length(Γ_new_idxs)) dominated = false @@ -134,7 +135,7 @@ function prune_strictly_dominated!(tree::SARSOPTree, eps=1e-10) push!(Γ_new_idxs, α_try_idx) end end - + Γ_idxs_to_delete = setdiff(1:length(Γ), Γ_new_idxs) deleteat!(Γ, Γ_idxs_to_delete) end From 03a751ed191e7f2cff3b7fadc7df44abadcba345 Mon Sep 17 00:00:00 2001 From: Dylan Asmar Date: Fri, 28 Jun 2024 17:29:59 -0600 Subject: [PATCH 6/9] Added tests for prune functions --- test/prune.jl | 19 +++++++++++++++++++ test/runtests.jl | 2 ++ 2 files changed, 21 insertions(+) create mode 100644 test/prune.jl diff --git a/test/prune.jl b/test/prune.jl new file mode 100644 index 0000000..f55fb36 --- /dev/null +++ b/test/prune.jl @@ -0,0 +1,19 @@ +@testset "prune" begin + # NativeSARSOP.strictly_dominates + a1 = [1.0, 2.0, 3.0] + a2 = [1.0, 2.1, 2.9] + a3 = [0.9, 1.9, 2.9] + @test !NativeSARSOP.strictly_dominates(a1, a2, 1e-10) + @test NativeSARSOP.strictly_dominates(a1, a1, 1e-10) + @test NativeSARSOP.strictly_dominates(a1, a3, 1e-10) + + # NativeSARSOP.intersection_distance + b = SparseVector([1.0, 0.0]) + a1 = [1.0, 0.0] + a2 = [0.0, 1.0] + @test isapprox(NativeSARSOP.intersection_distance(a1, a2, b), + sqrt(0.5^2 + 0.5^2), atol=1e-10) + + b = SparseVector([0.5, 0.5]) + @test isapprox(NativeSARSOP.intersection_distance(a1, a2, b), 0.0, atol=1e-10) +end diff --git a/test/runtests.jl b/test/runtests.jl index 806fca1..044dfde 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,6 +27,8 @@ include("sample.jl") include("updater.jl") +include("prune.jl") + include("tree.jl") @testset "Tiger POMDP" begin From 4794b0e3aeaf9ed542633e6add8aea1c83740f06 Mon Sep 17 00:00:00 2001 From: Dylan Asmar <91484811+dylan-asmar@users.noreply.github.com> Date: Tue, 9 Jul 2024 22:32:30 -0700 Subject: [PATCH 7/9] Update src/prune.jl Co-authored-by: Tyler Becker <52610169+WhiffleFish@users.noreply.github.com> --- src/prune.jl | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/prune.jl b/src/prune.jl index 667fa4b..7b1812c 100644 --- a/src/prune.jl +++ b/src/prune.jl @@ -49,12 +49,18 @@ function prune!(tree::SARSOPTree) end end -@inline function intersection_distance(α1, α2, b) - diff = α1 - α2 - dot_b = dot(diff, b) - dot_diff = dot(diff, diff) - d = dot_b / sqrt(dot_diff) - return d +function intersection_distance(α1, α2, b) + dot_sum = 0.0 + I,B = b.nzind, b.nzval + for _i ∈ eachindex(I) + i = I[_i] + dot_sum += (α1[i] - α2[i])*B[_i] + end + s = 0.0 + for i ∈ eachindex(α1, α2) + s += (α1[i] - α2[i])^2 + end + return dot_sum / sqrt(s) end function prune_alpha!(tree::SARSOPTree, δ, eps=0.0) From edbe96016fce6c60a2f22b0bbe9c38a2da5a5b91 Mon Sep 17 00:00:00 2001 From: Dylan Asmar <91484811+dylan-asmar@users.noreply.github.com> Date: Tue, 9 Jul 2024 22:32:55 -0700 Subject: [PATCH 8/9] Update src/prune.jl Co-authored-by: Tyler Becker <52610169+WhiffleFish@users.noreply.github.com> --- src/prune.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/prune.jl b/src/prune.jl index 7b1812c..58b664b 100644 --- a/src/prune.jl +++ b/src/prune.jl @@ -122,7 +122,7 @@ end function prune_strictly_dominated!(tree::SARSOPTree, eps=1e-10) Γ = tree.Γ - Γ_new_idxs = [] + Γ_new_idxs = Int[] for (α_try_idx, α_try) in enumerate(Γ) marked_for_deletion = falses(length(Γ_new_idxs)) From 67133d0e7cfeadb14d53ba5c0f8966f0d6b14425 Mon Sep 17 00:00:00 2001 From: Dylan Asmar Date: Mon, 22 Jul 2024 14:31:49 -0700 Subject: [PATCH 9/9] Updated `prune_strictly_dominated!` to reduce allocations --- src/prune.jl | 41 +++++++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/src/prune.jl b/src/prune.jl index 58b664b..395da7d 100644 --- a/src/prune.jl +++ b/src/prune.jl @@ -51,10 +51,10 @@ end function intersection_distance(α1, α2, b) dot_sum = 0.0 - I,B = b.nzind, b.nzval + I, B = b.nzind, b.nzval for _i ∈ eachindex(I) i = I[_i] - dot_sum += (α1[i] - α2[i])*B[_i] + dot_sum += (α1[i] - α2[i]) * B[_i] end s = 0.0 for i ∈ eachindex(α1, α2) @@ -122,26 +122,47 @@ end function prune_strictly_dominated!(tree::SARSOPTree, eps=1e-10) Γ = tree.Γ - Γ_new_idxs = Int[] + Γ_new_idxs = Vector{Int}(undef, length(Γ)) + keep = trues(length(Γ)) + idx_count = 0 for (α_try_idx, α_try) in enumerate(Γ) - marked_for_deletion = falses(length(Γ_new_idxs)) dominated = false - for (jj, α_in_idx) in enumerate(Γ_new_idxs) + for jj in 1:idx_count + α_in_idx = Γ_new_idxs[jj] α_in = Γ[α_in_idx] if strictly_dominates(α_try, α_in, eps) - marked_for_deletion[jj] = true + keep[jj] = false elseif strictly_dominates(α_in, α_try, eps) dominated = true break end end if !dominated - Γ_new_idxs = Γ_new_idxs[.!marked_for_deletion] - push!(Γ_new_idxs, α_try_idx) + new_idx_count = 0 + for jj in 1:idx_count + if keep[jj] + new_idx_count += 1 + Γ_new_idxs[new_idx_count] = Γ_new_idxs[jj] + end + end + new_idx_count += 1 + Γ_new_idxs[new_idx_count] = α_try_idx + idx_count = new_idx_count + fill!(keep, true) end end - Γ_idxs_to_delete = setdiff(1:length(Γ), Γ_new_idxs) - deleteat!(Γ, Γ_idxs_to_delete) + resize!(Γ_new_idxs, idx_count) + + to_delete = trues(length(Γ)) + for idx in Γ_new_idxs + to_delete[idx] = false + end + + for ii in length(Γ):-1:1 + if to_delete[ii] + deleteat!(Γ, ii) + end + end end