From ec6fe585b764fc6ab2ec73f97b4570cd200482b9 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 16 Jul 2021 01:32:33 -0400 Subject: [PATCH 01/22] metadatatensor --- funsor/torch/__init__.py | 1 + funsor/torch/metadata.py | 24 ++++++++++++++++++++++++ test/test_metadata.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+) create mode 100644 funsor/torch/metadata.py create mode 100644 test/test_metadata.py diff --git a/funsor/torch/__init__.py b/funsor/torch/__init__.py index 9f69ab983..1d0501c83 100644 --- a/funsor/torch/__init__.py +++ b/funsor/torch/__init__.py @@ -10,6 +10,7 @@ from . import distributions as _ from . import ops as _ +from .metadata import MetadataTensor del _ # flake8 diff --git a/funsor/torch/metadata.py b/funsor/torch/metadata.py new file mode 100644 index 000000000..af2a914ef --- /dev/null +++ b/funsor/torch/metadata.py @@ -0,0 +1,24 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +class MetadataTensor(object): + def __init__(self, data, metadata=frozenset(), **kwargs): + assert isinstance(metadata, frozenset) + self._t = torch.as_tensor(data, **kwargs) + self._metadata = metadata + + def __repr__(self): + return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t) + + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + meta = frozenset().union( + *tuple(a._metadata for a in args if hasattr(a, "_metadata")) + ) + args = [a._t if hasattr(a, "_t") else a for a in args] + ret = func(*args, **kwargs) + return MetadataTensor(ret, metadata=meta) diff --git a/test/test_metadata.py b/test/test_metadata.py new file mode 100644 index 000000000..645eda00a --- /dev/null +++ b/test/test_metadata.py @@ -0,0 +1,32 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from funsor.torch import MetadataTensor + + +@pytest.mark.parametrize( + "data1,metadata1", + [ + (torch.tensor([1]), frozenset({"a"})), + ], +) +@pytest.mark.parametrize( + "data2,metadata2", + [ + (torch.tensor([2]), frozenset({"b"})), + (torch.tensor([2]), None), + (2, None), + ], +) +def test_metadata(data1, metadata1, data2, metadata2): + if metadata1 is not None: + data1 = MetadataTensor(data1, metadata1) + if metadata2 is not None: + data2 = MetadataTensor(data2, metadata2) + + expected = frozenset.union(*[m for m in (metadata1, metadata2) if m is not None]) + actual = torch.add(data1, data2)._metadata + assert actual == expected From de39abfa3f842545f9825f2828f7e30f5a117b1a Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Fri, 16 Jul 2021 01:49:34 -0400 Subject: [PATCH 02/22] lint --- funsor/torch/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/funsor/torch/__init__.py b/funsor/torch/__init__.py index 1d0501c83..c782c0829 100644 --- a/funsor/torch/__init__.py +++ b/funsor/torch/__init__.py @@ -29,3 +29,6 @@ def _quote(x, indent, out): @dispatch(torch.Tensor, torch.Tensor, [float]) def allclose(a, b, rtol=1e-05, atol=1e-08): return torch.allclose(a, b, rtol=rtol, atol=atol) + + +__all__ = ["MetadataTensor"] From ac14606f2659c127c6f6958504f71f50d5a60e53 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 21 Jul 2021 01:24:00 -0400 Subject: [PATCH 03/22] Constant funsor --- funsor/constant.py | 76 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 funsor/constant.py diff --git a/funsor/constant.py b/funsor/constant.py new file mode 100644 index 000000000..d76a7d4d8 --- /dev/null +++ b/funsor/constant.py @@ -0,0 +1,76 @@ +from collections import OrderedDict + +from funsor.terms import Funsor, Variable, Binary, eager, Number, Unary +from .ops import BinaryOp, FinitaryOp, GetitemOp, MatmulOp, Op, ReshapeOp, UnaryOp +from funsor.tensor import Tensor +from funsor.delta import Delta +from funsor.distribution import Distribution + +class Constant(Funsor): + def __init__(self, const_vars, arg): + assert isinstance(arg, Funsor) + assert isinstance(const_vars, frozenset) + assert all(isinstance(v, Variable) for v in const_vars) + assert all(v not in arg.inputs for v in const_vars) + # const_names = frozenset(v.name for v in cont_vars) + inputs = OrderedDict( + (v.name, v.output) for v in const_vars + ) + inputs.update(arg.inputs) + output = arg.output + fresh = const_vars + bound = {} + super(Constant, self).__init__(inputs, output, fresh, bound) + self.arg = arg + self.const_vars = const_vars + + def eager_subs(self, subs): + assert isinstance(subs, tuple) + const_vars = self.const_vars + for name, value in subs: + if isinstance(value, Variable): + breakpoint() + continue + + breakpoint() + if isinstance(value, (Number, Tensor)): + const_vars = const_vars - value + + return Constant(const_vars, self.arg) + + def eager_reduce(self, op, reduced_vars): + assert reduced_vars.issubset(self.inputs) + const_vars = frozenset({v for v in self.const_vars if v.name not in reduced_vars}) + reduced_vars = reduced_vars - frozenset({v.name for v in self.const_vars}) + if not const_vars: + return self.arg.reduce(op, reduced_vars) + return Constant(const_vars, self.arg.reduce(op, reduced_vars)) + + +@eager.register(Binary, BinaryOp, Constant, Constant) +def eager_binary_constant_constant(op, lhs, rhs): + const_vars = lhs.const_vars | rhs.const_vars - lhs.input_vars - rhs.input_vars + if not const_vars: + return op(lhs.arg, rhs.arg) + return Constant(const_vars, op(lhs.arg, rhs.arg)) + + +@eager.register(Binary, BinaryOp, Constant, (Number, Tensor)) +def eager_binary_constant_tensor(op, lhs, rhs): + const_vars = lhs.const_vars - rhs.input_vars + if not const_vars: + return op(lhs.arg, rhs) + return Constant(const_vars, op(lhs.arg, rhs)) + + +@eager.register(Binary, BinaryOp, (Number, Tensor), Constant) +def eager_binary_tensor_constant(op, lhs, rhs): + const_vars = rhs.const_vars - lhs.input_vars + if not const_vars: + return op(lhs, rhs.arg) + return Constant(const_vars, op(lhs, rhs.arg)) + + +@eager.register(Unary, UnaryOp, Constant) +def eager_binary_tensor_constant(op, arg): + return Constant(arg.const_vars, op(arg.arg)) From 58ae920d5a0a9333b80bc9ccdaf0196af6a7e6c4 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 26 Jul 2021 10:25:25 -0400 Subject: [PATCH 04/22] subclass torch.Tensor --- funsor/torch/metadata.py | 27 +++++++++++++++++---------- test/test_metadata.py | 1 + 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/funsor/torch/metadata.py b/funsor/torch/metadata.py index af2a914ef..ce4cdd3b0 100644 --- a/funsor/torch/metadata.py +++ b/funsor/torch/metadata.py @@ -4,21 +4,28 @@ import torch -class MetadataTensor(object): - def __init__(self, data, metadata=frozenset(), **kwargs): +class MetadataTensor(torch.Tensor): + def __new__(cls, data, metadata=frozenset(), **kwargs): assert isinstance(metadata, frozenset) - self._t = torch.as_tensor(data, **kwargs) - self._metadata = metadata + t = torch.Tensor._make_subclass(cls, data) + t._metadata = metadata + return t def __repr__(self): - return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t) + return "Metadata:\n{}\ndata:\n{}".format( + self._metadata, torch.Tensor._make_subclass(torch.Tensor, self) + ) def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - meta = frozenset().union( - *tuple(a._metadata for a in args if hasattr(a, "_metadata")) - ) - args = [a._t if hasattr(a, "_t") else a for a in args] - ret = func(*args, **kwargs) + meta = frozenset() + _args = [] + for arg in args: + if isinstance(arg, MetadataTensor): + meta |= arg._metadata + _args.append(torch.Tensor._make_subclass(torch.Tensor, arg)) + else: + _args.append(arg) + ret = func(*_args, **kwargs) return MetadataTensor(ret, metadata=meta) diff --git a/test/test_metadata.py b/test/test_metadata.py index 645eda00a..3415e7cef 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -4,6 +4,7 @@ import pytest import torch +import funsor from funsor.torch import MetadataTensor From 5eec3e054fbbf4ad97e5001168a86aaaf030c155 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 26 Jul 2021 11:22:38 -0400 Subject: [PATCH 05/22] save --- funsor/adjoint.py | 3 +- funsor/constant.py | 59 +++++++++++++++++++++++++++++++++++----- funsor/torch/metadata.py | 14 ++++++++-- test/test_metadata.py | 2 ++ 4 files changed, 67 insertions(+), 11 deletions(-) diff --git a/funsor/adjoint.py b/funsor/adjoint.py index 65b1565a5..bdae045af 100644 --- a/funsor/adjoint.py +++ b/funsor/adjoint.py @@ -58,7 +58,8 @@ def interpret(self, cls, *args): ) for arg in args ] - self._eager_to_lazy[result] = reflect.interpret(cls, *lazy_args) + with self._old_interpretation: + self._eager_to_lazy[result] = reflect.interpret(cls, *lazy_args) return result def __enter__(self): diff --git a/funsor/constant.py b/funsor/constant.py index d76a7d4d8..a9ef53489 100644 --- a/funsor/constant.py +++ b/funsor/constant.py @@ -1,10 +1,16 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + from collections import OrderedDict -from funsor.terms import Funsor, Variable, Binary, eager, Number, Unary -from .ops import BinaryOp, FinitaryOp, GetitemOp, MatmulOp, Op, ReshapeOp, UnaryOp -from funsor.tensor import Tensor from funsor.delta import Delta from funsor.distribution import Distribution +from funsor.tensor import Tensor +from funsor.terms import Binary, Funsor, Number, Unary, Variable, eager, to_funsor +from funsor.torch import MetadataTensor + +from .ops import BinaryOp, FinitaryOp, GetitemOp, MatmulOp, Op, ReshapeOp, UnaryOp + class Constant(Funsor): def __init__(self, const_vars, arg): @@ -13,9 +19,7 @@ def __init__(self, const_vars, arg): assert all(isinstance(v, Variable) for v in const_vars) assert all(v not in arg.inputs for v in const_vars) # const_names = frozenset(v.name for v in cont_vars) - inputs = OrderedDict( - (v.name, v.output) for v in const_vars - ) + inputs = OrderedDict((v.name, v.output) for v in const_vars) inputs.update(arg.inputs) output = arg.output fresh = const_vars @@ -40,7 +44,9 @@ def eager_subs(self, subs): def eager_reduce(self, op, reduced_vars): assert reduced_vars.issubset(self.inputs) - const_vars = frozenset({v for v in self.const_vars if v.name not in reduced_vars}) + const_vars = frozenset( + {v for v in self.const_vars if v.name not in reduced_vars} + ) reduced_vars = reduced_vars - frozenset({v.name for v in self.const_vars}) if not const_vars: return self.arg.reduce(op, reduced_vars) @@ -74,3 +80,42 @@ def eager_binary_tensor_constant(op, lhs, rhs): @eager.register(Unary, UnaryOp, Constant) def eager_binary_tensor_constant(op, arg): return Constant(arg.const_vars, op(arg.arg)) + + +@to_funsor.register(MetadataTensor) +def tensor_to_funsor(x, output=None, dim_to_name=None): + breakpoint() + if not dim_to_name: + output = output if output is not None else Reals[x.shape] + result = Tensor(x, dtype=output.dtype) + if result.output != output: + raise ValueError( + "Invalid shape: expected {}, actual {}".format( + output.shape, result.output.shape + ) + ) + return result + else: + assert all( + isinstance(k, int) and k < 0 and isinstance(v, str) + for k, v in dim_to_name.items() + ) + + if output is None: + # Assume the leftmost dim_to_name key refers to the leftmost dim of x + # when there is ambiguity about event shape + batch_ndims = min(-min(dim_to_name.keys()), len(x.shape)) + output = Reals[x.shape[batch_ndims:]] + + # logic very similar to pyro.ops.packed.pack + # this should not touch memory, only reshape + # pack the tensor according to the dim => name mapping in inputs + packed_inputs = OrderedDict() + for dim, size in zip(range(len(x.shape) - len(output.shape)), x.shape): + name = dim_to_name.get(dim + len(output.shape) - len(x.shape), None) + if name is not None and size != 1: + packed_inputs[name] = Bint[size] + shape = tuple(d.size for d in packed_inputs.values()) + output.shape + if x.shape != shape: + x = x.reshape(shape) + return Tensor(x, packed_inputs, dtype=output.dtype) diff --git a/funsor/torch/metadata.py b/funsor/torch/metadata.py index ce4cdd3b0..be33abcdb 100644 --- a/funsor/torch/metadata.py +++ b/funsor/torch/metadata.py @@ -7,9 +7,17 @@ class MetadataTensor(torch.Tensor): def __new__(cls, data, metadata=frozenset(), **kwargs): assert isinstance(metadata, frozenset) - t = torch.Tensor._make_subclass(cls, data) - t._metadata = metadata - return t + if isinstance(data, torch.Tensor): + t = torch.Tensor._make_subclass(cls, data) + t._metadata = metadata + return t + else: + return data + # breakpoint() + # pass + # if isinstance(data, torch.Size): + # # Is this correct? + # return data def __repr__(self): return "Metadata:\n{}\ndata:\n{}".format( diff --git a/test/test_metadata.py b/test/test_metadata.py index 3415e7cef..e4b327f83 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -27,6 +27,8 @@ def test_metadata(data1, metadata1, data2, metadata2): data1 = MetadataTensor(data1, metadata1) if metadata2 is not None: data2 = MetadataTensor(data2, metadata2) + t = funsor.to_funsor(data1) + breakpoint() expected = frozenset.union(*[m for m in (metadata1, metadata2) if m is not None]) actual = torch.add(data1, data2)._metadata From 1b754caf542446b87091b19e900840522c619f1a Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 26 Jul 2021 11:24:54 -0400 Subject: [PATCH 06/22] subclass torch.Tensor --- funsor/torch/metadata.py | 30 ++++++++++++++++++++---------- test/test_metadata.py | 1 + 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/funsor/torch/metadata.py b/funsor/torch/metadata.py index af2a914ef..399aea79b 100644 --- a/funsor/torch/metadata.py +++ b/funsor/torch/metadata.py @@ -4,21 +4,31 @@ import torch -class MetadataTensor(object): - def __init__(self, data, metadata=frozenset(), **kwargs): +class MetadataTensor(torch.Tensor): + def __new__(cls, data, metadata=frozenset(), **kwargs): assert isinstance(metadata, frozenset) - self._t = torch.as_tensor(data, **kwargs) - self._metadata = metadata + if isinstance(data, torch.Tensor): + t = torch.Tensor._make_subclass(cls, data) + t._metadata = metadata + return t + else: + return data def __repr__(self): - return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t) + return "Metadata:\n{}\ndata:\n{}".format( + self._metadata, torch.Tensor._make_subclass(torch.Tensor, self) + ) def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - meta = frozenset().union( - *tuple(a._metadata for a in args if hasattr(a, "_metadata")) - ) - args = [a._t if hasattr(a, "_t") else a for a in args] - ret = func(*args, **kwargs) + meta = frozenset() + _args = [] + for arg in args: + if isinstance(arg, MetadataTensor): + meta |= arg._metadata + _args.append(torch.Tensor._make_subclass(torch.Tensor, arg)) + else: + _args.append(arg) + ret = func(*_args, **kwargs) return MetadataTensor(ret, metadata=meta) diff --git a/test/test_metadata.py b/test/test_metadata.py index 645eda00a..3415e7cef 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -4,6 +4,7 @@ import pytest import torch +import funsor from funsor.torch import MetadataTensor From fb5d30a2f7e4effd36c47edf4a40fba9d33a666c Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 26 Jul 2021 11:53:33 -0400 Subject: [PATCH 07/22] rename to ProvenanceTensor --- funsor/torch/__init__.py | 4 ---- funsor/torch/metadata.py | 34 ---------------------------------- funsor/torch/provenance.py | 33 +++++++++++++++++++++++++++++++++ test/test_metadata.py | 33 --------------------------------- test/test_provenance.py | 34 ++++++++++++++++++++++++++++++++++ 5 files changed, 67 insertions(+), 71 deletions(-) delete mode 100644 funsor/torch/metadata.py create mode 100644 funsor/torch/provenance.py delete mode 100644 test/test_metadata.py create mode 100644 test/test_provenance.py diff --git a/funsor/torch/__init__.py b/funsor/torch/__init__.py index c782c0829..9f69ab983 100644 --- a/funsor/torch/__init__.py +++ b/funsor/torch/__init__.py @@ -10,7 +10,6 @@ from . import distributions as _ from . import ops as _ -from .metadata import MetadataTensor del _ # flake8 @@ -29,6 +28,3 @@ def _quote(x, indent, out): @dispatch(torch.Tensor, torch.Tensor, [float]) def allclose(a, b, rtol=1e-05, atol=1e-08): return torch.allclose(a, b, rtol=rtol, atol=atol) - - -__all__ = ["MetadataTensor"] diff --git a/funsor/torch/metadata.py b/funsor/torch/metadata.py deleted file mode 100644 index 399aea79b..000000000 --- a/funsor/torch/metadata.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -class MetadataTensor(torch.Tensor): - def __new__(cls, data, metadata=frozenset(), **kwargs): - assert isinstance(metadata, frozenset) - if isinstance(data, torch.Tensor): - t = torch.Tensor._make_subclass(cls, data) - t._metadata = metadata - return t - else: - return data - - def __repr__(self): - return "Metadata:\n{}\ndata:\n{}".format( - self._metadata, torch.Tensor._make_subclass(torch.Tensor, self) - ) - - def __torch_function__(self, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - meta = frozenset() - _args = [] - for arg in args: - if isinstance(arg, MetadataTensor): - meta |= arg._metadata - _args.append(torch.Tensor._make_subclass(torch.Tensor, arg)) - else: - _args.append(arg) - ret = func(*_args, **kwargs) - return MetadataTensor(ret, metadata=meta) diff --git a/funsor/torch/provenance.py b/funsor/torch/provenance.py new file mode 100644 index 000000000..0d8b0b6cf --- /dev/null +++ b/funsor/torch/provenance.py @@ -0,0 +1,33 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +class ProvenanceTensor(torch.Tensor): + def __new__(cls, data, provenance=frozenset(), **kwargs): + assert isinstance(provenance, frozenset) + t = torch.Tensor._make_subclass(cls, data) + t._provenance = provenance + return t + + def __repr__(self): + return "Provenance:\n{}\nTensor:\n{}".format( + self._provenance, torch.Tensor._make_subclass(torch.Tensor, self) + ) + + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + provenance = frozenset() + _args = [] + for arg in args: + if isinstance(arg, ProvenanceTensor): + provenance |= arg._provenance + _args.append(torch.Tensor._make_subclass(torch.Tensor, arg)) + else: + _args.append(arg) + ret = func(*_args, **kwargs) + if isinstance(ret, torch.Tensor): + return ProvenanceTensor(ret, provenance=provenance) + return ret diff --git a/test/test_metadata.py b/test/test_metadata.py deleted file mode 100644 index 3415e7cef..000000000 --- a/test/test_metadata.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -import pytest -import torch - -import funsor -from funsor.torch import MetadataTensor - - -@pytest.mark.parametrize( - "data1,metadata1", - [ - (torch.tensor([1]), frozenset({"a"})), - ], -) -@pytest.mark.parametrize( - "data2,metadata2", - [ - (torch.tensor([2]), frozenset({"b"})), - (torch.tensor([2]), None), - (2, None), - ], -) -def test_metadata(data1, metadata1, data2, metadata2): - if metadata1 is not None: - data1 = MetadataTensor(data1, metadata1) - if metadata2 is not None: - data2 = MetadataTensor(data2, metadata2) - - expected = frozenset.union(*[m for m in (metadata1, metadata2) if m is not None]) - actual = torch.add(data1, data2)._metadata - assert actual == expected diff --git a/test/test_provenance.py b/test/test_provenance.py new file mode 100644 index 000000000..5aa36c134 --- /dev/null +++ b/test/test_provenance.py @@ -0,0 +1,34 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from funsor.torch.provenance import ProvenanceTensor + + +@pytest.mark.parametrize( + "data1,provenance1", + [ + (torch.tensor([1]), frozenset({"a"})), + ], +) +@pytest.mark.parametrize( + "data2,provenance2", + [ + (torch.tensor([2]), frozenset({"b"})), + (torch.tensor([2]), None), + (2, None), + ], +) +def test_provenance(data1, provenance1, data2, provenance2): + if provenance1 is not None: + data1 = ProvenanceTensor(data1, provenance1) + if provenance2 is not None: + data2 = ProvenanceTensor(data2, provenance2) + + expected = frozenset.union( + *[m for m in (provenance1, provenance2) if m is not None] + ) + actual = torch.add(data1, data2)._provenance + assert actual == expected From b3bb934204126606a8ed2c6d7e8cc16fe480fb5a Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 28 Jul 2021 03:35:45 -0400 Subject: [PATCH 08/22] working version --- funsor/adjoint.py | 60 +++++++++--------- funsor/constant.py | 126 ++++++++++++++++++------------------- funsor/torch/__init__.py | 7 +++ funsor/torch/provenance.py | 40 ++++++++++-- test/test_metadata.py | 35 ----------- test/test_provenance.py | 7 ++- 6 files changed, 139 insertions(+), 136 deletions(-) delete mode 100644 test/test_metadata.py diff --git a/funsor/adjoint.py b/funsor/adjoint.py index bdae045af..596034120 100644 --- a/funsor/adjoint.py +++ b/funsor/adjoint.py @@ -58,8 +58,8 @@ def interpret(self, cls, *args): ) for arg in args ] - with self._old_interpretation: - self._eager_to_lazy[result] = reflect.interpret(cls, *lazy_args) + # with self._old_interpretation: + # self._eager_to_lazy[result] = reflect.interpret(cls, *lazy_args) return result def __enter__(self): @@ -84,34 +84,34 @@ def adjoint(self, sum_op, bin_op, root, targets=None): continue # reverse the effects of alpha-renaming - with reflect: - - lazy_output = self._eager_to_lazy[output] - lazy_fn = type(lazy_output) - lazy_inputs = lazy_output._ast_values - # TODO abstract this into a helper function - # FIXME make lazy_output linear instead of quadratic in the size of the tape - lazy_other_subs = tuple( - (name, to_funsor(name.split("__BOUND")[0], domain)) - for name, domain in lazy_output.inputs.items() - if "__BOUND" in name - ) - lazy_inputs = _alpha_unmangle( - substitute(lazy_fn(*lazy_inputs), lazy_other_subs) - ) - lazy_output = type(lazy_output)( - *_alpha_unmangle(substitute(lazy_output, lazy_other_subs)) - ) - - other_subs = tuple( - (name, to_funsor(name.split("__BOUND")[0], domain)) - for name, domain in output.inputs.items() - if "__BOUND" in name - ) - inputs = _alpha_unmangle(substitute(fn(*inputs), other_subs)) - output = type(output)(*_alpha_unmangle(substitute(output, other_subs))) - - self._eager_to_lazy[output] = lazy_output + # with reflect: + # + # lazy_output = self._eager_to_lazy[output] + # lazy_fn = type(lazy_output) + # lazy_inputs = lazy_output._ast_values + # # TODO abstract this into a helper function + # # FIXME make lazy_output linear instead of quadratic in the size of the tape + # lazy_other_subs = tuple( + # (name, to_funsor(name.split("__BOUND")[0], domain)) + # for name, domain in lazy_output.inputs.items() + # if "__BOUND" in name + # ) + # lazy_inputs = _alpha_unmangle( + # substitute(lazy_fn(*lazy_inputs), lazy_other_subs) + # ) + # lazy_output = type(lazy_output)( + # *_alpha_unmangle(substitute(lazy_output, lazy_other_subs)) + # ) + # + # other_subs = tuple( + # (name, to_funsor(name.split("__BOUND")[0], domain)) + # for name, domain in output.inputs.items() + # if "__BOUND" in name + # ) + # inputs = _alpha_unmangle(substitute(fn(*inputs), other_subs)) + # output = type(output)(*_alpha_unmangle(substitute(output, other_subs))) + # + # self._eager_to_lazy[output] = lazy_output in_adjs = adjoint_ops(fn, sum_op, bin_op, adjoint_values[output], *inputs) for v, adjv in in_adjs: diff --git a/funsor/constant.py b/funsor/constant.py index a9ef53489..12b8d9d50 100644 --- a/funsor/constant.py +++ b/funsor/constant.py @@ -6,41 +6,45 @@ from funsor.delta import Delta from funsor.distribution import Distribution from funsor.tensor import Tensor -from funsor.terms import Binary, Funsor, Number, Unary, Variable, eager, to_funsor -from funsor.torch import MetadataTensor +from funsor.terms import Binary, Funsor, Number, Unary, Variable, eager, to_data +from funsor.torch.provenance import ProvenanceTensor -from .ops import BinaryOp, FinitaryOp, GetitemOp, MatmulOp, Op, ReshapeOp, UnaryOp +from .ops import BinaryOp, FinitaryOp, GetitemOp, MatmulOp, Op, ReshapeOp, UnaryOp, AddOp class Constant(Funsor): - def __init__(self, const_vars, arg): + def __init__(self, const_inputs, arg): assert isinstance(arg, Funsor) - assert isinstance(const_vars, frozenset) - assert all(isinstance(v, Variable) for v in const_vars) - assert all(v not in arg.inputs for v in const_vars) + assert isinstance(const_inputs, tuple) + assert set(const_inputs).isdisjoint(arg.inputs) + # assert all(v not in arg.inputs for v in const_inputs) # const_names = frozenset(v.name for v in cont_vars) - inputs = OrderedDict((v.name, v.output) for v in const_vars) + const_inputs = OrderedDict(const_inputs) + inputs = const_inputs.copy() inputs.update(arg.inputs) output = arg.output - fresh = const_vars + fresh = frozenset(const_inputs.keys()) bound = {} super(Constant, self).__init__(inputs, output, fresh, bound) self.arg = arg - self.const_vars = const_vars + self.const_vars = frozenset(Variable(k, v) for k, v in const_inputs.items()) + self.const_inputs = const_inputs def eager_subs(self, subs): assert isinstance(subs, tuple) - const_vars = self.const_vars - for name, value in subs: - if isinstance(value, Variable): - breakpoint() - continue - - breakpoint() - if isinstance(value, (Number, Tensor)): - const_vars = const_vars - value - - return Constant(const_vars, self.arg) + subs = OrderedDict((k, v) for k, v in subs) + const_inputs = OrderedDict() + for k, d in self.const_inputs.items(): + # handle when subs is in self.arg.inputs + if k in subs: + v = subs[k] + if isinstance(v, Variable): + del subs[k] + k = v.name + const_inputs[k] = d + if const_inputs: + return Constant(tuple(const_inputs.items()), self.arg) + return self.arg def eager_reduce(self, op, reduced_vars): assert reduced_vars.issubset(self.inputs) @@ -50,7 +54,8 @@ def eager_reduce(self, op, reduced_vars): reduced_vars = reduced_vars - frozenset({v.name for v in self.const_vars}) if not const_vars: return self.arg.reduce(op, reduced_vars) - return Constant(const_vars, self.arg.reduce(op, reduced_vars)) + const_inputs = tuple((v.name, v.output) for v in const_vars) + return Constant(const_inputs, self.arg.reduce(op, reduced_vars)) @eager.register(Binary, BinaryOp, Constant, Constant) @@ -58,7 +63,8 @@ def eager_binary_constant_constant(op, lhs, rhs): const_vars = lhs.const_vars | rhs.const_vars - lhs.input_vars - rhs.input_vars if not const_vars: return op(lhs.arg, rhs.arg) - return Constant(const_vars, op(lhs.arg, rhs.arg)) + const_inputs = tuple((v.name, v.output) for v in const_vars) + return Constant(const_inputs, op(lhs.arg, rhs.arg)) @eager.register(Binary, BinaryOp, Constant, (Number, Tensor)) @@ -66,7 +72,8 @@ def eager_binary_constant_tensor(op, lhs, rhs): const_vars = lhs.const_vars - rhs.input_vars if not const_vars: return op(lhs.arg, rhs) - return Constant(const_vars, op(lhs.arg, rhs)) + const_inputs = tuple((v.name, v.output) for v in const_vars) + return Constant(const_inputs, op(lhs.arg, rhs)) @eager.register(Binary, BinaryOp, (Number, Tensor), Constant) @@ -74,48 +81,37 @@ def eager_binary_tensor_constant(op, lhs, rhs): const_vars = rhs.const_vars - lhs.input_vars if not const_vars: return op(lhs, rhs.arg) - return Constant(const_vars, op(lhs, rhs.arg)) + const_inputs = tuple((v.name, v.output) for v in const_vars) + return Constant(const_inputs, op(lhs, rhs.arg)) @eager.register(Unary, UnaryOp, Constant) def eager_binary_tensor_constant(op, arg): - return Constant(arg.const_vars, op(arg.arg)) - - -@to_funsor.register(MetadataTensor) -def tensor_to_funsor(x, output=None, dim_to_name=None): - breakpoint() - if not dim_to_name: - output = output if output is not None else Reals[x.shape] - result = Tensor(x, dtype=output.dtype) - if result.output != output: - raise ValueError( - "Invalid shape: expected {}, actual {}".format( - output.shape, result.output.shape - ) - ) - return result - else: - assert all( - isinstance(k, int) and k < 0 and isinstance(v, str) - for k, v in dim_to_name.items() - ) - - if output is None: - # Assume the leftmost dim_to_name key refers to the leftmost dim of x - # when there is ambiguity about event shape - batch_ndims = min(-min(dim_to_name.keys()), len(x.shape)) - output = Reals[x.shape[batch_ndims:]] - - # logic very similar to pyro.ops.packed.pack - # this should not touch memory, only reshape - # pack the tensor according to the dim => name mapping in inputs - packed_inputs = OrderedDict() - for dim, size in zip(range(len(x.shape) - len(output.shape)), x.shape): - name = dim_to_name.get(dim + len(output.shape) - len(x.shape), None) - if name is not None and size != 1: - packed_inputs[name] = Bint[size] - shape = tuple(d.size for d in packed_inputs.values()) + output.shape - if x.shape != shape: - x = x.reshape(shape) - return Tensor(x, packed_inputs, dtype=output.dtype) + const_inputs = tuple((v.name, v.output) for v in arg.const_vars) + return Constant(const_inputs, op(arg.arg)) + + +# @eager.register(Binary, AddOp, Constant, Delta) +# def eager_binary_constant_tensor(op, lhs, rhs): +# const_vars = lhs.const_vars - rhs.input_vars +# breakpoint() +# if not const_vars: +# return op(lhs.arg, rhs) +# const_inputs = tuple((v.name, v.output) for v in const_vars) +# return Constant(const_inputs, op(lhs.arg, rhs)) +# +# +# @eager.register(Binary, AddOp, Delta, Constant) +# def eager_binary_tensor_constant(op, lhs, rhs): +# const_vars = rhs.const_vars - lhs.input_vars +# breakpoint() +# if not const_vars: +# return op(lhs, rhs.arg) +# const_inputs = tuple((v.name, v.output) for v in const_vars) +# return Constant(const_inputs, op(lhs, rhs.arg)) + + +@to_data.register(Constant) +def constant_to_data(x, name_to_dim=None): + data = to_data(x.arg, name_to_dim=name_to_dim) + return ProvenanceTensor(data, provenance=frozenset((v.name, v.output) for v in x.const_vars)) diff --git a/funsor/torch/__init__.py b/funsor/torch/__init__.py index 9f69ab983..3bfca043f 100644 --- a/funsor/torch/__init__.py +++ b/funsor/torch/__init__.py @@ -7,6 +7,8 @@ from funsor.tensor import tensor_to_funsor from funsor.terms import to_funsor from funsor.util import quote +from funsor.torch.provenance import ProvenanceTensor +from funsor.constant import Constant from . import distributions as _ from . import ops as _ @@ -21,6 +23,11 @@ def _quote(x, indent, out): """ out.append((indent, "torch.tensor({}, dtype={})".format(repr(x.tolist()), x.dtype))) +@to_funsor.register(ProvenanceTensor) +def provenance_to_funsor(x, output=None, dim_to_name=None): + if isinstance(x, ProvenanceTensor): + ret = to_funsor(x._t, output=output, dim_to_name=dim_to_name) + return Constant(tuple(x._provenance), ret) to_funsor.register(torch.Tensor)(tensor_to_funsor) diff --git a/funsor/torch/provenance.py b/funsor/torch/provenance.py index 0d8b0b6cf..49811bb15 100644 --- a/funsor/torch/provenance.py +++ b/funsor/torch/provenance.py @@ -6,14 +6,27 @@ class ProvenanceTensor(torch.Tensor): def __new__(cls, data, provenance=frozenset(), **kwargs): + # assert isinstance(provenance, frozenset) + # t = torch.Tensor._make_subclass(cls, data) + # t._provenance = provenance + # return data + if not provenance: + return data + instance = torch.Tensor.__new__(cls) + instance.__init__(data, provenance) + return instance + # return super(object).__new__(cls, data, provenance) + + def __init__(self, data, provenance=frozenset()): assert isinstance(provenance, frozenset) - t = torch.Tensor._make_subclass(cls, data) - t._provenance = provenance - return t + # t = torch.Tensor._make_subclass(cls, data) + self._t = data + self._provenance = provenance def __repr__(self): return "Provenance:\n{}\nTensor:\n{}".format( - self._provenance, torch.Tensor._make_subclass(torch.Tensor, self) + self._provenance, self._t + # self._provenance, torch.Tensor._make_subclass(torch.Tensor, self) ) def __torch_function__(self, func, types, args=(), kwargs=None): @@ -24,10 +37,27 @@ def __torch_function__(self, func, types, args=(), kwargs=None): for arg in args: if isinstance(arg, ProvenanceTensor): provenance |= arg._provenance - _args.append(torch.Tensor._make_subclass(torch.Tensor, arg)) + _args.append(arg._t) else: _args.append(arg) ret = func(*_args, **kwargs) if isinstance(ret, torch.Tensor): return ProvenanceTensor(ret, provenance=provenance) + if isinstance(ret, tuple): + _ret = [] + for r in ret: + if isinstance(r, torch.Tensor): + _ret.append(ProvenanceTensor(r, provenance=provenance)) + else: + _ret.append(r) + return tuple(_ret) return ret + +class MyObject(torch.Tensor): + @staticmethod + def __new__(cls, x, extra_data, *args, **kwargs): + return super().__new__(cls, x, *args, **kwargs) + + def __init__(self, x, extra_data): + #super().__init__() # optional + self.extra_data = extra_data diff --git a/test/test_metadata.py b/test/test_metadata.py deleted file mode 100644 index e4b327f83..000000000 --- a/test/test_metadata.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -import pytest -import torch - -import funsor -from funsor.torch import MetadataTensor - - -@pytest.mark.parametrize( - "data1,metadata1", - [ - (torch.tensor([1]), frozenset({"a"})), - ], -) -@pytest.mark.parametrize( - "data2,metadata2", - [ - (torch.tensor([2]), frozenset({"b"})), - (torch.tensor([2]), None), - (2, None), - ], -) -def test_metadata(data1, metadata1, data2, metadata2): - if metadata1 is not None: - data1 = MetadataTensor(data1, metadata1) - if metadata2 is not None: - data2 = MetadataTensor(data2, metadata2) - t = funsor.to_funsor(data1) - breakpoint() - - expected = frozenset.union(*[m for m in (metadata1, metadata2) if m is not None]) - actual = torch.add(data1, data2)._metadata - assert actual == expected diff --git a/test/test_provenance.py b/test/test_provenance.py index 5aa36c134..cb7f2bb39 100644 --- a/test/test_provenance.py +++ b/test/test_provenance.py @@ -4,7 +4,8 @@ import pytest import torch -from funsor.torch.provenance import ProvenanceTensor +from funsor.terms import to_funsor +from funsor.torch.provenance import ProvenanceTensor, MyObject @pytest.mark.parametrize( @@ -22,10 +23,14 @@ ], ) def test_provenance(data1, provenance1, data2, provenance2): + # breakpoint() + # mo = MyObject(data1, extra_data=provenance1) if provenance1 is not None: data1 = ProvenanceTensor(data1, provenance1) if provenance2 is not None: data2 = ProvenanceTensor(data2, provenance2) + breakpoint() + to_funsor(data1) expected = frozenset.union( *[m for m in (provenance1, provenance2) if m is not None] From cc86ee358745653441212b89107afbc507ef2ed6 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 28 Jul 2021 19:55:23 -0400 Subject: [PATCH 09/22] pass second test --- funsor/cnf.py | 19 ++++++++++++------- funsor/constant.py | 4 ++-- funsor/delta.py | 2 +- funsor/distribution.py | 23 ++++++++++++++++++----- funsor/montecarlo.py | 5 +++-- funsor/sum_product.py | 2 +- funsor/tensor.py | 2 +- funsor/terms.py | 7 ++++--- funsor/torch/provenance.py | 14 +++----------- 9 files changed, 45 insertions(+), 33 deletions(-) diff --git a/funsor/cnf.py b/funsor/cnf.py index 39e82e1eb..12203ff76 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -102,7 +102,7 @@ def __str__(self): ) return super().__str__() - def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): + def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None, raw_value=None): sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: return self @@ -120,7 +120,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): # binary choices symbolic. terms = [ term.unscaled_sample( - sampled_vars.intersection(term.inputs), sample_inputs + sampled_vars.intersection(term.inputs), sample_inputs, raw_value=raw_value ) for term, rng_key in zip(self.terms, rng_keys) ] @@ -136,18 +136,23 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): # Sample variables greedily in order of the terms in which they appear. for term in self.terms: + # breakpoint() greedy_vars = sampled_vars.intersection(term.inputs) if greedy_vars: break greedy_terms, terms = [], [] for term in self.terms: + if isinstance(term, funsor.torch.distributions.Poisson): + dd = {term.value.name} + else: + dd = term.inputs ( - terms if greedy_vars.isdisjoint(term.inputs) else greedy_terms + terms if greedy_vars.isdisjoint(dd) else greedy_terms ).append(term) if len(greedy_terms) == 1: term = greedy_terms[0] terms.append( - term.unscaled_sample(greedy_vars, sample_inputs, rng_keys[0]) + term.unscaled_sample(greedy_vars, sample_inputs, rng_keys[0], raw_value=raw_value) ) result = Contraction( self.red_op, self.bin_op, self.reduced_vars, *terms @@ -162,7 +167,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): terms.append(gaussian) terms.append(-gaussian.log_normalizer) terms.append( - term.unscaled_sample(greedy_vars, sample_inputs, rng_keys[0]) + term.unscaled_sample(greedy_vars, sample_inputs, rng_keys[0], raw_value=raw_value) ) result = Contraction( self.red_op, self.bin_op, self.reduced_vars, *terms @@ -174,7 +179,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): ): sampled_terms = [ term.unscaled_sample( - greedy_vars.intersection(term.value.inputs), sample_inputs + greedy_vars.intersection(term.value.inputs), sample_inputs, raw_value=raw_value ) for term in greedy_terms if isinstance(term, funsor.distribution.Distribution) @@ -193,7 +198,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): ) ) return result.unscaled_sample( - sampled_vars - greedy_vars, sample_inputs, rng_keys[1] + sampled_vars - greedy_vars, sample_inputs, rng_keys[1], raw_value=raw_value ) raise TypeError( diff --git a/funsor/constant.py b/funsor/constant.py index 12b8d9d50..18cad324a 100644 --- a/funsor/constant.py +++ b/funsor/constant.py @@ -67,7 +67,7 @@ def eager_binary_constant_constant(op, lhs, rhs): return Constant(const_inputs, op(lhs.arg, rhs.arg)) -@eager.register(Binary, BinaryOp, Constant, (Number, Tensor)) +@eager.register(Binary, BinaryOp, Constant, (Number, Tensor, Distribution)) def eager_binary_constant_tensor(op, lhs, rhs): const_vars = lhs.const_vars - rhs.input_vars if not const_vars: @@ -76,7 +76,7 @@ def eager_binary_constant_tensor(op, lhs, rhs): return Constant(const_inputs, op(lhs.arg, rhs)) -@eager.register(Binary, BinaryOp, (Number, Tensor), Constant) +@eager.register(Binary, BinaryOp, (Number, Tensor, Distribution), Constant) def eager_binary_tensor_constant(op, lhs, rhs): const_vars = rhs.const_vars - lhs.input_vars if not const_vars: diff --git a/funsor/delta.py b/funsor/delta.py index ddc3d2960..e97f34c0c 100644 --- a/funsor/delta.py +++ b/funsor/delta.py @@ -200,7 +200,7 @@ def eager_reduce(self, op, reduced_vars): return None # defer to default implementation - def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): + def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None, raw_value=None): return self diff --git a/funsor/distribution.py b/funsor/distribution.py index 6a1e19397..137051fb2 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -12,6 +12,7 @@ import makefun +import funsor import funsor.delta import funsor.ops as ops from funsor.affine import is_affine @@ -145,13 +146,15 @@ def __repr__(self): ) def eager_reduce(self, op, reduced_vars): + # breakpoint() assert reduced_vars.issubset(self.inputs) if ( op is ops.logaddexp and isinstance(self.value, Variable) and self.value.name in reduced_vars ): - return Number(0.0) # distributions are normalized + const_inputs = tuple((k, v) for k, v in self.inputs.items() if k not in reduced_vars) + return funsor.constant.Constant(const_inputs, Number(0.0)) # distributions are normalized return super(Distribution, self).eager_reduce(op, reduced_vars) def _get_raw_dist(self): @@ -206,7 +209,7 @@ def eager_log_prob(cls, *params): inputs.update(x.inputs) return log_prob.align(tuple(inputs)) - def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): + def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None, raw_value=None): # note this should handle transforms correctly via distribution_to_data raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist() @@ -220,10 +223,19 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): sample_args = ( (sample_shape,) if get_backend() == "torch" else (rng_key, sample_shape) ) - if raw_dist.has_rsample: - raw_value = raw_dist.rsample(*sample_args) + if raw_value is None: + # fix this + raw_value = {} + raw_value = {var: raw_value[var] for var in sampled_vars if var in raw_value} + if not raw_value: + if raw_dist.has_rsample: + raw_value = raw_dist.rsample(*sample_args) + else: + raw_value = ops.detach(raw_dist.sample(*sample_args)) else: - raw_value = ops.detach(raw_dist.sample(*sample_args)) + raw_value = raw_value[value_name] + # if "data" in dim_to_name.values(): + # raw_value = raw_value.unsqueeze(-1) funsor_value = to_funsor( raw_value, output=value_output, dim_to_name=dim_to_name @@ -232,6 +244,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): tuple(sample_inputs) + tuple(inp for inp in self.inputs if inp in funsor_value.inputs) ) + result = funsor.delta.Delta(value_name, funsor_value) if not raw_dist.has_rsample: # scaling of dice_factor by num samples should already be handled by Funsor.sample diff --git a/funsor/montecarlo.py b/funsor/montecarlo.py index e0669640e..6fcc418bd 100644 --- a/funsor/montecarlo.py +++ b/funsor/montecarlo.py @@ -19,9 +19,10 @@ class MonteCarlo(StatefulInterpretation): :param rng_key: """ - def __init__(self, *, rng_key=None, **sample_inputs): + def __init__(self, *, rng_key=None, raw_value=None, **sample_inputs): super().__init__("monte_carlo") self.rng_key = rng_key + self.raw_value = raw_value self.sample_inputs = OrderedDict(sample_inputs) @@ -33,7 +34,7 @@ def monte_carlo_integrate(state, log_measure, integrand, reduced_vars): sample_options["rng_key"], state.rng_key = jax.random.split(state.rng_key) - sample = log_measure.sample(reduced_vars, state.sample_inputs, **sample_options) + sample = log_measure.sample(reduced_vars, state.sample_inputs, raw_value=state.raw_value, **sample_options) if sample is log_measure: return None # cannot progress reduced_vars |= frozenset( diff --git a/funsor/sum_product.py b/funsor/sum_product.py index 3c22775e4..f5814e5a7 100644 --- a/funsor/sum_product.py +++ b/funsor/sum_product.py @@ -246,7 +246,7 @@ def partial_sum_product( ) if new_plates == leaf: raise ValueError("intractable!") - f = f.reduce(prod_op, leaf - new_plates) + f = f.reduce(prod_op, (leaf & eliminate) - new_plates) ordinal_to_factors[new_plates].append(f) return results diff --git a/funsor/tensor.py b/funsor/tensor.py index 1c5853d35..ce8ec4a50 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -329,7 +329,7 @@ def eager_reduce(self, op, reduced_vars): return Tensor(data, inputs, dtype) return super(Tensor, self).eager_reduce(op, reduced_vars) - def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): + def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None, raw_value=None): assert self.output == Real sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: diff --git a/funsor/terms.py b/funsor/terms.py index 1041a56f3..5ad1384eb 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -420,7 +420,7 @@ def approximate(self, op, guide, approx_vars=None): return self # exact return Approximate(op, self, guide, approx_vars) - def sample(self, sampled_vars, sample_inputs=None, rng_key=None): + def sample(self, sampled_vars, sample_inputs=None, rng_key=None, raw_value=None): """ Create a Monte Carlo approximation to this funsor by replacing functions of ``sampled_vars`` with :class:`~funsor.delta.Delta` s. @@ -457,8 +457,9 @@ def sample(self, sampled_vars, sample_inputs=None, rng_key=None): if sampled_vars.isdisjoint(self.inputs): return self + # breakpoint() result = instrument.debug_logged(self.unscaled_sample)( - sampled_vars, sample_inputs, rng_key + sampled_vars, sample_inputs, rng_key, raw_value ) if sample_inputs is not None: log_scale = 0 @@ -469,7 +470,7 @@ def sample(self, sampled_vars, sample_inputs=None, rng_key=None): result += log_scale return result - def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): + def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None, raw_value=None): """ Internal method to draw an unscaled sample. This should be overridden by subclasses. diff --git a/funsor/torch/provenance.py b/funsor/torch/provenance.py index 49811bb15..6b397ddef 100644 --- a/funsor/torch/provenance.py +++ b/funsor/torch/provenance.py @@ -15,18 +15,19 @@ def __new__(cls, data, provenance=frozenset(), **kwargs): instance = torch.Tensor.__new__(cls) instance.__init__(data, provenance) return instance - # return super(object).__new__(cls, data, provenance) def __init__(self, data, provenance=frozenset()): assert isinstance(provenance, frozenset) # t = torch.Tensor._make_subclass(cls, data) + if isinstance(data, ProvenanceTensor): + provenance |= data._provenance + data = data._t self._t = data self._provenance = provenance def __repr__(self): return "Provenance:\n{}\nTensor:\n{}".format( self._provenance, self._t - # self._provenance, torch.Tensor._make_subclass(torch.Tensor, self) ) def __torch_function__(self, func, types, args=(), kwargs=None): @@ -52,12 +53,3 @@ def __torch_function__(self, func, types, args=(), kwargs=None): _ret.append(r) return tuple(_ret) return ret - -class MyObject(torch.Tensor): - @staticmethod - def __new__(cls, x, extra_data, *args, **kwargs): - return super().__new__(cls, x, *args, **kwargs) - - def __init__(self, x, extra_data): - #super().__init__() # optional - self.extra_data = extra_data From 59a1baa88b231bbaa60339c002b4cb845e9e6667 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 29 Jul 2021 00:16:49 -0400 Subject: [PATCH 10/22] clean up constant.py --- funsor/cnf.py | 29 +++++++--- funsor/constant.py | 109 +++++++++++++++++++------------------ funsor/delta.py | 4 +- funsor/distribution.py | 12 +++- funsor/montecarlo.py | 4 +- funsor/sum_product.py | 2 +- funsor/tensor.py | 4 +- funsor/terms.py | 4 +- funsor/torch/__init__.py | 10 +++- funsor/torch/provenance.py | 4 +- test/test_provenance.py | 2 +- 11 files changed, 107 insertions(+), 77 deletions(-) diff --git a/funsor/cnf.py b/funsor/cnf.py index 12203ff76..e46670ec6 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -102,7 +102,9 @@ def __str__(self): ) return super().__str__() - def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None, raw_value=None): + def unscaled_sample( + self, sampled_vars, sample_inputs, rng_key=None, raw_value=None + ): sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: return self @@ -120,7 +122,9 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None, raw_value=N # binary choices symbolic. terms = [ term.unscaled_sample( - sampled_vars.intersection(term.inputs), sample_inputs, raw_value=raw_value + sampled_vars.intersection(term.inputs), + sample_inputs, + raw_value=raw_value, ) for term, rng_key in zip(self.terms, rng_keys) ] @@ -146,13 +150,13 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None, raw_value=N dd = {term.value.name} else: dd = term.inputs - ( - terms if greedy_vars.isdisjoint(dd) else greedy_terms - ).append(term) + (terms if greedy_vars.isdisjoint(dd) else greedy_terms).append(term) if len(greedy_terms) == 1: term = greedy_terms[0] terms.append( - term.unscaled_sample(greedy_vars, sample_inputs, rng_keys[0], raw_value=raw_value) + term.unscaled_sample( + greedy_vars, sample_inputs, rng_keys[0], raw_value=raw_value + ) ) result = Contraction( self.red_op, self.bin_op, self.reduced_vars, *terms @@ -167,7 +171,9 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None, raw_value=N terms.append(gaussian) terms.append(-gaussian.log_normalizer) terms.append( - term.unscaled_sample(greedy_vars, sample_inputs, rng_keys[0], raw_value=raw_value) + term.unscaled_sample( + greedy_vars, sample_inputs, rng_keys[0], raw_value=raw_value + ) ) result = Contraction( self.red_op, self.bin_op, self.reduced_vars, *terms @@ -179,7 +185,9 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None, raw_value=N ): sampled_terms = [ term.unscaled_sample( - greedy_vars.intersection(term.value.inputs), sample_inputs, raw_value=raw_value + greedy_vars.intersection(term.value.inputs), + sample_inputs, + raw_value=raw_value, ) for term in greedy_terms if isinstance(term, funsor.distribution.Distribution) @@ -198,7 +206,10 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None, raw_value=N ) ) return result.unscaled_sample( - sampled_vars - greedy_vars, sample_inputs, rng_keys[1], raw_value=raw_value + sampled_vars - greedy_vars, + sample_inputs, + rng_keys[1], + raw_value=raw_value, ) raise TypeError( diff --git a/funsor/constant.py b/funsor/constant.py index 18cad324a..9ea4c7c69 100644 --- a/funsor/constant.py +++ b/funsor/constant.py @@ -3,27 +3,45 @@ from collections import OrderedDict -from funsor.delta import Delta from funsor.distribution import Distribution from funsor.tensor import Tensor -from funsor.terms import Binary, Funsor, Number, Unary, Variable, eager, to_data +from funsor.terms import ( + Binary, + Funsor, + FunsorMeta, + Number, + Unary, + Variable, + eager, + to_data, +) from funsor.torch.provenance import ProvenanceTensor -from .ops import BinaryOp, FinitaryOp, GetitemOp, MatmulOp, Op, ReshapeOp, UnaryOp, AddOp +from .ops import BinaryOp, UnaryOp -class Constant(Funsor): +class ConstantMeta(FunsorMeta): + """ + Wrapper to convert ``const_inputs`` to a tuple. + """ + + def __call__(cls, const_inputs, arg): + if isinstance(const_inputs, dict): + const_inputs = tuple(const_inputs.items()) + + return super(ConstantMeta, cls).__call__(const_inputs, arg) + + +class Constant(Funsor, metaclass=ConstantMeta): def __init__(self, const_inputs, arg): assert isinstance(arg, Funsor) assert isinstance(const_inputs, tuple) assert set(const_inputs).isdisjoint(arg.inputs) - # assert all(v not in arg.inputs for v in const_inputs) - # const_names = frozenset(v.name for v in cont_vars) const_inputs = OrderedDict(const_inputs) inputs = const_inputs.copy() inputs.update(arg.inputs) output = arg.output - fresh = frozenset(const_inputs.keys()) + fresh = frozenset(const_inputs) bound = {} super(Constant, self).__init__(inputs, output, fresh, bound) self.arg = arg @@ -43,75 +61,60 @@ def eager_subs(self, subs): k = v.name const_inputs[k] = d if const_inputs: - return Constant(tuple(const_inputs.items()), self.arg) + return Constant(const_inputs, self.arg) return self.arg def eager_reduce(self, op, reduced_vars): assert reduced_vars.issubset(self.inputs) - const_vars = frozenset( - {v for v in self.const_vars if v.name not in reduced_vars} + const_inputs = OrderedDict( + (k, v) for k, v in self.const_inputs.items() if k not in reduced_vars ) - reduced_vars = reduced_vars - frozenset({v.name for v in self.const_vars}) - if not const_vars: - return self.arg.reduce(op, reduced_vars) - const_inputs = tuple((v.name, v.output) for v in const_vars) - return Constant(const_inputs, self.arg.reduce(op, reduced_vars)) + reduced_vars = reduced_vars - frozenset(self.const_inputs) + reduced_arg = self.arg.reduce(op, reduced_vars) + if const_inputs: + return Constant(const_inputs, reduced_arg) + return reduced_arg @eager.register(Binary, BinaryOp, Constant, Constant) def eager_binary_constant_constant(op, lhs, rhs): - const_vars = lhs.const_vars | rhs.const_vars - lhs.input_vars - rhs.input_vars - if not const_vars: - return op(lhs.arg, rhs.arg) - const_inputs = tuple((v.name, v.output) for v in const_vars) - return Constant(const_inputs, op(lhs.arg, rhs.arg)) + const_inputs = OrderedDict( + (k, v) for k, v in lhs.const_inputs.items() if k not in rhs.const_inputs + ) + const_inputs.update( + (k, v) for k, v in rhs.const_inputs.items() if k not in lhs.const_inputs + ) + if const_inputs: + return Constant(const_inputs, op(lhs.arg, rhs.arg)) + return op(lhs.arg, rhs.arg) @eager.register(Binary, BinaryOp, Constant, (Number, Tensor, Distribution)) def eager_binary_constant_tensor(op, lhs, rhs): - const_vars = lhs.const_vars - rhs.input_vars - if not const_vars: - return op(lhs.arg, rhs) - const_inputs = tuple((v.name, v.output) for v in const_vars) - return Constant(const_inputs, op(lhs.arg, rhs)) + const_inputs = OrderedDict( + (k, v) for k, v in lhs.const_inputs.items() if k not in rhs.inputs + ) + if const_inputs: + return Constant(const_inputs, op(lhs.arg, rhs)) + return op(lhs.arg, rhs) @eager.register(Binary, BinaryOp, (Number, Tensor, Distribution), Constant) def eager_binary_tensor_constant(op, lhs, rhs): - const_vars = rhs.const_vars - lhs.input_vars - if not const_vars: - return op(lhs, rhs.arg) - const_inputs = tuple((v.name, v.output) for v in const_vars) - return Constant(const_inputs, op(lhs, rhs.arg)) + const_inputs = OrderedDict( + (k, v) for k, v in rhs.const_inputs.items() if k not in lhs.inputs + ) + if const_inputs: + return Constant(const_inputs, op(lhs, rhs.arg)) + return op(lhs, rhs.arg) @eager.register(Unary, UnaryOp, Constant) def eager_binary_tensor_constant(op, arg): - const_inputs = tuple((v.name, v.output) for v in arg.const_vars) - return Constant(const_inputs, op(arg.arg)) - - -# @eager.register(Binary, AddOp, Constant, Delta) -# def eager_binary_constant_tensor(op, lhs, rhs): -# const_vars = lhs.const_vars - rhs.input_vars -# breakpoint() -# if not const_vars: -# return op(lhs.arg, rhs) -# const_inputs = tuple((v.name, v.output) for v in const_vars) -# return Constant(const_inputs, op(lhs.arg, rhs)) -# -# -# @eager.register(Binary, AddOp, Delta, Constant) -# def eager_binary_tensor_constant(op, lhs, rhs): -# const_vars = rhs.const_vars - lhs.input_vars -# breakpoint() -# if not const_vars: -# return op(lhs, rhs.arg) -# const_inputs = tuple((v.name, v.output) for v in const_vars) -# return Constant(const_inputs, op(lhs, rhs.arg)) + return Constant(arg.const_inputs, op(arg.arg)) @to_data.register(Constant) def constant_to_data(x, name_to_dim=None): data = to_data(x.arg, name_to_dim=name_to_dim) - return ProvenanceTensor(data, provenance=frozenset((v.name, v.output) for v in x.const_vars)) + return ProvenanceTensor(data, provenance=frozenset(x.const_inputs.items())) diff --git a/funsor/delta.py b/funsor/delta.py index e97f34c0c..f028d7d0a 100644 --- a/funsor/delta.py +++ b/funsor/delta.py @@ -200,7 +200,9 @@ def eager_reduce(self, op, reduced_vars): return None # defer to default implementation - def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None, raw_value=None): + def unscaled_sample( + self, sampled_vars, sample_inputs, rng_key=None, raw_value=None + ): return self diff --git a/funsor/distribution.py b/funsor/distribution.py index 137051fb2..1ad011bc2 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -153,8 +153,12 @@ def eager_reduce(self, op, reduced_vars): and isinstance(self.value, Variable) and self.value.name in reduced_vars ): - const_inputs = tuple((k, v) for k, v in self.inputs.items() if k not in reduced_vars) - return funsor.constant.Constant(const_inputs, Number(0.0)) # distributions are normalized + const_inputs = OrderedDict( + (k, v) for k, v in self.inputs.items() if k not in reduced_vars + ) + return funsor.constant.Constant( + const_inputs, Number(0.0) + ) # distributions are normalized return super(Distribution, self).eager_reduce(op, reduced_vars) def _get_raw_dist(self): @@ -209,7 +213,9 @@ def eager_log_prob(cls, *params): inputs.update(x.inputs) return log_prob.align(tuple(inputs)) - def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None, raw_value=None): + def unscaled_sample( + self, sampled_vars, sample_inputs, rng_key=None, raw_value=None + ): # note this should handle transforms correctly via distribution_to_data raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist() diff --git a/funsor/montecarlo.py b/funsor/montecarlo.py index 6fcc418bd..cb624d394 100644 --- a/funsor/montecarlo.py +++ b/funsor/montecarlo.py @@ -34,7 +34,9 @@ def monte_carlo_integrate(state, log_measure, integrand, reduced_vars): sample_options["rng_key"], state.rng_key = jax.random.split(state.rng_key) - sample = log_measure.sample(reduced_vars, state.sample_inputs, raw_value=state.raw_value, **sample_options) + sample = log_measure.sample( + reduced_vars, state.sample_inputs, raw_value=state.raw_value, **sample_options + ) if sample is log_measure: return None # cannot progress reduced_vars |= frozenset( diff --git a/funsor/sum_product.py b/funsor/sum_product.py index f5814e5a7..3c22775e4 100644 --- a/funsor/sum_product.py +++ b/funsor/sum_product.py @@ -246,7 +246,7 @@ def partial_sum_product( ) if new_plates == leaf: raise ValueError("intractable!") - f = f.reduce(prod_op, (leaf & eliminate) - new_plates) + f = f.reduce(prod_op, leaf - new_plates) ordinal_to_factors[new_plates].append(f) return results diff --git a/funsor/tensor.py b/funsor/tensor.py index ce8ec4a50..0b0f4b703 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -329,7 +329,9 @@ def eager_reduce(self, op, reduced_vars): return Tensor(data, inputs, dtype) return super(Tensor, self).eager_reduce(op, reduced_vars) - def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None, raw_value=None): + def unscaled_sample( + self, sampled_vars, sample_inputs, rng_key=None, raw_value=None + ): assert self.output == Real sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: diff --git a/funsor/terms.py b/funsor/terms.py index 5ad1384eb..f0a58ff0b 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -470,7 +470,9 @@ def sample(self, sampled_vars, sample_inputs=None, rng_key=None, raw_value=None) result += log_scale return result - def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None, raw_value=None): + def unscaled_sample( + self, sampled_vars, sample_inputs, rng_key=None, raw_value=None + ): """ Internal method to draw an unscaled sample. This should be overridden by subclasses. diff --git a/funsor/torch/__init__.py b/funsor/torch/__init__.py index 3bfca043f..b6439989f 100644 --- a/funsor/torch/__init__.py +++ b/funsor/torch/__init__.py @@ -1,14 +1,16 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from collections import OrderedDict + import torch from multipledispatch import dispatch +from funsor.constant import Constant from funsor.tensor import tensor_to_funsor from funsor.terms import to_funsor -from funsor.util import quote from funsor.torch.provenance import ProvenanceTensor -from funsor.constant import Constant +from funsor.util import quote from . import distributions as _ from . import ops as _ @@ -23,11 +25,13 @@ def _quote(x, indent, out): """ out.append((indent, "torch.tensor({}, dtype={})".format(repr(x.tolist()), x.dtype))) + @to_funsor.register(ProvenanceTensor) def provenance_to_funsor(x, output=None, dim_to_name=None): if isinstance(x, ProvenanceTensor): ret = to_funsor(x._t, output=output, dim_to_name=dim_to_name) - return Constant(tuple(x._provenance), ret) + return Constant(OrderedDict(x._provenance), ret) + to_funsor.register(torch.Tensor)(tensor_to_funsor) diff --git a/funsor/torch/provenance.py b/funsor/torch/provenance.py index 6b397ddef..3e2697b0b 100644 --- a/funsor/torch/provenance.py +++ b/funsor/torch/provenance.py @@ -26,9 +26,7 @@ def __init__(self, data, provenance=frozenset()): self._provenance = provenance def __repr__(self): - return "Provenance:\n{}\nTensor:\n{}".format( - self._provenance, self._t - ) + return "Provenance:\n{}\nTensor:\n{}".format(self._provenance, self._t) def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: diff --git a/test/test_provenance.py b/test/test_provenance.py index cb7f2bb39..641aebd4d 100644 --- a/test/test_provenance.py +++ b/test/test_provenance.py @@ -5,7 +5,7 @@ import torch from funsor.terms import to_funsor -from funsor.torch.provenance import ProvenanceTensor, MyObject +from funsor.torch.provenance import ProvenanceTensor @pytest.mark.parametrize( From 9c2fcb258495d1e9f359fac9c4c2c3e7d5e75040 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 29 Jul 2021 02:05:02 -0400 Subject: [PATCH 11/22] more clean up --- funsor/cnf.py | 9 ++++----- funsor/distribution.py | 22 +++++++++------------- test/test_provenance.py | 5 ----- 3 files changed, 13 insertions(+), 23 deletions(-) diff --git a/funsor/cnf.py b/funsor/cnf.py index e46670ec6..ec377484c 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -140,17 +140,16 @@ def unscaled_sample( # Sample variables greedily in order of the terms in which they appear. for term in self.terms: - # breakpoint() greedy_vars = sampled_vars.intersection(term.inputs) if greedy_vars: break greedy_terms, terms = [], [] for term in self.terms: - if isinstance(term, funsor.torch.distributions.Poisson): - dd = {term.value.name} + if isinstance(term, funsor.distribution.Distribution): + term_var = {term.value.name} else: - dd = term.inputs - (terms if greedy_vars.isdisjoint(dd) else greedy_terms).append(term) + term_var = term.inputs + (terms if greedy_vars.isdisjoint(term_var) else greedy_terms).append(term) if len(greedy_terms) == 1: term = greedy_terms[0] terms.append( diff --git a/funsor/distribution.py b/funsor/distribution.py index 1ad011bc2..180f067af 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -146,7 +146,6 @@ def __repr__(self): ) def eager_reduce(self, op, reduced_vars): - # breakpoint() assert reduced_vars.issubset(self.inputs) if ( op is ops.logaddexp @@ -156,9 +155,11 @@ def eager_reduce(self, op, reduced_vars): const_inputs = OrderedDict( (k, v) for k, v in self.inputs.items() if k not in reduced_vars ) - return funsor.constant.Constant( - const_inputs, Number(0.0) - ) # distributions are normalized + if const_inputs: + return funsor.constant.Constant( + const_inputs, Number(0.0) + ) + return Number(0.0) # distributions are normalized return super(Distribution, self).eager_reduce(op, reduced_vars) def _get_raw_dist(self): @@ -229,19 +230,14 @@ def unscaled_sample( sample_args = ( (sample_shape,) if get_backend() == "torch" else (rng_key, sample_shape) ) - if raw_value is None: - # fix this - raw_value = {} - raw_value = {var: raw_value[var] for var in sampled_vars if var in raw_value} - if not raw_value: + + if raw_value is not None and value_name in raw_value: + raw_value = raw_value[value_name] + else: if raw_dist.has_rsample: raw_value = raw_dist.rsample(*sample_args) else: raw_value = ops.detach(raw_dist.sample(*sample_args)) - else: - raw_value = raw_value[value_name] - # if "data" in dim_to_name.values(): - # raw_value = raw_value.unsqueeze(-1) funsor_value = to_funsor( raw_value, output=value_output, dim_to_name=dim_to_name diff --git a/test/test_provenance.py b/test/test_provenance.py index 641aebd4d..5aa36c134 100644 --- a/test/test_provenance.py +++ b/test/test_provenance.py @@ -4,7 +4,6 @@ import pytest import torch -from funsor.terms import to_funsor from funsor.torch.provenance import ProvenanceTensor @@ -23,14 +22,10 @@ ], ) def test_provenance(data1, provenance1, data2, provenance2): - # breakpoint() - # mo = MyObject(data1, extra_data=provenance1) if provenance1 is not None: data1 = ProvenanceTensor(data1, provenance1) if provenance2 is not None: data2 = ProvenanceTensor(data2, provenance2) - breakpoint() - to_funsor(data1) expected = frozenset.union( *[m for m in (provenance1, provenance2) if m is not None] From 9d15024e4edbe9d077ba3ccbbe573a17fa685ba7 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 29 Jul 2021 02:13:01 -0400 Subject: [PATCH 12/22] rm metadata file --- funsor/cnf.py | 4 +++- funsor/constant.py | 8 ++++++++ funsor/distribution.py | 4 +--- funsor/torch/__init__.py | 11 ----------- funsor/torch/metadata.py | 39 -------------------------------------- funsor/torch/provenance.py | 5 ----- 6 files changed, 12 insertions(+), 59 deletions(-) delete mode 100644 funsor/torch/metadata.py diff --git a/funsor/cnf.py b/funsor/cnf.py index ec377484c..1c9941d01 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -149,7 +149,9 @@ def unscaled_sample( term_var = {term.value.name} else: term_var = term.inputs - (terms if greedy_vars.isdisjoint(term_var) else greedy_terms).append(term) + ( + terms if greedy_vars.isdisjoint(term_var) else greedy_terms + ).append(term) if len(greedy_terms) == 1: term = greedy_terms[0] terms.append( diff --git a/funsor/constant.py b/funsor/constant.py index 9ea4c7c69..6e51ea840 100644 --- a/funsor/constant.py +++ b/funsor/constant.py @@ -14,6 +14,7 @@ Variable, eager, to_data, + to_funsor, ) from funsor.torch.provenance import ProvenanceTensor @@ -118,3 +119,10 @@ def eager_binary_tensor_constant(op, arg): def constant_to_data(x, name_to_dim=None): data = to_data(x.arg, name_to_dim=name_to_dim) return ProvenanceTensor(data, provenance=frozenset(x.const_inputs.items())) + + +@to_funsor.register(ProvenanceTensor) +def provenance_to_funsor(x, output=None, dim_to_name=None): + if isinstance(x, ProvenanceTensor): + ret = to_funsor(x._t, output=output, dim_to_name=dim_to_name) + return Constant(OrderedDict(x._provenance), ret) diff --git a/funsor/distribution.py b/funsor/distribution.py index 180f067af..96876e343 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -156,9 +156,7 @@ def eager_reduce(self, op, reduced_vars): (k, v) for k, v in self.inputs.items() if k not in reduced_vars ) if const_inputs: - return funsor.constant.Constant( - const_inputs, Number(0.0) - ) + return funsor.constant.Constant(const_inputs, Number(0.0)) return Number(0.0) # distributions are normalized return super(Distribution, self).eager_reduce(op, reduced_vars) diff --git a/funsor/torch/__init__.py b/funsor/torch/__init__.py index b6439989f..9f69ab983 100644 --- a/funsor/torch/__init__.py +++ b/funsor/torch/__init__.py @@ -1,15 +1,11 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict - import torch from multipledispatch import dispatch -from funsor.constant import Constant from funsor.tensor import tensor_to_funsor from funsor.terms import to_funsor -from funsor.torch.provenance import ProvenanceTensor from funsor.util import quote from . import distributions as _ @@ -26,13 +22,6 @@ def _quote(x, indent, out): out.append((indent, "torch.tensor({}, dtype={})".format(repr(x.tolist()), x.dtype))) -@to_funsor.register(ProvenanceTensor) -def provenance_to_funsor(x, output=None, dim_to_name=None): - if isinstance(x, ProvenanceTensor): - ret = to_funsor(x._t, output=output, dim_to_name=dim_to_name) - return Constant(OrderedDict(x._provenance), ret) - - to_funsor.register(torch.Tensor)(tensor_to_funsor) diff --git a/funsor/torch/metadata.py b/funsor/torch/metadata.py deleted file mode 100644 index be33abcdb..000000000 --- a/funsor/torch/metadata.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -import torch - - -class MetadataTensor(torch.Tensor): - def __new__(cls, data, metadata=frozenset(), **kwargs): - assert isinstance(metadata, frozenset) - if isinstance(data, torch.Tensor): - t = torch.Tensor._make_subclass(cls, data) - t._metadata = metadata - return t - else: - return data - # breakpoint() - # pass - # if isinstance(data, torch.Size): - # # Is this correct? - # return data - - def __repr__(self): - return "Metadata:\n{}\ndata:\n{}".format( - self._metadata, torch.Tensor._make_subclass(torch.Tensor, self) - ) - - def __torch_function__(self, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - meta = frozenset() - _args = [] - for arg in args: - if isinstance(arg, MetadataTensor): - meta |= arg._metadata - _args.append(torch.Tensor._make_subclass(torch.Tensor, arg)) - else: - _args.append(arg) - ret = func(*_args, **kwargs) - return MetadataTensor(ret, metadata=meta) diff --git a/funsor/torch/provenance.py b/funsor/torch/provenance.py index 3e2697b0b..c7541ebdf 100644 --- a/funsor/torch/provenance.py +++ b/funsor/torch/provenance.py @@ -6,10 +6,6 @@ class ProvenanceTensor(torch.Tensor): def __new__(cls, data, provenance=frozenset(), **kwargs): - # assert isinstance(provenance, frozenset) - # t = torch.Tensor._make_subclass(cls, data) - # t._provenance = provenance - # return data if not provenance: return data instance = torch.Tensor.__new__(cls) @@ -18,7 +14,6 @@ def __new__(cls, data, provenance=frozenset(), **kwargs): def __init__(self, data, provenance=frozenset()): assert isinstance(provenance, frozenset) - # t = torch.Tensor._make_subclass(cls, data) if isinstance(data, ProvenanceTensor): provenance |= data._provenance data = data._t From 94810771e1de49f427f12ace60442980cccb7775 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 3 Aug 2021 20:16:21 -0400 Subject: [PATCH 13/22] add Constant docstring --- funsor/constant.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/funsor/constant.py b/funsor/constant.py index 6e51ea840..6fa7c9356 100644 --- a/funsor/constant.py +++ b/funsor/constant.py @@ -34,6 +34,27 @@ def __call__(cls, const_inputs, arg): class Constant(Funsor, metaclass=ConstantMeta): + """ + Constant funsor wrt to multiple variables (``const_inputs``). + + This can be used for provenance tracking. + + ``const_inputs`` are ignored (removed) under + substition/reduction/binary operations:: + + a = Constant(OrderedDict(x=Real, y=Bint[3]), arg) + assert a.reduce(ops.add, "x") is Constant(OrderedDict(y=Bint[3]), arg) + assert a(y=1) is Constant(OrderedDict(x=Real), arg) + + c = Normal(0, 1, value="x") + assert (a + c) is Constant(OrderedDict(y=Bint[3]), arg + c) + + d = Tensor(torch.tensor([1, 2, 3]))["y"] + assert (a + d) is Constant(OrderedDict(x=Real), arg + d) + + :param dict const_inputs: A mapping from input name (str) to datatype (``funsor.domain.Domain``). + :param funsor arg: A funsor that is constant wrt to const_inputs. + """ def __init__(self, const_inputs, arg): assert isinstance(arg, Funsor) assert isinstance(const_inputs, tuple) @@ -61,6 +82,8 @@ def eager_subs(self, subs): del subs[k] k = v.name const_inputs[k] = d + else: + const_inputs[k] = d if const_inputs: return Constant(const_inputs, self.arg) return self.arg From 00c50514bf845708c6eed5176f26d69741a0cb8b Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 3 Aug 2021 20:19:02 -0400 Subject: [PATCH 14/22] add one more example to docstring --- funsor/constant.py | 1 + 1 file changed, 1 insertion(+) diff --git a/funsor/constant.py b/funsor/constant.py index 6fa7c9356..56fdc73be 100644 --- a/funsor/constant.py +++ b/funsor/constant.py @@ -45,6 +45,7 @@ class Constant(Funsor, metaclass=ConstantMeta): a = Constant(OrderedDict(x=Real, y=Bint[3]), arg) assert a.reduce(ops.add, "x") is Constant(OrderedDict(y=Bint[3]), arg) assert a(y=1) is Constant(OrderedDict(x=Real), arg) + assert a(x=0, y=1) is arg c = Normal(0, 1, value="x") assert (a + c) is Constant(OrderedDict(y=Bint[3]), arg + c) From 4d4d90c62bf24e2b4abdd405518bc370fbd797a0 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 7 Aug 2021 23:44:12 -0400 Subject: [PATCH 15/22] dice factors as importance weights --- funsor/__init__.py | 2 ++ funsor/cnf.py | 21 ++++----------------- funsor/constant.py | 18 +++++++++++------- funsor/delta.py | 15 +++++++-------- funsor/distribution.py | 24 +++++++----------------- funsor/integrate.py | 20 ++++++++++++++++++++ funsor/montecarlo.py | 7 ++----- funsor/sum_product.py | 6 ++++-- funsor/tensor.py | 4 +--- funsor/terms.py | 9 +++------ test/test_distribution.py | 2 +- 11 files changed, 62 insertions(+), 66 deletions(-) diff --git a/funsor/__init__.py b/funsor/__init__.py index 6bb086507..f145d7a02 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -30,6 +30,7 @@ affine, approximations, cnf, + constant, delta, distribution, domains, @@ -74,6 +75,7 @@ "backward", "bint", "cnf", + "constant", "delta", "distribution", "domains", diff --git a/funsor/cnf.py b/funsor/cnf.py index 1c9941d01..d6b216e14 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -102,9 +102,7 @@ def __str__(self): ) return super().__str__() - def unscaled_sample( - self, sampled_vars, sample_inputs, rng_key=None, raw_value=None - ): + def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: return self @@ -124,7 +122,6 @@ def unscaled_sample( term.unscaled_sample( sampled_vars.intersection(term.inputs), sample_inputs, - raw_value=raw_value, ) for term, rng_key in zip(self.terms, rng_keys) ] @@ -145,19 +142,13 @@ def unscaled_sample( break greedy_terms, terms = [], [] for term in self.terms: - if isinstance(term, funsor.distribution.Distribution): - term_var = {term.value.name} - else: - term_var = term.inputs ( - terms if greedy_vars.isdisjoint(term_var) else greedy_terms + terms if greedy_vars.isdisjoint(term.inputs) else greedy_terms ).append(term) if len(greedy_terms) == 1: term = greedy_terms[0] terms.append( - term.unscaled_sample( - greedy_vars, sample_inputs, rng_keys[0], raw_value=raw_value - ) + term.unscaled_sample(greedy_vars, sample_inputs, rng_keys[0]) ) result = Contraction( self.red_op, self.bin_op, self.reduced_vars, *terms @@ -172,9 +163,7 @@ def unscaled_sample( terms.append(gaussian) terms.append(-gaussian.log_normalizer) terms.append( - term.unscaled_sample( - greedy_vars, sample_inputs, rng_keys[0], raw_value=raw_value - ) + term.unscaled_sample(greedy_vars, sample_inputs, rng_keys[0]) ) result = Contraction( self.red_op, self.bin_op, self.reduced_vars, *terms @@ -188,7 +177,6 @@ def unscaled_sample( term.unscaled_sample( greedy_vars.intersection(term.value.inputs), sample_inputs, - raw_value=raw_value, ) for term in greedy_terms if isinstance(term, funsor.distribution.Distribution) @@ -210,7 +198,6 @@ def unscaled_sample( sampled_vars - greedy_vars, sample_inputs, rng_keys[1], - raw_value=raw_value, ) raise TypeError( diff --git a/funsor/constant.py b/funsor/constant.py index 56fdc73be..bfdf0a054 100644 --- a/funsor/constant.py +++ b/funsor/constant.py @@ -3,7 +3,6 @@ from collections import OrderedDict -from funsor.distribution import Distribution from funsor.tensor import Tensor from funsor.terms import ( Binary, @@ -56,11 +55,12 @@ class Constant(Funsor, metaclass=ConstantMeta): :param dict const_inputs: A mapping from input name (str) to datatype (``funsor.domain.Domain``). :param funsor arg: A funsor that is constant wrt to const_inputs. """ + def __init__(self, const_inputs, arg): assert isinstance(arg, Funsor) assert isinstance(const_inputs, tuple) - assert set(const_inputs).isdisjoint(arg.inputs) const_inputs = OrderedDict(const_inputs) + assert set(const_inputs).isdisjoint(arg.inputs) inputs = const_inputs.copy() inputs.update(arg.inputs) output = arg.output @@ -104,17 +104,21 @@ def eager_reduce(self, op, reduced_vars): @eager.register(Binary, BinaryOp, Constant, Constant) def eager_binary_constant_constant(op, lhs, rhs): const_inputs = OrderedDict( - (k, v) for k, v in lhs.const_inputs.items() if k not in rhs.const_inputs + (k, v) + for k, v in lhs.const_inputs.items() + if k not in frozenset(rhs.const_inputs) - frozenset(lhs.const_inputs) ) const_inputs.update( - (k, v) for k, v in rhs.const_inputs.items() if k not in lhs.const_inputs + (k, v) + for k, v in rhs.const_inputs.items() + if k not in frozenset(lhs.const_inputs) - frozenset(rhs.const_inputs) ) if const_inputs: return Constant(const_inputs, op(lhs.arg, rhs.arg)) return op(lhs.arg, rhs.arg) -@eager.register(Binary, BinaryOp, Constant, (Number, Tensor, Distribution)) +@eager.register(Binary, BinaryOp, Constant, (Number, Tensor)) def eager_binary_constant_tensor(op, lhs, rhs): const_inputs = OrderedDict( (k, v) for k, v in lhs.const_inputs.items() if k not in rhs.inputs @@ -124,7 +128,7 @@ def eager_binary_constant_tensor(op, lhs, rhs): return op(lhs.arg, rhs) -@eager.register(Binary, BinaryOp, (Number, Tensor, Distribution), Constant) +@eager.register(Binary, BinaryOp, (Number, Tensor), Constant) def eager_binary_tensor_constant(op, lhs, rhs): const_inputs = OrderedDict( (k, v) for k, v in rhs.const_inputs.items() if k not in lhs.inputs @@ -135,7 +139,7 @@ def eager_binary_tensor_constant(op, lhs, rhs): @eager.register(Unary, UnaryOp, Constant) -def eager_binary_tensor_constant(op, arg): +def eager_unary(op, arg): return Constant(arg.const_inputs, op(arg.arg)) diff --git a/funsor/delta.py b/funsor/delta.py index f028d7d0a..62e35ff48 100644 --- a/funsor/delta.py +++ b/funsor/delta.py @@ -3,6 +3,7 @@ from collections import OrderedDict +import funsor from funsor.domains import Domain, Real from funsor.instrument import debug_logged from funsor.ops import AddOp, SubOp, TransformOp @@ -140,11 +141,7 @@ def eager_subs(self, subs): new_terms[value.name] = new_terms.pop(name) continue - if not any( - d.dtype == "real" - for side in (value, terms[name][0]) - for d in side.inputs.values() - ): + if value.input_vars == terms[name][0].input_vars: point, point_log_density = new_terms.pop(name) log_density += (value == point).all().log() + point_log_density continue @@ -200,9 +197,7 @@ def eager_reduce(self, op, reduced_vars): return None # defer to default implementation - def unscaled_sample( - self, sampled_vars, sample_inputs, rng_key=None, raw_value=None - ): + def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): return self @@ -220,6 +215,8 @@ def eager_add_multidelta(op, lhs, rhs): @eager.register(Binary, (AddOp, SubOp), Delta, (Funsor, Align)) def eager_add_delta_funsor(op, lhs, rhs): if lhs.fresh.intersection(rhs.inputs): + if isinstance(rhs, funsor.constant.Constant): + return funsor.constant.eager_binary_tensor_constant(op, lhs, rhs) rhs = rhs( **{ name: point @@ -235,6 +232,8 @@ def eager_add_delta_funsor(op, lhs, rhs): @eager.register(Binary, AddOp, (Funsor, Align), Delta) def eager_add_funsor_delta(op, lhs, rhs): if rhs.fresh.intersection(lhs.inputs): + if isinstance(lhs, funsor.constant.Constant): + return funsor.constant.eager_binary_constant_tensor(op, lhs, rhs) lhs = lhs( **{ name: point diff --git a/funsor/distribution.py b/funsor/distribution.py index 96876e343..c82564e13 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -152,11 +152,6 @@ def eager_reduce(self, op, reduced_vars): and isinstance(self.value, Variable) and self.value.name in reduced_vars ): - const_inputs = OrderedDict( - (k, v) for k, v in self.inputs.items() if k not in reduced_vars - ) - if const_inputs: - return funsor.constant.Constant(const_inputs, Number(0.0)) return Number(0.0) # distributions are normalized return super(Distribution, self).eager_reduce(op, reduced_vars) @@ -212,9 +207,7 @@ def eager_log_prob(cls, *params): inputs.update(x.inputs) return log_prob.align(tuple(inputs)) - def unscaled_sample( - self, sampled_vars, sample_inputs, rng_key=None, raw_value=None - ): + def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): # note this should handle transforms correctly via distribution_to_data raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist() @@ -229,13 +222,10 @@ def unscaled_sample( (sample_shape,) if get_backend() == "torch" else (rng_key, sample_shape) ) - if raw_value is not None and value_name in raw_value: - raw_value = raw_value[value_name] + if raw_dist.has_rsample: + raw_value = raw_dist.rsample(*sample_args) else: - if raw_dist.has_rsample: - raw_value = raw_dist.rsample(*sample_args) - else: - raw_value = ops.detach(raw_dist.sample(*sample_args)) + raw_value = ops.detach(raw_dist.sample(*sample_args)) funsor_value = to_funsor( raw_value, output=value_output, dim_to_name=dim_to_name @@ -244,8 +234,6 @@ def unscaled_sample( tuple(sample_inputs) + tuple(inp for inp in self.inputs if inp in funsor_value.inputs) ) - - result = funsor.delta.Delta(value_name, funsor_value) if not raw_dist.has_rsample: # scaling of dice_factor by num samples should already be handled by Funsor.sample raw_log_prob = raw_dist.log_prob(raw_value) @@ -254,7 +242,9 @@ def unscaled_sample( output=self.output, dim_to_name=dim_to_name, ) - result = result + dice_factor + result = funsor.delta.Delta(value_name, funsor_value, dice_factor) + else: + result = funsor.delta.Delta(value_name, funsor_value) return result def enumerate_support(self, expand=False): diff --git a/funsor/integrate.py b/funsor/integrate.py index f20879a04..677f2eb68 100644 --- a/funsor/integrate.py +++ b/funsor/integrate.py @@ -4,6 +4,7 @@ from collections import OrderedDict from typing import Union +import funsor import funsor.ops as ops from funsor.cnf import Contraction, GaussianMixture from funsor.delta import Delta @@ -102,6 +103,7 @@ def normalize_integrate_contraction(log_measure, integrand, reduced_vars): and t.fresh.intersection(reduced_names, integrand.inputs) ] for delta in delta_terms: + integrand_inputs = integrand.inputs integrand = integrand( **{ name: point @@ -109,6 +111,23 @@ def normalize_integrate_contraction(log_measure, integrand, reduced_vars): if name in reduced_names.intersection(integrand.inputs) } ) + const_inputs = OrderedDict( + { + name: point.output + for name, (point, log_density) in delta.terms + if name in integrand_inputs + } + ) + log_measure = funsor.constant.Constant( + const_inputs, + log_measure( + **{ + name: point + for name, (point, log_density) in delta.terms + if name in integrand_inputs + } + ), + ) return normalize_integrate(log_measure, integrand, reduced_vars) @@ -155,6 +174,7 @@ def eager_integrate(delta, integrand, reduced_vars): delta_fresh = frozenset(Variable(k, delta.inputs[k]) for k in delta.fresh) if reduced_vars.isdisjoint(delta_fresh): return None + breakpoint() reduced_names = frozenset(v.name for v in reduced_vars) subs = tuple( (name, point) diff --git a/funsor/montecarlo.py b/funsor/montecarlo.py index cb624d394..e0669640e 100644 --- a/funsor/montecarlo.py +++ b/funsor/montecarlo.py @@ -19,10 +19,9 @@ class MonteCarlo(StatefulInterpretation): :param rng_key: """ - def __init__(self, *, rng_key=None, raw_value=None, **sample_inputs): + def __init__(self, *, rng_key=None, **sample_inputs): super().__init__("monte_carlo") self.rng_key = rng_key - self.raw_value = raw_value self.sample_inputs = OrderedDict(sample_inputs) @@ -34,9 +33,7 @@ def monte_carlo_integrate(state, log_measure, integrand, reduced_vars): sample_options["rng_key"], state.rng_key = jax.random.split(state.rng_key) - sample = log_measure.sample( - reduced_vars, state.sample_inputs, raw_value=state.raw_value, **sample_options - ) + sample = log_measure.sample(reduced_vars, state.sample_inputs, **sample_options) if sample is log_measure: return None # cannot progress reduced_vars |= frozenset( diff --git a/funsor/sum_product.py b/funsor/sum_product.py index 3c22775e4..c9c618e6f 100644 --- a/funsor/sum_product.py +++ b/funsor/sum_product.py @@ -236,7 +236,8 @@ def partial_sum_product( leaf_factors = ordinal_to_factors.pop(leaf) leaf_reduce_vars = ordinal_to_vars[leaf] for (group_factors, group_vars) in _partition(leaf_factors, leaf_reduce_vars): - f = reduce(prod_op, group_factors).reduce(sum_op, group_vars) + f = reduce(prod_op, group_factors) + f = f.reduce(sum_op, group_vars) remaining_sum_vars = sum_vars.intersection(f.inputs) if not remaining_sum_vars: results.append(f.reduce(prod_op, leaf & eliminate)) @@ -963,4 +964,5 @@ def eager_markov_product(sum_op, prod_op, trans, time, step, step_names): else: raise NotImplementedError("https://github.com/pyro-ppl/funsor/issues/233") - return Subs(result, step_names) + with funsor.terms.eager: + return Subs(result, step_names) diff --git a/funsor/tensor.py b/funsor/tensor.py index 0b0f4b703..1c5853d35 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -329,9 +329,7 @@ def eager_reduce(self, op, reduced_vars): return Tensor(data, inputs, dtype) return super(Tensor, self).eager_reduce(op, reduced_vars) - def unscaled_sample( - self, sampled_vars, sample_inputs, rng_key=None, raw_value=None - ): + def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): assert self.output == Real sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: diff --git a/funsor/terms.py b/funsor/terms.py index f0a58ff0b..1041a56f3 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -420,7 +420,7 @@ def approximate(self, op, guide, approx_vars=None): return self # exact return Approximate(op, self, guide, approx_vars) - def sample(self, sampled_vars, sample_inputs=None, rng_key=None, raw_value=None): + def sample(self, sampled_vars, sample_inputs=None, rng_key=None): """ Create a Monte Carlo approximation to this funsor by replacing functions of ``sampled_vars`` with :class:`~funsor.delta.Delta` s. @@ -457,9 +457,8 @@ def sample(self, sampled_vars, sample_inputs=None, rng_key=None, raw_value=None) if sampled_vars.isdisjoint(self.inputs): return self - # breakpoint() result = instrument.debug_logged(self.unscaled_sample)( - sampled_vars, sample_inputs, rng_key, raw_value + sampled_vars, sample_inputs, rng_key ) if sample_inputs is not None: log_scale = 0 @@ -470,9 +469,7 @@ def sample(self, sampled_vars, sample_inputs=None, rng_key=None, raw_value=None) result += log_scale return result - def unscaled_sample( - self, sampled_vars, sample_inputs, rng_key=None, raw_value=None - ): + def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): """ Internal method to draw an unscaled sample. This should be overridden by subclasses. diff --git a/test/test_distribution.py b/test/test_distribution.py index 5fa75b463..085933ec2 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -1427,7 +1427,7 @@ def test_categorical_event_dim_conversion(batch_shape, event_shape): name_to_dim = {batch_dim: -1 - i for i, batch_dim in enumerate(batch_dims)} rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) - data = actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0].terms[0][1][0] + data = actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0][1][0] actual_log_prob = funsor.to_data(actual(value=data), name_to_dim=name_to_dim) expected_log_prob = funsor.to_data(actual, name_to_dim=name_to_dim).log_prob( From b1f3488423513a5860d77dc04d3f723131bbc9d7 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 7 Aug 2021 23:51:51 -0400 Subject: [PATCH 16/22] revert old changes --- funsor/cnf.py | 10 +++------- funsor/distribution.py | 1 - funsor/integrate.py | 5 ++--- funsor/sum_product.py | 6 ++---- 4 files changed, 7 insertions(+), 15 deletions(-) diff --git a/funsor/cnf.py b/funsor/cnf.py index d6b216e14..39e82e1eb 100644 --- a/funsor/cnf.py +++ b/funsor/cnf.py @@ -120,8 +120,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): # binary choices symbolic. terms = [ term.unscaled_sample( - sampled_vars.intersection(term.inputs), - sample_inputs, + sampled_vars.intersection(term.inputs), sample_inputs ) for term, rng_key in zip(self.terms, rng_keys) ] @@ -175,8 +174,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): ): sampled_terms = [ term.unscaled_sample( - greedy_vars.intersection(term.value.inputs), - sample_inputs, + greedy_vars.intersection(term.value.inputs), sample_inputs ) for term in greedy_terms if isinstance(term, funsor.distribution.Distribution) @@ -195,9 +193,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): ) ) return result.unscaled_sample( - sampled_vars - greedy_vars, - sample_inputs, - rng_keys[1], + sampled_vars - greedy_vars, sample_inputs, rng_keys[1] ) raise TypeError( diff --git a/funsor/distribution.py b/funsor/distribution.py index c82564e13..5d6d0b6c4 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -221,7 +221,6 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): sample_args = ( (sample_shape,) if get_backend() == "torch" else (rng_key, sample_shape) ) - if raw_dist.has_rsample: raw_value = raw_dist.rsample(*sample_args) else: diff --git a/funsor/integrate.py b/funsor/integrate.py index 677f2eb68..4d4dfd483 100644 --- a/funsor/integrate.py +++ b/funsor/integrate.py @@ -4,9 +4,9 @@ from collections import OrderedDict from typing import Union -import funsor import funsor.ops as ops from funsor.cnf import Contraction, GaussianMixture +from funsor.constant import Constant from funsor.delta import Delta from funsor.gaussian import Gaussian, _mv, _trace_mm, _vv, align_gaussian from funsor.interpretations import eager, normalize @@ -118,7 +118,7 @@ def normalize_integrate_contraction(log_measure, integrand, reduced_vars): if name in integrand_inputs } ) - log_measure = funsor.constant.Constant( + log_measure = Constant( const_inputs, log_measure( **{ @@ -174,7 +174,6 @@ def eager_integrate(delta, integrand, reduced_vars): delta_fresh = frozenset(Variable(k, delta.inputs[k]) for k in delta.fresh) if reduced_vars.isdisjoint(delta_fresh): return None - breakpoint() reduced_names = frozenset(v.name for v in reduced_vars) subs = tuple( (name, point) diff --git a/funsor/sum_product.py b/funsor/sum_product.py index c9c618e6f..3c22775e4 100644 --- a/funsor/sum_product.py +++ b/funsor/sum_product.py @@ -236,8 +236,7 @@ def partial_sum_product( leaf_factors = ordinal_to_factors.pop(leaf) leaf_reduce_vars = ordinal_to_vars[leaf] for (group_factors, group_vars) in _partition(leaf_factors, leaf_reduce_vars): - f = reduce(prod_op, group_factors) - f = f.reduce(sum_op, group_vars) + f = reduce(prod_op, group_factors).reduce(sum_op, group_vars) remaining_sum_vars = sum_vars.intersection(f.inputs) if not remaining_sum_vars: results.append(f.reduce(prod_op, leaf & eliminate)) @@ -964,5 +963,4 @@ def eager_markov_product(sum_op, prod_op, trans, time, step, step_names): else: raise NotImplementedError("https://github.com/pyro-ppl/funsor/issues/233") - with funsor.terms.eager: - return Subs(result, step_names) + return Subs(result, step_names) From b7680f9a6d4798c94f5f6300c1535ce0bca8707e Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 17 Aug 2021 16:51:33 -0400 Subject: [PATCH 17/22] fixes --- funsor/__init__.py | 2 ++ funsor/adjoint.py | 60 +++++++++++++++++++------------------- funsor/constant.py | 12 ++------ funsor/delta.py | 17 +++++++---- funsor/torch/provenance.py | 15 ++++++++-- 5 files changed, 59 insertions(+), 47 deletions(-) diff --git a/funsor/__init__.py b/funsor/__init__.py index f145d7a02..f0c4496f4 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -1,6 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from funsor.constant import Constant from funsor.domains import Array, Bint, Domain, Real, Reals, bint, find_domain, reals from funsor.factory import make_funsor from funsor.integrate import Integrate @@ -55,6 +56,7 @@ "Array", "Bint", "Cat", + "Constant", "Domain", "Funsor", "Independent", diff --git a/funsor/adjoint.py b/funsor/adjoint.py index 596034120..bdae045af 100644 --- a/funsor/adjoint.py +++ b/funsor/adjoint.py @@ -58,8 +58,8 @@ def interpret(self, cls, *args): ) for arg in args ] - # with self._old_interpretation: - # self._eager_to_lazy[result] = reflect.interpret(cls, *lazy_args) + with self._old_interpretation: + self._eager_to_lazy[result] = reflect.interpret(cls, *lazy_args) return result def __enter__(self): @@ -84,34 +84,34 @@ def adjoint(self, sum_op, bin_op, root, targets=None): continue # reverse the effects of alpha-renaming - # with reflect: - # - # lazy_output = self._eager_to_lazy[output] - # lazy_fn = type(lazy_output) - # lazy_inputs = lazy_output._ast_values - # # TODO abstract this into a helper function - # # FIXME make lazy_output linear instead of quadratic in the size of the tape - # lazy_other_subs = tuple( - # (name, to_funsor(name.split("__BOUND")[0], domain)) - # for name, domain in lazy_output.inputs.items() - # if "__BOUND" in name - # ) - # lazy_inputs = _alpha_unmangle( - # substitute(lazy_fn(*lazy_inputs), lazy_other_subs) - # ) - # lazy_output = type(lazy_output)( - # *_alpha_unmangle(substitute(lazy_output, lazy_other_subs)) - # ) - # - # other_subs = tuple( - # (name, to_funsor(name.split("__BOUND")[0], domain)) - # for name, domain in output.inputs.items() - # if "__BOUND" in name - # ) - # inputs = _alpha_unmangle(substitute(fn(*inputs), other_subs)) - # output = type(output)(*_alpha_unmangle(substitute(output, other_subs))) - # - # self._eager_to_lazy[output] = lazy_output + with reflect: + + lazy_output = self._eager_to_lazy[output] + lazy_fn = type(lazy_output) + lazy_inputs = lazy_output._ast_values + # TODO abstract this into a helper function + # FIXME make lazy_output linear instead of quadratic in the size of the tape + lazy_other_subs = tuple( + (name, to_funsor(name.split("__BOUND")[0], domain)) + for name, domain in lazy_output.inputs.items() + if "__BOUND" in name + ) + lazy_inputs = _alpha_unmangle( + substitute(lazy_fn(*lazy_inputs), lazy_other_subs) + ) + lazy_output = type(lazy_output)( + *_alpha_unmangle(substitute(lazy_output, lazy_other_subs)) + ) + + other_subs = tuple( + (name, to_funsor(name.split("__BOUND")[0], domain)) + for name, domain in output.inputs.items() + if "__BOUND" in name + ) + inputs = _alpha_unmangle(substitute(fn(*inputs), other_subs)) + output = type(output)(*_alpha_unmangle(substitute(output, other_subs))) + + self._eager_to_lazy[output] = lazy_output in_adjs = adjoint_ops(fn, sum_op, bin_op, adjoint_values[output], *inputs) for v, adjv in in_adjs: diff --git a/funsor/constant.py b/funsor/constant.py index bfdf0a054..10a418b15 100644 --- a/funsor/constant.py +++ b/funsor/constant.py @@ -103,16 +103,10 @@ def eager_reduce(self, op, reduced_vars): @eager.register(Binary, BinaryOp, Constant, Constant) def eager_binary_constant_constant(op, lhs, rhs): - const_inputs = OrderedDict( - (k, v) - for k, v in lhs.const_inputs.items() - if k not in frozenset(rhs.const_inputs) - frozenset(lhs.const_inputs) - ) - const_inputs.update( - (k, v) - for k, v in rhs.const_inputs.items() - if k not in frozenset(lhs.const_inputs) - frozenset(rhs.const_inputs) + const_vars = ( + (lhs.const_vars | rhs.const_vars) - lhs.arg.input_vars - rhs.arg.input_vars ) + const_inputs = OrderedDict((v.name, v.output) for v in const_vars) if const_inputs: return Constant(const_inputs, op(lhs.arg, rhs.arg)) return op(lhs.arg, rhs.arg) diff --git a/funsor/delta.py b/funsor/delta.py index 62e35ff48..8f93fa904 100644 --- a/funsor/delta.py +++ b/funsor/delta.py @@ -141,7 +141,8 @@ def eager_subs(self, subs): new_terms[value.name] = new_terms.pop(name) continue - if value.input_vars == terms[name][0].input_vars: + var_diff = value.input_vars ^ terms[name][0].input_vars + if not any(d.output.dtype == "real" for d in var_diff): point, point_log_density = new_terms.pop(name) log_density += (value == point).all().log() + point_log_density continue @@ -154,9 +155,13 @@ def eager_subs(self, subs): old_point, old_point_density = new_terms.pop(name) new_terms[new_name] = (new_point, old_point_density + point_log_density) - return ( - Delta(tuple(new_terms.items())) + log_density if new_terms else log_density - ) + if new_terms: + return ( + Delta(tuple(new_terms.items())) + log_density + if log_density is not Number(0) + else Delta(tuple(new_terms.items())) + ) + return log_density def eager_reduce(self, op, reduced_vars): assert reduced_vars.issubset(self.inputs) @@ -215,7 +220,7 @@ def eager_add_multidelta(op, lhs, rhs): @eager.register(Binary, (AddOp, SubOp), Delta, (Funsor, Align)) def eager_add_delta_funsor(op, lhs, rhs): if lhs.fresh.intersection(rhs.inputs): - if isinstance(rhs, funsor.constant.Constant): + if isinstance(rhs, funsor.Constant): return funsor.constant.eager_binary_tensor_constant(op, lhs, rhs) rhs = rhs( **{ @@ -232,7 +237,7 @@ def eager_add_delta_funsor(op, lhs, rhs): @eager.register(Binary, AddOp, (Funsor, Align), Delta) def eager_add_funsor_delta(op, lhs, rhs): if rhs.fresh.intersection(lhs.inputs): - if isinstance(lhs, funsor.constant.Constant): + if isinstance(lhs, funsor.Constant): return funsor.constant.eager_binary_constant_tensor(op, lhs, rhs) lhs = lhs( **{ diff --git a/funsor/torch/provenance.py b/funsor/torch/provenance.py index c7541ebdf..56b922acf 100644 --- a/funsor/torch/provenance.py +++ b/funsor/torch/provenance.py @@ -32,15 +32,26 @@ def __torch_function__(self, func, types, args=(), kwargs=None): if isinstance(arg, ProvenanceTensor): provenance |= arg._provenance _args.append(arg._t) + elif isinstance(arg, tuple): + _arg = [] + for a in arg: + if isinstance(a, ProvenanceTensor): + provenance |= a._provenance + _arg.append(a._t) + else: + _arg.append(a) + _args.append(tuple(_arg)) else: _args.append(arg) ret = func(*_args, **kwargs) if isinstance(ret, torch.Tensor): - return ProvenanceTensor(ret, provenance=provenance) + if provenance: + return ProvenanceTensor(ret, provenance=provenance) + return ret if isinstance(ret, tuple): _ret = [] for r in ret: - if isinstance(r, torch.Tensor): + if isinstance(r, torch.Tensor) and provenance: _ret.append(ProvenanceTensor(r, provenance=provenance)) else: _ret.append(r) From 2acb8b37d8eaa4c081c2f2018960bc5dcfbb736c Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 17 Aug 2021 17:00:41 -0400 Subject: [PATCH 18/22] revert extra changes --- funsor/__init__.py | 4 - funsor/constant.py | 150 -------------------------------------- funsor/delta.py | 22 ++---- funsor/distribution.py | 6 +- funsor/integrate.py | 19 ----- test/test_distribution.py | 2 +- 6 files changed, 11 insertions(+), 192 deletions(-) delete mode 100644 funsor/constant.py diff --git a/funsor/__init__.py b/funsor/__init__.py index f0c4496f4..6bb086507 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -1,7 +1,6 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from funsor.constant import Constant from funsor.domains import Array, Bint, Domain, Real, Reals, bint, find_domain, reals from funsor.factory import make_funsor from funsor.integrate import Integrate @@ -31,7 +30,6 @@ affine, approximations, cnf, - constant, delta, distribution, domains, @@ -56,7 +54,6 @@ "Array", "Bint", "Cat", - "Constant", "Domain", "Funsor", "Independent", @@ -77,7 +74,6 @@ "backward", "bint", "cnf", - "constant", "delta", "distribution", "domains", diff --git a/funsor/constant.py b/funsor/constant.py deleted file mode 100644 index 10a418b15..000000000 --- a/funsor/constant.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -from collections import OrderedDict - -from funsor.tensor import Tensor -from funsor.terms import ( - Binary, - Funsor, - FunsorMeta, - Number, - Unary, - Variable, - eager, - to_data, - to_funsor, -) -from funsor.torch.provenance import ProvenanceTensor - -from .ops import BinaryOp, UnaryOp - - -class ConstantMeta(FunsorMeta): - """ - Wrapper to convert ``const_inputs`` to a tuple. - """ - - def __call__(cls, const_inputs, arg): - if isinstance(const_inputs, dict): - const_inputs = tuple(const_inputs.items()) - - return super(ConstantMeta, cls).__call__(const_inputs, arg) - - -class Constant(Funsor, metaclass=ConstantMeta): - """ - Constant funsor wrt to multiple variables (``const_inputs``). - - This can be used for provenance tracking. - - ``const_inputs`` are ignored (removed) under - substition/reduction/binary operations:: - - a = Constant(OrderedDict(x=Real, y=Bint[3]), arg) - assert a.reduce(ops.add, "x") is Constant(OrderedDict(y=Bint[3]), arg) - assert a(y=1) is Constant(OrderedDict(x=Real), arg) - assert a(x=0, y=1) is arg - - c = Normal(0, 1, value="x") - assert (a + c) is Constant(OrderedDict(y=Bint[3]), arg + c) - - d = Tensor(torch.tensor([1, 2, 3]))["y"] - assert (a + d) is Constant(OrderedDict(x=Real), arg + d) - - :param dict const_inputs: A mapping from input name (str) to datatype (``funsor.domain.Domain``). - :param funsor arg: A funsor that is constant wrt to const_inputs. - """ - - def __init__(self, const_inputs, arg): - assert isinstance(arg, Funsor) - assert isinstance(const_inputs, tuple) - const_inputs = OrderedDict(const_inputs) - assert set(const_inputs).isdisjoint(arg.inputs) - inputs = const_inputs.copy() - inputs.update(arg.inputs) - output = arg.output - fresh = frozenset(const_inputs) - bound = {} - super(Constant, self).__init__(inputs, output, fresh, bound) - self.arg = arg - self.const_vars = frozenset(Variable(k, v) for k, v in const_inputs.items()) - self.const_inputs = const_inputs - - def eager_subs(self, subs): - assert isinstance(subs, tuple) - subs = OrderedDict((k, v) for k, v in subs) - const_inputs = OrderedDict() - for k, d in self.const_inputs.items(): - # handle when subs is in self.arg.inputs - if k in subs: - v = subs[k] - if isinstance(v, Variable): - del subs[k] - k = v.name - const_inputs[k] = d - else: - const_inputs[k] = d - if const_inputs: - return Constant(const_inputs, self.arg) - return self.arg - - def eager_reduce(self, op, reduced_vars): - assert reduced_vars.issubset(self.inputs) - const_inputs = OrderedDict( - (k, v) for k, v in self.const_inputs.items() if k not in reduced_vars - ) - reduced_vars = reduced_vars - frozenset(self.const_inputs) - reduced_arg = self.arg.reduce(op, reduced_vars) - if const_inputs: - return Constant(const_inputs, reduced_arg) - return reduced_arg - - -@eager.register(Binary, BinaryOp, Constant, Constant) -def eager_binary_constant_constant(op, lhs, rhs): - const_vars = ( - (lhs.const_vars | rhs.const_vars) - lhs.arg.input_vars - rhs.arg.input_vars - ) - const_inputs = OrderedDict((v.name, v.output) for v in const_vars) - if const_inputs: - return Constant(const_inputs, op(lhs.arg, rhs.arg)) - return op(lhs.arg, rhs.arg) - - -@eager.register(Binary, BinaryOp, Constant, (Number, Tensor)) -def eager_binary_constant_tensor(op, lhs, rhs): - const_inputs = OrderedDict( - (k, v) for k, v in lhs.const_inputs.items() if k not in rhs.inputs - ) - if const_inputs: - return Constant(const_inputs, op(lhs.arg, rhs)) - return op(lhs.arg, rhs) - - -@eager.register(Binary, BinaryOp, (Number, Tensor), Constant) -def eager_binary_tensor_constant(op, lhs, rhs): - const_inputs = OrderedDict( - (k, v) for k, v in rhs.const_inputs.items() if k not in lhs.inputs - ) - if const_inputs: - return Constant(const_inputs, op(lhs, rhs.arg)) - return op(lhs, rhs.arg) - - -@eager.register(Unary, UnaryOp, Constant) -def eager_unary(op, arg): - return Constant(arg.const_inputs, op(arg.arg)) - - -@to_data.register(Constant) -def constant_to_data(x, name_to_dim=None): - data = to_data(x.arg, name_to_dim=name_to_dim) - return ProvenanceTensor(data, provenance=frozenset(x.const_inputs.items())) - - -@to_funsor.register(ProvenanceTensor) -def provenance_to_funsor(x, output=None, dim_to_name=None): - if isinstance(x, ProvenanceTensor): - ret = to_funsor(x._t, output=output, dim_to_name=dim_to_name) - return Constant(OrderedDict(x._provenance), ret) diff --git a/funsor/delta.py b/funsor/delta.py index 8f93fa904..ddc3d2960 100644 --- a/funsor/delta.py +++ b/funsor/delta.py @@ -3,7 +3,6 @@ from collections import OrderedDict -import funsor from funsor.domains import Domain, Real from funsor.instrument import debug_logged from funsor.ops import AddOp, SubOp, TransformOp @@ -141,8 +140,11 @@ def eager_subs(self, subs): new_terms[value.name] = new_terms.pop(name) continue - var_diff = value.input_vars ^ terms[name][0].input_vars - if not any(d.output.dtype == "real" for d in var_diff): + if not any( + d.dtype == "real" + for side in (value, terms[name][0]) + for d in side.inputs.values() + ): point, point_log_density = new_terms.pop(name) log_density += (value == point).all().log() + point_log_density continue @@ -155,13 +157,9 @@ def eager_subs(self, subs): old_point, old_point_density = new_terms.pop(name) new_terms[new_name] = (new_point, old_point_density + point_log_density) - if new_terms: - return ( - Delta(tuple(new_terms.items())) + log_density - if log_density is not Number(0) - else Delta(tuple(new_terms.items())) - ) - return log_density + return ( + Delta(tuple(new_terms.items())) + log_density if new_terms else log_density + ) def eager_reduce(self, op, reduced_vars): assert reduced_vars.issubset(self.inputs) @@ -220,8 +218,6 @@ def eager_add_multidelta(op, lhs, rhs): @eager.register(Binary, (AddOp, SubOp), Delta, (Funsor, Align)) def eager_add_delta_funsor(op, lhs, rhs): if lhs.fresh.intersection(rhs.inputs): - if isinstance(rhs, funsor.Constant): - return funsor.constant.eager_binary_tensor_constant(op, lhs, rhs) rhs = rhs( **{ name: point @@ -237,8 +233,6 @@ def eager_add_delta_funsor(op, lhs, rhs): @eager.register(Binary, AddOp, (Funsor, Align), Delta) def eager_add_funsor_delta(op, lhs, rhs): if rhs.fresh.intersection(lhs.inputs): - if isinstance(lhs, funsor.Constant): - return funsor.constant.eager_binary_constant_tensor(op, lhs, rhs) lhs = lhs( **{ name: point diff --git a/funsor/distribution.py b/funsor/distribution.py index 5d6d0b6c4..6a1e19397 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -12,7 +12,6 @@ import makefun -import funsor import funsor.delta import funsor.ops as ops from funsor.affine import is_affine @@ -233,6 +232,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): tuple(sample_inputs) + tuple(inp for inp in self.inputs if inp in funsor_value.inputs) ) + result = funsor.delta.Delta(value_name, funsor_value) if not raw_dist.has_rsample: # scaling of dice_factor by num samples should already be handled by Funsor.sample raw_log_prob = raw_dist.log_prob(raw_value) @@ -241,9 +241,7 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): output=self.output, dim_to_name=dim_to_name, ) - result = funsor.delta.Delta(value_name, funsor_value, dice_factor) - else: - result = funsor.delta.Delta(value_name, funsor_value) + result = result + dice_factor return result def enumerate_support(self, expand=False): diff --git a/funsor/integrate.py b/funsor/integrate.py index 4d4dfd483..f20879a04 100644 --- a/funsor/integrate.py +++ b/funsor/integrate.py @@ -6,7 +6,6 @@ import funsor.ops as ops from funsor.cnf import Contraction, GaussianMixture -from funsor.constant import Constant from funsor.delta import Delta from funsor.gaussian import Gaussian, _mv, _trace_mm, _vv, align_gaussian from funsor.interpretations import eager, normalize @@ -103,7 +102,6 @@ def normalize_integrate_contraction(log_measure, integrand, reduced_vars): and t.fresh.intersection(reduced_names, integrand.inputs) ] for delta in delta_terms: - integrand_inputs = integrand.inputs integrand = integrand( **{ name: point @@ -111,23 +109,6 @@ def normalize_integrate_contraction(log_measure, integrand, reduced_vars): if name in reduced_names.intersection(integrand.inputs) } ) - const_inputs = OrderedDict( - { - name: point.output - for name, (point, log_density) in delta.terms - if name in integrand_inputs - } - ) - log_measure = Constant( - const_inputs, - log_measure( - **{ - name: point - for name, (point, log_density) in delta.terms - if name in integrand_inputs - } - ), - ) return normalize_integrate(log_measure, integrand, reduced_vars) diff --git a/test/test_distribution.py b/test/test_distribution.py index 085933ec2..5fa75b463 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -1427,7 +1427,7 @@ def test_categorical_event_dim_conversion(batch_shape, event_shape): name_to_dim = {batch_dim: -1 - i for i, batch_dim in enumerate(batch_dims)} rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) - data = actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0][1][0] + data = actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0].terms[0][1][0] actual_log_prob = funsor.to_data(actual(value=data), name_to_dim=name_to_dim) expected_log_prob = funsor.to_data(actual, name_to_dim=name_to_dim).log_prob( From a6dc56e6d301924e2ba819c4355c953285fcca47 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 17 Aug 2021 17:23:29 -0400 Subject: [PATCH 19/22] add indexing tests --- test/test_provenance.py | 90 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 2 deletions(-) diff --git a/test/test_provenance.py b/test/test_provenance.py index 5aa36c134..07619fcbf 100644 --- a/test/test_provenance.py +++ b/test/test_provenance.py @@ -3,10 +3,31 @@ import pytest import torch +from pyro.ops.indexing import Vindex from funsor.torch.provenance import ProvenanceTensor +@pytest.mark.parametrize( + "op", + ["log", "exp", "long"], +) +@pytest.mark.parametrize( + "data,provenance", + [ + (torch.tensor([1]), frozenset({"a", "b"})), + (torch.tensor([1]), frozenset({"a"})), + ], +) +def test_unary(op, data, provenance): + if provenance is not None: + data = ProvenanceTensor(data, provenance) + + expected = provenance + actual = getattr(data, op)()._provenance + assert actual == expected + + @pytest.mark.parametrize( "data1,provenance1", [ @@ -21,7 +42,7 @@ (2, None), ], ) -def test_provenance(data1, provenance1, data2, provenance2): +def test_binary_add(data1, provenance1, data2, provenance2): if provenance1 is not None: data1 = ProvenanceTensor(data1, provenance1) if provenance2 is not None: @@ -30,5 +51,70 @@ def test_provenance(data1, provenance1, data2, provenance2): expected = frozenset.union( *[m for m in (provenance1, provenance2) if m is not None] ) - actual = torch.add(data1, data2)._provenance + actual = (data1 + data2)._provenance + assert actual == expected + + +@pytest.mark.parametrize( + "data1,provenance1", + [ + (torch.tensor([0, 1]), frozenset({"a"})), + (torch.tensor([0, 1]), None), + ], +) +@pytest.mark.parametrize( + "data2,provenance2", + [ + (torch.tensor([0]), frozenset({"b"})), + (torch.tensor([1]), None), + ], +) +def test_indexing(data1, provenance1, data2, provenance2): + if provenance1 is not None: + data1 = ProvenanceTensor(data1, provenance1) + if provenance2 is not None: + data2 = ProvenanceTensor(data2, provenance2) + + expected = frozenset().union( + *[m for m in (provenance1, provenance2) if m is not None] + ) + actual = getattr(data1[data2], "_provenance", frozenset()) + assert actual == expected + + +@pytest.mark.parametrize( + "data1,provenance1", + [ + (torch.tensor([[0, 1], [2, 3]]), frozenset({"a"})), + (torch.tensor([[0, 1], [2, 3]]), None), + ], +) +@pytest.mark.parametrize( + "data2,provenance2", + [ + (torch.tensor([0.0, 1.0]), frozenset({"b"})), + (torch.tensor([0.0, 1.0]), None), + ], +) +@pytest.mark.parametrize( + "data3,provenance3", + [ + (torch.tensor([0, 1]), frozenset({"c"})), + (torch.tensor([0, 1]), None), + ], +) +def test_vindex(data1, provenance1, data2, provenance2, data3, provenance3): + if provenance1 is not None: + data1 = ProvenanceTensor(data1, provenance1) + if provenance2 is not None: + data2 = ProvenanceTensor(data2, provenance2) + if provenance3 is not None: + data3 = ProvenanceTensor(data3, provenance3) + + expected = frozenset().union( + *[m for m in (provenance1, provenance2, provenance3) if m is not None] + ) + actual = getattr( + Vindex(data1)[data2.long().unsqueeze(-1), data3], "_provenance", frozenset() + ) assert actual == expected From aceb7d2a16a2fcdb60b759f527197c714eb5cbe6 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 17 Aug 2021 17:50:24 -0400 Subject: [PATCH 20/22] remove Vindex test --- funsor/torch/provenance.py | 29 +++++++++++++++++---------- test/test_provenance.py | 41 +------------------------------------- 2 files changed, 19 insertions(+), 51 deletions(-) diff --git a/funsor/torch/provenance.py b/funsor/torch/provenance.py index 56b922acf..5b877d77e 100644 --- a/funsor/torch/provenance.py +++ b/funsor/torch/provenance.py @@ -5,6 +5,12 @@ class ProvenanceTensor(torch.Tensor): + """ + Provenance tracking implementation in Pytorch. + + Provenance of the output tensor is the union of provenances of input tensors. + """ + def __new__(cls, data, provenance=frozenset(), **kwargs): if not provenance: return data @@ -26,7 +32,9 @@ def __repr__(self): def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} + # collect provenance information provenance = frozenset() + # extract tensor data ._t from ProvenanceTensor args _args = [] for arg in args: if isinstance(arg, ProvenanceTensor): @@ -44,16 +52,15 @@ def __torch_function__(self, func, types, args=(), kwargs=None): else: _args.append(arg) ret = func(*_args, **kwargs) - if isinstance(ret, torch.Tensor): - if provenance: + if provenance: + if isinstance(ret, torch.Tensor): return ProvenanceTensor(ret, provenance=provenance) - return ret - if isinstance(ret, tuple): - _ret = [] - for r in ret: - if isinstance(r, torch.Tensor) and provenance: - _ret.append(ProvenanceTensor(r, provenance=provenance)) - else: - _ret.append(r) - return tuple(_ret) + if isinstance(ret, tuple): + _ret = [] + for r in ret: + if isinstance(r, torch.Tensor): + _ret.append(ProvenanceTensor(r, provenance=provenance)) + else: + _ret.append(r) + return tuple(_ret) return ret diff --git a/test/test_provenance.py b/test/test_provenance.py index 07619fcbf..86f46a7d5 100644 --- a/test/test_provenance.py +++ b/test/test_provenance.py @@ -3,7 +3,6 @@ import pytest import torch -from pyro.ops.indexing import Vindex from funsor.torch.provenance import ProvenanceTensor @@ -48,7 +47,7 @@ def test_binary_add(data1, provenance1, data2, provenance2): if provenance2 is not None: data2 = ProvenanceTensor(data2, provenance2) - expected = frozenset.union( + expected = frozenset().union( *[m for m in (provenance1, provenance2) if m is not None] ) actual = (data1 + data2)._provenance @@ -80,41 +79,3 @@ def test_indexing(data1, provenance1, data2, provenance2): ) actual = getattr(data1[data2], "_provenance", frozenset()) assert actual == expected - - -@pytest.mark.parametrize( - "data1,provenance1", - [ - (torch.tensor([[0, 1], [2, 3]]), frozenset({"a"})), - (torch.tensor([[0, 1], [2, 3]]), None), - ], -) -@pytest.mark.parametrize( - "data2,provenance2", - [ - (torch.tensor([0.0, 1.0]), frozenset({"b"})), - (torch.tensor([0.0, 1.0]), None), - ], -) -@pytest.mark.parametrize( - "data3,provenance3", - [ - (torch.tensor([0, 1]), frozenset({"c"})), - (torch.tensor([0, 1]), None), - ], -) -def test_vindex(data1, provenance1, data2, provenance2, data3, provenance3): - if provenance1 is not None: - data1 = ProvenanceTensor(data1, provenance1) - if provenance2 is not None: - data2 = ProvenanceTensor(data2, provenance2) - if provenance3 is not None: - data3 = ProvenanceTensor(data3, provenance3) - - expected = frozenset().union( - *[m for m in (provenance1, provenance2, provenance3) if m is not None] - ) - actual = getattr( - Vindex(data1)[data2.long().unsqueeze(-1), data3], "_provenance", frozenset() - ) - assert actual == expected From 87eacbecafac542e8f14a042f2db5ffff83e7e2b Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 19 Aug 2021 22:19:59 -0400 Subject: [PATCH 21/22] move to test/torch --- Makefile | 7 ++- test/test_provenance.py | 81 -------------------------- test/torch/test_provenance.py | 103 ++++++++++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 83 deletions(-) delete mode 100644 test/test_provenance.py create mode 100644 test/torch/test_provenance.py diff --git a/Makefile b/Makefile index 76cefda72..d4b862cd4 100644 --- a/Makefile +++ b/Makefile @@ -53,13 +53,16 @@ ifeq (${FUNSOR_BACKEND}, torch) python examples/adam.py --num-steps=21 @echo PASS else ifeq (${FUNSOR_BACKEND}, jax) - pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi --ignore=test/test_distribution.py --ignore=test/test_distribution_generic.py + pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi \ + --ignore=test/test_distribution.py --ignore=test/test_distribution_generic.py \ + --ignore=test/torch pytest -v -n auto test/test_distribution.py pytest -v -n auto test/test_distribution_generic.py @echo PASS else # default backend - pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi + pytest -v -n auto --ignore=test/examples --ignore=test/pyro \ + --ignore=test/pyroapi --ignore=test/torch @echo PASS endif diff --git a/test/test_provenance.py b/test/test_provenance.py deleted file mode 100644 index 86f46a7d5..000000000 --- a/test/test_provenance.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -import pytest -import torch - -from funsor.torch.provenance import ProvenanceTensor - - -@pytest.mark.parametrize( - "op", - ["log", "exp", "long"], -) -@pytest.mark.parametrize( - "data,provenance", - [ - (torch.tensor([1]), frozenset({"a", "b"})), - (torch.tensor([1]), frozenset({"a"})), - ], -) -def test_unary(op, data, provenance): - if provenance is not None: - data = ProvenanceTensor(data, provenance) - - expected = provenance - actual = getattr(data, op)()._provenance - assert actual == expected - - -@pytest.mark.parametrize( - "data1,provenance1", - [ - (torch.tensor([1]), frozenset({"a"})), - ], -) -@pytest.mark.parametrize( - "data2,provenance2", - [ - (torch.tensor([2]), frozenset({"b"})), - (torch.tensor([2]), None), - (2, None), - ], -) -def test_binary_add(data1, provenance1, data2, provenance2): - if provenance1 is not None: - data1 = ProvenanceTensor(data1, provenance1) - if provenance2 is not None: - data2 = ProvenanceTensor(data2, provenance2) - - expected = frozenset().union( - *[m for m in (provenance1, provenance2) if m is not None] - ) - actual = (data1 + data2)._provenance - assert actual == expected - - -@pytest.mark.parametrize( - "data1,provenance1", - [ - (torch.tensor([0, 1]), frozenset({"a"})), - (torch.tensor([0, 1]), None), - ], -) -@pytest.mark.parametrize( - "data2,provenance2", - [ - (torch.tensor([0]), frozenset({"b"})), - (torch.tensor([1]), None), - ], -) -def test_indexing(data1, provenance1, data2, provenance2): - if provenance1 is not None: - data1 = ProvenanceTensor(data1, provenance1) - if provenance2 is not None: - data2 = ProvenanceTensor(data2, provenance2) - - expected = frozenset().union( - *[m for m in (provenance1, provenance2) if m is not None] - ) - actual = getattr(data1[data2], "_provenance", frozenset()) - assert actual == expected diff --git a/test/torch/test_provenance.py b/test/torch/test_provenance.py new file mode 100644 index 000000000..b97a5380d --- /dev/null +++ b/test/torch/test_provenance.py @@ -0,0 +1,103 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from pyro.ops.indexing import Vindex + +from funsor.torch.provenance import ProvenanceTensor + + +@pytest.mark.parametrize("op", ["log", "exp", "long"]) +@pytest.mark.parametrize( + "data,provenance", + [ + (torch.tensor([1]), "ab"), + (torch.tensor([1]), "a"), + ], +) +def test_unary(op, data, provenance): + data = ProvenanceTensor(data, frozenset(provenance)) + + expected = frozenset(provenance) + actual = getattr(data, op)()._provenance + assert actual == expected + + +@pytest.mark.parametrize("data1,provenance1", [(torch.tensor([1]), "a")]) +@pytest.mark.parametrize( + "data2,provenance2", + [ + (torch.tensor([2]), "b"), + (torch.tensor([2]), ""), + (2, ""), + ], +) +def test_binary_add(data1, provenance1, data2, provenance2): + data1 = ProvenanceTensor(data1, frozenset(provenance1)) + if provenance2: + data2 = ProvenanceTensor(data2, frozenset(provenance2)) + + expected = frozenset(provenance1 + provenance2) + actual = torch.add(data1, data2)._provenance + assert actual == expected + + +@pytest.mark.parametrize( + "data1,provenance1", + [ + (torch.tensor([0, 1]), "a"), + (torch.tensor([0, 1]), ""), + ], +) +@pytest.mark.parametrize( + "data2,provenance2", + [ + (torch.tensor([0]), "b"), + (torch.tensor([1]), ""), + ], +) +def test_indexing(data1, provenance1, data2, provenance2): + if provenance1: + data1 = ProvenanceTensor(data1, frozenset(provenance1)) + if provenance2: + data2 = ProvenanceTensor(data2, frozenset(provenance2)) + + expected = frozenset(provenance1 + provenance2) + actual = getattr(data1[data2], "_provenance", frozenset()) + assert actual == expected + + +@pytest.mark.parametrize( + "data1,provenance1", + [ + (torch.tensor([[0, 1], [2, 3]]), "a"), + (torch.tensor([[0, 1], [2, 3]]), ""), + ], +) +@pytest.mark.parametrize( + "data2,provenance2", + [ + (torch.tensor([0.0, 1.0]), "b"), + (torch.tensor([0.0, 1.0]), ""), + ], +) +@pytest.mark.parametrize( + "data3,provenance3", + [ + (torch.tensor([0, 1]), "c"), + (torch.tensor([0, 1]), ""), + ], +) +def test_vindex(data1, provenance1, data2, provenance2, data3, provenance3): + if provenance1: + data1 = ProvenanceTensor(data1, frozenset(provenance1)) + if provenance2: + data2 = ProvenanceTensor(data2, frozenset(provenance2)) + if provenance3: + data3 = ProvenanceTensor(data3, frozenset(provenance3)) + + expected = frozenset(provenance1 + provenance2 + provenance3) + result = Vindex(data1)[data2.long().unsqueeze(-1), data3] + actual = getattr(result, "_provenance", frozenset()) + assert actual == expected From 447c4892658fc20d178837ae30dd5cf1a87f0105 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 19 Aug 2021 22:42:14 -0400 Subject: [PATCH 22/22] simplify logic --- funsor/torch/provenance.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/funsor/torch/provenance.py b/funsor/torch/provenance.py index 5b877d77e..345c0d0e5 100644 --- a/funsor/torch/provenance.py +++ b/funsor/torch/provenance.py @@ -32,9 +32,9 @@ def __repr__(self): def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - # collect provenance information + # collect provenance information from args provenance = frozenset() - # extract tensor data ._t from ProvenanceTensor args + # extract ProvenanceTensor._t data from args _args = [] for arg in args: if isinstance(arg, ProvenanceTensor): @@ -52,15 +52,14 @@ def __torch_function__(self, func, types, args=(), kwargs=None): else: _args.append(arg) ret = func(*_args, **kwargs) - if provenance: - if isinstance(ret, torch.Tensor): - return ProvenanceTensor(ret, provenance=provenance) - if isinstance(ret, tuple): - _ret = [] - for r in ret: - if isinstance(r, torch.Tensor): - _ret.append(ProvenanceTensor(r, provenance=provenance)) - else: - _ret.append(r) - return tuple(_ret) + if isinstance(ret, torch.Tensor): + return ProvenanceTensor(ret, provenance=provenance) + if isinstance(ret, tuple): + _ret = [] + for r in ret: + if isinstance(r, torch.Tensor): + _ret.append(ProvenanceTensor(r, provenance=provenance)) + else: + _ret.append(r) + return tuple(_ret) return ret