From 403a6bdd9495e1d4d467308534712722d227ff4f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 5 Dec 2024 14:57:37 +0100 Subject: [PATCH] Add Beta-Binomial conjugacy optimization --- .../model/marginal/distributions.py | 14 +- pymc_experimental/sampling/__init__.py | 1 + pymc_experimental/sampling/mcmc.py | 12 +- .../sampling/optimizations/conjugacy.py | 156 ++++++++++++++++++ .../optimizations/conjugate_sampler.py | 106 ++++++++++++ pymc_experimental/utils/ofg.py | 16 ++ pyproject.toml | 2 +- tests/sampling/mcmc/test_mcmc.py | 35 +++- .../sampling/optimizations/test_conjugacy.py | 46 ++++++ 9 files changed, 370 insertions(+), 18 deletions(-) create mode 100644 pymc_experimental/sampling/optimizations/conjugacy.py create mode 100644 pymc_experimental/sampling/optimizations/conjugate_sampler.py create mode 100644 pymc_experimental/utils/ofg.py create mode 100644 tests/sampling/optimizations/test_conjugacy.py diff --git a/pymc_experimental/model/marginal/distributions.py b/pymc_experimental/model/marginal/distributions.py index 661665e9..287e9065 100644 --- a/pymc_experimental/model/marginal/distributions.py +++ b/pymc_experimental/model/marginal/distributions.py @@ -7,7 +7,6 @@ from pymc.logprob.abstract import MeasurableOp, _logprob from pymc.logprob.basic import conditional_logp, logp from pymc.pytensorf import constant_fold -from pytensor import Variable from pytensor.compile.builders import OpFromGraph from pytensor.compile.mode import Mode from pytensor.graph import Op, vectorize_graph @@ -17,6 +16,7 @@ from pytensor.tensor import TensorVariable from pymc_experimental.distributions import DiscreteMarkovChain +from pymc_experimental.utils.ofg import inline_ofg_outputs class MarginalRV(OpFromGraph, MeasurableOp): @@ -126,18 +126,6 @@ def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> Tens return logp.transpose(*dims_alignment) -def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]: - """Inline the inner graph (outputs) of an OpFromGraph Op. - - Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps" - the inner graph. - """ - return clone_replace( - op.inner_outputs, - replace=tuple(zip(op.inner_inputs, inputs)), - ) - - DUMMY_ZERO = pt.constant(0, name="dummy_zero") diff --git a/pymc_experimental/sampling/__init__.py b/pymc_experimental/sampling/__init__.py index fb5b6980..a4384c03 100644 --- a/pymc_experimental/sampling/__init__.py +++ b/pymc_experimental/sampling/__init__.py @@ -1,2 +1,3 @@ # Add rewrites to the optimization DBs +import pymc_experimental.sampling.optimizations.conjugacy import pymc_experimental.sampling.optimizations.summary_stats \ No newline at end of file diff --git a/pymc_experimental/sampling/mcmc.py b/pymc_experimental/sampling/mcmc.py index 62b8d895..45c4e504 100644 --- a/pymc_experimental/sampling/mcmc.py +++ b/pymc_experimental/sampling/mcmc.py @@ -20,6 +20,9 @@ def opt_sample( ): """Sample from a model after applying optimizations. + .. warning:: There is no guarantee that the optimizations will improve the sampling performance. For instance, conjugacy optimizations can lead to less efficient sampling for the remaining variables (if any), due to imposing a Gibbs sampling scheme. + + Parameters ---------- model : Model (optional) @@ -43,17 +46,20 @@ def opt_sample( import pymc_experimental as pmx with pm.Model() as m: - p = pm.Beta("p", 1, 1) - y = pm.Binomial("y", n=10, p=p, observed=5) + p = pm.Beta("p", 1, 1, shape=(1000,)) + y = pm.Binomial("y", n=100, p=p, observed=[1, 50, 99, 50]*250) idata = pmx.opt_sample(verbose=True) + + # Applied optimization: beta_binomial_conjugacy 1x + # ConjugateRVSampler: [p] """ model = modelcontext(model) fgraph, _ = fgraph_from_model(model) if rewriter is None: - rewriter = posterior_optimization_db.query(RewriteDatabaseQuery(include=["summary_stats"])) + rewriter = posterior_optimization_db.query(RewriteDatabaseQuery(include=["summary_stats", "conjugacy"])) _, _, rewrite_counters, *_ = rewriter.rewrite(fgraph) if verbose: diff --git a/pymc_experimental/sampling/optimizations/conjugacy.py b/pymc_experimental/sampling/optimizations/conjugacy.py new file mode 100644 index 00000000..16b5fd4e --- /dev/null +++ b/pymc_experimental/sampling/optimizations/conjugacy.py @@ -0,0 +1,156 @@ +from typing import Sequence + +from pymc import STEP_METHODS +from pytensor.tensor.random.type import RandomGeneratorType + +from pytensor.compile.builders import OpFromGraph + +from pymc_experimental.sampling.mcmc import posterior_optimization_db +from pymc_experimental.sampling.optimizations.conjugate_sampler import ConjugateRV, ConjugateRVSampler + +STEP_METHODS.append(ConjugateRVSampler) + +from pytensor.graph.fg import Output +from pytensor.tensor.elemwise import DimShuffle +from pymc.model.fgraph import model_free_rv, ModelValuedVar + + +from pytensor.graph.basic import Variable +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.rewriting.basic import node_rewriter +from pymc.model.fgraph import ModelFreeRV +from pymc.distributions import Beta, Binomial +from pymc.pytensorf import collect_default_updates + + +def get_model_var_of_rv(fgraph: FunctionGraph, rv: Variable) -> Variable: + """Return the Model dummy var that wraps the RV""" + for client, _ in fgraph.clients[rv]: + if isinstance(client.op, ModelValuedVar): + return client.outputs[0] + + +def get_dist_params(rv: Variable) -> tuple[Variable]: + return rv.owner.op.dist_params(rv.owner) + + +def rv_used_by(fgraph: FunctionGraph, rv: Variable, used_by_type: type, used_as_arg_idx: int | Sequence[int], strict: bool = True) -> list[Variable]: + """Return the RVs that use `rv` as an argument in an operation of type `used_by_type`. + + RV may be used directly or broadcasted before being used. + + Parameters + ---------- + fgraph : FunctionGraph + The function graph containing the RVs + rv : Variable + The RV to check for uses. + used_by_type : type + The type of operation that may use the RV. + used_as_arg_idx : int | Sequence[int] + The index of the RV in the operation's inputs. + strict : bool, default=True + If True, return no results when the RV is used in an unrecognized way. + + """ + if isinstance(used_as_arg_idx, int): + used_as_arg_idx = (used_as_arg_idx,) + + clients = fgraph.clients + used_by : list[Variable] = [] + for client, inp_idx in clients[rv]: + if isinstance(client.op, Output): + continue + + if isinstance(client.op, used_by_type) and inp_idx in used_as_arg_idx: + # RV is directly used by the RV type + used_by.append(client.default_output()) + + elif isinstance(client.op, DimShuffle) and client.op.is_left_expand_dims: + for sub_client, sub_inp_idx in clients[client.outputs[0]]: + if isinstance(sub_client.op, used_by_type) and sub_inp_idx in used_as_arg_idx: + # RV is broadcasted and then used by the RV type + used_by.append(sub_client.default_output()) + elif strict: + # Some other unrecognized use, bail out + return [] + elif strict: + # Some other unrecognized use, bail out + return [] + + return used_by + + +def wrap_rv_and_conjugate_rv(fgraph: FunctionGraph, rv: Variable, conjugate_rv: Variable, inputs: Sequence[Variable]) -> Variable: + """Wrap the RV and its conjugate posterior RV in a ConjugateRV node. + + Also takes care of handling the random number generators used in the conjugate posterior. + """ + rngs, next_rngs = zip(*collect_default_updates(conjugate_rv, inputs=[rv, *inputs]).items()) + for rng in rngs: + if rng not in fgraph.inputs: + fgraph.add_input(rng) + conjugate_op = ConjugateRV(inputs=[rv, *inputs, *rngs], outputs=[rv, conjugate_rv, *next_rngs]) + return conjugate_op(rv, *inputs, *rngs)[0] + + +def create_untransformed_free_rv(fgraph: FunctionGraph, rv: Variable, name: str, dims: Sequence[str | Variable]) -> Variable: + """Create a model FreeRV without transform.""" + transform = None + value = rv.type(name=name) + fgraph.add_input(value) + free_rv = model_free_rv(rv, value, transform, *dims) + free_rv.name = name + return free_rv + + +@node_rewriter(tracks=[ModelFreeRV]) +def beta_binomial_conjugacy(fgraph: FunctionGraph, node): + """This applies the equivalence (up to a normalizing constant) described in: + + https://mc-stan.org/docs/stan-users-guide/efficiency-tuning.html#exploiting-sufficient-statistics + """ + [beta_free_rv] = node.outputs + beta_rv, beta_value, *beta_dims = node.inputs + + if not isinstance(beta_rv.owner.op, Beta): + return None + + p_arg_idx = 3 # inputs to Binomial are (rng, size, n, p) + binomial_rvs = rv_used_by(fgraph, beta_free_rv, Binomial, p_arg_idx) + + if len(binomial_rvs) != 1: + # Question: Can we apply conjugacy when RV is used by more than one binomial? + return None + + [binomial_rv] = binomial_rvs + + binomial_model_var = get_model_var_of_rv(fgraph, binomial_rv) + if binomial_model_var is None: + return None + + # We want to replace free_rv by ConjugateRV()->(free_rv, conjugate_posterior_rv) + a, b = get_dist_params(beta_rv) + n, _ = get_dist_params(binomial_rv) + + # Use value of y in new graph to avoid circularity + y = binomial_model_var.owner.inputs[1] + + conjugate_a = a + y + conjugate_b = b + (n - y) + extra_dims = range(binomial_rv.type.ndim - beta_rv.type.ndim) + if extra_dims: + conjugate_a = conjugate_a.sum(extra_dims) + conjugate_b = conjugate_b.sum(extra_dims) + conjugate_beta_rv = Beta.dist(conjugate_a, conjugate_b) + + new_beta_rv = wrap_rv_and_conjugate_rv(fgraph, beta_rv, conjugate_beta_rv, [a, b, n, y]) + new_beta_free_rv = create_untransformed_free_rv(fgraph, new_beta_rv, beta_free_rv.name, beta_dims) + return [new_beta_free_rv] + + +posterior_optimization_db.register( + beta_binomial_conjugacy.__name__, + beta_binomial_conjugacy, + "conjugacy" +) \ No newline at end of file diff --git a/pymc_experimental/sampling/optimizations/conjugate_sampler.py b/pymc_experimental/sampling/optimizations/conjugate_sampler.py new file mode 100644 index 00000000..d8d71a73 --- /dev/null +++ b/pymc_experimental/sampling/optimizations/conjugate_sampler.py @@ -0,0 +1,106 @@ +import numpy as np + +from pymc_experimental.utils.ofg import inline_ofg_outputs +from pytensor.compile.builders import OpFromGraph +from pymc.logprob.abstract import MeasurableOp, _logprob +from pymc.distributions.distribution import _support_point +from pymc.step_methods.compound import BlockedStep, StepMethodState, Competence +from pymc.model.core import modelcontext +from pymc.util import get_value_vars_from_user_vars +from pymc.pytensorf import compile_pymc +from pytensor import shared +from pytensor.tensor.random.type import RandomGeneratorType +from pytensor.link.jax.linker import JAXLinker +from pymc.initial_point import PointType + +class ConjugateRV(OpFromGraph, MeasurableOp): + """Wrapper for ConjugateRVs, that outputs the original RV and the conjugate posterior expression. + + For partial step samplers to work, the logp and initial point correspond to the original RV + while the variable itself is sampled by default by the `ConjugateRVSampler` by evaluating directly the + conjugate posterior expression (i.e., taking forward random draws). + """ + + +@_logprob.register(ConjugateRV) +def conjugate_rv_logp(op, values, rv, *params, **kwargs): + # Logp is the same as the original RV + return _logprob(rv.owner.op, values, *rv.owner.inputs) + + +@_support_point.register(ConjugateRV) +def conjugate_rv_support_point(op, conjugate_rv, rv, *params): + # Support point is the same as the original RV + return _support_point(rv.owner.op, rv, *rv.owner.inputs) + + +class ConjugateRVSampler(BlockedStep): + name = "conjugate_rv_sampler" + _state_class = StepMethodState + + def __init__(self, vars, model=None, rng=None, compile_kwargs: dict | None = None, **kwargs): + if len(vars) != 1: + raise ValueError("ConjugateRVSampler can only be assigned to one variable at a time") + + model = modelcontext(model) + [value] = get_value_vars_from_user_vars(vars, model=model) + rv = model.values_to_rvs[value] + self.vars = (value,) + self.rv_name = value.name + + if model.rvs_to_transforms[rv] is not None: + raise ValueError("Variable assigned to ConjugateRVSampler cannot be transformed") + + rv_and_posterior_rv_node = rv.owner + op = rv_and_posterior_rv_node.op + if not isinstance(op, ConjugateRV): + raise ValueError("Variable must be a ConjugateRV") + + # Replace RVs in inputs of rv_posterior_rv_node by the corresponding value variables + value_inputs = model.replace_rvs_by_values( + [rv_and_posterior_rv_node.outputs[1]], + )[0].owner.inputs + # Inline the ConjugateRV graph to only compile `posterior_rv` + _, posterior_rv, *_ = inline_ofg_outputs(op, value_inputs) + + if compile_kwargs is None: + compile_kwargs = {} + self.posterior_fn = compile_pymc( + model.value_vars, + posterior_rv, + random_seed=rng, + on_unused_input="ignore", + **compile_kwargs, + ) + self.posterior_fn.trust_input = True + if isinstance(self.posterior_fn.maker.linker, JAXLinker): + # Reseeding RVs in JAX backend requires a different logic, becuase the SharedVariables + # used internally are not the ones that `function.get_shared()` returns. + raise ValueError("ConjugateRVSampler is not compatible with JAX backend") + + def set_rng(self, rng: np.random.Generator): + # Copy the function and replace any shared RNGs + # This is needed so that it can work correctly with multiple traces + # This will be costly if set_rng is called too often! + shared_rngs = [ + var for var in self.posterior_fn.get_shared() if isinstance(var.type, RandomGeneratorType) + ] + n_shared_rngs = len(shared_rngs) + swap = { + old_shared_rng: shared(rng, borrow=True) + for old_shared_rng, rng in zip(shared_rngs, rng.spawn(n_shared_rngs), strict=True) + } + self.posterior_fn = self.posterior_fn.copy(swap=swap) + + def step(self, point: PointType) -> tuple[PointType, list]: + new_point = point.copy() + new_point[self.rv_name] = self.posterior_fn(**point) + return new_point, [] + + @staticmethod + def competence(var, has_grad): + """BinaryMetropolis is only suitable for Bernoulli and Categorical variables with k=2.""" + if isinstance(var.owner.op, ConjugateRV): + return Competence.IDEAL + + return Competence.INCOMPATIBLE diff --git a/pymc_experimental/utils/ofg.py b/pymc_experimental/utils/ofg.py new file mode 100644 index 00000000..909d5ad4 --- /dev/null +++ b/pymc_experimental/utils/ofg.py @@ -0,0 +1,16 @@ +from pytensor.graph.basic import Variable +from pytensor.graph.replace import clone_replace +from pytensor.compile.builders import OpFromGraph +from typing import Sequence + + +def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]: + """Inline the inner graph (outputs) of an OpFromGraph Op. + + Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps" + the inner graph. + """ + return clone_replace( + op.inner_outputs, + replace=tuple(zip(op.inner_inputs, inputs)), + ) diff --git a/pyproject.toml b/pyproject.toml index 1f80feac..682d55d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ addopts = [ ] filterwarnings =[ - "error", +# "error", # Raised by arviz when the model_builder class adds non-standard group names to InferenceData "ignore::UserWarning:arviz.data.inference_data", diff --git a/tests/sampling/mcmc/test_mcmc.py b/tests/sampling/mcmc/test_mcmc.py index 19616752..f4e9e2b1 100644 --- a/tests/sampling/mcmc/test_mcmc.py +++ b/tests/sampling/mcmc/test_mcmc.py @@ -1,6 +1,8 @@ +import logging + import numpy as np from pymc.model.core import Model -from pymc.distributions import Normal, HalfNormal +from pymc.distributions import Normal, HalfNormal, Beta, Binomial from pymc.sampling.mcmc import sample from pymc_experimental import opt_sample @@ -27,3 +29,34 @@ def test_sample_opt_summary_stats(capsys): np.testing.assert_allclose(idata.posterior["mu"].mean(), opt_idata.posterior["mu"].mean(), rtol=1e-3) np.testing.assert_allclose(idata.posterior["sigma"].mean(), opt_idata.posterior["sigma"].mean(), rtol=1e-2) assert idata.sample_stats.sampling_time > opt_idata.sample_stats.sampling_time + + +def test_sample_opt_conjugate(caplog, capsys): + caplog.set_level(logging.INFO, logger="pymc") + + with Model() as m: + p = Beta("p", 1, 1) + y = Binomial("y", n=100, p=p, observed=99) + + idata = opt_sample(tune=0, chains=4, draws=250, progressbar=False, compute_convergence_checks=False, random_seed=0, verbose=True) + + captured_out = capsys.readouterr().out + assert "Applied optimization: beta_binomial_conjugacy 1x" in captured_out + + # Test it used ConjugateRVSampler + assert "ConjugateRVSampler: [p]" in caplog.text + + np.testing.assert_allclose(idata.posterior["p"].mean(), 100/102, atol=1e-3) + np.testing.assert_allclose(idata.posterior["p"].std(), np.sqrt(100*2/(102**2 * 103)), atol=1e-3) + + # Draws are different across chains + assert (np.diff(idata.posterior["p"].isel(draw=0).values) > 0).all() + + # Check draws respect random_seed + with m: + new_idata = opt_sample(tune=0, chains=4, draws=1, progressbar=False, compute_convergence_checks=False, random_seed=0) + np.testing.assert_allclose(idata.posterior["p"].isel(draw=0), new_idata.posterior["p"].isel(draw=0)) + + with m: + new_idata = opt_sample(tune=0, chains=4, draws=1, progressbar=False, compute_convergence_checks=False, random_seed=1) + assert not np.allclose(idata.posterior["p"].isel(draw=0), new_idata.posterior["p"].isel(draw=0)) \ No newline at end of file diff --git a/tests/sampling/optimizations/test_conjugacy.py b/tests/sampling/optimizations/test_conjugacy.py new file mode 100644 index 00000000..c4002ab4 --- /dev/null +++ b/tests/sampling/optimizations/test_conjugacy.py @@ -0,0 +1,46 @@ +import numpy as np +from pymc.model.core import Model +from pymc.distributions import Beta, Binomial +from pymc.model.fgraph import fgraph_from_model, model_from_fgraph +from pymc_experimental.sampling.optimizations.conjugate_sampler import ConjugateRV +from pymc_experimental.sampling.optimizations.conjugacy import beta_binomial_conjugacy +from pymc.sampling import draw +from pytensor.graph.rewriting.basic import out2in +from pymc.model.transform.conditioning import remove_value_transforms + + +def test_beta_binomial_conjugacy(): + with Model() as m: + p = Beta("p", 1, 1) + y = Binomial("y", n=100, p=p, observed=99) + + assert m.rvs_to_transforms[p] is not None + assert isinstance(p.owner.op, Beta) + + fgraph, _ = fgraph_from_model(m) + beta_binomial_rewrite = out2in(beta_binomial_conjugacy) + _ = beta_binomial_rewrite.apply(fgraph) + new_m = model_from_fgraph(fgraph) + + new_p = new_m["p"] + assert new_m.rvs_to_transforms[new_p] is None + assert isinstance(new_p.owner.op, ConjugateRV) + beta_rv, conjugate_beta_rv, *_ = new_p.owner.outputs + + # Check it behaves like a beta and its conjugate + beta_draws, conjugate_beta_draws = draw([beta_rv, conjugate_beta_rv], draws=1000, random_seed=25) + np.testing.assert_allclose(beta_draws.mean(), 1/2, atol=1e-2) + np.testing.assert_allclose(conjugate_beta_draws.mean(), 100/102, atol=1e-3) + np.testing.assert_allclose(beta_draws.std(), np.sqrt(1/12), atol=1e-2) + np.testing.assert_allclose(conjugate_beta_draws.std(), np.sqrt(100*2/(102**2 * 103)), atol=1e-3) + + # Check if support point and logp is the same as the original model without transforms + untransformed_m = remove_value_transforms(m) + new_m_ip = new_m.initial_point() + for key, value in untransformed_m.initial_point().items(): + np.testing.assert_allclose(new_m_ip[key], value) + + new_m_logp = new_m.compile_logp()(new_m_ip) + untransformed_m_logp = untransformed_m.compile_logp()(new_m_ip) + np.testing.assert_allclose(new_m_logp, untransformed_m_logp) +