-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ProvenanceTensor #543
Merged
ProvenanceTensor #543
Changes from 14 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
ec6fe58
metadatatensor
de39abf
lint
ac14606
Constant funsor
4a480d8
Merge branch 'metadata' into constant
58ae920
subclass torch.Tensor
5eec3e0
save
1b754ca
subclass torch.Tensor
fb5d30a
rename to ProvenanceTensor
43c17fb
merge provenance
b3bb934
working version
cc86ee3
pass second test
59a1baa
clean up constant.py
9c2fcb2
more clean up
9d15024
rm metadata file
9481077
add Constant docstring
00c5051
add one more example to docstring
4d4d90c
dice factors as importance weights
b1f3488
revert old changes
b7680f9
fixes
85917b5
Merge branch 'master' of https://github.com/pyro-ppl/funsor into meta…
2acb8b3
revert extra changes
a6dc56e
add indexing tests
aceb7d2
remove Vindex test
87eacbe
move to test/torch
447c489
simplify logic
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from collections import OrderedDict | ||
|
||
from funsor.distribution import Distribution | ||
from funsor.tensor import Tensor | ||
from funsor.terms import ( | ||
Binary, | ||
Funsor, | ||
FunsorMeta, | ||
Number, | ||
Unary, | ||
Variable, | ||
eager, | ||
to_data, | ||
to_funsor, | ||
) | ||
from funsor.torch.provenance import ProvenanceTensor | ||
|
||
from .ops import BinaryOp, UnaryOp | ||
|
||
|
||
class ConstantMeta(FunsorMeta): | ||
""" | ||
Wrapper to convert ``const_inputs`` to a tuple. | ||
""" | ||
|
||
def __call__(cls, const_inputs, arg): | ||
if isinstance(const_inputs, dict): | ||
const_inputs = tuple(const_inputs.items()) | ||
|
||
return super(ConstantMeta, cls).__call__(const_inputs, arg) | ||
|
||
|
||
class Constant(Funsor, metaclass=ConstantMeta): | ||
def __init__(self, const_inputs, arg): | ||
assert isinstance(arg, Funsor) | ||
assert isinstance(const_inputs, tuple) | ||
assert set(const_inputs).isdisjoint(arg.inputs) | ||
const_inputs = OrderedDict(const_inputs) | ||
inputs = const_inputs.copy() | ||
inputs.update(arg.inputs) | ||
output = arg.output | ||
fresh = frozenset(const_inputs) | ||
bound = {} | ||
super(Constant, self).__init__(inputs, output, fresh, bound) | ||
self.arg = arg | ||
self.const_vars = frozenset(Variable(k, v) for k, v in const_inputs.items()) | ||
self.const_inputs = const_inputs | ||
|
||
def eager_subs(self, subs): | ||
assert isinstance(subs, tuple) | ||
subs = OrderedDict((k, v) for k, v in subs) | ||
const_inputs = OrderedDict() | ||
for k, d in self.const_inputs.items(): | ||
# handle when subs is in self.arg.inputs | ||
if k in subs: | ||
v = subs[k] | ||
if isinstance(v, Variable): | ||
del subs[k] | ||
k = v.name | ||
const_inputs[k] = d | ||
if const_inputs: | ||
return Constant(const_inputs, self.arg) | ||
return self.arg | ||
|
||
def eager_reduce(self, op, reduced_vars): | ||
assert reduced_vars.issubset(self.inputs) | ||
const_inputs = OrderedDict( | ||
(k, v) for k, v in self.const_inputs.items() if k not in reduced_vars | ||
) | ||
reduced_vars = reduced_vars - frozenset(self.const_inputs) | ||
reduced_arg = self.arg.reduce(op, reduced_vars) | ||
if const_inputs: | ||
return Constant(const_inputs, reduced_arg) | ||
return reduced_arg | ||
|
||
|
||
@eager.register(Binary, BinaryOp, Constant, Constant) | ||
def eager_binary_constant_constant(op, lhs, rhs): | ||
const_inputs = OrderedDict( | ||
(k, v) for k, v in lhs.const_inputs.items() if k not in rhs.const_inputs | ||
) | ||
const_inputs.update( | ||
(k, v) for k, v in rhs.const_inputs.items() if k not in lhs.const_inputs | ||
) | ||
if const_inputs: | ||
return Constant(const_inputs, op(lhs.arg, rhs.arg)) | ||
return op(lhs.arg, rhs.arg) | ||
|
||
|
||
@eager.register(Binary, BinaryOp, Constant, (Number, Tensor, Distribution)) | ||
def eager_binary_constant_tensor(op, lhs, rhs): | ||
const_inputs = OrderedDict( | ||
(k, v) for k, v in lhs.const_inputs.items() if k not in rhs.inputs | ||
) | ||
if const_inputs: | ||
return Constant(const_inputs, op(lhs.arg, rhs)) | ||
return op(lhs.arg, rhs) | ||
|
||
|
||
@eager.register(Binary, BinaryOp, (Number, Tensor, Distribution), Constant) | ||
def eager_binary_tensor_constant(op, lhs, rhs): | ||
const_inputs = OrderedDict( | ||
(k, v) for k, v in rhs.const_inputs.items() if k not in lhs.inputs | ||
) | ||
if const_inputs: | ||
return Constant(const_inputs, op(lhs, rhs.arg)) | ||
return op(lhs, rhs.arg) | ||
|
||
|
||
@eager.register(Unary, UnaryOp, Constant) | ||
def eager_binary_tensor_constant(op, arg): | ||
return Constant(arg.const_inputs, op(arg.arg)) | ||
|
||
|
||
@to_data.register(Constant) | ||
def constant_to_data(x, name_to_dim=None): | ||
data = to_data(x.arg, name_to_dim=name_to_dim) | ||
return ProvenanceTensor(data, provenance=frozenset(x.const_inputs.items())) | ||
|
||
|
||
@to_funsor.register(ProvenanceTensor) | ||
def provenance_to_funsor(x, output=None, dim_to_name=None): | ||
if isinstance(x, ProvenanceTensor): | ||
ret = to_funsor(x._t, output=output, dim_to_name=dim_to_name) | ||
return Constant(OrderedDict(x._provenance), ret) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting! It would probably be easiest for us to go over this PR and pyro-ppl/pyro#2893 over Zoom, but one thing that would help me beforehand is if you could add a docstring here explaining how
Constant
behaves differently fromDelta
wrtReduce/Contraction/Integrate