diff --git a/funsor/optimizer.py b/funsor/optimizer.py index ed597565..1036f30c 100644 --- a/funsor/optimizer.py +++ b/funsor/optimizer.py @@ -46,7 +46,11 @@ def unfold_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): ) return Contraction(red_op, v.bin_op, reduced_vars, *new_terms) - if red_op in (v.red_op, ops.null) and (v.red_op, bin_op) in DISTRIBUTIVE_OPS: + if ( + red_op in (v.red_op, ops.null) + and (v.red_op, bin_op) in DISTRIBUTIVE_OPS + and v.reduced_vars.isdisjoint(reduced_vars) + ): new_terms = ( terms[:i] + (Contraction(v.red_op, v.bin_op, frozenset(), *v.terms),) @@ -56,7 +60,11 @@ def unfold_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms): red_op, reduced_vars ) - if v.red_op in (red_op, ops.null) and bin_op in (v.bin_op, ops.null): + if ( + v.red_op in (red_op, ops.null) + and bin_op in (v.bin_op, ops.null) + and v.reduced_vars.isdisjoint(reduced_vars) + ): red_op = v.red_op if red_op is ops.null else red_op bin_op = v.bin_op if bin_op is ops.null else bin_op new_terms = terms[:i] + v.terms + terms[i + 1 :] diff --git a/test/test_optimizer.py b/test/test_optimizer.py index 7951cd27..0af39959 100644 --- a/test/test_optimizer.py +++ b/test/test_optimizer.py @@ -6,6 +6,7 @@ import pytest import funsor +from funsor.cnf import Contraction from funsor.domains import Bint from funsor.einsum import ( einsum, @@ -160,3 +161,185 @@ def test_optimized_plated_einsum(equation, plates, backend): for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim] + + +def test_intersecting_contractions(): + import torch + + with funsor.terms.lazy: + term = Contraction( + funsor.terms.ops.logaddexp, + funsor.terms.ops.add, + frozenset( + { + Variable("_drop_0__BOUND_10", Bint[3]), + Variable("_drop_1__BOUND_11", Bint[2]), + } + ), # noqa + ( + Contraction( + funsor.terms.ops.logaddexp, + funsor.terms.ops.add, + frozenset( + { + Variable("_drop_0__BOUND_8", Bint[3]), + Variable("_drop_1__BOUND_9", Bint[2]), + } + ), # noqa + ( + Tensor( + torch.tensor( + [ + [ + [-1.1258398294448853, -1.152360200881958], + [-0.2505785822868347, -0.4338788092136383], + ], + [ + [0.8487103581428528, 0.6920091509819031], + [-0.31601276993751526, -2.1152193546295166], + ], + [ + [0.32227492332458496, -1.2633347511291504], + [0.34998318552970886, 0.30813392996788025], + ], + ], + dtype=torch.float32, + ), # noqa + ( + ( + "_drop_0__BOUND_8", + Bint[3], + ), + ( + "_drop_1__BOUND_9", + Bint[2], + ), + ( + "_PREV_b", + Bint[2], + ), + ), + "real", + ), + Tensor( + torch.tensor( + [ + [ + [0.11984150856733322, 1.237657904624939], + [1.1167771816253662, -0.2472781538963318], + ], + [ + [-1.3526537418365479, -1.6959311962127686], + [0.5666506290435791, 0.7935083508491516], + ], + [ + [0.5988394618034363, -1.5550950765609741], + [-0.3413603901863098, 1.85300612449646], + ], + ], + dtype=torch.float32, + ), # noqa + ( + ( + "_drop_0__BOUND_10", + Bint[3], + ), + ( + "_drop_1__BOUND_11", + Bint[2], + ), + ( + "_drop_1__BOUND_9", + Bint[2], + ), + ), + "real", + ), + ), + ), + Contraction( + funsor.terms.ops.logaddexp, + funsor.terms.ops.add, + frozenset( + { + Variable("_drop_0__BOUND_8", Bint[3]), + Variable("_drop_1__BOUND_9", Bint[2]), + } + ), # noqa + ( + Tensor( + torch.tensor( + [ + [ + [0.750189483165741, -0.5854975581169128], + [-0.1733967512845993, 0.18347793817520142], + ], + [ + [1.3893661499023438, 1.586334228515625], + [0.946298360824585, -0.843676745891571], + ], + [ + [-0.6135830879211426, 0.03159274160861969], + [-0.4926769733428955, 0.2484147548675537], + ], + ], + dtype=torch.float32, + ), # noqa + ( + ( + "_drop_0__BOUND_8", + Bint[3], + ), + ( + "_drop_1__BOUND_9", + Bint[2], + ), + ( + "_drop_1__BOUND_11", + Bint[2], + ), + ), + "real", + ), + Tensor( + torch.tensor( + [ + [ + [0.4396958351135254, 0.11241118609905243], + [0.6407923698425293, 0.441156268119812], + ], + [ + [-0.10230965167284012, 0.7924439907073975], + [-0.28966769576072693, 0.05250748619437218], + ], + [ + [0.5228604674339294, 2.3022053241729736], + [-1.4688938856124878, -1.586688756942749], + ], + ], + dtype=torch.float32, + ), # noqa + ( + ( + "a", + Bint[3], + ), + ( + "b", + Bint[2], + ), + ( + "_drop_1__BOUND_9", + Bint[2], + ), + ), + "real", + ), + ), + ), + ), + ) + expected = reinterpret(term) + actual = apply_optimizer(term) + expected = expected.align(tuple(actual.inputs.keys())) + assert_close(actual, expected)