Skip to content

Commit

Permalink
Merge pull request #5 from facusapienza21/name_bug_fix
Browse files Browse the repository at this point in the history
Bug regarding name handling
  • Loading branch information
esmucler authored Dec 3, 2022
2 parents e370635 + 9001d2b commit d8f9133
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 7 deletions.
4 changes: 4 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
optimaladj 0.0.4
==============

* Fixed a bug reported by Sara Taheri, whereby due to mishandling of node names we could end up including forbidden variables in an adjustment set.
4 changes: 2 additions & 2 deletions optimaladj/CausalGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def forbidden(self, treatment, outcome):
forbidden = set()

for node in self.causal_vertices(treatment, outcome):
forbidden = forbidden.union(nx.descendants(self, node).union(node))
forbidden = forbidden.union(nx.descendants(self, node).union({node}))

return forbidden.union(treatment)
return forbidden.union({treatment})

def ignore(self, treatment, outcome, L, N):
"""Returns the set of ignorable vertices with respect to treatment, outcome,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="optimaladj",
version="0.0.3",
version="0.0.4",
author="Facundo Sapienza, Ezequiel Smucler",
author_email="[email protected]",
description="A package to compute optimal adjustment sets in causal graphs",
Expand Down
64 changes: 64 additions & 0 deletions tests/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def __init__(self, G, treatment, outcome, L, N):
G_9.add_nodes_from(costs_9)
G_9.add_edges_from(
[
("A", "Y"),
("K", "A"),
("W1", "K"),
("W2", "K"),
Expand All @@ -362,3 +363,66 @@ def __init__(self, G, treatment, outcome, L, N):
OPTIMALS_MINIMAL.append(set(["W7", "W8", "W9"]))
OPTIMALS_MINIMUM.append(set(["K"]))
OPTIMALS_MINCOST.append(set(["W7", "W8", "W6"]))

# Regression test for bug on name handling, spotted by Sara Taheri

G_10 = CausalGraph()
L_10 = []
N_10 = ["T", "Y", "M1", "M2", "Z1", "Z2", "Z3"]
costs_10 = [(node, {"cost": 1}) for node in N_10]
treatment_10 = "T"
outcome_10 = "Y"
G_10.add_nodes_from(costs_10)
G_10.add_edges_from(
[
("Z1", "Z2"),
("Z1", "T"),
("Z2", "Z3"),
("Z3", "Y"),
("T", "M1"),
("M1", "M2"),
("M2", "Y"),
]
)

EXAMPLES.append(CausalGraphExample(G_10, treatment_10, outcome_10, L_10, N_10))

OPTIMALS.append(set(["Z3"]))
OPTIMALS_MINIMAL.append(set(["Z3"]))
OPTIMALS_MINIMUM.append(set(["Z3"]))
OPTIMALS_MINCOST.append(set(["Z3"]))

# Another regression test for bug on name handling, spotted by Sara Taheri

G_11 = CausalGraph()
L_11 = []
N_11 = ["T", "Y", "M1", "M2", "M3", "Z1", "Z2", "Z3", "Z4", "Z5"]
costs_11 = [(node, {"cost": 1}) for node in N_11]
treatment_11 = "T"
outcome_11 = "Y"
G_11.add_nodes_from(costs_11)
G_11.add_edges_from(
[
("Z1", "Z2"),
("Z1", "T"),
("Z2", "Z3"),
("Z3", "Z4"),
("Z4", "Z5"),
("Z5", "Y"),
("T", "M1"),
("M1", "M2"),
("M2", "M3"),
("M3", "Y"),
("U1", "Z1"),
("U1", "T"),
("U2", "Z2"),
("U2", "M1"),
]
)

EXAMPLES.append(CausalGraphExample(G_11, treatment_11, outcome_11, L_11, N_11))

OPTIMALS.append(set(["Z1", "Z2", "Z5"]))
OPTIMALS_MINIMAL.append(set(["Z1"]))
OPTIMALS_MINIMUM.append(set(["Z1"]))
OPTIMALS_MINCOST.append(set(["Z1"]))
8 changes: 4 additions & 4 deletions tests/test_CausalGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_no_adj_minimum_optimal(example=EXAMPLES[0]):

@pytest.mark.parametrize(
"example, optimal_stored",
zip(EXAMPLES[1:4] + [EXAMPLES[8]], OPTIMALS[1:4] + [OPTIMALS[8]]),
zip(EXAMPLES[1:4] + EXAMPLES[8:12], OPTIMALS[1:4] + OPTIMALS[8:12]),
)
def test_optimal(example, optimal_stored):
optimal = example.G.optimal_adj_set(
Expand All @@ -63,7 +63,7 @@ def test_optimal_failure(example):


@pytest.mark.parametrize(
"example, optimal_minimal_stored", zip(EXAMPLES[1:8], OPTIMALS_MINIMAL[1:8])
"example, optimal_minimal_stored", zip(EXAMPLES[1:12], OPTIMALS_MINIMAL[1:12])
)
def test_optimal_minimal(example, optimal_minimal_stored):
optimal = example.G.optimal_minimal_adj_set(
Expand All @@ -73,7 +73,7 @@ def test_optimal_minimal(example, optimal_minimal_stored):


@pytest.mark.parametrize(
"example, optimal_minimum_stored", zip(EXAMPLES[1:8], OPTIMALS_MINIMUM[1:8])
"example, optimal_minimum_stored", zip(EXAMPLES[1:12], OPTIMALS_MINIMUM[1:12])
)
def test_optimal_minimum(example, optimal_minimum_stored):
optimal = example.G.optimal_minimum_adj_set(
Expand All @@ -83,7 +83,7 @@ def test_optimal_minimum(example, optimal_minimum_stored):


@pytest.mark.parametrize(
"example, optimal_mincost_stored", zip(EXAMPLES[1:8], OPTIMALS_MINCOST[1:8])
"example, optimal_mincost_stored", zip(EXAMPLES[1:12], OPTIMALS_MINCOST[1:12])
)
def test_optimal_mincost(example, optimal_mincost_stored):
optimal = example.G.optimal_mincost_adj_set(
Expand Down

0 comments on commit d8f9133

Please sign in to comment.