Skip to content

Commit

Permalink
Fix lost updates on concurrent Map/MapOf resize (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
puzpuzpuz authored Nov 1, 2023
1 parent abcfab6 commit 42e3390
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 20 deletions.
24 changes: 14 additions & 10 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,17 +316,19 @@ func (m *Map) doCompute(
bidx := uint64(len(table.buckets)-1) & hash
rootb := &table.buckets[bidx]
lockBucket(&rootb.topHashMutex)
if m.newerTableExists(table) {
// Someone resized the table. Go for another attempt.
unlockBucket(&rootb.topHashMutex)
goto compute_attempt
}
// The following two checks must go in reverse to what's
// in the resize method.
if m.resizeInProgress() {
// Resize is in progress. Wait, then go for another attempt.
unlockBucket(&rootb.topHashMutex)
m.waitForResize()
goto compute_attempt
}
if m.newerTableExists(table) {
// Someone resized the table. Go for another attempt.
unlockBucket(&rootb.topHashMutex)
goto compute_attempt
}
b := rootb
for {
topHashes := atomic.LoadUint64(&b.topHashMutex)
Expand Down Expand Up @@ -454,13 +456,12 @@ func (m *Map) waitForResize() {
m.resizeMu.Unlock()
}

func (m *Map) resize(table *mapTable, hint mapResizeHint) {
var shrinkThreshold int64
tableLen := len(table.buckets)
func (m *Map) resize(knownTable *mapTable, hint mapResizeHint) {
knownTableLen := len(knownTable.buckets)
// Fast path for shrink attempts.
if hint == mapShrinkHint {
shrinkThreshold = int64((tableLen * entriesPerMapBucket) / mapShrinkFraction)
if tableLen == minMapTableLen || table.sumSize() > shrinkThreshold {
shrinkThreshold := int64((knownTableLen * entriesPerMapBucket) / mapShrinkFraction)
if knownTableLen == minMapTableLen || knownTable.sumSize() > shrinkThreshold {
return
}
}
Expand All @@ -471,12 +472,15 @@ func (m *Map) resize(table *mapTable, hint mapResizeHint) {
return
}
var newTable *mapTable
table := (*mapTable)(atomic.LoadPointer(&m.table))
tableLen := len(table.buckets)
switch hint {
case mapGrowHint:
// Grow the table with factor of 2.
atomic.AddInt64(&m.totalGrowths, 1)
newTable = newMapTable(tableLen << 1)
case mapShrinkHint:
shrinkThreshold := int64((tableLen * entriesPerMapBucket) / mapShrinkFraction)
if table.sumSize() <= shrinkThreshold {
// Shrink the table with factor of 2.
atomic.AddInt64(&m.totalShrinks, 1)
Expand Down
47 changes: 47 additions & 0 deletions map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,53 @@ func TestMapParallelRange(t *testing.T) {
<-cdone
}

func parallelShrinker(t *testing.T, m *Map, numIters, numEntries int, stopFlag *int64, cdone chan bool) {
for i := 0; i < numIters; i++ {
for j := 0; j < numEntries; j++ {
if p, loaded := m.LoadOrStore(strconv.Itoa(j), &point{int32(j), int32(j)}); loaded {
t.Errorf("value was present for %d: %v", j, p)
}
}
for j := 0; j < numEntries; j++ {
m.Delete(strconv.Itoa(j))
}
}
atomic.StoreInt64(stopFlag, 1)
cdone <- true
}

func parallelUpdater(t *testing.T, m *Map, idx int, stopFlag *int64, cdone chan bool) {
for atomic.LoadInt64(stopFlag) != 1 {
sleepUs := int(Fastrand() % 10)
if p, loaded := m.LoadOrStore(strconv.Itoa(idx), &point{int32(idx), int32(idx)}); loaded {
t.Errorf("value was present for %d: %v", idx, p)
}
time.Sleep(time.Duration(sleepUs) * time.Microsecond)
if _, ok := m.Load(strconv.Itoa(idx)); !ok {
t.Errorf("value was not found for %d", idx)
}
m.Delete(strconv.Itoa(idx))
}
cdone <- true
}

func TestMapDoesNotLoseEntriesOnResize(t *testing.T) {
const numIters = 10_000
const numEntries = 128
m := NewMap()
cdone := make(chan bool)
stopFlag := int64(0)
go parallelShrinker(t, m, numIters, numEntries, &stopFlag, cdone)
go parallelUpdater(t, m, numEntries, &stopFlag, cdone)
// Wait for the goroutines to finish.
<-cdone
<-cdone
// Verify map contents.
if s := m.Size(); s != 0 {
t.Fatalf("map is not empty: %d", s)
}
}

func TestMapTopHashMutex(t *testing.T) {
const (
numLockers = 4
Expand Down
24 changes: 14 additions & 10 deletions mapof.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,17 +266,19 @@ func (m *MapOf[K, V]) doCompute(
bidx := uint64(len(table.buckets)-1) & hash
rootb := &table.buckets[bidx]
rootb.mu.Lock()
if m.newerTableExists(table) {
// Someone resized the table. Go for another attempt.
rootb.mu.Unlock()
goto compute_attempt
}
// The following two checks must go in reverse to what's
// in the resize method.
if m.resizeInProgress() {
// Resize is in progress. Wait, then go for another attempt.
rootb.mu.Unlock()
m.waitForResize()
goto compute_attempt
}
if m.newerTableExists(table) {
// Someone resized the table. Go for another attempt.
rootb.mu.Unlock()
goto compute_attempt
}
b := rootb
for {
for i := 0; i < entriesPerMapBucket; i++ {
Expand Down Expand Up @@ -403,13 +405,12 @@ func (m *MapOf[K, V]) waitForResize() {
m.resizeMu.Unlock()
}

func (m *MapOf[K, V]) resize(table *mapOfTable[K, V], hint mapResizeHint) {
var shrinkThreshold int64
tableLen := len(table.buckets)
func (m *MapOf[K, V]) resize(knownTable *mapOfTable[K, V], hint mapResizeHint) {
knownTableLen := len(knownTable.buckets)
// Fast path for shrink attempts.
if hint == mapShrinkHint {
shrinkThreshold = int64((tableLen * entriesPerMapBucket) / mapShrinkFraction)
if tableLen == minMapTableLen || table.sumSize() > shrinkThreshold {
shrinkThreshold := int64((knownTableLen * entriesPerMapBucket) / mapShrinkFraction)
if knownTableLen == minMapTableLen || knownTable.sumSize() > shrinkThreshold {
return
}
}
Expand All @@ -420,12 +421,15 @@ func (m *MapOf[K, V]) resize(table *mapOfTable[K, V], hint mapResizeHint) {
return
}
var newTable *mapOfTable[K, V]
table := (*mapOfTable[K, V])(atomic.LoadPointer(&m.table))
tableLen := len(table.buckets)
switch hint {
case mapGrowHint:
// Grow the table with factor of 2.
atomic.AddInt64(&m.totalGrowths, 1)
newTable = newMapOfTable[K, V](tableLen << 1)
case mapShrinkHint:
shrinkThreshold := int64((tableLen * entriesPerMapBucket) / mapShrinkFraction)
if table.sumSize() <= shrinkThreshold {
// Shrink the table with factor of 2.
atomic.AddInt64(&m.totalShrinks, 1)
Expand Down
47 changes: 47 additions & 0 deletions mapof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,53 @@ func TestMapOfParallelRange(t *testing.T) {
<-cdone
}

func parallelTypedShrinker(t *testing.T, m *MapOf[uint64, *point], numIters, numEntries int, stopFlag *int64, cdone chan bool) {
for i := 0; i < numIters; i++ {
for j := 0; j < numEntries; j++ {
if p, loaded := m.LoadOrStore(uint64(j), &point{int32(j), int32(j)}); loaded {
t.Errorf("value was present for %d: %v", j, p)
}
}
for j := 0; j < numEntries; j++ {
m.Delete(uint64(j))
}
}
atomic.StoreInt64(stopFlag, 1)
cdone <- true
}

func parallelTypedUpdater(t *testing.T, m *MapOf[uint64, *point], idx int, stopFlag *int64, cdone chan bool) {
for atomic.LoadInt64(stopFlag) != 1 {
sleepUs := int(Fastrand() % 10)
if p, loaded := m.LoadOrStore(uint64(idx), &point{int32(idx), int32(idx)}); loaded {
t.Errorf("value was present for %d: %v", idx, p)
}
time.Sleep(time.Duration(sleepUs) * time.Microsecond)
if _, ok := m.Load(uint64(idx)); !ok {
t.Errorf("value was not found for %d", idx)
}
m.Delete(uint64(idx))
}
cdone <- true
}

func TestMapOfDoesNotLoseEntriesOnResize(t *testing.T) {
const numIters = 10_000
const numEntries = 128
m := NewMapOf[uint64, *point]()
cdone := make(chan bool)
stopFlag := int64(0)
go parallelTypedShrinker(t, m, numIters, numEntries, &stopFlag, cdone)
go parallelTypedUpdater(t, m, numEntries, &stopFlag, cdone)
// Wait for the goroutines to finish.
<-cdone
<-cdone
// Verify map contents.
if s := m.Size(); s != 0 {
t.Fatalf("map is not empty: %d", s)
}
}

func BenchmarkMapOf_NoWarmUp(b *testing.B) {
for _, bc := range benchmarkCases {
if bc.readPercentage == 100 {
Expand Down

0 comments on commit 42e3390

Please sign in to comment.