From 8468c5e161a091398fb06fe8f5b10976e966baab Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Fri, 21 Feb 2025 11:39:31 -0500 Subject: [PATCH 1/6] added cases, refactored SearchForExplanationTest --- chirho/explainable/handlers/explanation.py | 46 ++++++++++++- chirho/explainable/handlers/preemptions.py | 10 ++- .../explainable/test_handlers_explanation.py | 64 +++++++++++++++++-- 3 files changed, 110 insertions(+), 10 deletions(-) diff --git a/chirho/explainable/handlers/explanation.py b/chirho/explainable/handlers/explanation.py index b23e39680..ef56e3791 100644 --- a/chirho/explainable/handlers/explanation.py +++ b/chirho/explainable/handlers/explanation.py @@ -1,6 +1,8 @@ import contextlib +import warnings from typing import Callable, Mapping, Optional, TypeVar, Union +import pyro.distributions as dist import pyro.distributions.constraints as constraints import torch @@ -27,6 +29,7 @@ def SplitSubsets( *, bias: float = 0.0, prefix: str = "__cause_split_", + cases: Optional[Mapping[str, torch.Tensor]] = None, ): """ A context manager used for a stochastic search of minimal but-for causes among potential interventions. @@ -46,7 +49,7 @@ def SplitSubsets( } with do(actions=actions): - with Preemptions(actions=preemptions, bias=bias, prefix=prefix): + with Preemptions(actions=preemptions, bias=bias, prefix=prefix, cases=cases): yield @@ -66,6 +69,8 @@ def SearchForExplanation( antecedent_bias: float = 0.0, witness_bias: float = 0.0, prefix: str = "__cause__", + num_samples: Optional[int] = None, + sampling_dim: Optional[int] = None, ): """ A handler for transforming causal explanation queries into probabilistic inferences. @@ -105,6 +110,10 @@ def SearchForExplanation( assert not set(witnesses.keys()) & set(consequents.keys()) else: # if witness candidates are not provided, use all non-consequent nodes + warnings.warn( + "Witness candidates were not provided. Using all non-consequent nodes.", + UserWarning, + ) witnesses = {w: None for w in set(supports.keys()) - set(consequents.keys())} ################################################################## @@ -131,12 +140,46 @@ def SearchForExplanation( for a in antecedents.keys() } + if num_samples is not None: + if sampling_dim is None: + raise ValueError("sampling_dim must be provided if num_samples is provided") + + case_shape = [1] * torch.abs(torch.tensor(sampling_dim)) + case_shape[sampling_dim] = num_samples + + antecedent_probs = torch.tensor( + [0.5 - antecedent_bias] + ([(0.5 + antecedent_bias)]) + ) + + antecedent_case_dist = dist.Categorical(probs=antecedent_probs) + + antecedent_cases = { + key: antecedent_case_dist.sample(case_shape) for key in antecedents.keys() + } + + # for key in antecedents.keys(): + # antecedent_cases[key] = antecedent_case_dist.sample(case_shape) + + witness_probs = torch.tensor([0.5 - witness_bias] + ([(0.5 + witness_bias)])) + + witness_case_dist = dist.Categorical(probs=witness_probs) + + witness_cases = { + key: witness_case_dist.sample(case_shape) for key in witnesses.keys() + } + + witness_cases = { + key: value * antecedent_cases[key] if key in antecedent_cases else value + for key, value in witness_cases.items() + } + # interventions on subsets of antecedents antecedent_handler = SplitSubsets( {a: supports[a] for a in antecedents.keys()}, {a: (alternatives[a], sufficiency_actions[a]) for a in antecedents.keys()}, # type: ignore bias=antecedent_bias, prefix=f"{prefix}__antecedent_", + cases=antecedent_cases if num_samples is not None else None, ) # defaults for witness_preemptions @@ -151,6 +194,7 @@ def SearchForExplanation( ), bias=witness_bias, prefix=f"{prefix}__witness_", + cases=witness_cases if num_samples is not None else None, ) # diff --git a/chirho/explainable/handlers/preemptions.py b/chirho/explainable/handlers/preemptions.py index 4268d730f..94cce70c6 100644 --- a/chirho/explainable/handlers/preemptions.py +++ b/chirho/explainable/handlers/preemptions.py @@ -1,4 +1,4 @@ -from typing import Generic, Mapping, TypeVar +from typing import Generic, Mapping, Optional, TypeVar import pyro import torch @@ -53,11 +53,13 @@ def __init__( *, prefix: str = "__witness_split_", bias: float = 0.0, + cases: Optional[Mapping[str, torch.Tensor]] = None, ): assert -0.5 <= bias <= 0.5, "bias must be between -0.5 and 0.5" self.actions = actions self.bias = bias self.prefix = prefix + self.cases = cases super().__init__() def _pyro_post_sample(self, msg): @@ -73,7 +75,11 @@ def _pyro_post_sample(self, msg): device=msg["value"].device, ) case_dist = pyro.distributions.Categorical(probs=weights) - case = pyro.sample(f"{self.prefix}{msg['name']}", case_dist) + case = pyro.sample( + f"{self.prefix}{msg['name']}", + case_dist, + obs=self.cases[msg["name"]] if self.cases is not None else None, + ) msg["value"] = preempt( msg["value"], diff --git a/tests/explainable/test_handlers_explanation.py b/tests/explainable/test_handlers_explanation.py index 3e48cac88..414cac027 100644 --- a/tests/explainable/test_handlers_explanation.py +++ b/tests/explainable/test_handlers_explanation.py @@ -63,8 +63,8 @@ def stones_bayesian_model(): "bottle_shatters": bottle_shatters, } - -def test_SearchForExplanation(): +@pytest.fixture +def test_search_setup(): supports = { "sally_throws": constraints.boolean, "bill_throws": constraints.boolean, @@ -74,12 +74,8 @@ def test_SearchForExplanation(): } antecedents = {"sally_throws": torch.tensor(1.0)} - consequents = {"bottle_shatters": torch.tensor(1.0)} - - witnesses = { - "bill_throws": None, - } + witnesses = {"bill_throws": None} observation_keys = [ "prob_sally_throws", @@ -97,6 +93,60 @@ def test_SearchForExplanation(): alternatives = {"sally_throws": 0.0} + return { + "supports": supports, + "antecedents": antecedents, + "consequents": consequents, + "witnesses": witnesses, + "observations_conditioning": observations_conditioning, + "alternatives": alternatives, + } + + +def test_SearchForExplanation(test_search_setup): + # supports = { + # "sally_throws": constraints.boolean, + # "bill_throws": constraints.boolean, + # "sally_hits": constraints.boolean, + # "bill_hits": constraints.boolean, + # "bottle_shatters": constraints.boolean, + # } + + # antecedents = {"sally_throws": torch.tensor(1.0)} + + # consequents = {"bottle_shatters": torch.tensor(1.0)} + + # witnesses = { + # "bill_throws": None, + # } + + # observation_keys = [ + # "prob_sally_throws", + # "prob_bill_throws", + # "prob_sally_hits", + # "prob_bill_hits", + # "prob_bottle_shatters_if_sally", + # "prob_bottle_shatters_if_bill", + # ] + # observations = {k: torch.tensor(1.0) for k in observation_keys} + + # observations_conditioning = condition( + # data={k: torch.as_tensor(v) for k, v in observations.items()} + # ) + + # alternatives = {"sally_throws": 0.0} + + + supports = test_search_setup["supports"] + antecedents = test_search_setup["antecedents"] + consequents = test_search_setup["consequents"] + witnesses = test_search_setup["witnesses"] + observations_conditioning = test_search_setup["observations_conditioning"] + alternatives = test_search_setup["alternatives"] + observations_conditioning = test_search_setup["observations_conditioning"] + + + with MultiWorldCounterfactual() as mwc: with SearchForExplanation( supports=supports, From b0d74eb8119ad54bbf3629972888dd0d9eb03859 Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Fri, 21 Feb 2025 12:32:30 -0500 Subject: [PATCH 2/6] added test for dependent sampling --- .../explainable/test_handlers_explanation.py | 79 +++++++++++-------- 1 file changed, 45 insertions(+), 34 deletions(-) diff --git a/tests/explainable/test_handlers_explanation.py b/tests/explainable/test_handlers_explanation.py index 414cac027..52d27c6f1 100644 --- a/tests/explainable/test_handlers_explanation.py +++ b/tests/explainable/test_handlers_explanation.py @@ -63,6 +63,7 @@ def stones_bayesian_model(): "bottle_shatters": bottle_shatters, } + @pytest.fixture def test_search_setup(): supports = { @@ -76,6 +77,7 @@ def test_search_setup(): antecedents = {"sally_throws": torch.tensor(1.0)} consequents = {"bottle_shatters": torch.tensor(1.0)} witnesses = {"bill_throws": None} + wide_witness = {"sally_throws": torch.tensor(1.0), "bill_throws": None} observation_keys = [ "prob_sally_throws", @@ -98,44 +100,13 @@ def test_search_setup(): "antecedents": antecedents, "consequents": consequents, "witnesses": witnesses, + "wide_witness": wide_witness, "observations_conditioning": observations_conditioning, "alternatives": alternatives, } def test_SearchForExplanation(test_search_setup): - # supports = { - # "sally_throws": constraints.boolean, - # "bill_throws": constraints.boolean, - # "sally_hits": constraints.boolean, - # "bill_hits": constraints.boolean, - # "bottle_shatters": constraints.boolean, - # } - - # antecedents = {"sally_throws": torch.tensor(1.0)} - - # consequents = {"bottle_shatters": torch.tensor(1.0)} - - # witnesses = { - # "bill_throws": None, - # } - - # observation_keys = [ - # "prob_sally_throws", - # "prob_bill_throws", - # "prob_sally_hits", - # "prob_bill_hits", - # "prob_bottle_shatters_if_sally", - # "prob_bottle_shatters_if_bill", - # ] - # observations = {k: torch.tensor(1.0) for k in observation_keys} - - # observations_conditioning = condition( - # data={k: torch.as_tensor(v) for k, v in observations.items()} - # ) - - # alternatives = {"sally_throws": 0.0} - supports = test_search_setup["supports"] antecedents = test_search_setup["antecedents"] @@ -145,8 +116,6 @@ def test_SearchForExplanation(test_search_setup): alternatives = test_search_setup["alternatives"] observations_conditioning = test_search_setup["observations_conditioning"] - - with MultiWorldCounterfactual() as mwc: with SearchForExplanation( supports=supports, @@ -241,6 +210,48 @@ def test_SearchForExplanation(test_search_setup): assert suff_log_probs[step] <= -10 +def test_dependent_sampling(test_search_setup): + + supports = test_search_setup["supports"] + antecedents = test_search_setup["antecedents"] + consequents = test_search_setup["consequents"] + witnesses = test_search_setup[ + "wide_witness" + ] # this time we make sure `sally_throws` is in both. + observations_conditioning = test_search_setup["observations_conditioning"] + alternatives = test_search_setup["alternatives"] + observations_conditioning = test_search_setup["observations_conditioning"] + + with MultiWorldCounterfactual() as mwc: + with SearchForExplanation( + supports=supports, + antecedents=antecedents, + consequents=consequents, + witnesses=witnesses, + alternatives=alternatives, + antecedent_bias=0.1, + consequent_scale=1e-8, + num_samples=100, + sampling_dim=-1, + ): + with observations_conditioning: + with pyro.plate("sample", 100): + with pyro.poutine.trace() as tr: + stones_bayesian_model() + + tr.trace.compute_log_prob() + tr = tr.trace.nodes + + sally_antecedent_preemption = tr["__cause____antecedent_sally_throws"]["value"] + sally_witness_preemption = tr["__cause____witness_sally_throws"]["value"] + + assert torch.all( + torch.where( + sally_antecedent_preemption == 0, sally_witness_preemption == 0, True + ) + ) + + def test_SplitSubsets_single_layer(): observations = { "prob_sally_throws": 1.0, From b60b3f2aee9f1454b0373e60b5b9cf6f8fa7b987 Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Fri, 21 Feb 2025 12:49:52 -0500 Subject: [PATCH 3/6] update SearchForExplanation docstring --- chirho/explainable/handlers/explanation.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/chirho/explainable/handlers/explanation.py b/chirho/explainable/handlers/explanation.py index ef56e3791..c8052183c 100644 --- a/chirho/explainable/handlers/explanation.py +++ b/chirho/explainable/handlers/explanation.py @@ -76,8 +76,8 @@ def SearchForExplanation( A handler for transforming causal explanation queries into probabilistic inferences. When used as a context manager, ``SearchForExplanation`` yields a dictionary of observations - that can be used with ``condition`` to simultaneously impose an additional factivity constraint - alongside the necessity and sufficiency constraints implemented by ``SearchForExplanation`` :: + that can be used with ``condition`` to impose an additional factivity constraint + alongside the necessity and sufficiency constraints implemented by ``SearchForExplanation``:: with SearchForExplanation(supports, antecedents, consequents, ...) as evidence: with condition(data=evidence): @@ -90,13 +90,18 @@ def SearchForExplanation( :param alternatives: An optional mapping of names to alternative antecedent interventions. :param factors: An optional mapping of names to consequent constraint factors. :param preemptions: An optional mapping of names to witness preemption values. - :param antecedent_bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0. - :param consequent_scale: The scale of the consequent factor functions, defaults to 1e-2. - :param witness_bias: The scalar bias towards not preempting. Must be between -0.5 and 0.5, defaults to 0.0. + :param antecedent_bias: A scalar bias towards not intervening. Must be between -0.5 and 0.5. Defaults to 0.0. + :param consequent_scale: The scale of the consequent factor functions. Defaults to 1e-2. + :param witness_bias: A scalar bias towards not preempting. Must be between -0.5 and 0.5. Defaults to 0.0. :param prefix: A prefix used for naming additional consequent nodes. Defaults to ``__consequent_``. + :param num_samples: The number of samples to be drawn for each antecedent and witness. Needed if witness and antecedent samples are to be coordinated. + :param sampling_dim: The dimension along which the antecedent and witness nodes will be sampled, to be kept consistent with case sampling. + + :note: If ``num_samples`` is not provided, the antecedent and witness nodes will be sampled independently. :return: A context manager that can be used to query the evidence. """ + ######################################## # Validate input arguments ######################################## @@ -157,8 +162,6 @@ def SearchForExplanation( key: antecedent_case_dist.sample(case_shape) for key in antecedents.keys() } - # for key in antecedents.keys(): - # antecedent_cases[key] = antecedent_case_dist.sample(case_shape) witness_probs = torch.tensor([0.5 - witness_bias] + ([(0.5 + witness_bias)])) From 317ed293ee2d9b2071829014ec4f1a66a45e395f Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Fri, 21 Feb 2025 12:51:08 -0500 Subject: [PATCH 4/6] update SplitSubset docstring --- chirho/explainable/handlers/explanation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chirho/explainable/handlers/explanation.py b/chirho/explainable/handlers/explanation.py index c8052183c..2edbd3122 100644 --- a/chirho/explainable/handlers/explanation.py +++ b/chirho/explainable/handlers/explanation.py @@ -42,6 +42,7 @@ def SplitSubsets( :param actions: A mapping of sites to interventions. :param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0. :param prefix: A prefix used for naming additional preemption nodes. Defaults to ``__cause_split_``. + :param cases: A mapping of sites to their preemption cases (for possible coordination between sites). """ preemptions = { antecedent: undo_split(supports[antecedent], antecedents=[antecedent]) From 9acf46579079b1a8b4e9490bd134500d81150c04 Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Fri, 21 Feb 2025 12:54:24 -0500 Subject: [PATCH 5/6] update Preemptions docstring --- chirho/explainable/handlers/preemptions.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/chirho/explainable/handlers/preemptions.py b/chirho/explainable/handlers/preemptions.py index 94cce70c6..da170b21e 100644 --- a/chirho/explainable/handlers/preemptions.py +++ b/chirho/explainable/handlers/preemptions.py @@ -38,9 +38,14 @@ class Preemptions(Generic[T], pyro.poutine.messenger.Messenger): and the probability of each counterfactual case is ``(0.5 + bias) / num_actions``, where ``num_actions`` is the number of counterfactual actions for the sample site (usually 1). + In tasks where a site is both a potential antecedent and a potential witness, + sampling needs to be coordinated so that the site is neither antecedent-preempted nor witness-preempted. In such + cases, the ``cases`` argument can be used to pass coordinated preemption case tensors for the sites. + :param actions: A mapping from sample site names to interventions. :param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5. :param prefix: The prefix for naming the auxiliary discrete random variables. + :param cases: A mapping from sample site names to their preemption cases (for possible coordination between sites). """ actions: Mapping[str, Intervention[T]] From b68360c51514b9133cda925bba96374d7e94c954 Mon Sep 17 00:00:00 2001 From: rfl-urbaniak Date: Fri, 21 Feb 2025 14:00:00 -0500 Subject: [PATCH 6/6] format lint --- chirho/explainable/handlers/explanation.py | 3 +-- chirho/explainable/handlers/preemptions.py | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/chirho/explainable/handlers/explanation.py b/chirho/explainable/handlers/explanation.py index 2edbd3122..bcaf0590d 100644 --- a/chirho/explainable/handlers/explanation.py +++ b/chirho/explainable/handlers/explanation.py @@ -42,7 +42,7 @@ def SplitSubsets( :param actions: A mapping of sites to interventions. :param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5, defaults to 0.0. :param prefix: A prefix used for naming additional preemption nodes. Defaults to ``__cause_split_``. - :param cases: A mapping of sites to their preemption cases (for possible coordination between sites). + :param cases: A mapping of sites to their preemption cases (for possible coordination between sites). """ preemptions = { antecedent: undo_split(supports[antecedent], antecedents=[antecedent]) @@ -163,7 +163,6 @@ def SearchForExplanation( key: antecedent_case_dist.sample(case_shape) for key in antecedents.keys() } - witness_probs = torch.tensor([0.5 - witness_bias] + ([(0.5 + witness_bias)])) witness_case_dist = dist.Categorical(probs=witness_probs) diff --git a/chirho/explainable/handlers/preemptions.py b/chirho/explainable/handlers/preemptions.py index da170b21e..53810c346 100644 --- a/chirho/explainable/handlers/preemptions.py +++ b/chirho/explainable/handlers/preemptions.py @@ -38,9 +38,9 @@ class Preemptions(Generic[T], pyro.poutine.messenger.Messenger): and the probability of each counterfactual case is ``(0.5 + bias) / num_actions``, where ``num_actions`` is the number of counterfactual actions for the sample site (usually 1). - In tasks where a site is both a potential antecedent and a potential witness, - sampling needs to be coordinated so that the site is neither antecedent-preempted nor witness-preempted. In such - cases, the ``cases`` argument can be used to pass coordinated preemption case tensors for the sites. + In tasks where a site is both a potential antecedent and a potential witness, + sampling needs to be coordinated so that the site is neither antecedent-preempted nor witness-preempted. In such + cases, the ``cases`` argument can be used to pass coordinated preemption case tensors for the sites. :param actions: A mapping from sample site names to interventions. :param bias: The scalar bias towards not intervening. Must be between -0.5 and 0.5.