diff --git a/packages/squiggle-lang/__tests__/E/splitContinuousAndDiscrete_test.res b/packages/squiggle-lang/__tests__/E/splitContinuousAndDiscrete_test.res index 8977d96148..ca3e724b1b 100644 --- a/packages/squiggle-lang/__tests__/E/splitContinuousAndDiscrete_test.res +++ b/packages/squiggle-lang/__tests__/E/splitContinuousAndDiscrete_test.res @@ -1,9 +1,12 @@ open Jest open TestHelpers +open FastCheck +open Arbitrary +open Property.Sync let prepareInputs = (ar, minWeight) => - E.A.Floats.Sorted.splitContinuousAndDiscreteForMinWeight(ar, ~minDiscreteWeight=minWeight) |> ( - ((c, disc)) => (c, disc |> E.FloatFloatMap.toArray) + E.A.Floats.Sorted.splitContinuousAndDiscreteForMinWeight(ar, ~minDiscreteWeight=minWeight)->( + ((c, disc)) => (c, disc->E.FloatFloatMap.toArray) ) describe("Continuous and discrete splits", () => { @@ -33,22 +36,52 @@ describe("Continuous and discrete splits", () => { let makeDuplicatedArray = count => { let arr = Belt.Array.range(1, count)->E.A.fmap(float_of_int) - let sorted = arr |> Belt.SortArray.stableSortBy(_, compare) - E.A.concatMany([sorted, sorted, sorted, sorted]) |> Belt.SortArray.stableSortBy(_, compare) + let sorted = arr->Belt.SortArray.stableSortBy(compare) + E.A.concatMany([sorted, sorted, sorted, sorted])->Belt.SortArray.stableSortBy(compare) } let (_, discrete1) = E.A.Floats.Sorted.splitContinuousAndDiscreteForMinWeight( makeDuplicatedArray(10), ~minDiscreteWeight=2, ) - let toArr1 = discrete1 |> E.FloatFloatMap.toArray - makeTest("splitMedium at count=10", toArr1 |> E.A.length, 10) + let toArr1 = discrete1->E.FloatFloatMap.toArray + makeTest("splitMedium at count=10", toArr1->E.A.length, 10) let (_c, discrete2) = E.A.Floats.Sorted.splitContinuousAndDiscreteForMinWeight( makeDuplicatedArray(500), ~minDiscreteWeight=2, ) - let toArr2 = discrete2 |> E.FloatFloatMap.toArray - makeTest("splitMedium at count=500", toArr2 |> E.A.length, 500) - // makeTest("foo", [] |> E.A.length, 500) + let toArr2 = discrete2->E.FloatFloatMap.toArray + makeTest("splitMedium at count=500", toArr2->E.A.length, 500) + // makeTest("foo", [] -> E.A.length, 500) + + // Function for fast-check property testing + let testSegments = (counts, weight) => { + // Make random-length segments, join them, and split continuous/discrete + let random = _ => 0.01 +. Js.Math.random() // random() can produce 0 + let values = counts->E.A.length->E.A.makeBy(random)->E.A.Floats.cumSum + let segments = Belt.Array.zipBy(counts, values, Belt.Array.make) + let result = prepareInputs(segments->E.A.concatMany, weight) + + // Then split based on the segment length directly + let (contSegments, discSegments) = segments->Belt.Array.partition(s => E.A.length(s) < weight) + let expect = ( + contSegments->E.A.concatMany, + discSegments->E.A.fmap(a => (E.A.unsafe_get(a, 0), E.A.length(a)->Belt.Int.toFloat)), + ) + + makeTest("fast-check testing", result, expect) + true + } + + // rescript-fast-check's integerRange is broken, so we have to use nat plus a minimum + let testSegmentsCorrected = (counts, weight) => + testSegments(counts->E.A.fmap(c => 1 + c), weight + 2) + assert_( + property2( + Combinators.arrayWithLength(nat(~max=30, ()), 0, 50), + nat(~max=20, ()), + testSegmentsCorrected, + ), + ) }) diff --git a/packages/squiggle-lang/src/rescript/Utility/E/E_A.res b/packages/squiggle-lang/src/rescript/Utility/E/E_A.res index 6e04600cb3..2639838d47 100644 --- a/packages/squiggle-lang/src/rescript/Utility/E/E_A.res +++ b/packages/squiggle-lang/src/rescript/Utility/E/E_A.res @@ -254,15 +254,19 @@ module Floats = { /* This function goes through a sorted array and divides it into two different clusters: continuous samples and discrete samples. The discrete samples are stored in a mutable map. - Samples are thought to be discrete if they have at least `minDiscreteWight` duplicates. - - If the min discrete weight is 4, that would mean that at least four elements needed from a specific - value for that to be kept as discrete. This is important because in some cases, we can expect that - some common elements will be generated by regular operations. The final continuous array will be sorted. - - This function is performance-critical, don't change it significantly without benchmarking - SampleSet->PointSet conversion performance. + Samples are considered discrete if they have at least `minDiscreteWight` duplicates. + Using a `minDiscreteWight` higher than 2 is important because sometimes common elements + will be generated by regular operations. + The final continuous array will be sorted. + + The method here is designed for high performance for fairly small `minDiscreteWight` + values for both mostly-continuous and mostly-discrete inputs. + For each position i it visits, it compares it to the place where a run starting at i would end. + For continuous distributions, this comparison is always false, keeping branch prediction costs low. + If the comparison is true, it finds the complete run with recursive doubling then a binary search, + which skips over many elements for long runs. */ + exception BadWeight(string) let splitContinuousAndDiscreteForMinWeight = ( sortedArray: array, ~minDiscreteWeight: int, @@ -270,32 +274,56 @@ module Floats = { let continuous: array = [] let discrete = FloatFloatMap.empty() - let addData = (count: int, value: float): unit => { - if count >= minDiscreteWeight { - FloatFloatMap.add(value, count->Belt.Int.toFloat, discrete) - } else { - for _ in 1 to count { - continuous->Js.Array2.push(value)->ignore - } - } + // Weight of 1 is pointless because it indicates only discrete values, + // and causes an infinite loop in the doubling search used here. + if minDiscreteWeight <= 1 { + raise(BadWeight("Minimum discrete weight must be at least 1")) } - let (finalCount, finalValue) = sortedArray->Belt.Array.reduce( - // initial prev value doesn't matter; if it collides with the first element of the array, flush won't do anything - (0, 0.), - ((count, prev), element) => { - if element == prev { - (count + 1, prev) - } else { - // new value, process previous ones - addData(count, prev) - (1, element) + // In a run of exactly minDiscreteWeight, the first and last + // element indices differ by minDistance. + let minDistance = minDiscreteWeight - 1 + + let len = length(sortedArray) + let i = ref(0) + while i.contents < len - minDistance { + // We are interested in runs of elements equal to value + let value = sortedArray[i.contents] + if value != sortedArray[i.contents + minDistance] { + // No long run starting at i, so it's continuous + Js.Array2.push(continuous, value)->ignore + i := i.contents + 1 + } else { + // Now we know that a run starts at i + // Move i forward to next unequal value + // That is, find iNext so that isEqualAt(iNext-1) and !isEqualAt(iNext) + let iOrig = i.contents + // Find base so that iNext is in (iOrig+base, iOrig+2*base] + // This is where we start the binary search + let base = ref(minDistance) + let isEqualAt = (ind: int) => ind < len && sortedArray[ind] == value + while isEqualAt(iOrig + base.contents * 2) { + base := base.contents * 2 + } + // Maintain iNext in (lo, i]. Once lo+1 == i, i is iNext. + let lo = ref(iOrig + base.contents) + i := Js.Math.min_int(lo.contents + base.contents, len) + while i.contents - lo.contents > 1 { + let mid = lo.contents + (i.contents - lo.contents) / 2 + if sortedArray[mid] == value { + lo := mid + } else { + i := mid + } } - }, - ) - // flush final values - addData(finalCount, finalValue) + let count = i.contents - iOrig + FloatFloatMap.add(value, count->Belt.Int.toFloat, discrete) + } + } + // Remaining values are continuous + let tail = Belt.Array.sliceToEnd(sortedArray, i.contents) + Js.Array2.pushMany(continuous, tail)->ignore (continuous, discrete) }