-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ProvenanceTensor #543
ProvenanceTensor #543
Changes from all commits
ec6fe58
de39abf
ac14606
4a480d8
58ae920
5eec3e0
1b754ca
fb5d30a
43c17fb
b3bb934
cc86ee3
59a1baa
9c2fcb2
9d15024
9481077
00c5051
4d4d90c
b1f3488
b7680f9
85917b5
2acb8b3
a6dc56e
aceb7d2
87eacbe
447c489
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
|
||
|
||
class ProvenanceTensor(torch.Tensor): | ||
""" | ||
Provenance tracking implementation in Pytorch. | ||
|
||
Provenance of the output tensor is the union of provenances of input tensors. | ||
""" | ||
|
||
def __new__(cls, data, provenance=frozenset(), **kwargs): | ||
if not provenance: | ||
return data | ||
instance = torch.Tensor.__new__(cls) | ||
instance.__init__(data, provenance) | ||
return instance | ||
|
||
def __init__(self, data, provenance=frozenset()): | ||
assert isinstance(provenance, frozenset) | ||
if isinstance(data, ProvenanceTensor): | ||
provenance |= data._provenance | ||
data = data._t | ||
self._t = data | ||
self._provenance = provenance | ||
|
||
def __repr__(self): | ||
return "Provenance:\n{}\nTensor:\n{}".format(self._provenance, self._t) | ||
|
||
def __torch_function__(self, func, types, args=(), kwargs=None): | ||
if kwargs is None: | ||
kwargs = {} | ||
# collect provenance information from args | ||
provenance = frozenset() | ||
# extract ProvenanceTensor._t data from args | ||
_args = [] | ||
for arg in args: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic is a bit convoluted. Maybe it could be simplified with some of the helpers in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That might be useful, I look more into |
||
if isinstance(arg, ProvenanceTensor): | ||
provenance |= arg._provenance | ||
_args.append(arg._t) | ||
elif isinstance(arg, tuple): | ||
_arg = [] | ||
for a in arg: | ||
if isinstance(a, ProvenanceTensor): | ||
provenance |= a._provenance | ||
_arg.append(a._t) | ||
else: | ||
_arg.append(a) | ||
_args.append(tuple(_arg)) | ||
else: | ||
_args.append(arg) | ||
ret = func(*_args, **kwargs) | ||
if isinstance(ret, torch.Tensor): | ||
return ProvenanceTensor(ret, provenance=provenance) | ||
if isinstance(ret, tuple): | ||
_ret = [] | ||
for r in ret: | ||
if isinstance(r, torch.Tensor): | ||
_ret.append(ProvenanceTensor(r, provenance=provenance)) | ||
else: | ||
_ret.append(r) | ||
return tuple(_ret) | ||
return ret |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
import torch | ||
from pyro.ops.indexing import Vindex | ||
|
||
from funsor.torch.provenance import ProvenanceTensor | ||
|
||
|
||
@pytest.mark.parametrize("op", ["log", "exp", "long"]) | ||
@pytest.mark.parametrize( | ||
"data,provenance", | ||
[ | ||
(torch.tensor([1]), "ab"), | ||
(torch.tensor([1]), "a"), | ||
], | ||
) | ||
def test_unary(op, data, provenance): | ||
data = ProvenanceTensor(data, frozenset(provenance)) | ||
|
||
expected = frozenset(provenance) | ||
actual = getattr(data, op)()._provenance | ||
assert actual == expected | ||
|
||
|
||
@pytest.mark.parametrize("data1,provenance1", [(torch.tensor([1]), "a")]) | ||
@pytest.mark.parametrize( | ||
"data2,provenance2", | ||
[ | ||
(torch.tensor([2]), "b"), | ||
(torch.tensor([2]), ""), | ||
(2, ""), | ||
], | ||
) | ||
def test_binary_add(data1, provenance1, data2, provenance2): | ||
data1 = ProvenanceTensor(data1, frozenset(provenance1)) | ||
if provenance2: | ||
data2 = ProvenanceTensor(data2, frozenset(provenance2)) | ||
|
||
expected = frozenset(provenance1 + provenance2) | ||
actual = torch.add(data1, data2)._provenance | ||
assert actual == expected | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"data1,provenance1", | ||
[ | ||
(torch.tensor([0, 1]), "a"), | ||
(torch.tensor([0, 1]), ""), | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"data2,provenance2", | ||
[ | ||
(torch.tensor([0]), "b"), | ||
(torch.tensor([1]), ""), | ||
], | ||
) | ||
def test_indexing(data1, provenance1, data2, provenance2): | ||
if provenance1: | ||
data1 = ProvenanceTensor(data1, frozenset(provenance1)) | ||
if provenance2: | ||
data2 = ProvenanceTensor(data2, frozenset(provenance2)) | ||
|
||
expected = frozenset(provenance1 + provenance2) | ||
actual = getattr(data1[data2], "_provenance", frozenset()) | ||
assert actual == expected | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"data1,provenance1", | ||
[ | ||
(torch.tensor([[0, 1], [2, 3]]), "a"), | ||
(torch.tensor([[0, 1], [2, 3]]), ""), | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"data2,provenance2", | ||
[ | ||
(torch.tensor([0.0, 1.0]), "b"), | ||
(torch.tensor([0.0, 1.0]), ""), | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"data3,provenance3", | ||
[ | ||
(torch.tensor([0, 1]), "c"), | ||
(torch.tensor([0, 1]), ""), | ||
], | ||
) | ||
def test_vindex(data1, provenance1, data2, provenance2, data3, provenance3): | ||
if provenance1: | ||
data1 = ProvenanceTensor(data1, frozenset(provenance1)) | ||
if provenance2: | ||
data2 = ProvenanceTensor(data2, frozenset(provenance2)) | ||
if provenance3: | ||
data3 = ProvenanceTensor(data3, frozenset(provenance3)) | ||
|
||
expected = frozenset(provenance1 + provenance2 + provenance3) | ||
result = Vindex(data1)[data2.long().unsqueeze(-1), data3] | ||
actual = getattr(result, "_provenance", frozenset()) | ||
assert actual == expected |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ordabayevy now that you've had a chance to play around with
__torch_function__
, I'm curious about whether you think we should add aFunsor.__torch_function__
method and attempt to use it in Pyro more directly in lieu of the combination ofProvenanceTensor
andto_data
/to_funsor
. I opened #546 to discuss.