Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Funsor (Constant) that is constant wrt to input variables #547

Closed
ordabayevy opened this issue Aug 25, 2021 · 1 comment
Closed

Add Funsor (Constant) that is constant wrt to input variables #547

ordabayevy opened this issue Aug 25, 2021 · 1 comment
Labels
discussion enhancement New feature or request

Comments

@ordabayevy
Copy link
Member

ordabayevy commented Aug 25, 2021

Sometimes it is useful to keep track of certain Funsor inputs even though the underlying Funsor is constant wrt to those inputs. Examples are Zero terms in Delta.eager_reduce and In contrib.funsor.TraceEnum_ELBO. This is currently implemented as expanded Tensor terms which is memory/computation inefficient and only allows Bint input type. Proposed Constant funsor has following benefits:

  1. Declare constant inputs of a Funsor in a memory efficient way: e.g. Constant(cost.inputs, Number(0))
  2. Allow variables with Real output type: e.g. Constant(OrderedDict(x=Bint[3], y=Real), Number(0))
  3. Define computationally efficient patterns for subs/reduce/unary/binary operations
  4. Convert to and from ProvenanceTensor (ProvenanceTensor #543) where provenance of the tensor corresponds to const_inputs. Edit: I don’t think this is correct, the counterpart of ProvenanceTensor should something similar to Delta.
# Substitution

# 0(x, y)
a = Constant(OrderedDict(x=Bint[3], y=Real), Number(0))
# 0(x=0, y) = 0(y)
assert a(x=0) is Constant(OrderedDict(y=Real), Number(0))
# 0(x=0, y=3.14) = 0
assert a(x=0, y=3.14) is Number(0)

# Binary ops

# 1(x, y)
a = Constant(OrderedDict(x=Bint[3], y=Real), Number(1))
# b(x)
b = Tensor(...)["x"]
# 1(x, y) *  b(x) = (1 * b(x))(y) = b(x)(y)
c = a * b  # returns Constant(OrderedDict(y=Real), b*Number(1))

# Reduction

# 1(x, y)
a = Constant(OrderedDict(x=Bint[3], y=Real), Number(1))
# 1(x=0, y) + 1(x=1, y) + 1(x=2, y) = 3(y)
a.reduce(ops.add, "x")  # returns Constant(OrderedDict(y=Real), Number(3))

The argument in Constant(const_inputs, arg) doesn't have to be a Number. It can be any Funsor with the condition that const_inputs and arg.inputs have to be disjoint.

The main motivation is to use Constant as a wrapper for ProvenanceTensor and as targets for log_measures in a general version of contrib.fusor.Trace_ELBO (pyro-ppl/pyro#2893).

@ordabayevy ordabayevy added enhancement New feature or request discussion labels Aug 25, 2021
@eb8680
Copy link
Member

eb8680 commented Sep 23, 2021

Added in #548

@eb8680 eb8680 closed this as completed Sep 23, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants