Skip to content

Commit

Permalink
Add LoadOrTryCompute method to Map/MapOf
Browse files Browse the repository at this point in the history
  • Loading branch information
puzpuzpuz committed Jan 25, 2025
1 parent 0a87a63 commit 8f45cbe
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 46 deletions.
28 changes: 28 additions & 0 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,34 @@ func (m *Map) LoadAndStore(key string, value interface{}) (actual interface{}, l
)
}

// LoadOrTryCompute returns the existing value for the key if present.
// Otherwise, it tries to compute the value using the provided function
// and, if success, returns the computed value. The loaded result is true
// if the value was loaded, false if stored. If the compute attempt was
// cancelled, a nil will be returned.
//
// This call locks a hash table bucket while the compute function
// is executed. It means that modifications on other entries in
// the bucket will be blocked until the valueFn executes. Consider
// this when the function includes long-running operations.
func (m *Map) LoadOrTryCompute(
key string,
valueFn func() (newValue interface{}, cancel bool),
) (value interface{}, loaded bool) {
return m.doCompute(
key,
func(interface{}, bool) (interface{}, bool) {
nv, c := valueFn()
if !c {
return nv, false
}
return nil, true
},
true,
false,
)
}

// LoadOrCompute returns the existing value for the key if present.
// Otherwise, it computes the value using the provided function and
// returns the computed value. The loaded result is true if the value
Expand Down
56 changes: 56 additions & 0 deletions map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,63 @@ func TestMapLoadOrCompute_FunctionCalledOnce(t *testing.T) {
return v
})
}
m.Range(func(k string, v interface{}) bool {
if vi, ok := v.(int); !ok || strconv.Itoa(vi) != k {
t.Fatalf("%sth key is not equal to value %d", k, v)
}
return true
})
}

func TestMapLoadOrTryCompute(t *testing.T) {
const numEntries = 1000
m := NewMap()
for i := 0; i < numEntries; i++ {
v, loaded := m.LoadOrTryCompute(strconv.Itoa(i), func() (newValue interface{}, cancel bool) {
return i, true
})
if loaded {
t.Fatalf("value not computed for %d", i)
}
if v != nil {
t.Fatalf("values do not match for %d: %v", i, v)
}
}
if m.Size() != 0 {
t.Fatalf("zero map size expected: %d", m.Size())
}
for i := 0; i < numEntries; i++ {
v, loaded := m.LoadOrTryCompute(strconv.Itoa(i), func() (newValue interface{}, cancel bool) {
return i, false
})
if loaded {
t.Fatalf("value not computed for %d", i)
}
if v != i {
t.Fatalf("values do not match for %d: %v", i, v)
}
}
for i := 0; i < numEntries; i++ {
v, loaded := m.LoadOrTryCompute(strconv.Itoa(i), func() (newValue interface{}, cancel bool) {
return i, false
})
if !loaded {
t.Fatalf("value not loaded for %d", i)
}
if v != i {
t.Fatalf("values do not match for %d: %v", i, v)
}
}
}

func TestMapLoadOrTryCompute_FunctionCalledOnce(t *testing.T) {
m := NewMap()
for i := 0; i < 100; {
m.LoadOrTryCompute(strconv.Itoa(i), func() (v interface{}, cancel bool) {
v, i = i, i+1
return v, false
})
}
m.Range(func(k string, v interface{}) bool {
if vi, ok := v.(int); !ok || strconv.Itoa(vi) != k {
t.Fatalf("%sth key is not equal to value %d", k, v)
Expand Down
28 changes: 28 additions & 0 deletions mapof.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,34 @@ func (m *MapOf[K, V]) LoadOrCompute(key K, valueFn func() V) (actual V, loaded b
)
}

// LoadOrTryCompute returns the existing value for the key if present.
// Otherwise, it tries to compute the value using the provided function
// and, if success, returns the computed value. The loaded result is true
// if the value was loaded, false if stored. If the compute attempt was
// cancelled, a zero value of type V will be returned.
//
// This call locks a hash table bucket while the compute function
// is executed. It means that modifications on other entries in
// the bucket will be blocked until the valueFn executes. Consider
// this when the function includes long-running operations.
func (m *MapOf[K, V]) LoadOrTryCompute(
key K,
valueFn func() (newValue V, cancel bool),
) (value V, loaded bool) {
return m.doCompute(
key,
func(V, bool) (V, bool) {
nv, c := valueFn()
if !c {
return nv, false
}
return nv, true // nv is ignored
},
true,
false,
)
}

// Compute either sets the computed new value for the key or deletes
// the value for the key. When the delete result of the valueFn function
// is set to true, the value will be deleted, if it exists. When delete
Expand Down
57 changes: 57 additions & 0 deletions mapof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,63 @@ func TestMapOfLoadOrCompute_FunctionCalledOnce(t *testing.T) {
})
}

func TestMapOfLoadOrTryCompute(t *testing.T) {
const numEntries = 1000
m := NewMapOf[string, int]()
for i := 0; i < numEntries; i++ {
v, loaded := m.LoadOrTryCompute(strconv.Itoa(i), func() (newValue int, cancel bool) {
return i, true
})
if loaded {
t.Fatalf("value not computed for %d", i)
}
if v != 0 {
t.Fatalf("values do not match for %d: %v", i, v)
}
}
if m.Size() != 0 {
t.Fatalf("zero map size expected: %d", m.Size())
}
for i := 0; i < numEntries; i++ {
v, loaded := m.LoadOrTryCompute(strconv.Itoa(i), func() (newValue int, cancel bool) {
return i, false
})
if loaded {
t.Fatalf("value not computed for %d", i)
}
if v != i {
t.Fatalf("values do not match for %d: %v", i, v)
}
}
for i := 0; i < numEntries; i++ {
v, loaded := m.LoadOrTryCompute(strconv.Itoa(i), func() (newValue int, cancel bool) {
return i, false
})
if !loaded {
t.Fatalf("value not loaded for %d", i)
}
if v != i {
t.Fatalf("values do not match for %d: %v", i, v)
}
}
}

func TestMapOfLoadOrTryCompute_FunctionCalledOnce(t *testing.T) {
m := NewMapOf[int, int]()
for i := 0; i < 100; {
m.LoadOrTryCompute(i, func() (newValue int, cancel bool) {
newValue, i = i, i+1
return newValue, false
})
}
m.Range(func(k, v int) bool {
if k != v {
t.Fatalf("%dth key is not equal to value %d", k, v)
}
return true
})
}

func TestMapOfCompute(t *testing.T) {
m := NewMapOf[string, int]()
// Store a new value.
Expand Down
56 changes: 28 additions & 28 deletions spscqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@ import (
// Based on the data structure from the following article:
// https://rigtorp.se/ringbuffer/
type SPSCQueue struct {
cap uint64
p_idx uint64
cap uint64
pidx uint64
//lint:ignore U1000 prevents false sharing
pad0 [cacheLineSize - 8]byte
p_cached_idx uint64
pad0 [cacheLineSize - 8]byte
pcachedIdx uint64
//lint:ignore U1000 prevents false sharing
pad1 [cacheLineSize - 8]byte
c_idx uint64
pad1 [cacheLineSize - 8]byte
cidx uint64
//lint:ignore U1000 prevents false sharing
pad2 [cacheLineSize - 8]byte
c_cached_idx uint64
pad2 [cacheLineSize - 8]byte
ccachedIdx uint64
//lint:ignore U1000 prevents false sharing
pad3 [cacheLineSize - 8]byte
items []interface{}
Expand All @@ -48,21 +48,21 @@ func NewSPSCQueue(capacity int) *SPSCQueue {
// full and the item was inserted.
func (q *SPSCQueue) TryEnqueue(item interface{}) bool {
// relaxed memory order would be enough here
idx := atomic.LoadUint64(&q.p_idx)
next_idx := idx + 1
if next_idx == q.cap {
next_idx = 0
idx := atomic.LoadUint64(&q.pidx)
nextIdx := idx + 1
if nextIdx == q.cap {
nextIdx = 0
}
cached_idx := q.c_cached_idx
if next_idx == cached_idx {
cached_idx = atomic.LoadUint64(&q.c_idx)
q.c_cached_idx = cached_idx
if next_idx == cached_idx {
cachedIdx := q.ccachedIdx
if nextIdx == cachedIdx {
cachedIdx = atomic.LoadUint64(&q.cidx)
q.ccachedIdx = cachedIdx
if nextIdx == cachedIdx {
return false
}
}
q.items[idx] = item
atomic.StoreUint64(&q.p_idx, next_idx)
atomic.StoreUint64(&q.pidx, nextIdx)
return true
}

Expand All @@ -71,22 +71,22 @@ func (q *SPSCQueue) TryEnqueue(item interface{}) bool {
// indicates that the queue isn't empty and an item was retrieved.
func (q *SPSCQueue) TryDequeue() (item interface{}, ok bool) {
// relaxed memory order would be enough here
idx := atomic.LoadUint64(&q.c_idx)
cached_idx := q.p_cached_idx
if idx == cached_idx {
cached_idx = atomic.LoadUint64(&q.p_idx)
q.p_cached_idx = cached_idx
if idx == cached_idx {
idx := atomic.LoadUint64(&q.cidx)
cachedIdx := q.pcachedIdx
if idx == cachedIdx {
cachedIdx = atomic.LoadUint64(&q.pidx)
q.pcachedIdx = cachedIdx
if idx == cachedIdx {
return
}
}
item = q.items[idx]
q.items[idx] = nil
ok = true
next_idx := idx + 1
if next_idx == q.cap {
next_idx = 0
nextIdx := idx + 1
if nextIdx == q.cap {
nextIdx = 0
}
atomic.StoreUint64(&q.c_idx, next_idx)
atomic.StoreUint64(&q.cidx, nextIdx)
return
}
36 changes: 18 additions & 18 deletions spscqueueof.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ import (
// Based on the data structure from the following article:
// https://rigtorp.se/ringbuffer/
type SPSCQueueOf[I any] struct {
cap uint64
p_idx uint64
cap uint64
pidx uint64
//lint:ignore U1000 prevents false sharing
pad0 [cacheLineSize - 8]byte
p_cached_idx uint64
pad0 [cacheLineSize - 8]byte
pcachedIdx uint64
//lint:ignore U1000 prevents false sharing
pad1 [cacheLineSize - 8]byte
c_idx uint64
pad1 [cacheLineSize - 8]byte
cidx uint64
//lint:ignore U1000 prevents false sharing
pad2 [cacheLineSize - 8]byte
c_cached_idx uint64
pad2 [cacheLineSize - 8]byte
ccachedIdx uint64
//lint:ignore U1000 prevents false sharing
pad3 [cacheLineSize - 8]byte
items []I
Expand All @@ -51,21 +51,21 @@ func NewSPSCQueueOf[I any](capacity int) *SPSCQueueOf[I] {
// full and the item was inserted.
func (q *SPSCQueueOf[I]) TryEnqueue(item I) bool {
// relaxed memory order would be enough here
idx := atomic.LoadUint64(&q.p_idx)
idx := atomic.LoadUint64(&q.pidx)
next_idx := idx + 1
if next_idx == q.cap {
next_idx = 0
}
cached_idx := q.c_cached_idx
cached_idx := q.ccachedIdx
if next_idx == cached_idx {
cached_idx = atomic.LoadUint64(&q.c_idx)
q.c_cached_idx = cached_idx
cached_idx = atomic.LoadUint64(&q.cidx)
q.ccachedIdx = cached_idx
if next_idx == cached_idx {
return false
}
}
q.items[idx] = item
atomic.StoreUint64(&q.p_idx, next_idx)
atomic.StoreUint64(&q.pidx, next_idx)
return true
}

Expand All @@ -74,11 +74,11 @@ func (q *SPSCQueueOf[I]) TryEnqueue(item I) bool {
// indicates that the queue isn't empty and an item was retrieved.
func (q *SPSCQueueOf[I]) TryDequeue() (item I, ok bool) {
// relaxed memory order would be enough here
idx := atomic.LoadUint64(&q.c_idx)
cached_idx := q.p_cached_idx
idx := atomic.LoadUint64(&q.cidx)
cached_idx := q.pcachedIdx
if idx == cached_idx {
cached_idx = atomic.LoadUint64(&q.p_idx)
q.p_cached_idx = cached_idx
cached_idx = atomic.LoadUint64(&q.pidx)
q.pcachedIdx = cached_idx
if idx == cached_idx {
return
}
Expand All @@ -91,6 +91,6 @@ func (q *SPSCQueueOf[I]) TryDequeue() (item I, ok bool) {
if next_idx == q.cap {
next_idx = 0
}
atomic.StoreUint64(&q.c_idx, next_idx)
atomic.StoreUint64(&q.cidx, next_idx)
return
}

0 comments on commit 8f45cbe

Please sign in to comment.