diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6d1767f --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +**/.idea/ diff --git a/README.md b/README.md index dbd0821..5c696fe 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,3 @@ # rtreego -对https://github.com/dhconnelly/rtreego 精度进行了更改 +对https://github.com/dhconnelly/rtreego 精度进行了更改,改为float32 + diff --git a/filter.go b/filter.go new file mode 100644 index 0000000..3e037bd --- /dev/null +++ b/filter.go @@ -0,0 +1,32 @@ +package rtreego + +// Filter is an interface for filtering leaves during search. The parameters +// should be treated as read-only. If refuse is true, the current entry will +// not be added to the result set. If abort is true, the search is aborted and +// the current result set will be returned. +type Filter func(results []Spatial, object Spatial) (refuse, abort bool) + +// ApplyFilters applies the given filters and returns whether the entry is +// refused and/or the search should be aborted. If a filter refuses an entry, +// the following filters are not applied for the entry. If a filter aborts, the +// search terminates without further applying any filter. +func applyFilters(results []Spatial, object Spatial, filters []Filter) (bool, bool) { + for _, filter := range filters { + refuse, abort := filter(results, object) + if refuse || abort { + return refuse, abort + } + } + return false, false +} + +// LimitFilter checks if the results have reached the limit size and aborts if so. +func LimitFilter(limit int) Filter { + return func(results []Spatial, object Spatial) (refuse, abort bool) { + if len(results) >= limit { + return true, true + } + + return false, false + } +} diff --git a/geom.go b/geom.go new file mode 100644 index 0000000..a0de660 --- /dev/null +++ b/geom.go @@ -0,0 +1,390 @@ +// Copyright 2012 Daniel Connelly. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rtreego + +import ( + "fmt" + "math" + "strings" +) + +// DimError represents a failure due to mismatched dimensions. +type DimError struct { + Expected int + Actual int +} + +func (err DimError) Error() string { + return "rtreego: dimension mismatch" +} + +// DistError is an improper distance measurement. It implements the error +// and is generated when a distance-related assertion fails. +type DistError float32 + +func (err DistError) Error() string { + return "rtreego: improper distance" +} + +// Point represents a point in n-dimensional Euclidean space. +type Point []float32 + +func (p Point) Copy() Point { + result := make(Point, len(p)) + copy(result, p) + return result +} + +// Dist computes the Euclidean distance between two points p and q. +func (p Point) dist(q Point) float32 { + if len(p) != len(q) { + panic(DimError{len(p), len(q)}) + } + sum := float32(0.0) + for i := range p { + dx := p[i] - q[i] + sum += dx * dx + } + return float32(math.Sqrt(float64(sum))) +} + +// minDist computes the square of the distance from a point to a rectangle. +// If the point is contained in the rectangle then the distance is zero. +// +// Implemented per Definition 2 of "Nearest Neighbor Queries" by +// N. Roussopoulos, S. Kelley and F. Vincent, ACM SIGMOD, pages 71-79, 1995. +func (p Point) minDist(r Rect) float32 { + if len(p) != len(r.p) { + panic(DimError{len(p), len(r.p)}) + } + + sum := float32(0.0) + for i, pi := range p { + if pi < r.p[i] { + d := pi - r.p[i] + sum += d * d + } else if pi > r.q[i] { + d := pi - r.q[i] + sum += d * d + } else { + sum += 0 + } + } + return sum +} + +// minMaxDist computes the minimum of the maximum distances from p to points +// on r. If r is the bounding box of some geometric objects, then there is +// at least one object contained in r within minMaxDist(p, r) of p. +// +// Implemented per Definition 4 of "Nearest Neighbor Queries" by +// N. Roussopoulos, S. Kelley and F. Vincent, ACM SIGMOD, pages 71-79, 1995. +func (p Point) minMaxDist(r Rect) float32 { + if len(p) != len(r.p) { + panic(DimError{len(p), len(r.p)}) + } + + // by definition, MinMaxDist(p, r) = + // min{1<=k<=n}(|pk - rmk|^2 + sum{1<=i<=n, i != k}(|pi - rMi|^2)) + // where rmk and rMk are defined as follows: + + rm := func(k int) float32 { + if p[k] <= (r.p[k]+r.q[k])/2 { + return r.p[k] + } + return r.q[k] + } + + rM := func(k int) float32 { + if p[k] >= (r.p[k]+r.q[k])/2 { + return r.p[k] + } + return r.q[k] + } + + // This formula can be computed in linear time by precomputing + // S = sum{1<=i<=n}(|pi - rMi|^2). + + S := float32(0.0) + for i := range p { + d := p[i] - rM(i) + S += d * d + } + + // todo 这里的精度问题? + // Compute MinMaxDist using the precomputed S. + min := float32(math.MaxFloat32) + for k := range p { + d1 := p[k] - rM(k) + d2 := p[k] - rm(k) + d := S - d1*d1 + d2*d2 + if d < min { + min = d + } + } + + return min +} + +// Rect represents a subset of n-dimensional Euclidean space of the form +// [a1, b1] x [a2, b2] x ... x [an, bn], where ai < bi for all 1 <= i <= n. +type Rect struct { + p, q Point // Enforced by NewRect: p[i] <= q[i] for all i. +} + +// 详细输出 rect +func (r Rect) RectDetail() { + if r.p != nil && len(r.p) != 0 { + // 就输出 + for i, v := range r.p { + fmt.Printf("p[%d]: %v ", i, v) + } + } + fmt.Println() + if r.q != nil && len(r.q) != 0 { + // 就输出 + for i, v := range r.q { + fmt.Printf("q[%d]: %v ", i, v) + } + } + fmt.Println() +} + +// PointCoord returns the coordinate of the point of the rectangle at i +func (r Rect) PointCoord(i int) float32 { + return r.p[i] +} + +// LengthsCoord returns the coordinate of the lengths of the rectangle at i +func (r Rect) LengthsCoord(i int) float32 { + return r.q[i] - r.p[i] +} + +// Equal returns true if the two rectangles are equal +func (r Rect) Equal(other Rect) bool { + for i, e := range r.p { + if e != other.p[i] { + return false + } + } + for i, e := range r.q { + if e != other.q[i] { + return false + } + } + return true +} + +func (r Rect) String() string { + s := make([]string, len(r.p)) + for i, a := range r.p { + b := r.q[i] + s[i] = fmt.Sprintf("[%.2f, %.2f]", a, b) + } + return strings.Join(s, "x") +} + +// NewRect constructs and returns a pointer to a Rect given a corner point and +// the lengths of each dimension. The point p should be the most-negative point +// on the rectangle (in every dimension) and every length should be positive. +func NewRect(p Point, lengths []float32) (r Rect, err error) { + r.p = p + if len(p) != len(lengths) { + err = &DimError{len(p), len(lengths)} + return + } + r.q = make([]float32, len(p)) + //fmt.Println(" test") + for i := range p { + if lengths[i] <= 0 { + err = DistError(lengths[i]) + return + } + r.q[i] = p[i] + lengths[i] + //fmt.Printf("q: %v, p:%v, length: %v ", r.q[i], p[i], lengths[i]) + } + //fmt.Println(" test") + //fmt.Println("new rect:") + //r.RectDetail() + return +} + +// NewRectFromPoints constructs and returns a pointer to a Rect given a corner points. +func NewRectFromPoints(minPoint, maxPoint Point) (r Rect, err error) { + if len(minPoint) != len(maxPoint) { + err = &DimError{len(minPoint), len(maxPoint)} + return + } + + // check that min and max point coordinates require swapping + copied := false + for i, p := range minPoint { + if minPoint[i] > maxPoint[i] { + if !copied { + minPoint = minPoint.Copy() + maxPoint = maxPoint.Copy() + copied = true + } + minPoint[i] = maxPoint[i] + maxPoint[i] = p + } + } + + r = Rect{p: minPoint, q: maxPoint} + return +} + +// Size computes the measure of a rectangle (the product of its side lengths). +func (r Rect) Size() float32 { + size := float32(1.0) + for i, a := range r.p { + b := r.q[i] + size *= b - a + } + return size +} + +// margin computes the sum of the edge lengths of a rectangle. +func (r Rect) margin() float32 { + // The number of edges in an n-dimensional rectangle is n * 2^(n-1) + // (http://en.wikipedia.org/wiki/Hypercube_graph). Thus the number + // of edges of length (ai - bi), where the rectangle is determined + // by p = (a1, a2, ..., an) and q = (b1, b2, ..., bn), is 2^(n-1). + // + // The margin of the rectangle, then, is given by the formula + // 2^(n-1) * [(b1 - a1) + (b2 - a2) + ... + (bn - an)]. + dim := len(r.p) + sum := float32(0.0) + for i, a := range r.p { + b := r.q[i] + sum += b - a + } + return float32(math.Pow(2, float64(dim-1))) * sum +} + +// containsPoint tests whether p is located inside or on the boundary of r. +func (r Rect) containsPoint(p Point) bool { + if len(p) != len(r.p) { + panic(DimError{len(r.p), len(p)}) + } + + for i, a := range p { + // p is contained in (or on) r if and only if p <= a <= q for + // every dimension. + if a < r.p[i] || a > r.q[i] { + return false + } + } + + return true +} + +// containsRect tests whether r2 is is located inside r1. +func (r Rect) containsRect(r2 Rect) bool { + if len(r.p) != len(r2.p) { + panic(DimError{len(r.p), len(r2.p)}) + } + //fmt.Printf("rect1: %v \n rect2 %v\n", r, r2) + //fmt.Printf("key:%v\n", r.q[1]) + //fmt.Printf("key2:%v\n", r2.q[1]) + for i, a1 := range r.p { + b1, a2, b2 := r.q[i], r2.p[i], r2.q[i] + // enforced by constructor: a1 <= b1 and a2 <= b2. + // so containment holds if and only if a1 <= a2 <= b2 <= b1 + // for every dimension. + //fmt.Printf("%v %v %v %v i:%d \n", a1, r2.p[i], r.q[i], r2.q[i], i) + + // 精度检验 + if a1-a2 > 0.000001 || b2-b1 > 0.000001 { + return false + } + // 原本的 + //if a1 > a2 || b2 > b1 { + // return false + //} + } + + return true +} + +// intersect computes the intersection of two rectangles. If no intersection +// exists, the intersection is nil. +func intersect(r1, r2 Rect) bool { + dim := len(r1.p) + if len(r2.p) != dim { + panic(DimError{dim, len(r2.p)}) + } + + // There are four cases of overlap: + // + // 1. a1------------b1 + // a2------------b2 + // p--------q + // + // 2. a1------------b1 + // a2------------b2 + // p--------q + // + // 3. a1-----------------b1 + // a2-------b2 + // p--------q + // + // 4. a1-------b1 + // a2-----------------b2 + // p--------q + // + // Thus there are only two cases of non-overlap: + // + // 1. a1------b1 + // a2------b2 + // + // 2. a1------b1 + // a2------b2 + // + // Enforced by constructor: a1 <= b1 and a2 <= b2. So we can just + // check the endpoints. + + for i := range r1.p { + a1, b1, a2, b2 := r1.p[i], r1.q[i], r2.p[i], r2.q[i] + if b2 <= a1 || b1 <= a2 { + return false + } + } + return true +} + +// ToRect constructs a rectangle containing p with side lengths 2*tol. +func (p Point) ToRect(tol float32) Rect { + dim := len(p) + a, b := make([]float32, dim), make([]float32, dim) + for i := range p { + a[i] = p[i] - tol + b[i] = p[i] + tol + } + return Rect{a, b} +} + +// boundingBox constructs the smallest rectangle containing both r1 and r2. +func boundingBox(r1, r2 Rect) (bb Rect) { + dim := len(r1.p) + bb.p = make([]float32, dim) + bb.q = make([]float32, dim) + if len(r2.p) != dim { + panic(DimError{dim, len(r2.p)}) + } + for i := 0; i < dim; i++ { + if r1.p[i] <= r2.p[i] { + bb.p[i] = r1.p[i] + } else { + bb.p[i] = r2.p[i] + } + if r1.q[i] <= r2.q[i] { + bb.q[i] = r2.q[i] + } else { + bb.q[i] = r1.q[i] + } + } + return +} diff --git a/geom_test.go b/geom_test.go new file mode 100644 index 0000000..c84328e --- /dev/null +++ b/geom_test.go @@ -0,0 +1,379 @@ +// Copyright 2012 Daniel Connelly. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rtreego + +import ( + "math" + "testing" +) + +const EPS = 0.000001 + +func TestDist(t *testing.T) { + p := Point{1, 2, 3} + q := Point{4, 5, 6} + dist := float32(math.Sqrt(27)) + if d := p.dist(q); d != dist { + t.Errorf("dist(%v, %v) = %v; expected %v", p, q, d, dist) + } +} + +func TestNewRect(t *testing.T) { + p := Point{1.0, -2.5, 3.0} + q := Point{3.5, 5.5, 4.5} + lengths := []float32{2.5, 8.0, 1.5} + + rect, err := NewRect(p, lengths) + if err != nil { + t.Errorf("Error on NewRect(%v, %v): %v", p, lengths, err) + } + if d := p.dist(rect.p); d > EPS { + t.Errorf("Expected p == rect.p") + } + if d := q.dist(rect.q); d > EPS { + t.Errorf("Expected q == rect.q") + } +} + +func TestNewRectFromPoints(t *testing.T) { + p := Point{1.0, -2.5, 3.0} + q := Point{3.5, 5.5, 4.5} + + rect, err := NewRectFromPoints(p, q) + if err != nil { + t.Errorf("Error on NewRect(%v, %v): %v", p, q, err) + } + if d := p.dist(rect.p); d > EPS { + t.Errorf("Expected p == rect.p") + } + if d := q.dist(rect.q); d > EPS { + t.Errorf("Expected q == rect.q") + } +} + +func TestNewRectFromPointsWithSwapPoints(t *testing.T) { + p := Point{1.0, -2.5, 3.0} + q := Point{3.5, 5.5, 4.5} + + rect, err := NewRectFromPoints(q, p) + if err != nil { + t.Errorf("Error on NewRect(%v, %v): %v", q, p, err) + } + + if d := p.dist(rect.p); d > EPS { + t.Errorf("Expected p == rect.") + } + if d := q.dist(rect.q); d > EPS { + t.Errorf("Expected q == rect.q") + } +} + +func TestNewRectDimMismatch(t *testing.T) { + p := Point{-7.0, 10.0} + lengths := []float32{2.5, 8.0, 1.5} + _, err := NewRect(p, lengths) + if _, ok := err.(*DimError); !ok { + t.Errorf("Expected DimError on NewRect(%v, %v)", p, lengths) + } +} + +func TestNewRectDistError(t *testing.T) { + p := Point{1.0, -2.5, 3.0} + lengths := []float32{2.5, -8.0, 1.5} + _, err := NewRect(p, lengths) + if _, ok := err.(DistError); !ok { + t.Errorf("Expected distError on NewRect(%v, %v)", p, lengths) + } +} + +func TestRectPointCoord(t *testing.T) { + p := Point{1.0, -2.5} + lengths := []float32{2.5, 8.0} + rect, _ := NewRect(p, lengths) + + f := rect.PointCoord(0) + if f != 1.0 { + t.Errorf("Expected %v.PointCoord(0) == 1.0, got %v", rect, f) + } + f = rect.PointCoord(1) + if f != -2.5 { + t.Errorf("Expected %v.PointCoord(1) == -2.5, got %v", rect, f) + } +} + +func TestRectLengthsCoord(t *testing.T) { + p := Point{1.0, -2.5} + lengths := []float32{2.5, 8.0} + rect, _ := NewRect(p, lengths) + + f := rect.LengthsCoord(0) + if f != 2.5 { + t.Errorf("Expected %v.LengthsCoord(0) == 2.5, got %v", rect, f) + } + f = rect.LengthsCoord(1) + if f != 8.0 { + t.Errorf("Expected %v.LengthsCoord(1) == 8.0, got %v", rect, f) + } +} + +func TestRectEqual(t *testing.T) { + p := Point{1.0, -2.5, 3.0} + lengths := []float32{2.5, 8.0, 1.5} + a, _ := NewRect(p, lengths) + b, _ := NewRect(p, lengths) + c, _ := NewRect(Point{0.0, -2.5, 3.0}, lengths) + if !a.Equal(b) { + t.Errorf("Expected %v.Equal(%v) to return true", a, b) + } + if a.Equal(c) { + t.Errorf("Expected %v.Equal(%v) to return false", a, c) + } +} + +func TestRectSize(t *testing.T) { + p := Point{1.0, -2.5, 3.0} + lengths := []float32{2.5, 8.0, 1.5} + rect, _ := NewRect(p, lengths) + size := lengths[0] * lengths[1] * lengths[2] + actual := rect.Size() + if size != actual { + t.Errorf("Expected %v.Size() == %v, got %v", rect, size, actual) + } +} + +func TestRectMargin(t *testing.T) { + p := Point{1.0, -2.5, 3.0} + lengths := []float32{2.5, 8.0, 1.5} + rect, _ := NewRect(p, lengths) + size := float32(4*2.5 + 4*8.0 + 4*1.5) + actual := rect.margin() + if size != actual { + t.Errorf("self : Expected %f.margin() == %f, got %f", rect, size, actual) + t.Errorf("Expected %v.margin() == %v, got %v", rect, size, actual) + } +} + +func TestContainsPoint(t *testing.T) { + p := Point{3.7, -2.4, 0.0} + lengths := []float32{6.2, 1.1, 4.9} + rect, _ := NewRect(p, lengths) + + q := Point{4.5, -1.7, 4.8} + if yes := rect.containsPoint(q); !yes { + t.Errorf("Expected %v contains %v", rect, q) + } +} + +func TestDoesNotContainPoint(t *testing.T) { + p := Point{3.7, -2.4, 0.0} + lengths := []float32{6.2, 1.1, 4.9} + rect, _ := NewRect(p, lengths) + + q := Point{4.5, -1.7, -3.2} + if yes := rect.containsPoint(q); yes { + t.Errorf("Expected %v doesn't contain %v", rect, q) + } +} + +func TestContainsRect(t *testing.T) { + p := Point{3.7, -2.4, 0.0} + lengths1 := []float32{6.2, 1.1, 4.9} + rect1, _ := NewRect(p, lengths1) + + q := Point{4.1, -1.9, 1.0} + lengths2 := []float32{3.2, 0.6, 3.7} + rect2, _ := NewRect(q, lengths2) + //fmt.Println("rect1") + //rect1.RectDetail() + //fmt.Println("rect2") + //rect2.RectDetail() + //fmt.Println("+==========") + if yes := rect1.containsRect(rect2); !yes { + t.Errorf("Expected %v.containsRect(%v", rect1, rect2) + } +} + +func TestDoesNotContainRectOverlaps(t *testing.T) { + p := Point{3.7, -2.4, 0.0} + lengths1 := []float32{6.2, 1.1, 4.9} + rect1, _ := NewRect(p, lengths1) + + q := Point{4.1, -1.9, 1.0} + lengths2 := []float32{3.2, 1.4, 3.7} + rect2, _ := NewRect(q, lengths2) + + if yes := rect1.containsRect(rect2); yes { + t.Errorf("Expected %v doesn't contain %v", rect1, rect2) + } +} + +func TestDoesNotContainRectDisjoint(t *testing.T) { + p := Point{3.7, -2.4, 0.0} + lengths1 := []float32{6.2, 1.1, 4.9} + rect1, _ := NewRect(p, lengths1) + + q := Point{1.2, -19.6, -4.0} + lengths2 := []float32{2.2, 5.9, 0.5} + rect2, _ := NewRect(q, lengths2) + + if yes := rect1.containsRect(rect2); yes { + t.Errorf("Expected %v doesn't contain %v", rect1, rect2) + } +} + +func TestNoIntersection(t *testing.T) { + p := Point{1, 2, 3} + lengths1 := []float32{1, 1, 1} + rect1, _ := NewRect(p, lengths1) + + q := Point{-1, -2, -3} + lengths2 := []float32{2.5, 3, 6.5} + rect2, _ := NewRect(q, lengths2) + + // rect1 and rect2 fail to overlap in just one dimension (second) + + if intersect(rect1, rect2) { + t.Errorf("Expected intersect(%v, %v) == false", rect1, rect2) + } +} + +func TestNoIntersectionJustTouches(t *testing.T) { + p := Point{1, 2, 3} + lengths1 := []float32{1, 1, 1} + rect1, _ := NewRect(p, lengths1) + + q := Point{-1, -2, -3} + lengths2 := []float32{2.5, 4, 6.5} + rect2, _ := NewRect(q, lengths2) + + // rect1 and rect2 fail to overlap in just one dimension (second) + + if intersect(rect1, rect2) { + t.Errorf("Expected intersect(%v, %v) == false", rect1, rect2) + } +} + +func TestContainmentIntersection(t *testing.T) { + p := Point{1, 2, 3} + lengths1 := []float32{1, 1, 1} + rect1, _ := NewRect(p, lengths1) + + q := Point{1, 2.2, 3.3} + lengths2 := []float32{0.5, 0.5, 0.5} + rect2, _ := NewRect(q, lengths2) + + r := Point{1, 2.2, 3.3} + s := Point{1.5, 2.7, 3.8} + + if !intersect(rect1, rect2) { + t.Errorf("intersect(%v, %v) != %v, %v", rect1, rect2, r, s) + } +} + +func TestOverlapIntersection(t *testing.T) { + p := Point{1, 2, 3} + lengths1 := []float32{1, 2.5, 1} + rect1, _ := NewRect(p, lengths1) + + q := Point{1, 4, -3} + lengths2 := []float32{3, 2, 6.5} + rect2, _ := NewRect(q, lengths2) + + r := Point{1, 4, 3} + s := Point{2, 4.5, 3.5} + + if !intersect(rect1, rect2) { + t.Errorf("intersect(%v, %v) != %v, %v", rect1, rect2, r, s) + } +} + +func TestToRect(t *testing.T) { + x := Point{3.7, -2.4, 0.0} + tol := float32(0.05) + rect := x.ToRect(tol) + + p := Point{3.65, -2.45, -0.05} + q := Point{3.75, -2.35, 0.05} + d1 := p.dist(rect.p) + d2 := q.dist(rect.q) + //fmt.Printf("d1: %f , d2: %f\n", d1, d2) + if d1 > EPS || d2 > EPS { + t.Errorf("Expected %v.ToRect(%v) == %v, %v, got %v", x, tol, p, q, rect) + } +} + +func TestBoundingBox(t *testing.T) { + p := Point{3.7, -2.4, 0.0} + lengths1 := []float32{1, 15, 3} + rect1, _ := NewRect(p, lengths1) + + q := Point{-6.5, 4.7, 2.5} + lengths2 := []float32{4, 5, 6} + rect2, _ := NewRect(q, lengths2) + + r := Point{-6.5, -2.4, 0.0} + s := Point{4.7, 12.6, 8.5} + + bb := boundingBox(rect1, rect2) + d1 := r.dist(bb.p) + d2 := s.dist(bb.q) + if d1 > EPS || d2 > EPS { + t.Errorf("boundingBox(%v, %v) != %v, %v, got %v", rect1, rect2, r, s, bb) + } +} + +func TestBoundingBoxContains(t *testing.T) { + p := Point{3.7, -2.4, 0.0} + lengths1 := []float32{1, 15, 3} + rect1, _ := NewRect(p, lengths1) + + q := Point{4.0, 0.0, 1.5} + lengths2 := []float32{0.56, 6.222222, 0.946} + rect2, _ := NewRect(q, lengths2) + + bb := boundingBox(rect1, rect2) + d1 := rect1.p.dist(bb.p) + d2 := rect1.q.dist(bb.q) + if d1 > EPS || d2 > EPS { + t.Errorf("boundingBox(%v, %v) != %v, got %v", rect1, rect2, rect1, bb) + } +} + +func TestMinDistZero(t *testing.T) { + p := Point{1, 2, 3} + r := p.ToRect(1) + if d := p.minDist(r); d > EPS { + t.Errorf("Expected %v.minDist(%v) == 0, got %v", p, r, d) + } +} + +func TestMinDistPositive(t *testing.T) { + p := Point{1, 2, 3} + r := Rect{Point{-1, -4, 7}, Point{2, -2, 9}} + expected := float32((-2-2)*(-2-2) + (7-3)*(7-3)) + if d := p.minDist(r); math.Abs(float64(d-expected)) > EPS { + t.Errorf("Expected %v.minDist(%v) == %v, got %v", p, r, expected, d) + } +} + +func TestMinMaxdist(t *testing.T) { + p := Point{-3, -2, -1} + r := Rect{Point{0, 0, 0}, Point{1, 2, 3}} + + // furthest points from p on the faces closest to p in each dimension + q1 := Point{0, 2, 3} + q2 := Point{1, 0, 3} + q3 := Point{1, 2, 0} + + // find the closest distance from p to one of these furthest points + d1 := p.dist(q1) + d2 := p.dist(q2) + d3 := p.dist(q3) + expected := math.Min(float64(d1*d1), math.Min(float64(d2*d2), float64(d3*d3))) + + if d := p.minMaxDist(r); math.Abs(float64(d)-expected) > EPS { + t.Errorf("Expected %v.minMaxDist(%v) == %v, got %v", p, r, expected, d) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..6f70a66 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/dbjtech/rtreego + +go 1.20 diff --git a/rtree.go b/rtree.go new file mode 100644 index 0000000..ee1e387 --- /dev/null +++ b/rtree.go @@ -0,0 +1,879 @@ +// Copyright 2012 Daniel Connelly. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package rtreego is a library for efficiently storing and querying spatial data. +package rtreego + +import ( + "fmt" + "math" + "sort" +) + +// Comparator compares two spatials and returns whether they are equal. +type Comparator func(obj1, obj2 Spatial) (equal bool) + +func defaultComparator(obj1, obj2 Spatial) bool { + return obj1 == obj2 +} + +// Rtree represents an R-tree, a balanced search tree for storing and querying +// spatial objects. Dim specifies the number of spatial dimensions and +// MinChildren/MaxChildren specify the minimum/maximum branching factors. +type Rtree struct { + Dim int + MinChildren int + MaxChildren int + root *node + size int + height int + + // deleted is a temporary buffer to avoid memory allocations in Delete. + // It is just an optimization and not part of the data structure. + deleted []*node +} + +// NewTree returns an Rtree. If the number of objects given on initialization +// is larger than max, the Rtree will be initialized using the Overlap +// Minimizing Top-down bulk-loading algorithm. +func NewTree(dim, min, max int, objs ...Spatial) *Rtree { + rt := &Rtree{ + Dim: dim, + MinChildren: min, + MaxChildren: max, + height: 1, + root: &node{ + entries: []entry{}, + leaf: true, + level: 1, + }, + } + + if len(objs) <= rt.MaxChildren { + for _, obj := range objs { + rt.Insert(obj) + } + } else { + rt.bulkLoad(objs) + } + + return rt +} + +// Size returns the number of objects currently stored in tree. +func (tree *Rtree) Size() int { + return tree.size +} + +func (tree *Rtree) String() string { + return "foo" +} + +// Depth returns the maximum depth of tree. +func (tree *Rtree) Depth() int { + return tree.height +} + +type dimSorter struct { + dim int + objs []entry +} + +func (s *dimSorter) Len() int { + return len(s.objs) +} + +func (s *dimSorter) Swap(i, j int) { + s.objs[i], s.objs[j] = s.objs[j], s.objs[i] +} + +func (s *dimSorter) Less(i, j int) bool { + return s.objs[i].bb.p[s.dim] < s.objs[j].bb.p[s.dim] +} + +// walkPartitions splits objs into slices of maximum k elements and +// iterates over these partitions. +func walkPartitions(k int, objs []entry, iter func(parts []entry)) { + n := (len(objs) + k - 1) / k // ceil(len(objs) / k) + + for i := 1; i < n; i++ { + iter(objs[(i-1)*k : i*k]) + } + iter(objs[(n-1)*k:]) +} + +func sortByDim(dim int, objs []entry) { + sort.Sort(&dimSorter{dim, objs}) +} + +// bulkLoad bulk loads the Rtree using OMT algorithm. bulkLoad contains special +// handling for the root node. +func (tree *Rtree) bulkLoad(objs []Spatial) { + n := len(objs) + + // create entries for all the objects + entries := make([]entry, n) + for i := range objs { + entries[i] = entry{ + bb: objs[i].Bounds(), + obj: objs[i], + } + } + + // following equations are defined in the paper describing OMT + var ( + N = float32(n) + M = float32(tree.MaxChildren) + ) + // Eq1: height of the tree + // use log2 instead of log due to rounding errors with log, + // eg, math.Log(9) / math.Log(3) > 2 + h := math.Ceil(math.Log2(float64(N)) / math.Log2(float64(M))) + + // Eq2: size of subtrees at the root + nsub := math.Pow(float64(M), h-1) + + // Inner Eq3: number of subtrees at the root + s := math.Ceil(float64(N) / nsub) + + // Eq3: number of slices + S := math.Floor(math.Sqrt(s)) + + // sort all entries by first dimension + sortByDim(0, entries) + + tree.height = int(h) + tree.size = n + tree.root = tree.omt(int(h), int(S), entries, int(s)) +} + +// omt is the recursive part of the Overlap Minimizing Top-loading bulk- +// load approach. Returns the root node of a subtree. +func (tree *Rtree) omt(level, nSlices int, objs []entry, m int) *node { + // if number of objects is less than or equal than max children per leaf, + // we need to create a leaf node + if len(objs) <= m { + // as long as the recursion is not at the leaf, call it again + if level > 1 { + child := tree.omt(level-1, nSlices, objs, m) + n := &node{ + level: level, + entries: []entry{{ + bb: child.computeBoundingBox(), + child: child, + }}, + } + child.parent = n + return n + } + entries := make([]entry, len(objs)) + copy(entries, objs) + return &node{ + leaf: true, + entries: entries, + level: level, + } + } + + n := &node{ + level: level, + entries: make([]entry, 0, m), + } + + // maximum node size given at most M nodes at this level + k := (len(objs) + m - 1) / m // = ceil(N / M) + + // In the root level, split objs in nSlices. In all other levels, + // we use a single slice. + vertSize := len(objs) + if nSlices > 1 { + vertSize = nSlices * k + } + + // create sub trees + walkPartitions(vertSize, objs, func(vert []entry) { + // sort vertical slice by a different dimension on every level + sortByDim((tree.height-level+1)%tree.Dim, vert) + + // split slice into groups of size k + walkPartitions(k, vert, func(part []entry) { + child := tree.omt(level-1, 1, part, tree.MaxChildren) + child.parent = n + + n.entries = append(n.entries, entry{ + bb: child.computeBoundingBox(), + child: child, + }) + }) + }) + return n +} + +// node represents a tree node of an Rtree. +type node struct { + parent *node + entries []entry + level int // node depth in the Rtree + leaf bool +} + +func (n *node) String() string { + return fmt.Sprintf("node{leaf: %v, entries: %v}", n.leaf, n.entries) +} + +// entry represents a spatial index record stored in a tree node. +type entry struct { + bb Rect // bounding-box of all children of this entry + child *node + obj Spatial +} + +func (e entry) String() string { + if e.child != nil { + return fmt.Sprintf("entry{bb: %v, child: %v}", e.bb, e.child) + } + return fmt.Sprintf("entry{bb: %v, obj: %v}", e.bb, e.obj) +} + +// Spatial is an interface for objects that can be stored in an Rtree and queried. +type Spatial interface { + Bounds() Rect +} + +// Insertion + +// Insert inserts a spatial object into the tree. If insertion +// causes a leaf node to overflow, the tree is rebalanced automatically. +// +// Implemented per Section 3.2 of "R-trees: A Dynamic Index Structure for +// Spatial Searching" by A. Guttman, Proceedings of ACM SIGMOD, p. 47-57, 1984. +func (tree *Rtree) Insert(obj Spatial) { + e := entry{obj.Bounds(), nil, obj} + tree.insert(e, 1) + tree.size++ +} + +// insert adds the specified entry to the tree at the specified level. +func (tree *Rtree) insert(e entry, level int) { + leaf := tree.chooseNode(tree.root, e, level) + leaf.entries = append(leaf.entries, e) + + // update parent pointer if necessary + if e.child != nil { + e.child.parent = leaf + } + + // split leaf if overflows + var split *node + if len(leaf.entries) > tree.MaxChildren { + leaf, split = leaf.split(tree.MinChildren) + } + root, splitRoot := tree.adjustTree(leaf, split) + if splitRoot != nil { + oldRoot := root + tree.height++ + tree.root = &node{ + parent: nil, + level: tree.height, + entries: []entry{ + {bb: oldRoot.computeBoundingBox(), child: oldRoot}, + {bb: splitRoot.computeBoundingBox(), child: splitRoot}, + }, + } + oldRoot.parent = tree.root + splitRoot.parent = tree.root + } +} + +// chooseNode finds the node at the specified level to which e should be added. +func (tree *Rtree) chooseNode(n *node, e entry, level int) *node { + if n.leaf || n.level == level { + return n + } + + // find the entry whose bb needs least enlargement to include obj + diff := float32(math.MaxFloat32) + var chosen entry + for _, en := range n.entries { + bb := boundingBox(en.bb, e.bb) + d := bb.Size() - en.bb.Size() + if d < diff || (d == diff && en.bb.Size() < chosen.bb.Size()) { + diff = d + chosen = en + } + } + + return tree.chooseNode(chosen.child, e, level) +} + +// adjustTree splits overflowing nodes and propagates the changes upwards. +func (tree *Rtree) adjustTree(n, nn *node) (*node, *node) { + // Let the caller handle root adjustments. + if n == tree.root { + return n, nn + } + + // Re-size the bounding box of n to account for lower-level changes. + en := n.getEntry() + prevBox := en.bb + en.bb = n.computeBoundingBox() + + // If nn is nil, then we're just propagating changes upwards. + if nn == nil { + // Optimize for the case where nothing is changed + // to avoid computeBoundingBox which is expensive. + if en.bb.Equal(prevBox) { + return tree.root, nil + } + return tree.adjustTree(n.parent, nil) + } + + // Otherwise, these are two nodes resulting from a split. + // n was reused as the "left" node, but we need to add nn to n.parent. + enn := entry{nn.computeBoundingBox(), nn, nil} + n.parent.entries = append(n.parent.entries, enn) + + // If the new entry overflows the parent, split the parent and propagate. + if len(n.parent.entries) > tree.MaxChildren { + return tree.adjustTree(n.parent.split(tree.MinChildren)) + } + + // Otherwise keep propagating changes upwards. + return tree.adjustTree(n.parent, nil) +} + +// getEntry returns a pointer to the entry for the node n from n's parent. +func (n *node) getEntry() *entry { + var e *entry + for i := range n.parent.entries { + if n.parent.entries[i].child == n { + e = &n.parent.entries[i] + break + } + } + return e +} + +// computeBoundingBox finds the MBR of the children of n. +func (n *node) computeBoundingBox() (bb Rect) { + if len(n.entries) == 1 { + bb = n.entries[0].bb + return + } + + bb = boundingBox(n.entries[0].bb, n.entries[1].bb) + for _, e := range n.entries[2:] { + bb = boundingBox(bb, e.bb) + } + return +} + +// split splits a node into two groups while attempting to minimize the +// bounding-box area of the resulting groups. +func (n *node) split(minGroupSize int) (left, right *node) { + // find the initial split + l, r := n.pickSeeds() + leftSeed, rightSeed := n.entries[l], n.entries[r] + + // get the entries to be divided between left and right + remaining := append(n.entries[:l], n.entries[l+1:r]...) + remaining = append(remaining, n.entries[r+1:]...) + + // setup the new split nodes, but re-use n as the left node + left = n + left.entries = []entry{leftSeed} + right = &node{ + parent: n.parent, + leaf: n.leaf, + level: n.level, + entries: []entry{rightSeed}, + } + + // TODO + if rightSeed.child != nil { + rightSeed.child.parent = right + } + if leftSeed.child != nil { + leftSeed.child.parent = left + } + + // distribute all of n's old entries into left and right. + for len(remaining) > 0 { + next := pickNext(left, right, remaining) + e := remaining[next] + + if len(remaining)+len(left.entries) <= minGroupSize { + assign(e, left) + } else if len(remaining)+len(right.entries) <= minGroupSize { + assign(e, right) + } else { + assignGroup(e, left, right) + } + + remaining = append(remaining[:next], remaining[next+1:]...) + } + + return +} + +// getAllBoundingBoxes traverses tree populating slice of bounding boxes of non-leaf nodes. +func (n *node) getAllBoundingBoxes() []Rect { + var rects []Rect + if n.leaf { + return rects + } + for _, e := range n.entries { + if e.child == nil { + return rects + } + rectsInter := append(e.child.getAllBoundingBoxes(), e.bb) + rects = append(rects, rectsInter...) + } + return rects +} + +func assign(e entry, group *node) { + if e.child != nil { + e.child.parent = group + } + group.entries = append(group.entries, e) +} + +// assignGroup chooses one of two groups to which a node should be added. +func assignGroup(e entry, left, right *node) { + leftBB := left.computeBoundingBox() + rightBB := right.computeBoundingBox() + leftEnlarged := boundingBox(leftBB, e.bb) + rightEnlarged := boundingBox(rightBB, e.bb) + + // first, choose the group that needs the least enlargement + leftDiff := leftEnlarged.Size() - leftBB.Size() + rightDiff := rightEnlarged.Size() - rightBB.Size() + if diff := leftDiff - rightDiff; diff < 0 { + assign(e, left) + return + } else if diff > 0 { + assign(e, right) + return + } + + // next, choose the group that has smaller area + if diff := leftBB.Size() - rightBB.Size(); diff < 0 { + assign(e, left) + return + } else if diff > 0 { + assign(e, right) + return + } + + // next, choose the group with fewer entries + if diff := len(left.entries) - len(right.entries); diff <= 0 { + assign(e, left) + return + } + assign(e, right) +} + +// pickSeeds chooses two child entries of n to start a split. +func (n *node) pickSeeds() (int, int) { + left, right := 0, 1 + maxWastedSpace := float32(-1.0) + for i, e1 := range n.entries { + for j, e2 := range n.entries[i+1:] { + d := boundingBox(e1.bb, e2.bb).Size() - e1.bb.Size() - e2.bb.Size() + if d > maxWastedSpace { + maxWastedSpace = d + left, right = i, j+i+1 + } + } + } + return left, right +} + +// pickNext chooses an entry to be added to an entry group. +func pickNext(left, right *node, entries []entry) (next int) { + maxDiff := -1.0 + leftBB := left.computeBoundingBox() + rightBB := right.computeBoundingBox() + for i, e := range entries { + d1 := boundingBox(leftBB, e.bb).Size() - leftBB.Size() + d2 := boundingBox(rightBB, e.bb).Size() - rightBB.Size() + d := math.Abs(float64(d1 - d2)) + if d > maxDiff { + maxDiff = d + next = i + } + } + return +} + +// Deletion + +// Delete removes an object from the tree. If the object is not found, returns +// false, otherwise returns true. Uses the default comparator when checking +// equality. +// +// Implemented per Section 3.3 of "R-trees: A Dynamic Index Structure for +// Spatial Searching" by A. Guttman, Proceedings of ACM SIGMOD, p. 47-57, 1984. +func (tree *Rtree) Delete(obj Spatial) bool { + return tree.DeleteWithComparator(obj, defaultComparator) +} + +// DeleteWithComparator removes an object from the tree using a custom +// comparator for evaluating equalness. This is useful when you want to remove +// an object from a tree but don't have a pointer to the original object +// anymore. +func (tree *Rtree) DeleteWithComparator(obj Spatial, cmp Comparator) bool { + n := tree.findLeaf(tree.root, obj, cmp) + if n == nil { + return false + } + + ind := -1 + for i, e := range n.entries { + if cmp(e.obj, obj) { + ind = i + } + } + if ind < 0 { + return false + } + + n.entries = append(n.entries[:ind], n.entries[ind+1:]...) + + tree.condenseTree(n) + tree.size-- + + if !tree.root.leaf && len(tree.root.entries) == 1 { + tree.root = tree.root.entries[0].child + } + + tree.height = tree.root.level + + return true +} + +// findLeaf finds the leaf node containing obj. +func (tree *Rtree) findLeaf(n *node, obj Spatial, cmp Comparator) *node { + if n.leaf { + return n + } + // if not leaf, search all candidate subtrees + for _, e := range n.entries { + if e.bb.containsRect(obj.Bounds()) { + leaf := tree.findLeaf(e.child, obj, cmp) + if leaf == nil { + continue + } + // check if the leaf actually contains the object + for _, leafEntry := range leaf.entries { + if cmp(leafEntry.obj, obj) { + return leaf + } + } + } + } + return nil +} + +// condenseTree deletes underflowing nodes and propagates the changes upwards. +func (tree *Rtree) condenseTree(n *node) { + // reset the deleted buffer + tree.deleted = tree.deleted[:0] + + for n != tree.root { + if len(n.entries) < tree.MinChildren { + // find n and delete it by swapping the last entry into its place + idx := -1 + for i, e := range n.parent.entries { + if e.child == n { + idx = i + break + } + } + if idx == -1 { + panic(fmt.Errorf("Failed to remove entry from parent")) + } + l := len(n.parent.entries) + n.parent.entries[idx] = n.parent.entries[l-1] + n.parent.entries = n.parent.entries[:l-1] + + // only add n to deleted if it still has children + if len(n.entries) > 0 { + tree.deleted = append(tree.deleted, n) + } + } else { + // just a child entry deletion, no underflow + en := n.getEntry() + prevBox := en.bb + en.bb = n.computeBoundingBox() + + if en.bb.Equal(prevBox) { + // Optimize for the case where nothing is changed + // to avoid computeBoundingBox which is expensive. + break + } + } + n = n.parent + } + + for i := len(tree.deleted) - 1; i >= 0; i-- { + n := tree.deleted[i] + // reinsert entry so that it will remain at the same level as before + e := entry{n.computeBoundingBox(), n, nil} + tree.insert(e, n.level+1) + } +} + +// Searching + +// SearchIntersect returns all objects that intersect the specified rectangle. +// Implemented per Section 3.1 of "R-trees: A Dynamic Index Structure for +// Spatial Searching" by A. Guttman, Proceedings of ACM SIGMOD, p. 47-57, 1984. +func (tree *Rtree) SearchIntersect(bb Rect, filters ...Filter) []Spatial { + return tree.searchIntersect([]Spatial{}, tree.root, bb, filters) +} + +// SearchIntersectWithLimit is similar to SearchIntersect, but returns +// immediately when the first k results are found. A negative k behaves exactly +// like SearchIntersect and returns all the results. +// +// Kept for backwards compatibility, please use SearchIntersect with a +// LimitFilter. +func (tree *Rtree) SearchIntersectWithLimit(k int, bb Rect) []Spatial { + // backwards compatibility, previous implementation didn't limit results if + // k was negative. + if k < 0 { + return tree.SearchIntersect(bb) + } + return tree.SearchIntersect(bb, LimitFilter(k)) +} + +func (tree *Rtree) searchIntersect(results []Spatial, n *node, bb Rect, filters []Filter) []Spatial { + for _, e := range n.entries { + if !intersect(e.bb, bb) { + continue + } + + if !n.leaf { + results = tree.searchIntersect(results, e.child, bb, filters) + continue + } + + refuse, abort := applyFilters(results, e.obj, filters) + if !refuse { + results = append(results, e.obj) + } + + if abort { + break + } + } + return results +} + +// NearestNeighbor returns the closest object to the specified point. +// Implemented per "Nearest Neighbor Queries" by Roussopoulos et al +func (tree *Rtree) NearestNeighbor(p Point) Spatial { + obj, _ := tree.nearestNeighbor(p, tree.root, math.MaxFloat32, nil) + return obj +} + +// GetAllBoundingBoxes returning slice of bounding boxes by traversing tree. Slice +// includes bounding boxes from all non-leaf nodes. +func (tree *Rtree) GetAllBoundingBoxes() []Rect { + var rects []Rect + if tree.root != nil { + rects = tree.root.getAllBoundingBoxes() + } + return rects +} + +// utilities for sorting slices of entries + +type entrySlice struct { + entries []entry + dists []float32 +} + +func (s entrySlice) Len() int { return len(s.entries) } + +func (s entrySlice) Swap(i, j int) { + s.entries[i], s.entries[j] = s.entries[j], s.entries[i] + s.dists[i], s.dists[j] = s.dists[j], s.dists[i] +} + +func (s entrySlice) Less(i, j int) bool { + return s.dists[i] < s.dists[j] +} + +func sortEntries(p Point, entries []entry) ([]entry, []float32) { + sorted := make([]entry, len(entries)) + dists := make([]float32, len(entries)) + return sortPreallocEntries(p, entries, sorted, dists) +} + +func sortPreallocEntries(p Point, entries, sorted []entry, dists []float32) ([]entry, []float32) { + // use preallocated slices + sorted = sorted[:len(entries)] + dists = dists[:len(entries)] + + for i := 0; i < len(entries); i++ { + sorted[i] = entries[i] + dists[i] = p.minDist(entries[i].bb) + } + sort.Sort(entrySlice{sorted, dists}) + return sorted, dists +} + +func pruneEntries(p Point, entries []entry, minDists []float32) []entry { + minMinMaxDist := float32(math.MaxFloat32) + for i := range entries { + minMaxDist := p.minMaxDist(entries[i].bb) + if minMaxDist < minMinMaxDist { + minMinMaxDist = minMaxDist + } + } + // remove all entries with minDist > minMinMaxDist + pruned := []entry{} + for i := range entries { + if minDists[i] <= minMinMaxDist { + pruned = append(pruned, entries[i]) + } + } + return pruned +} + +func pruneEntriesMinDist(d float32, entries []entry, minDists []float32) []entry { + var i int + for ; i < len(entries); i++ { + if minDists[i] > d { + break + } + } + return entries[:i] +} + +func (tree *Rtree) nearestNeighbor(p Point, n *node, d float32, nearest Spatial) (Spatial, float32) { + if n.leaf { + for _, e := range n.entries { + dist := float32(math.Sqrt(float64(p.minDist(e.bb)))) + if dist < d { + d = dist + nearest = e.obj + } + } + } else { + // Search only through entries with minDist <= minMinMaxDist, + // where minDist is the distance between a point and a rectangle, + // and minMaxDist is the smallest value among the maximum distance across all axes. + // + // Entries with minDist > minMinMaxDist are guaranteed to be farther away than some other entry. + // + // For more details, please consult + // N. Roussopoulos, S. Kelley and F. Vincent, ACM SIGMOD, pages 71-79, 1995. + minMinMaxDist := float32(math.MaxFloat32) + for _, e := range n.entries { + minMaxDist := p.minMaxDist(e.bb) + if minMaxDist < minMinMaxDist { + minMinMaxDist = minMaxDist + } + } + + for _, e := range n.entries { + minDist := p.minDist(e.bb) + if minDist > minMinMaxDist { + continue + } + + subNearest, dist := tree.nearestNeighbor(p, e.child, d, nearest) + if dist < d { + d = dist + nearest = subNearest + } + } + } + + return nearest, d +} + +// NearestNeighbors gets the closest Spatials to the Point. +func (tree *Rtree) NearestNeighbors(k int, p Point, filters ...Filter) []Spatial { + // preallocate the buffers for sortings the branches. At each level of the + // tree, we slide the buffer by the number of entries in the node. + maxBufSize := tree.MaxChildren * tree.Depth() + branches := make([]entry, maxBufSize) + branchDists := make([]float32, maxBufSize) + + // allocate the buffers for the results + dists := make([]float32, 0, k) + objs := make([]Spatial, 0, k) + + objs, _, _ = tree.nearestNeighbors(k, p, tree.root, dists, objs, filters, branches, branchDists) + return objs +} + +// todo 准确性检验 sort包没有32位对应的方法,直接改了一个 +func SearchFloat32s(a []float32, x float32) int { + return sort.Search(len(a), func(i int) bool { return a[i] >= x }) +} + +// insert obj into nearest and return the first k elements in increasing order. +func insertNearest(k int, dists []float32, nearest []Spatial, dist float32, obj Spatial, filters []Filter) ([]float32, []Spatial, bool) { + i := SearchFloat32s(dists, dist) + for i < len(nearest) && dist >= dists[i] { + i++ + } + if i >= k { + return dists, nearest, false + } + + if refuse, abort := applyFilters(nearest, obj, filters); refuse || abort { + return dists, nearest, abort + } + + // no resize since cap = k + if len(nearest) < k { + dists = append(dists, 0) + nearest = append(nearest, nil) + } + + left, right := dists[:i], dists[i:len(dists)-1] + copy(dists, left) + copy(dists[i+1:], right) + dists[i] = dist + + leftObjs, rightObjs := nearest[:i], nearest[i:len(nearest)-1] + copy(nearest, leftObjs) + copy(nearest[i+1:], rightObjs) + nearest[i] = obj + + return dists, nearest, false +} + +func (tree *Rtree) nearestNeighbors(k int, p Point, n *node, dists []float32, nearest []Spatial, filters []Filter, b []entry, bd []float32) ([]Spatial, []float32, bool) { + var abort bool + if n.leaf { + for _, e := range n.entries { + dist := p.minDist(e.bb) + dists, nearest, abort = insertNearest(k, dists, nearest, dist, e.obj, filters) + if abort { + break + } + } + } else { + branches, branchDists := sortPreallocEntries(p, n.entries, b, bd) + // only prune if buffer has k elements + if l := len(dists); l >= k { + branches = pruneEntriesMinDist(dists[l-1], branches, branchDists) + } + for _, e := range branches { + nearest, dists, abort = tree.nearestNeighbors(k, p, e.child, dists, nearest, filters, b[len(n.entries):], bd[len(n.entries):]) + if abort { + break + } + } + } + return nearest, dists, abort +} diff --git a/rtree_test.go b/rtree_test.go new file mode 100644 index 0000000..1a9017e --- /dev/null +++ b/rtree_test.go @@ -0,0 +1,1392 @@ +package rtreego + +import ( + "fmt" + "log" + "math/rand" + "sort" + "strconv" + "strings" + "testing" +) + +type testCase struct { + name string + build func() *Rtree +} + +func tests(dim, min, max int, objs ...Spatial) []*testCase { + return []*testCase{ + { + "dynamically built", + func() *Rtree { + rt := NewTree(dim, min, max) + for _, thing := range objs { + rt.Insert(thing) + } + return rt + }, + }, + { + "bulk-loaded", + func() *Rtree { + return NewTree(dim, min, max, objs...) + }, + }, + } +} + +func (r Rect) Bounds() Rect { + return r +} + +func rectEq(a, b Rect) bool { + if len(a.p) != len(b.p) { + return false + } + for i := 0; i < len(a.p); i++ { + if a.p[i] != b.p[i] { + return false + } + } + + if len(a.q) != len(b.q) { + return false + } + for i := 0; i < len(a.q); i++ { + if a.q[i] != b.q[i] { + return false + } + } + + return true +} + +func entryEq(a, b entry) bool { + if !rectEq(a.bb, b.bb) { + return false + } + if a.child != b.child { + return false + } + if a.obj != b.obj { + return false + } + return true +} + +func mustRect(p Point, widths []float32) Rect { + r, err := NewRect(p, widths) + if err != nil { + panic(err) + } + return r +} + +func printNode(n *node, level int) { + padding := strings.Repeat("\t", level) + fmt.Printf("%sNode: %p\n", padding, n) + fmt.Printf("%sParent: %p\n", padding, n.parent) + fmt.Printf("%sLevel: %d\n", padding, n.level) + fmt.Printf("%sLeaf: %t\n%sEntries:\n", padding, n.leaf, padding) + for _, e := range n.entries { + printEntry(e, level+1) + } +} + +func printEntry(e entry, level int) { + padding := strings.Repeat("\t", level) + fmt.Printf("%sBB: %v\n", padding, e.bb) + if e.child != nil { + printNode(e.child, level) + } else { + fmt.Printf("%sObject: %v\n", padding, e.obj) + } + fmt.Println() +} + +func items(n *node) chan Spatial { + ch := make(chan Spatial) + go func() { + for _, e := range n.entries { + if n.leaf { + ch <- e.obj + } else { + for obj := range items(e.child) { + ch <- obj + } + } + } + close(ch) + }() + return ch +} + +func validate(n *node, height, max int) error { + if n.level != height { + return fmt.Errorf("level %d != height %d", n.level, height) + } + if len(n.entries) > max { + return fmt.Errorf("node with too many entries at level %d/%d (actual: %d max: %d)", n.level, height, len(n.entries), max) + } + if n.leaf { + if n.level != 1 { + return fmt.Errorf("leaf node at level %d", n.level) + } + return nil + } + for _, e := range n.entries { + if e.child.level != n.level-1 { + return fmt.Errorf("failed to preserve level order") + } + if e.child.parent != n { + return fmt.Errorf("failed to update parent pointer") + } + if err := validate(e.child, height-1, max); err != nil { + return err + } + } + return nil +} + +func verify(t *testing.T, rt *Rtree) { + if rt.height != rt.root.level { + t.Errorf("invalid tree: height %d differs root level %d", rt.height, rt.root.level) + } + + if err := validate(rt.root, rt.height, rt.MaxChildren); err != nil { + printNode(rt.root, 0) + t.Errorf("invalid tree: %v", err) + } +} + +func indexOf(objs []Spatial, obj Spatial) int { + ind := -1 + for i, r := range objs { + if r == obj { + ind = i + break + } + } + return ind +} + +var chooseLeafNodeTests = []struct { + bb0, bb1, bb2 Rect // leaf bounding boxes + exp int // expected chosen leaf + desc string + level int +}{ + { + mustRect(Point{1, 1, 1}, []float32{1, 1, 1}), + mustRect(Point{-1, -1, -1}, []float32{0.5, 0.5, 0.5}), + mustRect(Point{3, 4, -5}, []float32{2, 0.9, 8}), + 1, + "clear winner", + 1, + }, + { + mustRect(Point{-1, -1.5, -1}, []float32{0.5, 2.5025, 0.5}), + mustRect(Point{0.5, 1, 0.5}, []float32{0.5, 0.815, 0.5}), + mustRect(Point{3, 4, -5}, []float32{2, 0.9, 8}), + 1, + "leaves tie", + 1, + }, + { + mustRect(Point{-1, -1.5, -1}, []float32{0.5, 2.5025, 0.5}), + mustRect(Point{0.5, 1, 0.5}, []float32{0.5, 0.815, 0.5}), + mustRect(Point{-1, -2, -3}, []float32{2, 4, 6}), + 2, + "leaf contains obj", + 1, + }, +} + +func TestChooseLeafNodeEmpty(t *testing.T) { + rt := NewTree(3, 5, 10) + obj := Point{0, 0, 0}.ToRect(0.5) + e := entry{obj, nil, obj} + if leaf := rt.chooseNode(rt.root, e, 1); leaf != rt.root { + t.Errorf("expected chooseLeaf of empty tree to return root") + } +} + +func TestChooseLeafNode(t *testing.T) { + for _, test := range chooseLeafNodeTests { + rt := Rtree{} + rt.root = &node{} + + leaf0 := &node{rt.root, []entry{}, 1, true} + entry0 := entry{test.bb0, leaf0, nil} + + leaf1 := &node{rt.root, []entry{}, 1, true} + entry1 := entry{test.bb1, leaf1, nil} + + leaf2 := &node{rt.root, []entry{}, 1, true} + entry2 := entry{test.bb2, leaf2, nil} + + rt.root.entries = []entry{entry0, entry1, entry2} + + obj := Point{0, 0, 0}.ToRect(0.5) + e := entry{obj, nil, obj} + + expected := rt.root.entries[test.exp].child + if leaf := rt.chooseNode(rt.root, e, 1); leaf != expected { + t.Errorf("%s: expected %d", test.desc, test.exp) + } + } +} + +func TestPickSeeds(t *testing.T) { + entry1 := entry{bb: mustRect(Point{1, 1}, []float32{1, 1})} + entry2 := entry{bb: mustRect(Point{1, -1}, []float32{2, 1})} + entry3 := entry{bb: mustRect(Point{-1, -1}, []float32{1, 2})} + n := node{entries: []entry{entry1, entry2, entry3}} + left, right := n.pickSeeds() + if !entryEq(n.entries[left], entry1) || !entryEq(n.entries[right], entry3) { + t.Errorf("expected entries %d, %d", 1, 3) + } +} + +func TestPickNext(t *testing.T) { + leftEntry := entry{bb: mustRect(Point{1, 1}, []float32{1, 1})} + left := &node{entries: []entry{leftEntry}} + + rightEntry := entry{bb: mustRect(Point{-1, -1}, []float32{1, 2})} + right := &node{entries: []entry{rightEntry}} + + entry1 := entry{bb: mustRect(Point{0, 0}, []float32{1, 1})} + entry2 := entry{bb: mustRect(Point{-2, -2}, []float32{1, 1})} + entry3 := entry{bb: mustRect(Point{1, 2}, []float32{1, 1})} + entries := []entry{entry1, entry2, entry3} + + chosen := pickNext(left, right, entries) + if !entryEq(entries[chosen], entry2) { + t.Errorf("expected entry %d", 3) + } +} + +func TestSplit(t *testing.T) { + entry1 := entry{bb: mustRect(Point{-3, -1}, []float32{2, 1})} + entry2 := entry{bb: mustRect(Point{1, 2}, []float32{1, 1})} + entry3 := entry{bb: mustRect(Point{-1, 0}, []float32{1, 1})} + entry4 := entry{bb: mustRect(Point{-3, -3}, []float32{1, 1})} + entry5 := entry{bb: mustRect(Point{1, -1}, []float32{2, 2})} + entries := []entry{entry1, entry2, entry3, entry4, entry5} + n := &node{entries: entries} + + l, r := n.split(0) // left=entry2, right=entry4 + expLeft := mustRect(Point{1, -1}, []float32{2, 4}) + expRight := mustRect(Point{-3, -3}, []float32{3, 4}) + + lbb := l.computeBoundingBox() + rbb := r.computeBoundingBox() + if lbb.p.dist(expLeft.p) >= EPS || lbb.q.dist(expLeft.q) >= EPS { + t.Errorf("expected left.bb = %s, got %s", expLeft, lbb) + } + if rbb.p.dist(expRight.p) >= EPS || rbb.q.dist(expRight.q) >= EPS { + t.Errorf("expected right.bb = %s, got %s", expRight, rbb) + } +} + +func TestSplitUnderflow(t *testing.T) { + entry1 := entry{bb: mustRect(Point{0, 0}, []float32{1, 1})} + entry2 := entry{bb: mustRect(Point{0, 1}, []float32{1, 1})} + entry3 := entry{bb: mustRect(Point{0, 2}, []float32{1, 1})} + entry4 := entry{bb: mustRect(Point{0, 3}, []float32{1, 1})} + entry5 := entry{bb: mustRect(Point{-50, -50}, []float32{1, 1})} + entries := []entry{entry1, entry2, entry3, entry4, entry5} + n := &node{entries: entries} + + l, r := n.split(2) + + if len(l.entries) != 3 || len(r.entries) != 2 { + t.Errorf("expected underflow assignment for right group") + } +} + +func TestAssignGroupLeastEnlargement(t *testing.T) { + r00 := entry{bb: mustRect(Point{0, 0}, []float32{1, 1})} + r01 := entry{bb: mustRect(Point{0, 1}, []float32{1, 1})} + r10 := entry{bb: mustRect(Point{1, 0}, []float32{1, 1})} + r11 := entry{bb: mustRect(Point{1, 1}, []float32{1, 1})} + r02 := entry{bb: mustRect(Point{0, 2}, []float32{1, 1})} + + group1 := &node{entries: []entry{r00, r01}} + group2 := &node{entries: []entry{r10, r11}} + + assignGroup(r02, group1, group2) + if len(group1.entries) != 3 || len(group2.entries) != 2 { + t.Errorf("expected r02 added to group 1") + } +} + +func TestAssignGroupSmallerArea(t *testing.T) { + r00 := entry{bb: mustRect(Point{0, 0}, []float32{1, 1})} + r01 := entry{bb: mustRect(Point{0, 1}, []float32{1, 1})} + r12 := entry{bb: mustRect(Point{1, 2}, []float32{1, 1})} + r02 := entry{bb: mustRect(Point{0, 2}, []float32{1, 1})} + + group1 := &node{entries: []entry{r00, r01}} + group2 := &node{entries: []entry{r12}} + + assignGroup(r02, group1, group2) + if len(group2.entries) != 2 || len(group1.entries) != 2 { + t.Errorf("expected r02 added to group 2") + } +} + +func TestAssignGroupFewerEntries(t *testing.T) { + r0001 := entry{bb: mustRect(Point{0, 0}, []float32{1, 2})} + r12 := entry{bb: mustRect(Point{1, 2}, []float32{1, 1})} + r22 := entry{bb: mustRect(Point{2, 2}, []float32{1, 1})} + r02 := entry{bb: mustRect(Point{0, 2}, []float32{1, 1})} + + group1 := &node{entries: []entry{r0001}} + group2 := &node{entries: []entry{r12, r22}} + + assignGroup(r02, group1, group2) + if len(group2.entries) != 2 || len(group1.entries) != 2 { + t.Errorf("expected r02 added to group 2") + } +} + +func TestAdjustTreeNoPreviousSplit(t *testing.T) { + rt := Rtree{root: &node{}} + + r00 := entry{bb: mustRect(Point{0, 0}, []float32{1, 1})} + r01 := entry{bb: mustRect(Point{0, 1}, []float32{1, 1})} + r10 := entry{bb: mustRect(Point{1, 0}, []float32{1, 1})} + entries := []entry{r00, r01, r10} + n := node{rt.root, entries, 1, false} + rt.root.entries = []entry{{bb: Point{0, 0}.ToRect(0), child: &n}} + + rt.adjustTree(&n, nil) + + e := rt.root.entries[0] + p, q := Point{0, 0}, Point{2, 2} + if p.dist(e.bb.p) >= EPS || q.dist(e.bb.q) >= EPS { + t.Errorf("Expected adjustTree to fit %v,%v,%v", r00.bb, r01.bb, r10.bb) + } +} + +func TestAdjustTreeNoSplit(t *testing.T) { + rt := NewTree(2, 3, 3) + + r00 := entry{bb: mustRect(Point{0, 0}, []float32{1, 1})} + r01 := entry{bb: mustRect(Point{0, 1}, []float32{1, 1})} + left := node{rt.root, []entry{r00, r01}, 1, false} + leftEntry := entry{bb: Point{0, 0}.ToRect(0), child: &left} + + r10 := entry{bb: mustRect(Point{1, 0}, []float32{1, 1})} + r11 := entry{bb: mustRect(Point{1, 1}, []float32{1, 1})} + right := node{rt.root, []entry{r10, r11}, 1, false} + + rt.root.entries = []entry{leftEntry} + retl, retr := rt.adjustTree(&left, &right) + + if retl != rt.root || retr != nil { + t.Errorf("Expected adjustTree didn't split the root") + } + + entries := rt.root.entries + if entries[0].child != &left || entries[1].child != &right { + t.Errorf("Expected adjustTree keeps left and adds n in parent") + } + + lbb, rbb := entries[0].bb, entries[1].bb + if lbb.p.dist(Point{0, 0}) >= EPS || lbb.q.dist(Point{1, 2}) >= EPS { + t.Errorf("Expected adjustTree to adjust left bb") + } + if rbb.p.dist(Point{1, 0}) >= EPS || rbb.q.dist(Point{2, 2}) >= EPS { + t.Errorf("Expected adjustTree to adjust right bb") + } +} + +func TestAdjustTreeSplitParent(t *testing.T) { + rt := NewTree(2, 1, 1) + + r00 := entry{bb: mustRect(Point{0, 0}, []float32{1, 1})} + r01 := entry{bb: mustRect(Point{0, 1}, []float32{1, 1})} + left := node{rt.root, []entry{r00, r01}, 1, false} + leftEntry := entry{bb: Point{0, 0}.ToRect(0), child: &left} + + r10 := entry{bb: mustRect(Point{1, 0}, []float32{1, 1})} + r11 := entry{bb: mustRect(Point{1, 1}, []float32{1, 1})} + right := node{rt.root, []entry{r10, r11}, 1, false} + + rt.root.entries = []entry{leftEntry} + retl, retr := rt.adjustTree(&left, &right) + + if len(retl.entries) != 1 || len(retr.entries) != 1 { + t.Errorf("Expected adjustTree distributed the entries") + } + + lbb, rbb := retl.entries[0].bb, retr.entries[0].bb + if lbb.p.dist(Point{0, 0}) >= EPS || lbb.q.dist(Point{1, 2}) >= EPS { + t.Errorf("Expected left split got left entry") + } + if rbb.p.dist(Point{1, 0}) >= EPS || rbb.q.dist(Point{2, 2}) >= EPS { + t.Errorf("Expected right split got right entry") + } +} + +func TestInsertRepeated(t *testing.T) { + var things []Spatial + for i := 0; i < 10; i++ { + things = append(things, mustRect(Point{0, 0}, []float32{2, 1})) + } + + for _, tc := range tests(2, 3, 5, things...) { + t.Run(tc.name, func(t *testing.T) { + rt := tc.build() + rt.Insert(mustRect(Point{0, 0}, []float32{2, 1})) + }) + } +} + +func TestInsertNoSplit(t *testing.T) { + rt := NewTree(2, 3, 3) + thing := mustRect(Point{0, 0}, []float32{2, 1}) + rt.Insert(thing) + + if rt.Size() != 1 { + t.Errorf("Insert failed to increase tree size") + } + + if len(rt.root.entries) != 1 || !rectEq(rt.root.entries[0].obj.(Rect), thing) { + t.Errorf("Insert failed to insert thing into root entries") + } +} + +func TestInsertSplitRoot(t *testing.T) { + rt := NewTree(2, 3, 3) + things := []Rect{ + mustRect(Point{0, 0}, []float32{2, 1}), + mustRect(Point{3, 1}, []float32{1, 2}), + mustRect(Point{1, 2}, []float32{2, 2}), + mustRect(Point{8, 6}, []float32{1, 1}), + mustRect(Point{10, 3}, []float32{1, 2}), + mustRect(Point{11, 7}, []float32{1, 1}), + } + for _, thing := range things { + rt.Insert(thing) + } + + if rt.Size() != 6 { + t.Errorf("Insert failed to insert") + } + + if len(rt.root.entries) != 2 { + t.Errorf("Insert failed to split") + } + + left, right := rt.root.entries[0].child, rt.root.entries[1].child + if len(left.entries) != 3 || len(right.entries) != 3 { + t.Errorf("Insert failed to split evenly") + } +} + +func TestInsertSplit(t *testing.T) { + rt := NewTree(2, 3, 3) + things := []Rect{ + mustRect(Point{0, 0}, []float32{2, 1}), + mustRect(Point{3, 1}, []float32{1, 2}), + mustRect(Point{1, 2}, []float32{2, 2}), + mustRect(Point{8, 6}, []float32{1, 1}), + mustRect(Point{10, 3}, []float32{1, 2}), + mustRect(Point{11, 7}, []float32{1, 1}), + mustRect(Point{10, 10}, []float32{2, 2}), + } + for _, thing := range things { + rt.Insert(thing) + } + + if rt.Size() != 7 { + t.Errorf("Insert failed to insert") + } + + if len(rt.root.entries) != 3 { + t.Errorf("Insert failed to split") + } + + a, b, c := rt.root.entries[0], rt.root.entries[1], rt.root.entries[2] + if len(a.child.entries) != 3 || + len(b.child.entries) != 3 || + len(c.child.entries) != 1 { + t.Errorf("Insert failed to split evenly") + } +} + +func TestInsertSplitSecondLevel(t *testing.T) { + rt := NewTree(2, 3, 3) + things := []Rect{ + mustRect(Point{0, 0}, []float32{2, 1}), + mustRect(Point{3, 1}, []float32{1, 2}), + mustRect(Point{1, 2}, []float32{2, 2}), + mustRect(Point{8, 6}, []float32{1, 1}), + mustRect(Point{10, 3}, []float32{1, 2}), + mustRect(Point{11, 7}, []float32{1, 1}), + mustRect(Point{0, 6}, []float32{1, 2}), + mustRect(Point{1, 6}, []float32{1, 2}), + mustRect(Point{0, 8}, []float32{1, 2}), + mustRect(Point{1, 8}, []float32{1, 2}), + } + for _, thing := range things { + rt.Insert(thing) + } + + if rt.Size() != 10 { + t.Errorf("Insert failed to insert") + } + + // should split root + if len(rt.root.entries) != 2 { + t.Errorf("Insert failed to split the root") + } + + // split level + entries level + objs level + if rt.Depth() != 3 { + t.Errorf("Insert failed to adjust properly") + } + + var checkParents func(n *node) + checkParents = func(n *node) { + if n.leaf { + return + } + for _, e := range n.entries { + if e.child.parent != n { + t.Errorf("Insert failed to update parent pointers") + } + checkParents(e.child) + } + } + checkParents(rt.root) +} + +func TestBulkLoadingValidity(t *testing.T) { + var things []Spatial + for i := float32(0); i < float32(100); i++ { + things = append(things, mustRect(Point{i, i}, []float32{1, 1})) + } + + testCases := []struct { + count int + max int + }{ + { + count: 5, + max: 2, + }, + { + count: 33, + max: 5, + }, + { + count: 34, + max: 7, + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("count=%d-max=%d", tc.count, tc.max), func(t *testing.T) { + rt := NewTree(2, 1, tc.max, things[:tc.count]...) + verify(t, rt) + }) + } +} + +func TestFindLeaf(t *testing.T) { + rt := NewTree(2, 3, 3) + rects := []Rect{ + mustRect(Point{0, 0}, []float32{2, 1}), + mustRect(Point{3, 1}, []float32{1, 2}), + mustRect(Point{1, 2}, []float32{2, 2}), + mustRect(Point{8, 6}, []float32{1, 1}), + mustRect(Point{10, 3}, []float32{1, 2}), + mustRect(Point{11, 7}, []float32{1, 1}), + mustRect(Point{0, 6}, []float32{1, 2}), + mustRect(Point{1, 6}, []float32{1, 2}), + mustRect(Point{0, 8}, []float32{1, 2}), + mustRect(Point{1, 8}, []float32{1, 2}), + } + things := []Spatial{} + for i := range rects { + things = append(things, &rects[i]) + } + + for _, thing := range things { + rt.Insert(thing) + } + verify(t, rt) + for _, thing := range things { + leaf := rt.findLeaf(rt.root, thing, defaultComparator) + if leaf == nil { + printNode(rt.root, 0) + t.Fatalf("Unable to find leaf containing an entry after insertion!") + } + var found *Rect + for _, other := range leaf.entries { + if other.obj == thing { + found = other.obj.(*Rect) + break + } + } + if found == nil { + printNode(rt.root, 0) + printNode(leaf, 0) + t.Errorf("Entry %v not found in leaf node %v!", thing, leaf) + } + } +} + +func TestFindLeafDoesNotExist(t *testing.T) { + rt := NewTree(2, 3, 3) + things := []Rect{ + mustRect(Point{0, 0}, []float32{2, 1}), + mustRect(Point{3, 1}, []float32{1, 2}), + mustRect(Point{1, 2}, []float32{2, 2}), + mustRect(Point{8, 6}, []float32{1, 1}), + mustRect(Point{10, 3}, []float32{1, 2}), + mustRect(Point{11, 7}, []float32{1, 1}), + mustRect(Point{0, 6}, []float32{1, 2}), + mustRect(Point{1, 6}, []float32{1, 2}), + mustRect(Point{0, 8}, []float32{1, 2}), + mustRect(Point{1, 8}, []float32{1, 2}), + } + for _, thing := range things { + rt.Insert(thing) + } + + obj := mustRect(Point{99, 99}, []float32{99, 99}) + leaf := rt.findLeaf(rt.root, obj, defaultComparator) + if leaf != nil { + t.Errorf("findLeaf failed to return nil for non-existent object") + } +} + +func TestCondenseTreeEliminate(t *testing.T) { + rt := NewTree(2, 3, 3) + things := []Rect{ + mustRect(Point{0, 0}, []float32{2, 1}), + mustRect(Point{3, 1}, []float32{1, 2}), + mustRect(Point{1, 2}, []float32{2, 2}), + mustRect(Point{8, 6}, []float32{1, 1}), + mustRect(Point{10, 3}, []float32{1, 2}), + mustRect(Point{11, 7}, []float32{1, 1}), + mustRect(Point{0, 6}, []float32{1, 2}), + mustRect(Point{1, 6}, []float32{1, 2}), + mustRect(Point{0, 8}, []float32{1, 2}), + mustRect(Point{1, 8}, []float32{1, 2}), + } + for _, thing := range things { + rt.Insert(thing) + } + + // delete entry 2 from parent entries + parent := rt.root.entries[0].child.entries[1].child + parent.entries = append(parent.entries[:2], parent.entries[3:]...) + rt.condenseTree(parent) + + retrieved := []Spatial{} + for obj := range items(rt.root) { + retrieved = append(retrieved, obj) + } + + if len(retrieved) != len(things)-1 { + t.Errorf("condenseTree failed to reinsert upstream elements") + } + + verify(t, rt) +} + +func TestChooseNodeNonLeaf(t *testing.T) { + rt := NewTree(2, 3, 3) + things := []Rect{ + mustRect(Point{0, 0}, []float32{2, 1}), + mustRect(Point{3, 1}, []float32{1, 2}), + mustRect(Point{1, 2}, []float32{2, 2}), + mustRect(Point{8, 6}, []float32{1, 1}), + mustRect(Point{10, 3}, []float32{1, 2}), + mustRect(Point{11, 7}, []float32{1, 1}), + mustRect(Point{0, 6}, []float32{1, 2}), + mustRect(Point{1, 6}, []float32{1, 2}), + mustRect(Point{0, 8}, []float32{1, 2}), + mustRect(Point{1, 8}, []float32{1, 2}), + } + for _, thing := range things { + rt.Insert(thing) + } + + obj := mustRect(Point{0, 10}, []float32{1, 2}) + e := entry{obj, nil, obj} + n := rt.chooseNode(rt.root, e, 2) + if n.level != 2 { + t.Errorf("chooseNode failed to stop at desired level") + } +} + +func TestInsertNonLeaf(t *testing.T) { + rt := NewTree(2, 3, 3) + things := []Rect{ + mustRect(Point{0, 0}, []float32{2, 1}), + mustRect(Point{3, 1}, []float32{1, 2}), + mustRect(Point{1, 2}, []float32{2, 2}), + mustRect(Point{8, 6}, []float32{1, 1}), + mustRect(Point{10, 3}, []float32{1, 2}), + mustRect(Point{11, 7}, []float32{1, 1}), + mustRect(Point{0, 6}, []float32{1, 2}), + mustRect(Point{1, 6}, []float32{1, 2}), + mustRect(Point{0, 8}, []float32{1, 2}), + mustRect(Point{1, 8}, []float32{1, 2}), + } + for _, thing := range things { + rt.Insert(thing) + } + + obj := mustRect(Point{99, 99}, []float32{99, 99}) + e := entry{obj, nil, obj} + rt.insert(e, 2) + + expected := rt.root.entries[1].child + if !rectEq(expected.entries[1].obj.(Rect), obj) { + t.Errorf("insert failed to insert entry at correct level") + } +} + +func TestDeleteFlatten(t *testing.T) { + rects := []Rect{ + mustRect(Point{0, 0}, []float32{2, 1}), + mustRect(Point{3, 1}, []float32{1, 2}), + } + things := []Spatial{} + for i := range rects { + things = append(things, &rects[i]) + } + + for _, tc := range tests(2, 3, 3, things...) { + t.Run(tc.name, func(t *testing.T) { + rt := tc.build() + // make sure flattening didn't nuke the tree + rt.Delete(things[0]) + verify(t, rt) + }) + } +} + +func TestDelete(t *testing.T) { + rects := []Rect{ + mustRect(Point{0, 0}, []float32{2, 1}), + mustRect(Point{3, 1}, []float32{1, 2}), + mustRect(Point{1, 2}, []float32{2, 2}), + mustRect(Point{8, 6}, []float32{1, 1}), + mustRect(Point{10, 3}, []float32{1, 2}), + mustRect(Point{11, 7}, []float32{1, 1}), + mustRect(Point{0, 6}, []float32{1, 2}), + mustRect(Point{1, 6}, []float32{1, 2}), + mustRect(Point{0, 8}, []float32{1, 2}), + mustRect(Point{1, 8}, []float32{1, 2}), + } + things := []Spatial{} + for i := range rects { + things = append(things, &rects[i]) + } + + for _, tc := range tests(2, 3, 3, things...) { + t.Run(tc.name, func(t *testing.T) { + rt := tc.build() + + verify(t, rt) + + things2 := []Spatial{} + for len(things) > 0 { + i := rand.Int() % len(things) + things2 = append(things2, things[i]) + things = append(things[:i], things[i+1:]...) + } + + for i, thing := range things2 { + ok := rt.Delete(thing) + if !ok { + t.Errorf("Thing %v was not found in tree during deletion", thing) + return + } + + if rt.Size() != len(things2)-i-1 { + t.Errorf("Delete failed to remove %v", thing) + return + } + verify(t, rt) + } + }) + } +} + +func TestDeleteWithDepthChange(t *testing.T) { + rt := NewTree(2, 3, 3) + rects := []Rect{ + mustRect(Point{0, 0}, []float32{2, 1}), + mustRect(Point{3, 1}, []float32{1, 2}), + mustRect(Point{1, 2}, []float32{2, 2}), + mustRect(Point{8, 6}, []float32{1, 1}), + } + things := []Spatial{} + for i := range rects { + things = append(things, &rects[i]) + } + + for _, thing := range things { + rt.Insert(thing) + } + + // delete last item and condense nodes + rt.Delete(things[3]) + + // rt.height should be 1 otherwise insert increases height to 3 + rt.Insert(things[3]) + + // and verify would fail + verify(t, rt) +} + +func TestDeleteWithComparator(t *testing.T) { + type IDRect struct { + ID string + Rect + } + + things := []Spatial{ + &IDRect{"1", mustRect(Point{0, 0}, []float32{2, 1})}, + &IDRect{"2", mustRect(Point{3, 1}, []float32{1, 2})}, + &IDRect{"3", mustRect(Point{1, 2}, []float32{2, 2})}, + &IDRect{"4", mustRect(Point{8, 6}, []float32{1, 1})}, + &IDRect{"5", mustRect(Point{10, 3}, []float32{1, 2})}, + &IDRect{"6", mustRect(Point{11, 7}, []float32{1, 1})}, + &IDRect{"7", mustRect(Point{0, 6}, []float32{1, 2})}, + &IDRect{"8", mustRect(Point{1, 6}, []float32{1, 2})}, + &IDRect{"9", mustRect(Point{0, 8}, []float32{1, 2})}, + &IDRect{"10", mustRect(Point{1, 8}, []float32{1, 2})}, + } + + for _, tc := range tests(2, 3, 3, things...) { + t.Run(tc.name, func(t *testing.T) { + rt := tc.build() + + verify(t, rt) + + cmp := func(obj1, obj2 Spatial) bool { + idr1 := obj1.(*IDRect) + idr2 := obj2.(*IDRect) + return idr1.ID == idr2.ID + } + + things2 := []*IDRect{} + for len(things) > 0 { + i := rand.Int() % len(things) + // make a deep copy + copy := &IDRect{things[i].(*IDRect).ID, things[i].(*IDRect).Rect} + things2 = append(things2, copy) + + if !cmp(things[i], copy) { + log.Fatalf("expected copy to be equal to the original, original: %v, copy: %v", things[i], copy) + } + + things = append(things[:i], things[i+1:]...) + } + + for i, thing := range things2 { + ok := rt.DeleteWithComparator(thing, cmp) + if !ok { + t.Errorf("Thing %v was not found in tree during deletion", thing) + return + } + + if rt.Size() != len(things2)-i-1 { + t.Errorf("Delete failed to remove %v", thing) + return + } + verify(t, rt) + } + }) + } +} + +func TestDeleteThenInsert(t *testing.T) { + tol := float32(1e-3) + rects := []Rect{ + mustRect(Point{3, 1}, []float32{tol, tol}), + mustRect(Point{1, 2}, []float32{tol, tol}), + mustRect(Point{2, 6}, []float32{tol, tol}), + mustRect(Point{3, 6}, []float32{tol, tol}), + mustRect(Point{2, 8}, []float32{tol, tol}), + } + things := []Spatial{} + for i := range rects { + things = append(things, &rects[i]) + } + + rt := NewTree(2, 2, 2, things...) + + if ok := rt.Delete(things[3]); !ok { + t.Fatalf("%#v", things[3]) + } + rt.Insert(things[3]) + + // Deleting and then inserting things[3] should not affect things[4]. + if ok := rt.Delete(things[4]); !ok { + t.Fatalf("%#v", things[4]) + } +} + +func TestSearchIntersect(t *testing.T) { + rects := []Rect{ + mustRect(Point{0, 0}, []float32{2, 1}), + mustRect(Point{3, 1}, []float32{1, 2}), + mustRect(Point{1, 2}, []float32{2, 2}), + mustRect(Point{8, 6}, []float32{1, 1}), + mustRect(Point{10, 3}, []float32{1, 2}), + mustRect(Point{11, 7}, []float32{1, 1}), + mustRect(Point{2, 6}, []float32{1, 2}), + mustRect(Point{3, 6}, []float32{1, 2}), + mustRect(Point{2, 8}, []float32{1, 2}), + mustRect(Point{3, 8}, []float32{1, 2}), + } + things := []Spatial{} + for i := range rects { + things = append(things, &rects[i]) + } + + for _, tc := range tests(2, 3, 3, things...) { + t.Run(tc.name, func(t *testing.T) { + rt := tc.build() + + p := Point{2, 1.5} + bb := mustRect(p, []float32{10, 5.5}) + q := rt.SearchIntersect(bb) + + var expected []Spatial + for _, i := range []int{1, 2, 3, 4, 6, 7} { + expected = append(expected, things[i]) + } + + ensureDisorderedSubset(t, q, expected) + }) + } + +} + +func TestSearchIntersectWithLimit(t *testing.T) { + rects := []Rect{ + mustRect(Point{0, 0}, []float32{2, 1}), + mustRect(Point{3, 1}, []float32{1, 2}), + mustRect(Point{1, 2}, []float32{2, 2}), + mustRect(Point{8, 6}, []float32{1, 1}), + mustRect(Point{10, 3}, []float32{1, 2}), + mustRect(Point{11, 7}, []float32{1, 1}), + mustRect(Point{2, 6}, []float32{1, 2}), + mustRect(Point{3, 6}, []float32{1, 2}), + mustRect(Point{2, 8}, []float32{1, 2}), + mustRect(Point{3, 8}, []float32{1, 2}), + } + things := []Spatial{} + for i := range rects { + things = append(things, &rects[i]) + } + + for _, tc := range tests(2, 3, 3, things...) { + t.Run(tc.name, func(t *testing.T) { + rt := tc.build() + + bb := mustRect(Point{2, 1.5}, []float32{10, 5.5}) + + // expected contains all the intersecting things + var expected []Spatial + for _, i := range []int{1, 2, 6, 7, 3, 4} { + expected = append(expected, things[i]) + } + + // Loop through all possible limits k of SearchIntersectWithLimit, + // and test that the results are as expected. + for k := -1; k <= len(things); k++ { + q := rt.SearchIntersectWithLimit(k, bb) + + if k == -1 { + ensureDisorderedSubset(t, q, expected) + if len(q) != len(expected) { + t.Fatalf("length of actual (%v) was different from expected (%v)", len(q), len(expected)) + } + } else if k == 0 { + if len(q) != 0 { + t.Fatalf("length of actual (%v) was different from expected (%v)", len(q), len(expected)) + } + } else if k <= len(expected) { + ensureDisorderedSubset(t, q, expected) + if len(q) != k { + t.Fatalf("length of actual (%v) was different from expected (%v)", len(q), len(expected)) + } + } else { + ensureDisorderedSubset(t, q, expected) + if len(q) != len(expected) { + t.Fatalf("length of actual (%v) was different from expected (%v)", len(q), len(expected)) + } + } + } + }) + } +} + +func TestSearchIntersectWithTestFilter(t *testing.T) { + rects := []Rect{ + mustRect(Point{0, 0}, []float32{2, 1}), + mustRect(Point{3, 1}, []float32{1, 2}), + mustRect(Point{1, 2}, []float32{2, 2}), + mustRect(Point{8, 6}, []float32{1, 1}), + mustRect(Point{10, 3}, []float32{1, 2}), + mustRect(Point{11, 7}, []float32{1, 1}), + mustRect(Point{2, 6}, []float32{1, 2}), + mustRect(Point{3, 6}, []float32{1, 2}), + mustRect(Point{2, 8}, []float32{1, 2}), + mustRect(Point{3, 8}, []float32{1, 2}), + } + things := []Spatial{} + for i := range rects { + things = append(things, &rects[i]) + } + + for _, tc := range tests(2, 3, 3, things...) { + t.Run(tc.name, func(t *testing.T) { + rt := tc.build() + + bb := mustRect(Point{2, 1.5}, []float32{10, 5.5}) + + // intersecting indexes are 1, 2, 6, 7, 3, 4 + // rects which we do not filter out + var expected []Spatial + for _, i := range []int{1, 6, 4} { + expected = append(expected, things[i]) + } + + // this test filter will only pick the objects that are in expected + objects := rt.SearchIntersect(bb, func(results []Spatial, object Spatial) (bool, bool) { + for _, exp := range expected { + if exp == object { + return false, false + } + } + return true, false + }) + + ensureDisorderedSubset(t, objects, expected) + }) + } +} + +func TestSearchIntersectNoResults(t *testing.T) { + things := []Spatial{ + mustRect(Point{0, 0}, []float32{2, 1}), + mustRect(Point{3, 1}, []float32{1, 2}), + mustRect(Point{1, 2}, []float32{2, 2}), + mustRect(Point{8, 6}, []float32{1, 1}), + mustRect(Point{10, 3}, []float32{1, 2}), + mustRect(Point{11, 7}, []float32{1, 1}), + mustRect(Point{2, 6}, []float32{1, 2}), + mustRect(Point{3, 6}, []float32{1, 2}), + mustRect(Point{2, 8}, []float32{1, 2}), + mustRect(Point{3, 8}, []float32{1, 2}), + } + + for _, tc := range tests(2, 3, 3, things...) { + t.Run(tc.name, func(t *testing.T) { + rt := tc.build() + + bb := mustRect(Point{99, 99}, []float32{10, 5.5}) + q := rt.SearchIntersect(bb) + if len(q) != 0 { + t.Errorf("SearchIntersect failed to return nil slice on failing query") + } + }) + } +} + +func TestSortEntries(t *testing.T) { + objs := []Rect{ + mustRect(Point{1, 1}, []float32{1, 1}), + mustRect(Point{2, 2}, []float32{1, 1}), + mustRect(Point{3, 3}, []float32{1, 1})} + entries := []entry{ + {objs[2], nil, &objs[2]}, + {objs[1], nil, &objs[1]}, + {objs[0], nil, &objs[0]}, + } + sorted, dists := sortEntries(Point{0, 0}, entries) + if !entryEq(sorted[0], entries[2]) || !entryEq(sorted[1], entries[1]) || !entryEq(sorted[2], entries[0]) { + t.Errorf("sortEntries failed") + } + if dists[0] != 2 || dists[1] != 8 || dists[2] != 18 { + t.Errorf("sortEntries failed to calculate proper distances") + } +} + +func TestNearestNeighbor(t *testing.T) { + rects := []Rect{ + mustRect(Point{1, 1}, []float32{1, 1}), + mustRect(Point{1, 3}, []float32{1, 1}), + mustRect(Point{3, 2}, []float32{1, 1}), + mustRect(Point{-7, -7}, []float32{1, 1}), + mustRect(Point{7, 7}, []float32{1, 1}), + mustRect(Point{10, 2}, []float32{1, 1}), + } + things := []Spatial{} + for i := range rects { + things = append(things, &rects[i]) + } + + for _, tc := range tests(2, 3, 3, things...) { + t.Run(tc.name, func(t *testing.T) { + rt := tc.build() + + obj1 := rt.NearestNeighbor(Point{0.5, 0.5}) + obj2 := rt.NearestNeighbor(Point{1.5, 4.5}) + obj3 := rt.NearestNeighbor(Point{5, 2.5}) + obj4 := rt.NearestNeighbor(Point{3.5, 2.5}) + + if obj1 != things[0] || obj2 != things[1] || obj3 != things[2] || obj4 != things[2] { + t.Errorf("NearestNeighbor failed") + } + }) + } +} + +func TestComputeBoundingBox(t *testing.T) { + rect1, _ := NewRect(Point{0, 0}, []float32{1, 1}) + rect2, _ := NewRect(Point{0, 1}, []float32{1, 1}) + rect3, _ := NewRect(Point{1, 0}, []float32{1, 1}) + n := &node{} + n.entries = append(n.entries, entry{bb: rect1}) + n.entries = append(n.entries, entry{bb: rect2}) + n.entries = append(n.entries, entry{bb: rect3}) + + exp, _ := NewRect(Point{0, 0}, []float32{2, 2}) + bb := n.computeBoundingBox() + d1 := bb.p.dist(exp.p) + d2 := bb.q.dist(exp.q) + if d1 > EPS || d2 > EPS { + t.Errorf("boundingBoxN(%v, %v, %v) != %v, got %v", rect1, rect2, rect3, exp, bb) + } +} + +func TestGetAllBoundingBoxes(t *testing.T) { + rt1 := NewTree(2, 3, 3) + rt2 := NewTree(2, 2, 4) + rt3 := NewTree(2, 4, 8) + things := []Rect{ + mustRect(Point{0, 0}, []float32{2, 1}), + mustRect(Point{3, 1}, []float32{1, 2}), + mustRect(Point{1, 2}, []float32{2, 2}), + mustRect(Point{8, 6}, []float32{1, 1}), + mustRect(Point{10, 3}, []float32{1, 2}), + mustRect(Point{11, 7}, []float32{1, 1}), + mustRect(Point{10, 10}, []float32{2, 2}), + mustRect(Point{2, 3}, []float32{0.5, 1}), + mustRect(Point{3, 5}, []float32{1.5, 2}), + mustRect(Point{7, 14}, []float32{2.5, 2}), + mustRect(Point{15, 6}, []float32{1, 1}), + mustRect(Point{4, 3}, []float32{1, 2}), + mustRect(Point{1, 7}, []float32{1, 1}), + mustRect(Point{10, 5}, []float32{2, 2}), + } + for _, thing := range things { + rt1.Insert(thing) + } + for _, thing := range things { + rt2.Insert(thing) + } + for _, thing := range things { + rt3.Insert(thing) + } + + if rt1.Size() != 14 { + t.Errorf("Insert failed to insert") + } + if rt2.Size() != 14 { + t.Errorf("Insert failed to insert") + } + if rt3.Size() != 14 { + t.Errorf("Insert failed to insert") + } + + rtbb1 := rt1.GetAllBoundingBoxes() + rtbb2 := rt2.GetAllBoundingBoxes() + rtbb3 := rt3.GetAllBoundingBoxes() + + if len(rtbb1) != 13 { + t.Errorf("Failed bounding box traversal expected 13 got " + strconv.Itoa(len(rtbb1))) + } + if len(rtbb2) != 7 { + t.Errorf("Failed bounding box traversal expected 7 got " + strconv.Itoa(len(rtbb2))) + } + if len(rtbb3) != 2 { + t.Errorf("Failed bounding box traversal expected 2 got " + strconv.Itoa(len(rtbb3))) + } +} + +type byMinDist struct { + r []Spatial + p Point +} + +func (r byMinDist) Less(i, j int) bool { + return r.p.minDist(r.r[i].Bounds()) < r.p.minDist(r.r[j].Bounds()) +} + +func (r byMinDist) Len() int { + return len(r.r) +} + +func (r byMinDist) Swap(i, j int) { + r.r[i], r.r[j] = r.r[j], r.r[i] +} + +func TestNearestNeighborsAll(t *testing.T) { + rects := []Rect{ + mustRect(Point{1, 1}, []float32{1, 1}), + mustRect(Point{-7, -7}, []float32{1, 1}), + mustRect(Point{1, 3}, []float32{1, 1}), + mustRect(Point{7, 7}, []float32{1, 1}), + mustRect(Point{10, 2}, []float32{1, 1}), + mustRect(Point{3, 3}, []float32{1, 1}), + } + things := []Spatial{} + for i := range rects { + things = append(things, &rects[i]) + } + + for _, tc := range tests(2, 3, 3, things...) { + t.Run(tc.name, func(t *testing.T) { + rt := tc.build() + + verify(t, rt) + + p := Point{0.5, 0.5} + sort.Sort(byMinDist{things, p}) + + objs := rt.NearestNeighbors(len(things), p) + for i := range things { + if objs[i] != things[i] { + t.Errorf("NearestNeighbors failed at index %d: %v != %v", i, objs[i], things[i]) + } + } + + objs = rt.NearestNeighbors(len(things)+2, p) + if len(objs) > len(things) { + t.Errorf("NearestNeighbors failed: too many elements") + } + if len(objs) < len(things) { + t.Errorf("NearestNeighbors failed: not enough elements") + } + + }) + } +} + +func TestNearestNeighborsFilters(t *testing.T) { + rects := []Rect{ + mustRect(Point{1, 1}, []float32{1, 1}), + mustRect(Point{-7, -7}, []float32{1, 1}), + mustRect(Point{1, 3}, []float32{1, 1}), + mustRect(Point{7, 7}, []float32{1, 1}), + mustRect(Point{10, 2}, []float32{1, 1}), + mustRect(Point{3, 3}, []float32{1, 1}), + } + things := []Spatial{} + for i := range rects { + things = append(things, &rects[i]) + } + + expected := []Spatial{things[0], things[2], things[3]} + + for _, tc := range tests(2, 3, 3, things...) { + t.Run(tc.name, func(t *testing.T) { + rt := tc.build() + + p := Point{0.5, 0.5} + sort.Sort(byMinDist{expected, p}) + + objs := rt.NearestNeighbors(len(things), p, func(r []Spatial, obj Spatial) (bool, bool) { + for _, ex := range expected { + if ex == obj { + return false, false + } + } + + return true, false + }) + + ensureOrderedSubset(t, objs, expected) + }) + } +} + +func TestNearestNeighborsHalf(t *testing.T) { + rects := []Rect{ + mustRect(Point{1, 1}, []float32{1, 1}), + mustRect(Point{-7, -7}, []float32{1, 1}), + mustRect(Point{1, 3}, []float32{1, 1}), + mustRect(Point{7, 7}, []float32{1, 1}), + mustRect(Point{10, 2}, []float32{1, 1}), + mustRect(Point{3, 3}, []float32{1, 1}), + } + things := []Spatial{} + for i := range rects { + things = append(things, &rects[i]) + } + + p := Point{0.5, 0.5} + sort.Sort(byMinDist{things, p}) + + for _, tc := range tests(2, 3, 3, things...) { + t.Run(tc.name, func(t *testing.T) { + rt := tc.build() + + objs := rt.NearestNeighbors(3, p) + for i := range objs { + if objs[i] != things[i] { + t.Errorf("NearestNeighbors failed at index %d: %v != %v", i, objs[i], things[i]) + } + } + + objs = rt.NearestNeighbors(len(things)+2, p) + if len(objs) > len(things) { + t.Errorf("NearestNeighbors failed: too many elements") + } + }) + } +} + +func ensureOrderedSubset(t *testing.T, actual []Spatial, expected []Spatial) { + for i := range actual { + if len(expected)-1 < i || actual[i] != expected[i] { + t.Fatalf("actual is not an ordered subset of expected") + } + } +} + +func ensureDisorderedSubset(t *testing.T, actual []Spatial, expected []Spatial) { + for _, obj := range actual { + if !contains(obj, expected) { + t.Fatalf("actual contained an object that was not expected: %+v", obj) + } + } +} + +func contains(obj Spatial, slice []Spatial) bool { + for _, s := range slice { + if s == obj { + return true + } + } + + return false +}