Skip to content

Commit

Permalink
Add (s *Set) From(int) iter.Seq[int]
Browse files Browse the repository at this point in the history
  • Loading branch information
takeyourhatoff committed Sep 5, 2024
1 parent fe1772d commit 2446064
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 19 deletions.
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
language: go

go:
- "1.10.x"
- "1.23.x"
- "1.x"
- tip
46 changes: 32 additions & 14 deletions bitset.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package bitset

import (
"encoding/binary"
"iter"
"math/bits"
"strconv"
"strings"
)

const maxUint = 1<<bits.UintSize - 1
const maxUint = ^uint(0)

// Set represents a set of positive integers. Memory usage is proportional to the largest integer in the Set.
type Set struct {
Expand Down Expand Up @@ -44,8 +45,8 @@ func (s *Set) AddRange(low, hi int) {
s.grow(hi)
w0, _ := idx(low)
w1, _ := idx(hi - 1)
leftMask := uint(maxUint) << (uint(low) % bits.UintSize)
rightMask := uint(maxUint) >> (uint(bits.UintSize-hi) % bits.UintSize)
leftMask := maxUint << (uint(low) % bits.UintSize)
rightMask := maxUint >> (uint(bits.UintSize-hi) % bits.UintSize)
if w1 == w0 {
s.s[w0] |= leftMask & rightMask
return
Expand Down Expand Up @@ -86,8 +87,8 @@ func (s *Set) RemoveRange(low, hi int) {
hi = len(s.s) * bits.UintSize
w1 = len(s.s) - 1
}
leftMask := uint(maxUint) << (uint(low) % bits.UintSize)
rightMask := uint(maxUint) >> (uint(bits.UintSize-hi) % bits.UintSize)
leftMask := maxUint << (uint(low) % bits.UintSize)
rightMask := maxUint >> (uint(bits.UintSize-hi) % bits.UintSize)
if w1 == w0 {
s.s[w0] &^= leftMask & rightMask
return
Expand Down Expand Up @@ -187,19 +188,43 @@ func (s *Set) Equal(ss *Set) bool {
return true
}

// From returns a sequence of integers in s starting at i.
func (s *Set) From(i int) iter.Seq[int] {
return func(yield func(int) bool) {
if i < 0 {
i = 0
}
si := i / bits.UintSize
for idx := si; idx < len(s.s); idx++ {
word := s.s[idx]
if idx == si {
word &= maxUint << (i % bits.UintSize)
}
for word != 0 {
j := bits.TrailingZeros(word)
if !yield(idx*bits.UintSize + j) {
return
}
word &^= 1 << j
}
}
}
}

// NextAfter returns the smallest integer in s greater than or equal to i or -1 if no such integer exists.
func (s *Set) NextAfter(i int) int {
if i < 0 {
// There can be no integers in s less than 0 by definition
i = 0
}
mask := uint(maxUint) << (uint(i) % bits.UintSize)
mask := maxUint << (uint(i) % bits.UintSize)
for j := i / bits.UintSize; j < len(s.s); j++ {
word := s.s[j] & mask
mask = maxUint
if word != 0 {
return j*bits.UintSize + bits.TrailingZeros(word)
}
word &^= 1 << j
}
return -1
}
Expand All @@ -221,7 +246,7 @@ func (s *Set) String() string {
var buf strings.Builder
buf.WriteByte('[')
first := true
for i := s.NextAfter(0); i >= 0; i = s.NextAfter(i + 1) {
for i := range s.From(0) {
if !first {
buf.WriteByte(' ')
}
Expand Down Expand Up @@ -285,10 +310,3 @@ func idx(i int) (w int, mask uint) {
mask = 1 << (uint(i) % bits.UintSize)
return
}

func min(i, j int) int {
if i < j {
return i
}
return j
}
54 changes: 50 additions & 4 deletions bitset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,24 @@ package bitset
import (
"bytes"
"fmt"
"iter"
"math/bits"
"math/rand"
"reflect"
"runtime"
"slices"
"sort"
"testing"
"testing/quick"
)

// NextAfter can be used to iterate over the elements of the set.
func ExampleSet_NextAfter() {
// From can be used to iterate over the elements of the set.
func ExampleSet_From() {
s := new(Set)
s.Add(2)
s.Add(42)
s.Add(13)
for i := s.NextAfter(0); i >= 0; i = s.NextAfter(i + 1) {
for i := range s.From(0) {
fmt.Println(i)
}
// Output:
Expand Down Expand Up @@ -211,7 +213,7 @@ func TestNextAfter(t *testing.T) {
}
var n int
var oldi int
for i := b.NextAfter(0); i >= 0; i = b.NextAfter(i + 1) {
for i := b.NextAfter(-1); i >= 0; i = b.NextAfter(i + 1) {
if l[n] != i {
t.Logf("b.NextAfter(%d) = %d, expected %d", oldi, i, l[n])
return false
Expand All @@ -226,6 +228,31 @@ func TestNextAfter(t *testing.T) {
}
}

func TestFrom(t *testing.T) {
f := func(l ascendingInts, fstart float64) bool {
b := new(Set)
for _, i := range l {
b.Add(i)
}
start := int(fstart*float64(len(l))) - 1
got := slices.Collect(b.From(start))
var want ascendingInts
for _, num := range l {
if num >= start {
want = append(want, num)
}
}
if !slices.Equal(got, want) {
t.Logf("b.From(%d) = %v, expected %v", start, got, want)
return false
}
return true
}
if err := quick.Check(f, nil); err != nil {
t.Error(err)
}
}

func TestBytes(t *testing.T) {
f := func(data0 []byte) bool {
// Get rid of trailing zero bytes
Expand Down Expand Up @@ -429,6 +456,25 @@ func BenchmarkNextAfter(b *testing.B) {
}
}

func BenchmarkFrom(b *testing.B) {
buf := make([]byte, 10000)
rand.Read(buf)
s := new(Set)
s.FromBytes(buf)
next, stop := iter.Pull(s.From(0))
defer stop()
var x int
b.ResetTimer()
for i := 0; i < b.N; i++ {
var ok bool
x, ok = next()
if !ok {
x = 0
}
}
_ = x
}

func bitwiseF(f func(p, q bool) bool, l0, l1 []int) []int {
var x []int
lim := max(l0, l1)
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
module github.com/takeyourhatoff/bitset

go 1.23.0

0 comments on commit 2446064

Please sign in to comment.