From 864ecb3e906b0ceb3d74281f80672a5cec69971b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 9 Jun 2023 16:17:13 +0200 Subject: [PATCH] Fix bug in switch mixture logp The True and False branches were being mixed up --- pymc/logprob/mixture.py | 2 +- tests/logprob/test_mixture.py | 14 +++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/pymc/logprob/mixture.py b/pymc/logprob/mixture.py index 8b2491384ee..a84b89a732a 100644 --- a/pymc/logprob/mixture.py +++ b/pymc/logprob/mixture.py @@ -344,7 +344,7 @@ def find_measurable_switch_mixture(fgraph, node): old_mixture_rv.broadcastable, ) new_mixture_rv = mix_op.make_node( - *([NoneConst, as_nontensor_scalar(node.inputs[0])] + components) + *([NoneConst, as_nontensor_scalar(node.inputs[0])] + components[::-1]) ).default_output() if pytensor.config.compute_test_value != "off": diff --git a/tests/logprob/test_mixture.py b/tests/logprob/test_mixture.py index edc51a13817..697eaf9a388 100644 --- a/tests/logprob/test_mixture.py +++ b/tests/logprob/test_mixture.py @@ -917,7 +917,12 @@ def test_switch_mixture(): i_vv = I_rv.clone() i_vv.name = "i" + # When I_rv == True, X_rv flows through otherwise Y_rv does Z1_rv = pt.switch(I_rv, X_rv, Y_rv) + + assert Z1_rv.eval({I_rv: 0}) > 5 + assert Z1_rv.eval({I_rv: 1}) < -5 + z_vv = Z1_rv.clone() z_vv.name = "z1" @@ -935,7 +940,10 @@ def test_switch_mixture(): # building the identical graph but with a stack to check that mixture computations are identical - Z2_rv = pt.stack((X_rv, Y_rv))[I_rv] + Z2_rv = pt.stack((Y_rv, X_rv))[I_rv] + + assert Z2_rv.eval({I_rv: 0}) > 5 + assert Z2_rv.eval({I_rv: 1}) < -5 fgraph2, _, _ = construct_ir_fgraph({Z2_rv: z_vv, I_rv: i_vv}) @@ -949,8 +957,8 @@ def test_switch_mixture(): # below should follow immediately from the equal_computations assertion above assert equal_computations([z1_logp_combined], [z2_logp_combined]) - np.testing.assert_almost_equal(0.69049938, z1_logp_combined.eval({z_vv: -10, i_vv: 0})) - np.testing.assert_almost_equal(0.69049938, z2_logp_combined.eval({z_vv: -10, i_vv: 0})) + np.testing.assert_almost_equal(0.69049938, z1_logp_combined.eval({z_vv: -10, i_vv: 1})) + np.testing.assert_almost_equal(0.69049938, z2_logp_combined.eval({z_vv: -10, i_vv: 1})) def test_ifelse_mixture_one_component():