Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Point stat choice
Browse files Browse the repository at this point in the history
- Added method to calculate the median of an array in JavaScript.
- Updated the Marginal 1D tool to use the new method.
- Updated the Marginal 1D interfaces to include a new widget component.
- Updated the Python side of the tool to render the new widget.
- Updated Python `TypedDict` objects to include `docstrings` for
  inclusion in the documentation in the future.
- Updated the coin flipping tutorial to use the new point statistic
  button.

Resolves #1817
  • Loading branch information
ndmlny-qs committed Nov 7, 2022
1 parent ba5b8ee commit 4dc4905
Show file tree
Hide file tree
Showing 7 changed files with 505 additions and 205 deletions.
107 changes: 77 additions & 30 deletions src/beanmachine/ppl/diagnostics/tools/js/src/marginal1d/callbacks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
*/

import {Axis} from '@bokehjs/models/axes/axis';
import {cumulativeSum} from '../stats/array';
import {arrayMean, arrayMedian, cumulativeSum} from '../stats/array';
import {scaleToOne} from '../stats/dataTransformation';
import {
interval as hdiInterval,
data as hdiData,
} from '../stats/highestDensityInterval';
import {oneD} from '../stats/marginal';
import {mean as computeMean} from '../stats/pointStatistic';
import {interpolatePoints} from '../stats/utils';
import * as interfaces from './interfaces';

Expand Down Expand Up @@ -46,6 +45,8 @@ export const updateAxisLabel = (axis: Axis, label: string | null): void => {
* @param {number[]} marginalX - The support of the Kernel Density Estimate of the
* random variable.
* @param {number[]} marginalY - The Kernel Density Estimate of the random variable.
* @param {number} activeStatistic - The statistic to show in the tool. 0 is the mean
* and 1 is the median.
* @param {number | null} [hdiProb=null] - The highest density interval probability
* value. If the default value is not overwritten, then the default HDI probability
* is 0.89. See Statistical Rethinking by McElreath for a description as to why this
Expand All @@ -62,6 +63,7 @@ export const computeStats = (
rawData: number[],
marginalX: number[],
marginalY: number[],
activeStatistic: number,
hdiProb: number | null = null,
text_align: string[] = ['right', 'center', 'left'],
x_offset: number[] = [-5, 0, 5],
Expand All @@ -72,24 +74,44 @@ export const computeStats = (

// Compute the point statistics for the KDE, and create labels to display them in the
// figures.
const mean = computeMean(rawData);
const mean = arrayMean(rawData);
const median = arrayMedian(rawData);
const hdiBounds = hdiInterval(rawData, hdiProbability);
const x = [hdiBounds.lowerBound, mean, hdiBounds.upperBound];
const y = interpolatePoints({x: marginalX, y: marginalY, points: x});
const text = [
let x = [hdiBounds.lowerBound, mean, median, hdiBounds.upperBound];
let y = interpolatePoints({x: marginalX, y: marginalY, points: x});
let text = [
`Lower HDI: ${hdiBounds.lowerBound.toFixed(3)}`,
`Mean: ${mean.toFixed(3)}`,
`Median: ${median.toFixed(3)}`,
`Upper HDI: ${hdiBounds.upperBound.toFixed(3)}`,
];

return {
// We will filter the output based on the active statistic from the tool.
let mask: number[] = [];
if (activeStatistic === 0) {
mask = [0, 1, 3];
} else if (activeStatistic === 1) {
mask = [0, 2, 3];
}
x = mask.map((i) => {
return x[i];
});
y = mask.map((i) => {
return y[i];
});
text = mask.map((i) => {
return text[i];
});

const output = {
x: x,
y: y,
text: text,
text_align: text_align,
x_offset: x_offset,
y_offset: y_offset,
};
return output;
};

/**
Expand All @@ -100,6 +122,8 @@ export const computeStats = (
* calculating the Kernel Density Estimate (KDE).
* @param {number} hdiProbability - The highest density interval probability to use when
* calculating the HDI.
* @param {number} activeStatistic - The statistic to show in the tool. 0 is the mean
* and 1 is the median.
* @returns {interfaces.Data} The marginal distribution and cumulative
* distribution calculated from the given random variable data. Point statistics are
* also calculated.
Expand All @@ -108,6 +132,7 @@ export const computeData = (
data: number[],
bwFactor: number,
hdiProbability: number,
activeStatistic: number,
): interfaces.Data => {
const output = {} as interfaces.Data;
for (let i = 0; i < figureNames.length; i += 1) {
Expand All @@ -125,7 +150,13 @@ export const computeData = (
}

// Compute the point statistics for the given data.
const stats = computeStats(data, distribution.x, distribution.y, hdiProbability);
const stats = computeStats(
data,
distribution.x,
distribution.y,
activeStatistic,
hdiProbability,
);

output[figureName] = {
distribution: distribution,
Expand All @@ -150,6 +181,7 @@ export const computeData = (
* application.
* @param {interfaces.Figures} figures - Bokeh figures shown in the application.
* @param {interfaces.Tooltips} tooltips - Bokeh tooltips shown on the glyphs.
* @param {interfaces.Widgets} widgets - Bokeh widget object for the tool.
* @returns {number} We display the value of the bandwidth used for computing the Kernel
* Density Estimate in a div, and must return that value here in order to update the
* value displayed to the user.
Expand All @@ -162,29 +194,44 @@ export const update = (
sources: interfaces.Sources,
figures: interfaces.Figures,
tooltips: interfaces.Tooltips,
widgets: interfaces.Widgets,
): number => {
const computedData = computeData(data, bwFactor, hdiProbability);
for (let i = 0; i < figureNames.length; i += 1) {
// Update all sources with new data calculated above.
const figureName = figureNames[i];
sources[figureName].distribution.data = {
x: computedData[figureName].distribution.x,
y: computedData[figureName].distribution.y,
};
sources[figureName].hdi.data = {
base: computedData[figureName].hdi.base,
lower: computedData[figureName].hdi.lower,
upper: computedData[figureName].hdi.upper,
};
sources[figureName].stats.data = computedData[figureName].stats;
sources[figureName].labels.data = computedData[figureName].labels;
const activeStatistic = widgets.stats_button.active as number;
const computedData = computeData(data, bwFactor, hdiProbability, activeStatistic);

// Update the axes labels.
updateAxisLabel(figures[figureName].below[0], rvName);
// Marginal figure.
// eslint-disable-next-line prefer-destructuring
const bandwidth = computedData.marginal.distribution.bandwidth;
sources.marginal.distribution.data = {
x: computedData.marginal.distribution.x,
y: computedData.marginal.distribution.y,
};
sources.marginal.hdi.data = {
base: computedData.marginal.hdi.base,
lower: computedData.marginal.hdi.lower,
upper: computedData.marginal.hdi.upper,
};
sources.marginal.stats.data = computedData.marginal.stats;
sources.marginal.labels.data = computedData.marginal.labels;
tooltips.marginal.distribution.tooltips = [[rvName, '@x']];
tooltips.marginal.stats.tooltips = [['', '@text']];
updateAxisLabel(figures.marginal.below[0] as Axis, rvName);

// Update the tooltips.
tooltips[figureName].stats.tooltips = [['', '@text']];
tooltips[figureName].distribution.tooltips = [[rvName, '@x']];
}
return computedData.marginal.distribution.bandwidth;
// Cumulative figure.
sources.cumulative.distribution.data = {
x: computedData.cumulative.distribution.x,
y: computedData.cumulative.distribution.y,
};
sources.cumulative.hdi.data = {
base: computedData.cumulative.hdi.base,
lower: computedData.cumulative.hdi.lower,
upper: computedData.cumulative.hdi.upper,
};
sources.cumulative.stats.data = computedData.cumulative.stats;
sources.cumulative.labels.data = computedData.cumulative.labels;
tooltips.cumulative.distribution.tooltips = [[rvName, '@x']];
tooltips.cumulative.stats.tooltips = [['', '@text']];
updateAxisLabel(figures.cumulative.below[0] as Axis, rvName);

return bandwidth;
};
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import {Plot} from '@bokehjs/models/plots/plot';
import {ColumnDataSource} from '@bokehjs/models/sources/column_data_source';
import {HoverTool} from '@bokehjs/models/tools/inspectors/hover_tool';
import {Div} from '@bokehjs/models/widgets/div';
import {RadioButtonGroup} from '@bokehjs/models/widgets/radio_button_group';
import {Select} from '@bokehjs/models/widgets/selectbox';
import {Slider} from '@bokehjs/models/widgets/slider';

// NOTE: In the corresponding Python typing files for the diagnostic tool, we define
// similar types using a TypedDict object. TypeScript allows us to maintain
Expand Down Expand Up @@ -95,3 +99,11 @@ export interface Tooltips {
marginal: Tooltip;
cumulative: Tooltip;
}

export interface Widgets {
rv_select: Select;
bw_factor_slider: Slider;
bw_div: Div;
hdi_slider: Slider;
stats_button: RadioButtonGroup;
}
43 changes: 43 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/js/src/stats/array.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,29 @@
* LICENSE file in the root directory of this source tree.
*/

/**
* Syntactic sugar for summing an array of numbers.
*
* @param {number[]} data - The array of data.
* @returns {number} The sum of the array of data.
*/
export const arraySum = (data: number[]): number => {
return data.reduce((previousValue, currentValue) => {
return previousValue + currentValue;
});
};

/**
* Calculate the mean of the given array of data.
*
* @param {number[]} data - The array of data.
* @returns {number} The mean of the given data.
*/
export const arrayMean = (data: number[]): number => {
const dataSum = arraySum(data);
return dataSum / data.length;
};

/**
* Cumulative sum of the given data.
*
Expand Down Expand Up @@ -128,3 +151,23 @@ export const valueCounts = (data: number[]): {[key: string]: number} => {
}
return counts;
};

/**
* Calculate the median value for the given array.
*
* @param {number[]} data - Numerical array of data.
* @returns {number} The median value of the given data.
*/
export const arrayMedian = (data: number[]): number => {
const sortedArray = numericalSort(data);
const arrayLength = sortedArray.length;
const isEven = sortedArray.length % 2 === 0;
let median;
if (isEven) {
const index = arrayLength / 2;
median = (sortedArray[index - 1] + sortedArray[index]) / 2;
} else {
median = sortedArray[Math.floor(arrayLength / 2)];
}
return median;
};
17 changes: 14 additions & 3 deletions src/beanmachine/ppl/diagnostics/tools/marginal1d/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ class Marginal1d(DiagnosticToolBaseClass):
Attributes:
data (Dict[str, List[List[float]]]): JSON serializable representation of the
given `mcs` object.
given ``mcs`` object.
rv_names (List[str]): The list of random variables string names for the given
model.
num_chains (int): The number of chains of the model.
num_draws (int): The number of draws of the model for each chain.
palette (List[str]): A list of color values used for the glyphs in the figures.
The colors are specifically chosen from the Colorblind palette defined in
Bokeh.
The colors are specifically chosen from the ``Colorblind`` palette defined
in Bokeh.
tool_js (str):The JavaScript callbacks needed to render the Bokeh tool
independently from a Python server.
"""
Expand All @@ -40,6 +40,12 @@ def __init__(self: Marginal1d, mcs: MonteCarloSamples) -> None:
super(Marginal1d, self).__init__(mcs)

def create_document(self: Marginal1d) -> Model:
"""
Create the Bokeh document for the diagnostic tool.
Returns:
Model: A Bokeh Model object.
"""
# Initialize widget values using Python.
rv_name = self.rv_names[0]
bw_factor = 1.0
Expand Down Expand Up @@ -110,6 +116,7 @@ def create_document(self: Marginal1d) -> Model:
sources,
figures,
tooltips,
widgets,
);
}} catch (error) {{
{self.tool_js}
Expand All @@ -121,6 +128,7 @@ def create_document(self: Marginal1d) -> Model:
sources,
figures,
tooltips,
widgets,
);
}}
"""
Expand All @@ -135,6 +143,7 @@ def create_document(self: Marginal1d) -> Model:
"figures": figures,
"tooltips": tooltips,
"toolView": tool_view,
"widgets": widgets,
}

# Each widget requires slightly different JS, except for the sliders.
Expand All @@ -155,10 +164,12 @@ def create_document(self: Marginal1d) -> Model:
"""
rv_select_callback = CustomJS(args=callback_arguments, code=rv_select_js)
slider_callback = CustomJS(args=callback_arguments, code=slider_js)
button_callback = CustomJS(args=callback_arguments, code=slider_js)

# Tell Python to use the JavaScript.
widgets["rv_select"].js_on_change("value", rv_select_callback)
widgets["bw_factor_slider"].js_on_change("value", slider_callback)
widgets["hdi_slider"].js_on_change("value", slider_callback)
widgets["stats_button"].js_on_change("active", button_callback)

return tool_view
Loading

0 comments on commit 4dc4905

Please sign in to comment.