Skip to content

Commit

Permalink
Make logprob inference for binary ops independent of order of inputs (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyas3156 authored Apr 26, 2023
1 parent a59c9cd commit 55d915c
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 23 deletions.
44 changes: 34 additions & 10 deletions pymc/logprob/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import GE, GT, LE, LT
from pytensor.tensor import TensorVariable
from pytensor.tensor.math import ge, gt, le, lt

from pymc.logprob.abstract import (
Expand Down Expand Up @@ -50,26 +51,49 @@ def find_measurable_comparisons(
if isinstance(node.op, MeasurableComparison):
return None # pragma: no cover

(compared_var,) = node.outputs
base_var, const = node.inputs
measurable_inputs = [
(inp, idx)
for idx, inp in enumerate(node.inputs)
if inp.owner
and isinstance(inp.owner.op, MeasurableVariable)
and inp not in rv_map_feature.rv_values
]

if not (
base_var.owner
and isinstance(base_var.owner.op, MeasurableVariable)
and base_var not in rv_map_feature.rv_values
):
if len(measurable_inputs) != 1:
return None

# Make the measurable base_var always be the first input to the MeasurableComparison node
base_var: TensorVariable = measurable_inputs[0][0]

# Check that the other input is not potentially measurable, in which case this rewrite
# would be invalid
const = tuple(inp for inp in node.inputs if inp is not base_var)

# check for potential measurability of const
if not check_potential_measurability((const,), rv_map_feature):
if not check_potential_measurability(const, rv_map_feature):
return None

const = const[0]

# Make base_var unmeasurable
unmeasurable_base_var = ignore_logprob(base_var)

compared_op = MeasurableComparison(node.op.scalar_op)
node_scalar_op = node.op.scalar_op

# Change the Op if the base_var is the second input in node.inputs. e.g. pt.lt(const, dist) -> pt.gt(dist, const)
if measurable_inputs[0][1] == 1:
if isinstance(node_scalar_op, LT):
node_scalar_op = GT()
elif isinstance(node_scalar_op, GT):
node_scalar_op = LT()
elif isinstance(node_scalar_op, GE):
node_scalar_op = LE()
elif isinstance(node_scalar_op, LE):
node_scalar_op = GE()

compared_op = MeasurableComparison(node_scalar_op)
compared_rv = compared_op.make_node(unmeasurable_base_var, const).default_output()
compared_rv.name = compared_var.name
compared_rv.name = node.outputs[0].name
return [compared_rv]


Expand Down
39 changes: 26 additions & 13 deletions tests/logprob/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,17 @@


@pytest.mark.parametrize(
"comparison_op, exp_logp_true, exp_logp_false",
"comparison_op, exp_logp_true, exp_logp_false, inputs",
[
((pt.lt, pt.le), "logcdf", "logsf"),
((pt.gt, pt.ge), "logsf", "logcdf"),
((pt.lt, pt.le), "logcdf", "logsf", (pt.random.normal(0, 1), 0.5)),
((pt.gt, pt.ge), "logsf", "logcdf", (pt.random.normal(0, 1), 0.5)),
((pt.lt, pt.le), "logsf", "logcdf", (0.5, pt.random.normal(0, 1))),
((pt.gt, pt.ge), "logcdf", "logsf", (0.5, pt.random.normal(0, 1))),
],
)
def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
x_rv = pt.random.normal(0, 1)
def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false, inputs):
for op in comparison_op:
comp_x_rv = op(x_rv, 0.5)
comp_x_rv = op(*inputs)

comp_x_vv = comp_x_rv.clone()

Expand All @@ -49,33 +50,45 @@ def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):


@pytest.mark.parametrize(
"comparison_op, exp_logp_true, exp_logp_false",
"comparison_op, exp_logp_true, exp_logp_false, inputs",
[
(
pt.lt,
lambda x: st.poisson(2).logcdf(x - 1),
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
(pt.random.poisson(2), 3),
),
(
pt.ge,
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
lambda x: st.poisson(2).logcdf(x - 1),
(pt.random.poisson(2), 3),
),
(pt.gt, st.poisson(2).logsf, st.poisson(2).logcdf, (pt.random.poisson(2), 3)),
(pt.le, st.poisson(2).logcdf, st.poisson(2).logsf, (pt.random.poisson(2), 3)),
(
pt.gt,
pt.lt,
st.poisson(2).logsf,
st.poisson(2).logcdf,
(3, pt.random.poisson(2)),
),
(pt.ge, st.poisson(2).logcdf, st.poisson(2).logsf, (3, pt.random.poisson(2))),
(
pt.gt,
lambda x: st.poisson(2).logcdf(x - 1),
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
(3, pt.random.poisson(2)),
),
(
pt.le,
st.poisson(2).logcdf,
st.poisson(2).logsf,
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
lambda x: st.poisson(2).logcdf(x - 1),
(3, pt.random.poisson(2)),
),
],
)
def test_discrete_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
x_rv = pt.random.poisson(2)
cens_x_rv = comparison_op(x_rv, 3)
def test_discrete_rv_comparison(inputs, comparison_op, exp_logp_true, exp_logp_false):
cens_x_rv = comparison_op(*inputs)

cens_x_vv = cens_x_rv.clone()

Expand Down

0 comments on commit 55d915c

Please sign in to comment.