Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Importance funsor #578

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/funsors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,11 @@ Constant
:undoc-members:
:show-inheritance:
:member-order: bysource

Importance
----------
.. automodule:: funsor.importance
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
2 changes: 2 additions & 0 deletions funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from funsor.constant import Constant
from funsor.domains import Array, Bint, Domain, Real, Reals, bint, find_domain, reals
from funsor.factory import make_funsor
from funsor.importance import Importance
from funsor.integrate import Integrate
from funsor.interpreter import interpretation, reinterpret
from funsor.op_factory import make_op
Expand Down Expand Up @@ -61,6 +62,7 @@
"Constant",
"Domain",
"Funsor",
"Importance",
"Independent",
"Integrate",
"Lambda",
Expand Down
5 changes: 3 additions & 2 deletions funsor/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import reduce

import funsor.ops as ops
from funsor.importance import Importance
from funsor.tensor import Tensor
from funsor.terms import (
Binary,
Expand Down Expand Up @@ -165,7 +166,7 @@ def eager_binary_constant_constant(op, lhs, rhs):
return op(lhs.arg, rhs.arg)


@eager.register(Binary, ops.BinaryOp, Constant, (Number, Tensor))
@eager.register(Binary, ops.BinaryOp, Constant, (Importance, Number, Tensor))
def eager_binary_constant_tensor(op, lhs, rhs):
const_inputs = OrderedDict(
(k, v) for k, v in lhs.const_inputs.items() if k not in rhs.inputs
Expand All @@ -175,7 +176,7 @@ def eager_binary_constant_tensor(op, lhs, rhs):
return op(lhs.arg, rhs)


@eager.register(Binary, ops.BinaryOp, (Number, Tensor), Constant)
@eager.register(Binary, ops.BinaryOp, (Importance, Number, Tensor), Constant)
def eager_binary_tensor_constant(op, lhs, rhs):
const_inputs = OrderedDict(
(k, v) for k, v in rhs.const_inputs.items() if k not in lhs.inputs
Expand Down
7 changes: 2 additions & 5 deletions funsor/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,8 @@ def eager_subs(self, subs):
new_terms.append((value.name, (point, log_density)))
continue

if not any(
d.dtype == "real"
for side in (value, point)
for d in side.inputs.values()
):
var_diff = value.input_vars ^ point.input_vars
if not any(d.output.dtype == "real" for d in var_diff):
dtype = get_default_dtype()
is_equal = ops.astype((value == point).all(), dtype)
log_densities.append(is_equal.log() + log_density)
Expand Down
9 changes: 6 additions & 3 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,15 @@ def _sample(self, sampled_vars, sample_inputs, rng_key):
if not raw_dist.has_rsample:
# scaling of dice_factor by num samples should already be handled by Funsor.sample
raw_log_prob = raw_dist.log_prob(raw_value)
dice_factor = to_funsor(
raw_log_prob - ops.detach(raw_log_prob),
log_prob = to_funsor(
raw_log_prob,
output=self.output,
dim_to_name=dim_to_name,
)
result = funsor.delta.Delta(value_name, funsor_value, dice_factor)
model = funsor.delta.Delta(value_name, funsor_value, log_prob)
guide = funsor.delta.Delta(value_name, funsor_value, ops.detach(log_prob))
sampled_var = frozenset({Variable(value_name, self.inputs[value_name])})
result = funsor.Importance(model, guide, sampled_var)
else:
result = funsor.delta.Delta(value_name, funsor_value)
return result
Expand Down
57 changes: 57 additions & 0 deletions funsor/importance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0


from funsor.delta import Delta
from funsor.interpretations import DispatchedInterpretation
from funsor.terms import Funsor, eager, reflect


class Importance(Funsor):
"""
Importance sampling for approximating integrals wrt a set of variables.

When the proposal distribution (guide) is Delta then the eager
interpretation is ``Delta + log_importance_weight``.
The user-facing interface is the :meth:`Funsor.approximate` method.

:param Funsor model: A funsor depending on ``sampled_vars``.
:param Funsor guide: A proposal distribution.
:param frozenset sampled_vars: A set of input variables to sample.
"""

def __init__(self, model, guide, sampled_vars):
assert isinstance(model, Funsor)
assert isinstance(guide, Funsor)
assert isinstance(sampled_vars, frozenset), sampled_vars
inputs = model.inputs.copy()
inputs.update(guide.inputs)
output = model.output
super().__init__(inputs, output)
self.model = model
self.guide = guide
self.sampled_vars = sampled_vars

def eager_reduce(self, op, reduced_vars):
assert reduced_vars.issubset(self.inputs)
if not reduced_vars:
return self

return self.model.reduce(op, reduced_vars)


@eager.register(Importance, Funsor, Delta, frozenset)
def eager_importance(model, guide, sampled_vars):
# Delta + log_importance_weight
return guide + model - guide


lazy_importance = DispatchedInterpretation("lazy_importance")
"""
Lazy interpretation of the Importance with a Delta guide.
"""


@lazy_importance.register(Importance, Funsor, Delta, frozenset)
def _lazy_importance(model, guide, sampled_vars):
return reflect.interpret(Importance, model, guide, sampled_vars)
2 changes: 1 addition & 1 deletion funsor/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def eager_integrate(delta, integrand, reduced_vars):
if name in reduced_names
)
new_integrand = Subs(integrand, subs)
new_log_measure = Subs(delta, subs)
new_log_measure = delta.reduce(ops.logaddexp, reduced_names)
result = Integrate(new_log_measure, new_integrand, reduced_vars - delta_fresh)
return result

Expand Down
2 changes: 1 addition & 1 deletion test/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,7 @@ def test_categorical_event_dim_conversion(batch_shape, event_shape):

name_to_dim = {batch_dim: -1 - i for i, batch_dim in enumerate(batch_dims)}
rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32)
data = actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0][1][0]
data = actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0].terms[0][1][0]

actual_log_prob = funsor.to_data(actual(value=data), name_to_dim=name_to_dim)
expected_log_prob = funsor.to_data(actual, name_to_dim=name_to_dim).log_prob(
Expand Down