Skip to content

Commit

Permalink
Fixed Tee, added another test
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonbot committed Nov 28, 2024
1 parent d15b783 commit c3bd53a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
32 changes: 32 additions & 0 deletions cookbook_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package chains

import (
"fmt"
"maps"
"slices"
"strings"
Expand Down Expand Up @@ -137,3 +138,34 @@ func TestAllStreetFighterMatches(t *testing.T) {
t.Fatalf("%v != %v", allFights, allExpectedFights)
}
}

func TestTeeAndMap(t *testing.T) {
numbersToCompute := []int{1, 2, 3, 4, 10, 20, 50, 100}
expectedValues := []string{
"2 + 10 = 12",
"4 + 20 = 24",
"6 + 30 = 36",
"8 + 40 = 48",
"20 + 100 = 120",
"40 + 200 = 240",
"100 + 500 = 600",
"200 + 1000 = 1200",
}
iter1, iter2 := Tee(Each(numbersToCompute))

doubler := ChainFromIterator(iter1).Map(func(i int) int { return i * 2 })
tenner := ChainFromIterator(iter2).Map(func(i int) int { return i * 10 })
calculatedValues := ChainJunction2[int, int, string](Chain2FromIterator(
Zip(
doubler.Each(),
tenner.Each(),
),
)).Map(
func(a int, b int) string {
return fmt.Sprintf("%v + %v = %v", a, b, a+b)
},
).Slice()
if !slices.Equal(calculatedValues, expectedValues) {
t.Fatalf("%v != %v", calculatedValues, expectedValues)
}
}
11 changes: 7 additions & 4 deletions core.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,10 @@ func Tee[T any](in iter.Seq[T]) (iter.Seq[T], iter.Seq[T]) {
iter2Queue := []T{}

next, done := iter.Pull(in)
defer done()

iter1 := func(yield func(T) bool) {
defer done()

for {
if len(iter1Queue) == 0 {
if exhausted {
Expand All @@ -246,7 +247,7 @@ func Tee[T any](in iter.Seq[T]) (iter.Seq[T], iter.Seq[T]) {

iter1Queue = append(iter1Queue, nextval)
if !done2 {
iter2Queue = append(iter1Queue, nextval)
iter2Queue = append(iter2Queue, nextval)
}
}

Expand All @@ -261,6 +262,8 @@ func Tee[T any](in iter.Seq[T]) (iter.Seq[T], iter.Seq[T]) {
}

iter2 := func(yield func(T) bool) {
defer done()

for {
if len(iter2Queue) == 0 {
if exhausted {
Expand All @@ -276,11 +279,11 @@ func Tee[T any](in iter.Seq[T]) (iter.Seq[T], iter.Seq[T]) {
if !done1 {
iter1Queue = append(iter1Queue, nextval)
}
iter2Queue = append(iter1Queue, nextval)
iter2Queue = append(iter2Queue, nextval)
}

nextval := iter2Queue[0]
iter1Queue = iter2Queue[1:]
iter2Queue = iter2Queue[1:]

if !yield(nextval) {
done2 = true
Expand Down

0 comments on commit c3bd53a

Please sign in to comment.