Skip to content

Commit

Permalink
support for mask_function for aritary 2D prior exclusion region makin…
Browse files Browse the repository at this point in the history
…g 2D plots
  • Loading branch information
cmbant committed Jan 22, 2025
1 parent 032f06b commit 2a11d67
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 16 deletions.
50 changes: 48 additions & 2 deletions docs/plot_gallery.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion getdist/densities.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,17 +253,19 @@ class Density2D(GridDensity):
You can call it like a :class:`~scipy:scipy.interpolate.RectBivariateSpline` object to get interpolated values.
"""

def __init__(self, x, y, P=None, view_ranges=None):
def __init__(self, x, y, P=None, view_ranges=None, mask=None):
"""
:param x: array of x values
:param y: array of y values
:param P: 2D array of density values at x, y
:param view_ranges: optional ranges for viewing density
:param mask: optional 2D boolean array for non-trivial mask
"""
self.x = x
self.y = y
self.axes = [y, x]
self.view_ranges = view_ranges
self.mask = mask
self.spacing = (self.x[1] - self.x[0]) * (self.y[1] - self.y[0])
self.setP(P)

Expand Down
26 changes: 17 additions & 9 deletions getdist/mcsamples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,7 +1655,8 @@ def get2DDensity(self, x, y, normalized=False, **kwargs):
return density

# noinspection PyUnboundLocalVariable
def get2DDensityGridData(self, j, j2, num_plot_contours=None, get_density=False, meanlikes=False, **kwargs):
def get2DDensityGridData(self, j, j2, num_plot_contours=None, get_density=False, meanlikes=False,
mask_function: callable = None, **kwargs):
"""
Low-level function to get 2D plot marginalized density and optional additional plot data.
Expand All @@ -1665,6 +1666,9 @@ def get2DDensityGridData(self, j, j2, num_plot_contours=None, get_density=False,
:param get_density: only get the 2D marginalized density, don't calculate confidence level members
:param meanlikes: calculate mean likelihoods as well as marginalized density
(returned as array in density.likes)
:param mask_function: optional function, mask_function(minx, miny, stepx, stepy, mask),
which which sets mask to zero for values of parameters that are excluded by prior. Note this is not
needed for standard min, max bounds aligned with axes, as they are handled by default.
:param kwargs: optional settings to override instance settings of the same name (see `analysis_settings`):
- **fine_bins_2D**
Expand All @@ -1689,7 +1693,7 @@ def get2DDensityGridData(self, j, j2, num_plot_contours=None, get_density=False,
mult_bias_correction_order = kwargs.get('mult_bias_correction_order', self.mult_bias_correction_order)
smooth_scale_2D = float(kwargs.get('smooth_scale_2D', self.smooth_scale_2D))

has_prior = parx.has_limits or pary.has_limits
has_prior = parx.has_limits or pary.has_limits or mask_function

corr = self.getCorrelationMatrix()[j2][j]
actual_corr = corr
Expand Down Expand Up @@ -1761,7 +1765,7 @@ def get2DDensityGridData(self, j, j2, num_plot_contours=None, get_density=False,
logging.debug('time 2D binning and bandwidth: %s ; bins: %s', time.time() - start, fine_bins_2D)
start = time.time()
cache = {}
convolvesize = xsize + 2 * winw + Win.shape[0]
convolvesize = xsize + 2 * winw + Win.shape[0] # larger than needed for selecting fft pixel count
bins2D = convolve2D(histbins, Win, 'same', largest_size=convolvesize, cache=cache)

if meanlikes:
Expand All @@ -1782,12 +1786,18 @@ def get2DDensityGridData(self, j, j2, num_plot_contours=None, get_density=False,
if has_prior and boundary_correction_order >= 0:
# Correct for edge effects
prior_mask = np.ones((ysize + 2 * winw, xsize + 2 * winw))
if mask_function:
mask_function(xbinmin - winw * finewidthx, ybinmin - winw * finewidthy, finewidthx, finewidthy,
prior_mask)
self._setEdgeMask2D(parx, pary, prior_mask, winw)
a00 = convolve2D(prior_mask, Win, 'valid', largest_size=convolvesize, cache=cache)
ix = a00 * bins2D > np.max(bins2D) * 1e-8
a00 = a00[ix]
normed = bins2D[ix] / a00
if boundary_correction_order == 1:
if boundary_correction_order == 0 or mask_function:
# simple boundary correction by normalization
bins2D[ix] = normed
elif boundary_correction_order == 1:
# linear boundary correction
indexes = np.arange(-winw, winw + 1)
y = np.empty(Win.shape)
Expand All @@ -1811,13 +1821,10 @@ def get2DDensityGridData(self, j, j2, num_plot_contours=None, get_density=False,
Ay = a01 * a20 - a10 * a11
corrected = (bins2D[ix] * A + xP * Ax + yP * Ay) / denom
bins2D[ix] = normed * np.exp(np.minimum(corrected / normed, 4) - 1)
elif boundary_correction_order == 0:
# simple boundary correction by normalization
bins2D[ix] = normed
else:
raise SettingError('unknown boundary_correction_order (expected 0 or 1)')

if mult_bias_correction_order:
if mult_bias_correction_order and not mask_function:
prior_mask = np.ones((ysize + 2 * winw, xsize + 2 * winw))
self._setEdgeMask2D(parx, pary, prior_mask, winw, alledge=True)
a00 = convolve2D(prior_mask, Win, 'valid', largest_size=convolvesize, cache=cache, cache_args=[2])
Expand All @@ -1830,7 +1837,8 @@ def get2DDensityGridData(self, j, j2, num_plot_contours=None, get_density=False,

x = np.linspace(xbinmin, xbinmax, xsize)
y = np.linspace(ybinmin, ybinmax, ysize)
density = Density2D(x, y, bins2D,
density = Density2D(x, y, bins2D, mask=(None if not mask_function
else np.asarray(prior_mask[winw:-winw, winw:-winw] < 1e-8)),
view_ranges=[(parx.range_min, parx.range_max), (pary.range_min, pary.range_max)])
density.normalize('max', in_place=True)
if get_density:
Expand Down
18 changes: 14 additions & 4 deletions getdist/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,8 @@ def _is_color_like(self, color):
return False

def add_2d_contours(self, root, param1=None, param2=None, plotno=0, of=None, cols=None, contour_levels=None,
add_legend_proxy=True, param_pair=None, density=None, alpha=None, ax=None, **kwargs):
add_legend_proxy=True, param_pair=None, density=None, alpha=None, ax=None,
mask_function: callable = None, **kwargs):
"""
Low-level function to add 2D contours to plot for samples with given root name and parameters
Expand All @@ -1043,6 +1044,8 @@ def add_2d_contours(self, root, param1=None, param2=None, plotno=0, of=None, col
:param alpha: alpha for the contours added
:param ax: optional :class:`~matplotlib:matplotlib.axes.Axes` instance (or y,x subplot coordinate)
to add to (defaults to current plot or the first/main plot if none)
:param mask_function: optional function, mask_function(minx, miny, stepx, stepy, mask),
which which sets mask to zero for values of parameter name parx, pary that are excluded by prior.
:param kwargs: optional keyword arguments:
- **filled**: True to make filled contours
Expand All @@ -1055,7 +1058,13 @@ def add_2d_contours(self, root, param1=None, param2=None, plotno=0, of=None, col
if density is None:
param1, param2 = self.get_param_array(root, param_pair or [param1, param2])
ax.getdist_params = (param1, param2)
if isinstance(root, MixtureND):
if mask_function is not None:
samples = self.samples_for_root(root)
density = samples.get2DDensityGridData(param1.name, param2.name,
mask_function = mask_function,
num_plot_contours=self.settings.num_plot_contours,
meanlikes=self.settings.shade_meanlikes)
elif isinstance(root, MixtureND):
density = root.marginalizedMixture(params=[param1, param2]).density2D()
else:
density = self.sample_analyser.get_density_grid(root, param1, param2,
Expand Down Expand Up @@ -1098,13 +1107,14 @@ def clean_args(_args):
else:
cols = color
levels = sorted(np.append([density.P.max() + 1], contour_levels))
cs = ax.contourf(density.x, density.y, density.P, levels, colors=cols, alpha=alpha, **clean_args(kwargs))
z = density.P if density.mask is None else np.ma.masked_where(density.mask, density.P)
cs = ax.contourf(density.x, density.y, z, levels, colors=cols, alpha=alpha, **clean_args(kwargs))

fc = tuple(cs.to_rgba(cs.cvalues[-1], cs.alpha))
if proxy_ix >= 0:
self.contours_added[proxy_ix] = (
matplotlib.patches.Rectangle((0, 0), 1, 1, fc=fc))
ax.contour(density.x, density.y, density.P, levels[:1], colors=(fc,),
ax.contour(density.x, density.y, z, levels[:1], colors=(fc,),
linewidths=self._scaled_linewidth(self.settings.linewidth_contour
if kwargs.get('lw') is None else kwargs['lw']),
linestyles=kwargs.get('ls'),
Expand Down

0 comments on commit 2a11d67

Please sign in to comment.