From 5a83968bbd0e45a1244b7d1c39db4ecca6bfd9de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?okhowang=28=E7=8E=8B=E6=B2=9B=E6=96=87=29?= Date: Tue, 12 Nov 2019 20:35:19 +0800 Subject: [PATCH] fix dbscan bug of task dispatch add more comment and test --- csv_importer_test.go | 8 ++--- dbscan.go | 31 +++++++++++-------- dbscan_test.go | 72 ++++++++++++++++++++++++++++++++++++++++++++ kmeans_test.go | 2 +- 4 files changed, 96 insertions(+), 17 deletions(-) create mode 100644 dbscan_test.go diff --git a/csv_importer_test.go b/csv_importer_test.go index e48ade9..2dbbad9 100644 --- a/csv_importer_test.go +++ b/csv_importer_test.go @@ -29,9 +29,9 @@ func TestImportedLoadCorrectData(t *testing.T) { f = "data/test.csv" i = CsvImporter() s = [][]float64{ - []float64{0.1, 0.2, 0.3}, - []float64{0.4, 0.5, 0.6}, - []float64{0.7, 0.8, 0.9}, + {0.1, 0.2, 0.3}, + {0.4, 0.5, 0.6}, + {0.7, 0.8, 0.9}, } ) @@ -41,7 +41,7 @@ func TestImportedLoadCorrectData(t *testing.T) { } if !fsliceEqual(d, s) { - t.Error("Imported data mismatch: %v vs %v\n", d, s) + t.Errorf("Imported data mismatch: %v vs %v\n", d, s) } } diff --git a/dbscan.go b/dbscan.go index e36d656..26e5aa7 100644 --- a/dbscan.go +++ b/dbscan.go @@ -11,16 +11,25 @@ type dbscanClusterer struct { distance DistanceFunc // slices holding the cluster mapping and sizes. Access is synchronized to avoid read during computation. - mu sync.RWMutex - a, b []int + mu sync.RWMutex + // groups for dateset + a []int + b []int // variables used for concurrent computation of nearest neighbours - l, s, o, f int - j chan *rangeJob - m *sync.Mutex - w *sync.WaitGroup - r *[]int - p []float64 + // dataset len + l int + // worker number + s int + // work number for per worker + f int + j chan *rangeJob + m *sync.Mutex + w *sync.WaitGroup + // current point near + r *[]int + // current point + p []float64 // visited points v []bool @@ -78,7 +87,6 @@ func (c *dbscanClusterer) Learn(data [][]float64) error { c.l = len(data) c.s = c.numWorkers() - c.o = c.s - 1 c.f = c.l / c.s c.d = data @@ -198,15 +206,14 @@ func (c *dbscanClusterer) nearest(p int, l *int, r *[]int) { c.p = c.d[p] c.r = r - c.w.Add(c.s) - for i := 0; i < c.l; i += c.f { if c.l-i <= c.f { - b = c.l - 1 + b = c.l } else { b = i + c.f } + c.w.Add(1) c.j <- &rangeJob{ a: i, b: b, diff --git a/dbscan_test.go b/dbscan_test.go new file mode 100644 index 0000000..fe6636b --- /dev/null +++ b/dbscan_test.go @@ -0,0 +1,72 @@ +package clusters + +import ( + "reflect" + "testing" +) + +func TestDBSCANCluster(t *testing.T) { + tests := []struct { + MinPts int + Eps float64 + Points [][]float64 + Expected []int + }{ + { + MinPts: 1, + Eps: 1, + Points: [][]float64{{1}}, + Expected: []int{1}, + }, + { + MinPts: 1, + Eps: 1, + Points: [][]float64{{1}, {1.5}}, + Expected: []int{1, 1}, + }, + { + MinPts: 1, + Eps: 1, + Points: [][]float64{{1}, {1}}, + Expected: []int{1, 1}, + }, + { + MinPts: 1, + Eps: 1, + Points: [][]float64{{1}, {1}, {1}}, + Expected: []int{1, 1, 1}, + }, + { + MinPts: 1, + Eps: 1, + Points: [][]float64{{1}, {1.5}, {2}}, + Expected: []int{1, 1, 1}, + }, + { + MinPts: 1, + Eps: 1, + Points: [][]float64{{1}, {1.5}, {3}}, + Expected: []int{1, 1, 2}, + }, + { + MinPts: 2, + Eps: 1, + Points: [][]float64{{1}, {3}}, + Expected: []int{-1, -1}, + }, + } + for _, test := range tests { + c, e := DBSCAN(test.MinPts, test.Eps, 0, EuclideanDistance) + if e != nil { + t.Errorf("Error initializing kmeans clusterer: %s\n", e.Error()) + } + + if e = c.Learn(test.Points); e != nil { + t.Errorf("Error learning data: %s\n", e.Error()) + } + + if !reflect.DeepEqual(c.Guesses(), test.Expected) { + t.Errorf("guesses does not match: %d vs %d\n", c.Guesses(), test.Expected) + } + } +} diff --git a/kmeans_test.go b/kmeans_test.go index 75f1712..4c739ab 100644 --- a/kmeans_test.go +++ b/kmeans_test.go @@ -4,7 +4,7 @@ import ( "testing" ) -func TestKmeansClusterNumerMatches(t *testing.T) { +func TestKmeansClusterNumberMatches(t *testing.T) { const ( C = 8 )