Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce sampling dependency for auxiliary nodes in SearchForExplanation #578

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 53 additions & 6 deletions chirho/explainable/handlers/explanation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import contextlib
import warnings
from typing import Callable, Mapping, Optional, TypeVar, Union

import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import torch

Expand All @@ -27,6 +29,7 @@ def SplitSubsets(
*,
bias: float = 0.0,
prefix: str = "__cause_split_",
cases: Optional[Mapping[str, torch.Tensor]] = None,
):
"""
A context manager used for a stochastic search of minimal but-for causes among potential interventions.
Expand All @@ -39,14 +42,15 @@ def SplitSubsets(
:param actions: A mapping of sites to interventions.
:param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0.
:param prefix: A prefix used for naming additional preemption nodes. Defaults to ``__cause_split_``.
:param cases: A mapping of sites to their preemption cases (for possible coordination between sites).
"""
preemptions = {
antecedent: undo_split(supports[antecedent], antecedents=[antecedent])
for antecedent in actions.keys()
}

with do(actions=actions):
with Preemptions(actions=preemptions, bias=bias, prefix=prefix):
with Preemptions(actions=preemptions, bias=bias, prefix=prefix, cases=cases):
yield


Expand All @@ -66,13 +70,15 @@ def SearchForExplanation(
antecedent_bias: float = 0.0,
witness_bias: float = 0.0,
prefix: str = "__cause__",
num_samples: Optional[int] = None,
sampling_dim: Optional[int] = None,
):
"""
A handler for transforming causal explanation queries into probabilistic inferences.

When used as a context manager, ``SearchForExplanation`` yields a dictionary of observations
that can be used with ``condition`` to simultaneously impose an additional factivity constraint
alongside the necessity and sufficiency constraints implemented by ``SearchForExplanation`` ::
that can be used with ``condition`` to impose an additional factivity constraint
alongside the necessity and sufficiency constraints implemented by ``SearchForExplanation``::

with SearchForExplanation(supports, antecedents, consequents, ...) as evidence:
with condition(data=evidence):
Expand All @@ -85,13 +91,18 @@ def SearchForExplanation(
:param alternatives: An optional mapping of names to alternative antecedent interventions.
:param factors: An optional mapping of names to consequent constraint factors.
:param preemptions: An optional mapping of names to witness preemption values.
:param antecedent_bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0.
:param consequent_scale: The scale of the consequent factor functions, defaults to 1e-2.
:param witness_bias: The scalar bias towards not preempting. Must be between -0.5 and 0.5, defaults to 0.0.
:param antecedent_bias: A scalar bias towards not intervening. Must be between -0.5 and 0.5. Defaults to 0.0.
:param consequent_scale: The scale of the consequent factor functions. Defaults to 1e-2.
:param witness_bias: A scalar bias towards not preempting. Must be between -0.5 and 0.5. Defaults to 0.0.
:param prefix: A prefix used for naming additional consequent nodes. Defaults to ``__consequent_``.
:param num_samples: The number of samples to be drawn for each antecedent and witness. Needed if witness and antecedent samples are to be coordinated.
:param sampling_dim: The dimension along which the antecedent and witness nodes will be sampled, to be kept consistent with case sampling.

:note: If ``num_samples`` is not provided, the antecedent and witness nodes will be sampled independently.

:return: A context manager that can be used to query the evidence.
"""

########################################
# Validate input arguments
########################################
Expand All @@ -105,6 +116,10 @@ def SearchForExplanation(
assert not set(witnesses.keys()) & set(consequents.keys())
else:
# if witness candidates are not provided, use all non-consequent nodes
warnings.warn(
"Witness candidates were not provided. Using all non-consequent nodes.",
UserWarning,
)
witnesses = {w: None for w in set(supports.keys()) - set(consequents.keys())}

##################################################################
Expand All @@ -131,12 +146,43 @@ def SearchForExplanation(
for a in antecedents.keys()
}

if num_samples is not None:
if sampling_dim is None:
raise ValueError("sampling_dim must be provided if num_samples is provided")

case_shape = [1] * torch.abs(torch.tensor(sampling_dim))
case_shape[sampling_dim] = num_samples

antecedent_probs = torch.tensor(
[0.5 - antecedent_bias] + ([(0.5 + antecedent_bias)])
)

antecedent_case_dist = dist.Categorical(probs=antecedent_probs)

antecedent_cases = {
key: antecedent_case_dist.sample(case_shape) for key in antecedents.keys()
}

witness_probs = torch.tensor([0.5 - witness_bias] + ([(0.5 + witness_bias)]))

witness_case_dist = dist.Categorical(probs=witness_probs)

witness_cases = {
key: witness_case_dist.sample(case_shape) for key in witnesses.keys()
}

witness_cases = {
key: value * antecedent_cases[key] if key in antecedent_cases else value
for key, value in witness_cases.items()
}

# interventions on subsets of antecedents
antecedent_handler = SplitSubsets(
{a: supports[a] for a in antecedents.keys()},
{a: (alternatives[a], sufficiency_actions[a]) for a in antecedents.keys()}, # type: ignore
bias=antecedent_bias,
prefix=f"{prefix}__antecedent_",
cases=antecedent_cases if num_samples is not None else None,
)

# defaults for witness_preemptions
Expand All @@ -151,6 +197,7 @@ def SearchForExplanation(
),
bias=witness_bias,
prefix=f"{prefix}__witness_",
cases=witness_cases if num_samples is not None else None,
)

#
Expand Down
15 changes: 13 additions & 2 deletions chirho/explainable/handlers/preemptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generic, Mapping, TypeVar
from typing import Generic, Mapping, Optional, TypeVar

import pyro
import torch
Expand Down Expand Up @@ -38,9 +38,14 @@ class Preemptions(Generic[T], pyro.poutine.messenger.Messenger):
and the probability of each counterfactual case is ``(0.5 + bias) / num_actions``,
where ``num_actions`` is the number of counterfactual actions for the sample site (usually 1).

In tasks where a site is both a potential antecedent and a potential witness,
sampling needs to be coordinated so that the site is neither antecedent-preempted nor witness-preempted. In such
cases, the ``cases`` argument can be used to pass coordinated preemption case tensors for the sites.

:param actions: A mapping from sample site names to interventions.
:param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5.
:param prefix: The prefix for naming the auxiliary discrete random variables.
:param cases: A mapping from sample site names to their preemption cases (for possible coordination between sites).
"""

actions: Mapping[str, Intervention[T]]
Expand All @@ -53,11 +58,13 @@ def __init__(
*,
prefix: str = "__witness_split_",
bias: float = 0.0,
cases: Optional[Mapping[str, torch.Tensor]] = None,
):
assert -0.5 <= bias <= 0.5, "bias must be between -0.5 and 0.5"
self.actions = actions
self.bias = bias
self.prefix = prefix
self.cases = cases
super().__init__()

def _pyro_post_sample(self, msg):
Expand All @@ -73,7 +80,11 @@ def _pyro_post_sample(self, msg):
device=msg["value"].device,
)
case_dist = pyro.distributions.Categorical(probs=weights)
case = pyro.sample(f"{self.prefix}{msg['name']}", case_dist)
case = pyro.sample(
f"{self.prefix}{msg['name']}",
case_dist,
obs=self.cases[msg["name"]] if self.cases is not None else None,
)

msg["value"] = preempt(
msg["value"],
Expand Down
73 changes: 67 additions & 6 deletions tests/explainable/test_handlers_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def stones_bayesian_model():
}


def test_SearchForExplanation():
@pytest.fixture
def test_search_setup():
supports = {
"sally_throws": constraints.boolean,
"bill_throws": constraints.boolean,
Expand All @@ -74,12 +75,9 @@ def test_SearchForExplanation():
}

antecedents = {"sally_throws": torch.tensor(1.0)}

consequents = {"bottle_shatters": torch.tensor(1.0)}

witnesses = {
"bill_throws": None,
}
witnesses = {"bill_throws": None}
wide_witness = {"sally_throws": torch.tensor(1.0), "bill_throws": None}

observation_keys = [
"prob_sally_throws",
Expand All @@ -97,6 +95,27 @@ def test_SearchForExplanation():

alternatives = {"sally_throws": 0.0}

return {
"supports": supports,
"antecedents": antecedents,
"consequents": consequents,
"witnesses": witnesses,
"wide_witness": wide_witness,
"observations_conditioning": observations_conditioning,
"alternatives": alternatives,
}


def test_SearchForExplanation(test_search_setup):

supports = test_search_setup["supports"]
antecedents = test_search_setup["antecedents"]
consequents = test_search_setup["consequents"]
witnesses = test_search_setup["witnesses"]
observations_conditioning = test_search_setup["observations_conditioning"]
alternatives = test_search_setup["alternatives"]
observations_conditioning = test_search_setup["observations_conditioning"]

with MultiWorldCounterfactual() as mwc:
with SearchForExplanation(
supports=supports,
Expand Down Expand Up @@ -191,6 +210,48 @@ def test_SearchForExplanation():
assert suff_log_probs[step] <= -10


def test_dependent_sampling(test_search_setup):

supports = test_search_setup["supports"]
antecedents = test_search_setup["antecedents"]
consequents = test_search_setup["consequents"]
witnesses = test_search_setup[
"wide_witness"
] # this time we make sure `sally_throws` is in both.
observations_conditioning = test_search_setup["observations_conditioning"]
alternatives = test_search_setup["alternatives"]
observations_conditioning = test_search_setup["observations_conditioning"]

with MultiWorldCounterfactual() as mwc:
with SearchForExplanation(
supports=supports,
antecedents=antecedents,
consequents=consequents,
witnesses=witnesses,
alternatives=alternatives,
antecedent_bias=0.1,
consequent_scale=1e-8,
num_samples=100,
sampling_dim=-1,
):
with observations_conditioning:
with pyro.plate("sample", 100):
with pyro.poutine.trace() as tr:
stones_bayesian_model()

tr.trace.compute_log_prob()
tr = tr.trace.nodes

sally_antecedent_preemption = tr["__cause____antecedent_sally_throws"]["value"]
sally_witness_preemption = tr["__cause____witness_sally_throws"]["value"]

assert torch.all(
torch.where(
sally_antecedent_preemption == 0, sally_witness_preemption == 0, True
)
)


def test_SplitSubsets_single_layer():
observations = {
"prob_sally_throws": 1.0,
Expand Down
Loading