Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster continuous/discrete splitting #1286

Merged
merged 9 commits into from
Oct 27, 2022
83 changes: 52 additions & 31 deletions packages/squiggle-lang/src/rescript/Utility/E/E_A.res
Original file line number Diff line number Diff line change
Expand Up @@ -307,48 +307,69 @@ module Floats = {
/*
This function goes through a sorted array and divides it into two different clusters:
continuous samples and discrete samples. The discrete samples are stored in a mutable map.
Samples are thought to be discrete if they have at least `minDiscreteWight` duplicates.

If the min discrete weight is 4, that would mean that at least four elements needed from a specific
value for that to be kept as discrete. This is important because in some cases, we can expect that
some common elements will be generated by regular operations. The final continuous array will be sorted.

This function is performance-critical, don't change it significantly without benchmarking
SampleSet->PointSet conversion performance.
*/
Samples are considered discrete if they have at least `minDiscreteWight` duplicates.
Using a `minDiscreteWight` higher than 2 is important because sometimes common elements
will be generated by regular operations.
The final continuous array will be sorted.

The method here is designed for high performance for fairly small `minDiscreteWight`
values for both mostly-continuous and mostly-discrete inputs.
For each position i it visits, it compares it to the place where a run starting at i would end.
For continuous distributions, this comparison is always false, keeping branch prediction costs low.
If the comparison is true, it finds the complete run with recursive doubling then a binary search,
which skips over many elements for long runs.
*/
let splitContinuousAndDiscreteForMinWeight = (
sortedArray: array<float>,
~minDiscreteWeight: int,
) => {
let continuous: array<float> = []
let discrete = FloatFloatMap.empty()
mlochbaum marked this conversation as resolved.
Show resolved Hide resolved

let addData = (count: int, value: float): unit => {
if count >= minDiscreteWeight {
FloatFloatMap.add(value, count->Belt.Int.toFloat, discrete)
// In a run of exactly minDiscreteWeight, the first and last
// element indices differ by minDistance.
let minDistance = minDiscreteWeight - 1

let len = length(sortedArray)
Copy link
Contributor

@OAGr OAGr Oct 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fairly long function that took me a while to wrap my head around. I'd probably be conservative and have things spelled out more, especially at the top level. (Tiny names in tiny functions are fine.)

Some examples include:

  • len -> sortedArrayLen
  • i -> sortedArrayIndex (or sortedArrayI)
  • minDistance -> minDistanceOfSameValue
  • value -> indexValue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can change it, but I find that style harder to read, as sortedArrayLen and sortedArrayIndex aren't as obviously different as i and len. There's only one array. Is it really that hard to guess what i indexes into?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are several indexes. (i, i0, base, j, lo, mid). This, at very least, seems to have pretty terse naming to me.

I haven't seen much use of base, lo before. It seems likely that some of this follows lower level programming conventions I'm not used to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, lo, mid, hi is very common for binary searches (using i for hi partly because declaring mutable variables in rescript is so annoying). I use i0 here to mean a saved copy of i; maybe iOrig for i0 and iNext for j would be better. base is the base of the binary search.

let i = ref(0)
while i.contents < len - minDistance {
// We are interested in runs of elements equal to value
let value = sortedArray[i.contents]
if value != sortedArray[i.contents + minDistance] {
// No long run starting at i, so it's continuous
Js.Array2.push(continuous, value)->ignore
OAGr marked this conversation as resolved.
Show resolved Hide resolved
i := i.contents + 1
Copy link
Contributor

@OAGr OAGr Oct 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tiny optimization, but I assume at this point you could push all of the same values up to sortedArray[i.contents + minDistance], to continuous. Not sure if this would actually speed it up though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it more, this is probably not worth doing, it would add too much complexity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't, since a run might start anywhere in the middle: if you have 5 6 6 6 6 6 and test with minDistance of 3, you see that 5 doesn't start a run but the next 6 actually does. It's possible to skip values if you use a smaller stride. I'll explain this as a top-level comment.

Copy link
Contributor

@OAGr OAGr Oct 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I said same value.

If you have 55588888, and line 334 catches, then you rule all of the next 5s out. So start a second loop and increment until you get to a value that's not a 5. Or do a binary search here if you think that minDistance could be really high.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One annoying thing is that I presume we need to accept some heuristic of what the data would look like. If values were very likely on being unique, conditional on not having minDiscreteWeight duplicates, then the current code is probably close to optimal.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, got it. It's better to search for values not equal to the last one, rather than those that are equal to the first one. The test in the second loop would run at pretty much the same speed as the main one, and there's a hard-to-predict branch when it stops, so I don't think it would be an improvement with a forwards loop. If you run the loop backwards it's fairly similar to my minDiscreteWeight / 2 version, but splitting it up as minDiscreteWeight - 1 and 1 instead.

} else {
for _ in 1 to count {
continuous->Js.Array2.push(value)->ignore
// Now we know that a run starts at i
// Move i forward to next unequal value
// That is, find iNext so that isEqualAt(iNext-1) and !isEqualAt(iNext)
let iOrig = i.contents
// Find base so that iNext is in (iOrig+base, iOrig+2*base]
// This is where we start the binary search
let base = ref(minDistance)
let isEqualAt = (ind: int) => ind < len && sortedArray[ind] == value
while isEqualAt(iOrig + base.contents * 2) {
base := base.contents * 2
}
}
}

let (finalCount, finalValue) = sortedArray->Belt.Array.reduce(
// initial prev value doesn't matter; if it collides with the first element of the array, flush won't do anything
(0, 0.),
((count, prev), element) => {
if element == prev {
(count + 1, prev)
} else {
// new value, process previous ones
addData(count, prev)
(1, element)
// Maintain iNext in (lo, i]. Once lo+1 == i, i is iNext.
let lo = ref(iOrig + base.contents)
i := Js.Math.min_int(lo.contents + base.contents, len)
while i.contents - lo.contents > 1 {
let mid = lo.contents + (i.contents - lo.contents) / 2
mlochbaum marked this conversation as resolved.
Show resolved Hide resolved
if sortedArray[mid] == value {
lo := mid
} else {
i := mid
}
}
},
)

// flush final values
addData(finalCount, finalValue)
let count = i.contents - iOrig
FloatFloatMap.add(value, count->Belt.Int.toFloat, discrete)
}
}
// Remaining values are continuous
let tail = Belt.Array.sliceToEnd(sortedArray, i.contents)
Js.Array2.pushMany(continuous, tail)->ignore

(continuous, discrete)
}
Expand Down