Skip to content

Commit

Permalink
Swap argument order for most point-set distribution code
Browse files Browse the repository at this point in the history
  • Loading branch information
mlochbaum committed Oct 24, 2022
1 parent 1e9cec6 commit 7e360b2
Show file tree
Hide file tree
Showing 11 changed files with 157 additions and 165 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ describe("klDivergence: continuous -> continuous -> float", () => {
let stdev1 = 4.0
let stdev2 = 1.0

let prediction =
normalMakeR(mean1, stdev1)->E.R.errMap(s => DistributionTypes.ArgumentError(s))
let prediction = normalMakeR(mean1, stdev1)->E.R.errMap(s => DistributionTypes.ArgumentError(s))
let answer = normalMakeR(mean2, stdev2)->E.R.errMap(s => DistributionTypes.ArgumentError(s))
// https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians
let analyticalKl =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,8 @@ let pointwiseCombinationFloat = (
PointSetDist.T.mapYResult(
~integralSumCacheFn=integralSumCacheFn(f),
~integralCacheFn=integralCacheFn(f),
~fn=fn(f),
t,
fn(f),
)->E.R.errMap(x => DistributionTypes.OperationError(x))
})
let m = switch algebraicCombination {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,14 @@ let toLinear = (t: t): option<t> =>
xyShape->XYShape.Range.stepsToContinuous->E.O.fmap(make(~integralSumCache, ~integralCache))
| {interpolation: #Linear} => Some(t)
}
let shapeFn = (fn, t: t) => t->getShape->fn
let shapeFn = (t: t, fn) => t->getShape->fn

let updateIntegralSumCache = (integralSumCache, t: t): t => {
let updateIntegralSumCache = (t: t, integralSumCache): t => {
...t,
integralSumCache,
}

let updateIntegralCache = (integralCache, t: t): t => {...t, integralCache}
let updateIntegralCache = (t: t, integralCache): t => {...t, integralCache}

let sum = (
~integralSumCachesFn: (float, float) => option<float>=(_, _) => None,
Expand All @@ -151,9 +151,9 @@ let sum = (
)

let reduce = (
continuousShapes,
~integralSumCachesFn: (float, float) => option<float>=(_, _) => None,
fn: (float, float) => result<float, 'e>,
continuousShapes,
): result<t, 'e> => {
let merge = combinePointwise(~integralSumCachesFn, fn)
continuousShapes->E.A.R.foldM(merge, empty, _)
Expand All @@ -162,10 +162,12 @@ let reduce = (
let mapYResult = (
~integralSumCacheFn=_ => None,
~integralCacheFn=_ => None,
~fn: float => result<float, 'e>,
t: t,
fn: float => result<float, 'e>,
): result<t, 'e> =>
XYShape.T.mapYResult(fn, getShape(t))->E.R.fmap(x =>
getShape(t)
->XYShape.T.mapYResult(fn)
->E.R.fmap(x =>
make(
~interpolation=t.interpolation,
~integralSumCache=t.integralSumCache->E.O.bind(integralSumCacheFn),
Expand All @@ -177,47 +179,47 @@ let mapYResult = (
let mapY = (
~integralSumCacheFn=_ => None,
~integralCacheFn=_ => None,
~fn: float => float,
t: t,
fn: float => float,
): t =>
make(
~interpolation=t.interpolation,
~integralSumCache=t.integralSumCache->E.O.bind(integralSumCacheFn),
~integralCache=t.integralCache->E.O.bind(integralCacheFn),
t->getShape->XYShape.T.mapY(fn, _),
t->getShape->XYShape.T.mapY(fn),
)

let rec scaleBy = (~scale=1.0, t: t): t => {
let rec scaleBy = (t: t, scale): t => {
let scaledIntegralSumCache = E.O.bind(t.integralSumCache, v => Some(scale *. v))
let scaledIntegralCache = E.O.bind(t.integralCache, v => Some(scaleBy(~scale, v)))
let scaledIntegralCache = E.O.bind(t.integralCache, v => Some(scaleBy(v, scale)))

t
->mapY(~fn=(r: float) => r *. scale, _)
->updateIntegralSumCache(scaledIntegralSumCache, _)
->updateIntegralCache(scaledIntegralCache, _)
->mapY((r: float) => r *. scale)
->updateIntegralSumCache(scaledIntegralSumCache)
->updateIntegralCache(scaledIntegralCache)
}

module T = Dist({
type t = PointSetTypes.continuousShape
type integral = PointSetTypes.continuousShape
let minX = shapeFn(XYShape.T.minX)
let maxX = shapeFn(XYShape.T.maxX)
let minX = shapeFn(_, XYShape.T.minX)
let maxX = shapeFn(_, XYShape.T.maxX)
let mapY = mapY
let mapYResult = mapYResult
let updateIntegralCache = updateIntegralCache
let toDiscreteProbabilityMassFraction = _ => 0.0
let toPointSetDist = (t: t): PointSetTypes.pointSetDist => Continuous(t)
let xToY = (f, {interpolation, xyShape}: t) =>
switch interpolation {
| #Stepwise => xyShape->XYShape.XtoY.stepwiseIncremental(f, _)->E.O.default(0.0)
| #Linear => xyShape->XYShape.XtoY.linear(f, _)
| #Stepwise => xyShape->XYShape.XtoY.stepwiseIncremental(f)->E.O.default(0.0)
| #Linear => xyShape->XYShape.XtoY.linear(f)
}->PointSetTypes.MixedPoint.makeContinuous

let truncate = (leftCutoff: option<float>, rightCutoff: option<float>, t: t) => {
let lc = E.O.default(leftCutoff, neg_infinity)
let rc = E.O.default(rightCutoff, infinity)
let truncatedZippedPairs =
t->getShape->XYShape.T.zip->XYShape.Zipped.filterByX(x => x >= lc && x <= rc, _)
t->getShape->XYShape.T.zip->XYShape.Zipped.filterByX(x => x >= lc && x <= rc)

let leftNewPoint = leftCutoff->E.O.dimap(lc => [(lc -. epsilon_float, 0.)], _ => [])
let rightNewPoint = rightCutoff->E.O.dimap(rc => [(rc +. epsilon_float, 0.)], _ => [])
Expand Down Expand Up @@ -246,18 +248,18 @@ module T = Dist({
}

let downsample = (length, t): t =>
t->shapeMap(XYShape.XsConversion.proportionByProbabilityMass(length, integral(t).xyShape))
t->shapeMap(XYShape.XsConversion.proportionByProbabilityMass(_, length, integral(t).xyShape))
let integralEndY = (t: t) => t.integralSumCache->E.O.defaultFn(() => t->integral->lastY)
let integralXtoY = (f, t: t) => t->integral->shapeFn(XYShape.XtoY.linear(f), _)
let integralYtoX = (f, t: t) => t->integral->shapeFn(XYShape.YtoX.linear(f), _)
let integralXtoY = (f, t: t) => t->integral->shapeFn(XYShape.XtoY.linear(_, f))
let integralYtoX = (f, t: t) => t->integral->shapeFn(XYShape.YtoX.linear(_, f))
let toContinuous = t => Some(t)
let toDiscrete = _ => None

let normalize = (t: t): t =>
t
->updateIntegralCache(Some(integral(t)), _)
->scaleBy(~scale=1. /. integralEndY(t), _)
->updateIntegralSumCache(Some(1.0), _)
->updateIntegralCache(Some(integral(t)))
->scaleBy(1. /. integralEndY(t))
->updateIntegralSumCache(Some(1.0))

let mean = (t: t) => {
let indefiniteIntegralStepwise = (p, h1) => h1 *. p ** 2.0 /. 2.0
Expand All @@ -270,13 +272,13 @@ module T = Dist({
})

let isNormalized = (t: t): bool => {
let areaUnderIntegral = t->updateIntegralCache(Some(T.integral(t)), _)->T.integralEndY
let areaUnderIntegral = t->updateIntegralCache(Some(T.integral(t)))->T.integralEndY
areaUnderIntegral < 1. +. MagicNumbers.Epsilon.seven &&
areaUnderIntegral > 1. -. MagicNumbers.Epsilon.seven
}

let downsampleEquallyOverX = (length, t): t =>
t->shapeMap(XYShape.XsConversion.proportionEquallyOverX(length))
t->shapeMap(XYShape.XsConversion.proportionEquallyOverX(_, length))

/* This simply creates multiple copies of the continuous distribution, scaled and shifted according to
each discrete data point, and then adds them all together. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ let empty: PointSetTypes.discreteShape = {
integralCache: Some(emptyIntegral),
}

let shapeFn = (fn, t: t) => t->getShape->fn
let shapeFn = (t: t, fn) => t->getShape->fn

let lastY = (t: t) => t->getShape->XYShape.T.lastY

Expand All @@ -53,20 +53,20 @@ let combinePointwise = (
}

let reduce = (
discreteShapes: array<PointSetTypes.discreteShape>,
~integralSumCachesFn=(_, _) => None,
fn: (float, float) => result<float, 'e>,
discreteShapes: array<PointSetTypes.discreteShape>,
): result<t, 'e> => {
let merge = combinePointwise(~integralSumCachesFn, ~fn)
discreteShapes->E.A.R.foldM(merge, empty, _)
}

let updateIntegralSumCache = (integralSumCache, t: t): t => {
let updateIntegralSumCache = (t: t, integralSumCache): t => {
...t,
integralSumCache,
}

let updateIntegralCache = (integralCache, t: t): t => {
let updateIntegralCache = (t: t, integralCache): t => {
...t,
integralCache,
}
Expand Down Expand Up @@ -107,10 +107,12 @@ let combineAlgebraically = (op: Operation.convolutionOperation, t1: t, t2: t): t
let mapYResult = (
~integralSumCacheFn=_ => None,
~integralCacheFn=_ => None,
~fn: float => result<float, 'e>,
t: t,
fn: float => result<float, 'e>,
): result<t, 'e> =>
XYShape.T.mapYResult(fn, getShape(t))->E.R.fmap(x =>
getShape(t)
->XYShape.T.mapYResult(fn)
->E.R.fmap(x =>
make(
~integralSumCache=t.integralSumCache->E.O.bind(integralSumCacheFn),
~integralCache=t.integralCache->E.O.bind(integralCacheFn),
Expand All @@ -121,23 +123,23 @@ let mapYResult = (
let mapY = (
~integralSumCacheFn=_ => None,
~integralCacheFn=_ => None,
~fn: float => float,
t: t,
fn: float => float,
): t =>
make(
~integralSumCache=t.integralSumCache->E.O.bind(integralSumCacheFn),
~integralCache=t.integralCache->E.O.bind(integralCacheFn),
t->getShape->XYShape.T.mapY(fn, _),
t->getShape->XYShape.T.mapY(fn),
)

let scaleBy = (~scale=1.0, t: t): t => {
let scaleBy = (t: t, scale): t => {
let scaledIntegralSumCache = t.integralSumCache->E.O.fmap(\"*."(scale))
let scaledIntegralCache = t.integralCache->E.O.fmap(Continuous.scaleBy(~scale))
let scaledIntegralCache = t.integralCache->E.O.fmap(Continuous.scaleBy(_, scale))

t
->mapY(~fn=(r: float) => r *. scale, _)
->updateIntegralSumCache(scaledIntegralSumCache, _)
->updateIntegralCache(scaledIntegralCache, _)
->mapY((r: float) => r *. scale)
->updateIntegralSumCache(scaledIntegralSumCache)
->updateIntegralCache(scaledIntegralCache)
}

module T = Dist({
Expand All @@ -152,16 +154,15 @@ module T = Dist({
// The first xy of this integral should always be the zero, to ensure nice plotting
let firstX = ts->XYShape.T.minX
let prependedZeroPoint: XYShape.T.t = {xs: [firstX -. epsilon_float], ys: [0.]}
let integralShape =
ts->XYShape.T.concat(prependedZeroPoint, _)->XYShape.T.accumulateYs(\"+.", _)
let integralShape = ts->XYShape.T.concat(prependedZeroPoint, _)->XYShape.T.accumulateYs(\"+.")

Continuous.make(~interpolation=#Stepwise, integralShape)
}

let integralEndY = (t: t) =>
t.integralSumCache->E.O.defaultFn(() => t->integral->Continuous.lastY)
let minX = shapeFn(XYShape.T.minX)
let maxX = shapeFn(XYShape.T.maxX)
let minX = shapeFn(_, XYShape.T.minX)
let maxX = shapeFn(_, XYShape.T.maxX)
let toDiscreteProbabilityMassFraction = _ => 1.0
let mapY = mapY
let mapYResult = mapYResult
Expand All @@ -170,8 +171,7 @@ module T = Dist({
let toContinuous = _ => None
let toDiscrete = t => Some(t)

let normalize = (t: t): t =>
t->scaleBy(~scale=1. /. integralEndY(t), _)->updateIntegralSumCache(Some(1.0), _)
let normalize = (t: t): t => t->scaleBy(1. /. integralEndY(t))->updateIntegralSumCache(Some(1.0))

let downsample = (i, t: t): t => {
// It's not clear how to downsample a set of discrete points in a meaningful way.
Expand All @@ -197,23 +197,21 @@ module T = Dist({
t
->getShape
->XYShape.T.zip
->XYShape.Zipped.filterByX(
x => x >= E.O.default(leftCutoff, neg_infinity) && x <= E.O.default(rightCutoff, infinity),
_,
->XYShape.Zipped.filterByX(x =>
x >= E.O.default(leftCutoff, neg_infinity) && x <= E.O.default(rightCutoff, infinity)
)
->XYShape.T.fromZippedArray
->make

let xToY = (f, t) =>
t
->getShape
->XYShape.XtoY.stepwiseIfAtX(f, _)
->XYShape.XtoY.stepwiseIfAtX(f)
->E.O.default(0.0)
->PointSetTypes.MixedPoint.makeDiscrete

let integralXtoY = (f, t) => t->integral->Continuous.getShape->XYShape.XtoY.linear(f, _)

let integralYtoX = (f, t) => t->integral->Continuous.getShape->XYShape.YtoX.linear(f, _)
let integralXtoY = (f, t) => t->integral->Continuous.getShape->XYShape.XtoY.linear(f)
let integralYtoX = (f, t) => t->integral->Continuous.getShape->XYShape.YtoX.linear(f)

let mean = (t: t): float => {
let s = getShape(t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ module type dist = {
let mapY: (
~integralSumCacheFn: float => option<float>=?,
~integralCacheFn: PointSetTypes.continuousShape => option<PointSetTypes.continuousShape>=?,
~fn: float => float,
t,
float => float,
) => t
let mapYResult: (
~integralSumCacheFn: float => option<float>=?,
~integralCacheFn: PointSetTypes.continuousShape => option<PointSetTypes.continuousShape>=?,
~fn: float => result<float, 'e>,
t,
float => result<float, 'e>,
) => result<t, 'e>
let xToY: (float, t) => PointSetTypes.mixedPoint
let toPointSetDist: t => PointSetTypes.pointSetDist
Expand All @@ -24,7 +24,7 @@ module type dist = {
let downsample: (int, t) => t
let truncate: (option<float>, option<float>, t) => t

let updateIntegralCache: (option<PointSetTypes.continuousShape>, t) => t
let updateIntegralCache: (t, option<PointSetTypes.continuousShape>) => t

let integral: t => integral
let integralEndY: t => float
Expand Down
Loading

0 comments on commit 7e360b2

Please sign in to comment.