-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ProvenanceTensor #543
ProvenanceTensor #543
Conversation
@eb8680 is this along the lines what you were suggesting? Is this new tensor type supposed to be wrapped by |
funsor/torch/metadata.py
Outdated
if kwargs is None: | ||
kwargs = {} | ||
meta = frozenset().union( | ||
*tuple(a._metadata for a in args if hasattr(a, "_metadata")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Provenance of the output tensor is the union of provenances of input tensors.
funsor/constant.py
Outdated
return super(ConstantMeta, cls).__call__(const_inputs, arg) | ||
|
||
|
||
class Constant(Funsor, metaclass=ConstantMeta): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting! It would probably be easiest for us to go over this PR and pyro-ppl/pyro#2893 over Zoom, but one thing that would help me beforehand is if you could add a docstring here explaining how Constant
behaves differently from Delta
wrt Reduce/Contraction/Integrate
def __repr__(self): | ||
return "Provenance:\n{}\nTensor:\n{}".format(self._provenance, self._t) | ||
|
||
def __torch_function__(self, func, types, args=(), kwargs=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ordabayevy now that you've had a chance to play around with __torch_function__
, I'm curious about whether you think we should add a Funsor.__torch_function__
method and attempt to use it in Pyro more directly in lieu of the combination of ProvenanceTensor
and to_data
/to_funsor
. I opened #546 to discuss.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implementation seems reasonable, and nicely separated from the rest of the code.
provenance = frozenset() | ||
# extract ProvenanceTensor._t data from args | ||
_args = [] | ||
for arg in args: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is a bit convoluted. Maybe it could be simplified with some of the helpers in torch.overrides
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the helpers in torch.overrides?
That might be useful, I look more into torch.overrides
functionality.
This is an implementation of Provenance Tracking (https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.361.7132&rep=rep1&type=pdf) in Pytorch. The main idea is that provenance of the output tensor is the union of provenances of input tensors.
Tests: