Skip to content

Commit

Permalink
Add support for non-singleton treatment/outcome sets to "has_directed…
Browse files Browse the repository at this point in the history
…_path" method (#1247)

* Modifies has_directed_path and adds test

---------

Signed-off-by: Nicholas Parente <[email protected]>
  • Loading branch information
nparent1 authored Oct 9, 2024
1 parent 2f35b62 commit a5b56fb
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 12 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ venv.bak/
.spyderproject
.spyproject

# PyCharm
.idea/
*.iml
*.iws

# Rope project settings
.ropeproject

Expand Down
11 changes: 6 additions & 5 deletions dowhy/causal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import networkx as nx

from dowhy.gcm.causal_models import ProbabilisticCausalModel
from dowhy.graph import has_directed_path
from dowhy.utils.api import parse_state
from dowhy.utils.graph_operations import daggity_to_dot
from dowhy.utils.plotting import plot
Expand Down Expand Up @@ -446,14 +447,14 @@ def get_all_directed_paths(self, nodes1, nodes2):
# convert the outputted generator into a list
return [p for p in nx.all_simple_paths(self._graph, source=node1, target=node2)]

def has_directed_path(self, nodes1, nodes2):
def has_directed_path(self, action_nodes, outcome_nodes):
"""Checks if there is any directed path between two sets of nodes.
Currently only supports singleton sets.
Returns True if and only if every one of the treatments has at least one direct
path to one of the outcomes. And, every one of the outcomes has a direct path from
at least one of the treatments.
"""
# dpaths = self.get_all_directed_paths(nodes1, nodes2)
# return len(dpaths) > 0
return nx.has_path(self._graph, nodes1[0], nodes2[0])
return has_directed_path(self._graph, action_nodes, outcome_nodes)

def get_adjacency_matrix(self, *args, **kwargs):
"""
Expand Down
16 changes: 11 additions & 5 deletions dowhy/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,14 +270,20 @@ def get_all_directed_paths(graph: nx.DiGraph, nodes1, nodes2):
return [p for p in nx.all_simple_paths(graph, source=nodes1[0], target=nodes2[0])]


def has_directed_path(graph: nx.DiGraph, nodes1, nodes2):
def has_directed_path(graph: nx.DiGraph, action_nodes, outcome_nodes):
"""Checks if there is any directed path between two sets of nodes.
Currently only supports singleton sets.
Returns True if and only if every one of the treatments has at least one direct
path to one of the outcomes. And, every one of the outcomes has a direct path from
at least one of the treatments.
"""
# dpaths = self.get_all_directed_paths(nodes1, nodes2)
# return len(dpaths) > 0
return nx.has_path(graph, nodes1[0], nodes2[0])
outcome_node_candidates = set()
action_node_candidates = set()
for node in action_nodes:
outcome_node_candidates.update(nx.descendants(graph, node))
for node in outcome_nodes:
action_node_candidates.update(nx.ancestors(graph, node))
return set(outcome_nodes).issubset(outcome_node_candidates) and set(action_nodes).issubset(action_node_candidates)


def check_valid_mediation_set(graph: nx.DiGraph, nodes1, nodes2, candidate_nodes, mediation_paths=None):
Expand Down
17 changes: 17 additions & 0 deletions tests/causal_identifiers/test_auto_identifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from dowhy.causal_identifier import AutoIdentifier
from dowhy.causal_identifier.identify_effect import EstimandType
from dowhy.graph import build_graph_from_str


class TestAutoIdentification(object):
def test_auto_identify_identifies_no_directed_path(self):
# Test added for issue #1250
graph = build_graph_from_str("digraph{T->Y;A->Y;A->B;}")
identifier = AutoIdentifier(estimand_type=EstimandType.NONPARAMETRIC_ATE)

assert identifier.identify_effect(
graph, action_nodes=["T", "B"], outcome_nodes=["Y"], observed_nodes=["T", "Y", "A", "B"]
).no_directed_path
assert identifier.identify_effect(
graph, action_nodes=["B", "T"], outcome_nodes=["Y"], observed_nodes=["T", "Y", "A", "B"]
).no_directed_path
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_fail_multivar_outcome_efficient_backdoor_algorithms():
ident_eff.identify_effect(
build_graph_from_str(example["graph_str"]),
observed_nodes=example["observed_node_names"],
action_nodes=["X"],
outcome_nodes=["Y", "R"],
action_nodes=["U"],
outcome_nodes=["Y", "F"],
conditional_node_names=example["conditional_node_names"],
)
6 changes: 6 additions & 0 deletions tests/test_causal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,9 @@ def test_build_graph(self):
def test_build_graph_from_str(self):
build_graph_from_str(self.daggity_file)
build_graph_from_str(self.graph_str)

def test_has_path(self):
assert has_directed_path(self.nx_graph, ["X0"], ["y"])
assert has_directed_path(self.nx_graph, ["X0", "X1", "X2"], ["y", "v0"])
assert not has_directed_path(self.nx_graph, [], ["y"])
assert not has_directed_path(self.nx_graph, ["X0", "X1", "X2"], ["y", "v0", "Z0"])

0 comments on commit a5b56fb

Please sign in to comment.