From 71107461acd271e54f5779e8f88aeb0e65512444 Mon Sep 17 00:00:00 2001 From: Max Haas-Heger Date: Mon, 17 Apr 2023 22:55:17 +0000 Subject: [PATCH] Add 'within_bounding_box()' method --- src/kdtree.rs | 55 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/util.rs | 22 ++++++++++++++++++++- 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/src/kdtree.rs b/src/kdtree.rs index b2823a3..446f0cc 100644 --- a/src/kdtree.rs +++ b/src/kdtree.rs @@ -113,6 +113,38 @@ impl + std::cmp::Pa Ok(evaluated.into_sorted_vec().into_iter().map(Into::into).collect()) } + pub fn within_bounding_box(&self, min_bounds: &[A], max_bounds: &[A]) -> Result, ErrorKind> { + self.check_point(min_bounds)?; + self.check_point(max_bounds)?; + if self.size == 0 { + return Ok(vec![]); + } + let mut pending = vec![]; + let mut evaluated = vec![]; + pending.push(self); + while !pending.is_empty() { + let curr = pending.pop().unwrap(); + if curr.is_leaf() { + let points = curr.points.as_ref().unwrap().iter(); + let bucket = curr.bucket.as_ref().unwrap().iter(); + for (p, b) in points.zip(bucket) { + if util::within_bounding_box(p.as_ref(), min_bounds, max_bounds) { + evaluated.push(b); + } + } + } else { + if curr.belongs_in_left(min_bounds) { + pending.push(curr.left.as_ref().unwrap()); + } + if !curr.belongs_in_left(max_bounds) { + pending.push(curr.right.as_ref().unwrap()); + } + } + } + + Ok(evaluated) + } + fn nearest_step<'b, F>( &self, point: &[A], @@ -557,4 +589,27 @@ mod tests { tree.add([max], ()).unwrap(); } } + + #[test] + fn test_within_bounding_box() { + let mut tree = KdTree::with_capacity(2, 2); + for i in 0..10 { + for j in 0..10 { + let id = i.to_string() + &j.to_string(); + tree.add([i as f64, j as f64], id).unwrap(); + } + } + + let within: Vec = tree.within_bounding_box(&[4.0, 4.0], &[6.0, 6.0]).unwrap().iter().cloned().cloned().collect(); + assert_eq!(within.len(), 9); + assert!(within.contains(&String::from("44"))); + assert!(within.contains(&String::from("45"))); + assert!(within.contains(&String::from("46"))); + assert!(within.contains(&String::from("54"))); + assert!(within.contains(&String::from("55"))); + assert!(within.contains(&String::from("56"))); + assert!(within.contains(&String::from("64"))); + assert!(within.contains(&String::from("65"))); + assert!(within.contains(&String::from("66"))); + } } diff --git a/src/util.rs b/src/util.rs index c849996..1616475 100644 --- a/src/util.rs +++ b/src/util.rs @@ -18,10 +18,21 @@ where distance(p1, &p2[..]) } +pub fn within_bounding_box(p: &[T], min_bounds: &[T], max_bounds: &[T]) -> bool +where T: Float, +{ + for ((l, h), v) in min_bounds.iter().zip(max_bounds.iter()).zip(p) { + if v < l || v > h { + return false; + } + } + true +} + #[cfg(test)] mod tests { use super::distance_to_space; - use crate::distance::squared_euclidean; + use crate::{distance::squared_euclidean, util::within_bounding_box}; use std::f64::{INFINITY, NEG_INFINITY}; #[test] @@ -63,4 +74,13 @@ mod tests { ); assert_eq!(dis, 4.0); } + + #[test] + fn test_within_bounding_box() { + assert!(within_bounding_box(&[1.0, 1.0], &[0.0, 0.0], &[2.0, 2.0])); + assert!(within_bounding_box(&[1.0, 1.0], &[1.0, 1.0], &[2.0, 2.0])); + assert!(within_bounding_box(&[1.0, 1.0], &[0.0, 0.0], &[1.0, 1.0])); + assert!(!within_bounding_box(&[2.0, 2.0], &[0.0, 0.0], &[1.0, 1.0])); + assert!(!within_bounding_box(&[0.0, 0.0], &[1.0, 1.0], &[2.0, 2.0])); + } }