From 42e3390a5c0ba5977b76f3c2233189220fc404a4 Mon Sep 17 00:00:00 2001 From: Andrei Pechkurov <37772591+puzpuzpuz@users.noreply.github.com> Date: Wed, 1 Nov 2023 22:05:56 +0300 Subject: [PATCH] Fix lost updates on concurrent Map/MapOf resize (#111) --- map.go | 24 ++++++++++++++---------- map_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ mapof.go | 24 ++++++++++++++---------- mapof_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 20 deletions(-) diff --git a/map.go b/map.go index 05e0231..8d33d7a 100644 --- a/map.go +++ b/map.go @@ -316,17 +316,19 @@ func (m *Map) doCompute( bidx := uint64(len(table.buckets)-1) & hash rootb := &table.buckets[bidx] lockBucket(&rootb.topHashMutex) - if m.newerTableExists(table) { - // Someone resized the table. Go for another attempt. - unlockBucket(&rootb.topHashMutex) - goto compute_attempt - } + // The following two checks must go in reverse to what's + // in the resize method. if m.resizeInProgress() { // Resize is in progress. Wait, then go for another attempt. unlockBucket(&rootb.topHashMutex) m.waitForResize() goto compute_attempt } + if m.newerTableExists(table) { + // Someone resized the table. Go for another attempt. + unlockBucket(&rootb.topHashMutex) + goto compute_attempt + } b := rootb for { topHashes := atomic.LoadUint64(&b.topHashMutex) @@ -454,13 +456,12 @@ func (m *Map) waitForResize() { m.resizeMu.Unlock() } -func (m *Map) resize(table *mapTable, hint mapResizeHint) { - var shrinkThreshold int64 - tableLen := len(table.buckets) +func (m *Map) resize(knownTable *mapTable, hint mapResizeHint) { + knownTableLen := len(knownTable.buckets) // Fast path for shrink attempts. if hint == mapShrinkHint { - shrinkThreshold = int64((tableLen * entriesPerMapBucket) / mapShrinkFraction) - if tableLen == minMapTableLen || table.sumSize() > shrinkThreshold { + shrinkThreshold := int64((knownTableLen * entriesPerMapBucket) / mapShrinkFraction) + if knownTableLen == minMapTableLen || knownTable.sumSize() > shrinkThreshold { return } } @@ -471,12 +472,15 @@ func (m *Map) resize(table *mapTable, hint mapResizeHint) { return } var newTable *mapTable + table := (*mapTable)(atomic.LoadPointer(&m.table)) + tableLen := len(table.buckets) switch hint { case mapGrowHint: // Grow the table with factor of 2. atomic.AddInt64(&m.totalGrowths, 1) newTable = newMapTable(tableLen << 1) case mapShrinkHint: + shrinkThreshold := int64((tableLen * entriesPerMapBucket) / mapShrinkFraction) if table.sumSize() <= shrinkThreshold { // Shrink the table with factor of 2. atomic.AddInt64(&m.totalShrinks, 1) diff --git a/map_test.go b/map_test.go index a16b39a..ce4228b 100644 --- a/map_test.go +++ b/map_test.go @@ -945,6 +945,53 @@ func TestMapParallelRange(t *testing.T) { <-cdone } +func parallelShrinker(t *testing.T, m *Map, numIters, numEntries int, stopFlag *int64, cdone chan bool) { + for i := 0; i < numIters; i++ { + for j := 0; j < numEntries; j++ { + if p, loaded := m.LoadOrStore(strconv.Itoa(j), &point{int32(j), int32(j)}); loaded { + t.Errorf("value was present for %d: %v", j, p) + } + } + for j := 0; j < numEntries; j++ { + m.Delete(strconv.Itoa(j)) + } + } + atomic.StoreInt64(stopFlag, 1) + cdone <- true +} + +func parallelUpdater(t *testing.T, m *Map, idx int, stopFlag *int64, cdone chan bool) { + for atomic.LoadInt64(stopFlag) != 1 { + sleepUs := int(Fastrand() % 10) + if p, loaded := m.LoadOrStore(strconv.Itoa(idx), &point{int32(idx), int32(idx)}); loaded { + t.Errorf("value was present for %d: %v", idx, p) + } + time.Sleep(time.Duration(sleepUs) * time.Microsecond) + if _, ok := m.Load(strconv.Itoa(idx)); !ok { + t.Errorf("value was not found for %d", idx) + } + m.Delete(strconv.Itoa(idx)) + } + cdone <- true +} + +func TestMapDoesNotLoseEntriesOnResize(t *testing.T) { + const numIters = 10_000 + const numEntries = 128 + m := NewMap() + cdone := make(chan bool) + stopFlag := int64(0) + go parallelShrinker(t, m, numIters, numEntries, &stopFlag, cdone) + go parallelUpdater(t, m, numEntries, &stopFlag, cdone) + // Wait for the goroutines to finish. + <-cdone + <-cdone + // Verify map contents. + if s := m.Size(); s != 0 { + t.Fatalf("map is not empty: %d", s) + } +} + func TestMapTopHashMutex(t *testing.T) { const ( numLockers = 4 diff --git a/mapof.go b/mapof.go index 2d40f09..596e6d6 100644 --- a/mapof.go +++ b/mapof.go @@ -266,17 +266,19 @@ func (m *MapOf[K, V]) doCompute( bidx := uint64(len(table.buckets)-1) & hash rootb := &table.buckets[bidx] rootb.mu.Lock() - if m.newerTableExists(table) { - // Someone resized the table. Go for another attempt. - rootb.mu.Unlock() - goto compute_attempt - } + // The following two checks must go in reverse to what's + // in the resize method. if m.resizeInProgress() { // Resize is in progress. Wait, then go for another attempt. rootb.mu.Unlock() m.waitForResize() goto compute_attempt } + if m.newerTableExists(table) { + // Someone resized the table. Go for another attempt. + rootb.mu.Unlock() + goto compute_attempt + } b := rootb for { for i := 0; i < entriesPerMapBucket; i++ { @@ -403,13 +405,12 @@ func (m *MapOf[K, V]) waitForResize() { m.resizeMu.Unlock() } -func (m *MapOf[K, V]) resize(table *mapOfTable[K, V], hint mapResizeHint) { - var shrinkThreshold int64 - tableLen := len(table.buckets) +func (m *MapOf[K, V]) resize(knownTable *mapOfTable[K, V], hint mapResizeHint) { + knownTableLen := len(knownTable.buckets) // Fast path for shrink attempts. if hint == mapShrinkHint { - shrinkThreshold = int64((tableLen * entriesPerMapBucket) / mapShrinkFraction) - if tableLen == minMapTableLen || table.sumSize() > shrinkThreshold { + shrinkThreshold := int64((knownTableLen * entriesPerMapBucket) / mapShrinkFraction) + if knownTableLen == minMapTableLen || knownTable.sumSize() > shrinkThreshold { return } } @@ -420,12 +421,15 @@ func (m *MapOf[K, V]) resize(table *mapOfTable[K, V], hint mapResizeHint) { return } var newTable *mapOfTable[K, V] + table := (*mapOfTable[K, V])(atomic.LoadPointer(&m.table)) + tableLen := len(table.buckets) switch hint { case mapGrowHint: // Grow the table with factor of 2. atomic.AddInt64(&m.totalGrowths, 1) newTable = newMapOfTable[K, V](tableLen << 1) case mapShrinkHint: + shrinkThreshold := int64((tableLen * entriesPerMapBucket) / mapShrinkFraction) if table.sumSize() <= shrinkThreshold { // Shrink the table with factor of 2. atomic.AddInt64(&m.totalShrinks, 1) diff --git a/mapof_test.go b/mapof_test.go index b27816e..41a01de 100644 --- a/mapof_test.go +++ b/mapof_test.go @@ -989,6 +989,53 @@ func TestMapOfParallelRange(t *testing.T) { <-cdone } +func parallelTypedShrinker(t *testing.T, m *MapOf[uint64, *point], numIters, numEntries int, stopFlag *int64, cdone chan bool) { + for i := 0; i < numIters; i++ { + for j := 0; j < numEntries; j++ { + if p, loaded := m.LoadOrStore(uint64(j), &point{int32(j), int32(j)}); loaded { + t.Errorf("value was present for %d: %v", j, p) + } + } + for j := 0; j < numEntries; j++ { + m.Delete(uint64(j)) + } + } + atomic.StoreInt64(stopFlag, 1) + cdone <- true +} + +func parallelTypedUpdater(t *testing.T, m *MapOf[uint64, *point], idx int, stopFlag *int64, cdone chan bool) { + for atomic.LoadInt64(stopFlag) != 1 { + sleepUs := int(Fastrand() % 10) + if p, loaded := m.LoadOrStore(uint64(idx), &point{int32(idx), int32(idx)}); loaded { + t.Errorf("value was present for %d: %v", idx, p) + } + time.Sleep(time.Duration(sleepUs) * time.Microsecond) + if _, ok := m.Load(uint64(idx)); !ok { + t.Errorf("value was not found for %d", idx) + } + m.Delete(uint64(idx)) + } + cdone <- true +} + +func TestMapOfDoesNotLoseEntriesOnResize(t *testing.T) { + const numIters = 10_000 + const numEntries = 128 + m := NewMapOf[uint64, *point]() + cdone := make(chan bool) + stopFlag := int64(0) + go parallelTypedShrinker(t, m, numIters, numEntries, &stopFlag, cdone) + go parallelTypedUpdater(t, m, numEntries, &stopFlag, cdone) + // Wait for the goroutines to finish. + <-cdone + <-cdone + // Verify map contents. + if s := m.Size(); s != 0 { + t.Fatalf("map is not empty: %d", s) + } +} + func BenchmarkMapOf_NoWarmUp(b *testing.B) { for _, bc := range benchmarkCases { if bc.readPercentage == 100 {