diff --git a/chirho/counterfactual/handlers/counterfactual.py b/chirho/counterfactual/handlers/counterfactual.py index 0fbec2c25..e60e89e5b 100644 --- a/chirho/counterfactual/handlers/counterfactual.py +++ b/chirho/counterfactual/handlers/counterfactual.py @@ -196,6 +196,50 @@ def _pyro_split(cls, msg: Dict[str, Any]) -> None: msg["kwargs"]["name"] = msg["name"] = cls.default_name +class Splits(Generic[T], pyro.poutine.messenger.Messenger): + """ + Effect handler that applies the operation :func:`~chirho.counterfactual.ops.split` + to sample sites in a probabilistic program, + similar to the handler :func:`~chirho.observational.handlers.condition` + for :func:`~chirho.observational.ops.observe` . + or the handler :func:`~chirho.interventional.handlers.do` + for :func:`~chirho.interventional.ops.intervene` . + + See the documentation for :func:`~chirho.counterfactual.ops.split` for more details. + + :param actions: A mapping from sample site names to interventions. + :param prefix: An optional prefix for naming the auxiliary plates. + """ + + actions: Mapping[str, Intervention[T]] + prefix: str + + def __init__( + self, + actions: Mapping[str, Intervention[T]], + *, + prefix: str = "", + ): + self.actions = actions + self.prefix = prefix + super().__init__() + + def _pyro_post_sample(self, msg): + try: + action = self.actions[msg["name"]] + except KeyError: + return + + action = (action,) if not isinstance(action, tuple) else action + + msg["value"] = split( + msg["value"], + action, + event_dim=len(msg["fn"].event_shape), + name=f"{self.prefix}{msg['name']}", + ) + + class Preemptions(Generic[T], pyro.poutine.messenger.Messenger): """ Effect handler that applies the operation :func:`~chirho.counterfactual.ops.preempt`