Skip to content

Commit

Permalink
Merge pull request #194 from EvieQ01/main
Browse files Browse the repository at this point in the history
Add orientation rules 567 for Augmented FCI
  • Loading branch information
kunwuz authored Oct 8, 2024
2 parents 0021a13 + 0e076fc commit 7924253
Show file tree
Hide file tree
Showing 15 changed files with 297 additions and 59 deletions.
219 changes: 206 additions & 13 deletions causallearn/search/ConstraintBased/FCI.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from queue import Queue
from typing import List, Set, Tuple, Dict
from typing import List, Set, Tuple, Dict, Generator
from numpy import ndarray

from causallearn.graph.Edge import Edge
Expand All @@ -17,6 +17,19 @@
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
from itertools import combinations

def is_uncovered_path(nodes: List[Node], G: Graph) -> bool:
"""
Determines whether the given path is an uncovered path in this graph.
A path is an uncovered path if no two nonconsecutive nodes (Vi-1 and Vi+1) in the path are
adjacent.
"""
for i in range(len(nodes) - 2):
if G.is_adjacent_to(nodes[i], nodes[i + 2]):
return False
return True


def traverseSemiDirected(node: Node, edge: Edge) -> Node | None:
if node == edge.get_node1():
if edge.get_endpoint1() == Endpoint.TAIL or edge.get_endpoint1() == Endpoint.CIRCLE:
Expand All @@ -26,8 +39,17 @@ def traverseSemiDirected(node: Node, edge: Edge) -> Node | None:
return edge.get_node1()
return None

def traverseCircle(node: Node, edge: Edge) -> Node | None:
if node == edge.get_node1():
if edge.get_endpoint1() == Endpoint.CIRCLE and edge.get_endpoint2() == Endpoint.CIRCLE:
return edge.get_node2()
elif node == edge.get_node2():
if edge.get_endpoint1() == Endpoint.CIRCLE and edge.get_endpoint2() == Endpoint.CIRCLE:
return edge.get_node1()
return None


def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool:
def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool: ## TODO: Now it does not detect whether the path is an uncovered path
Q = Queue()
V = set()

Expand Down Expand Up @@ -60,6 +82,42 @@ def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool:

return False

def GetUncoveredCirclePath(node_from: Node, node_to: Node, G: Graph, exclude_node: List[Node]) -> Generator[Node] | None:
Q = Queue()
V = set()

path = [node_from]

for node_u in G.get_adjacent_nodes(node_from):
if node_u in exclude_node:
continue
edge = G.get_edge(node_from, node_u)
node_c = traverseCircle(node_from, edge)

if node_c is None or node_c in exclude_node:
continue

if not V.__contains__(node_c):
V.add(node_c)
Q.put((node_c, path + [node_c]))

while not Q.empty():
node_t, path = Q.get_nowait()
if node_t == node_to and is_uncovered_path(path, G):
yield path

for node_u in G.get_adjacent_nodes(node_t):
edge = G.get_edge(node_t, node_u)
node_c = traverseCircle(node_t, edge)

if node_c is None or node_c in exclude_node:
continue

if not V.__contains__(node_c):
V.add(node_c)
Q.put((node_c, path + [node_c]))



def existOnePathWithPossibleParents(previous, node_w: Node, node_x: Node, node_b: Node, graph: Graph) -> bool:
if node_w == node_x:
Expand Down Expand Up @@ -371,6 +429,131 @@ def ruleR3(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], bk: Backgrou
changeFlag = True
return changeFlag

def ruleR5(graph: Graph, changeFlag: bool,
verbose: bool = False) -> bool:
"""
Rule R5 of the FCI algorithm.
by Jiji Zhang, 2008, "On the completeness of orientation rules for causal discovery in the presence of latent confounders and selection bias"]
This function orients any edge that is part of an uncovered circle path between two nodes A and B,
if such a path exists. The path must start and end with a circle edge and must be uncovered, i.e. the
nodes on the path must not be adjacent to A or B. The orientation of the edges on the path is set to
double tail.
"""
nodes = graph.get_nodes()
def orient_on_path_helper(path, node_A, node_B):
# orient A - C, D - B
edge = graph.get_edge(node_A, path[0])
graph.remove_edge(edge)
graph.add_edge(Edge(node_A, path[0], Endpoint.TAIL, Endpoint.TAIL))

edge = graph.get_edge(node_B, path[-1])
graph.remove_edge(edge)
graph.add_edge(Edge(node_B, path[-1], Endpoint.TAIL, Endpoint.TAIL))
if verbose:
print("Orienting edge A - C (Double tail): " + graph.get_edge(node_A, path[0]).__str__())
print("Orienting edge B - D (Double tail): " + graph.get_edge(node_B, path[-1]).__str__())

# orient everything on the path to both tails
for i in range(len(path) - 1):
edge = graph.get_edge(path[i], path[i + 1])
graph.remove_edge(edge)
graph.add_edge(Edge(path[i], path[i + 1], Endpoint.TAIL, Endpoint.TAIL))
if verbose:
print("Orienting edge (Double tail): " + graph.get_edge(path[i], path[i + 1]).__str__())

for node_B in nodes:
intoBCircles = graph.get_nodes_into(node_B, Endpoint.CIRCLE)

for node_A in intoBCircles:
found_paths_between_AB = []
if graph.get_endpoint(node_B, node_A) != Endpoint.CIRCLE:
continue
else:
# Check if there is an uncovered circle path between A and B (A o-o C .. D o-o B)
# s.t. A is not adjacent to D and B is not adjacent to C
a_node_idx = graph.node_map[node_A]
b_node_idx = graph.node_map[node_B]
a_adj_nodes = graph.get_adjacent_nodes(node_A)
b_adj_nodes = graph.get_adjacent_nodes(node_B)

# get the adjacent nodes with circle edges of A and B
a_circle_adj_nodes_set = [node for node in a_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= b_node_idx
and graph.get_endpoint(node, node_A) == Endpoint.CIRCLE and graph.get_endpoint(node_A, node) == Endpoint.CIRCLE]
b_circle_adj_nodes_set = [node for node in b_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= b_node_idx
and graph.get_endpoint(node, node_B) == Endpoint.CIRCLE and graph.get_endpoint(node_B, node) == Endpoint.CIRCLE]

# get the adjacent nodes with circle edges of A and B that is non adjacent to B and A, respectively
for node_C in a_circle_adj_nodes_set:
if graph.is_adjacent_to(node_B, node_C):
continue
for node_D in b_circle_adj_nodes_set:
if graph.is_adjacent_to(node_A, node_D):
continue
paths = GetUncoveredCirclePath(node_from=node_C, node_to=node_D, G=graph, exclude_node=[node_A, node_B]) # get the uncovered circle path between C and D, excluding A and B
found_paths_between_AB.append(paths)

# Orient the uncovered circle path between A and B
for paths in found_paths_between_AB:
for path in paths:
changeFlag = True
if verbose:
print("Find uncovered circle path between A and B: " + graph.get_edge(node_A, node_B).__str__())
edge = graph.get_edge(node_A, node_B)
graph.remove_edge(edge)
graph.add_edge(Edge(node_A, node_B, Endpoint.TAIL, Endpoint.TAIL))
orient_on_path_helper(path, node_A, node_B)

return changeFlag

def ruleR6(graph: Graph, changeFlag: bool,
verbose: bool = False) -> bool:
nodes = graph.get_nodes()

for node_B in nodes:
# Find A - B
intoBTails = graph.get_nodes_into(node_B, Endpoint.TAIL)
exist = False
for node_A in intoBTails:
if graph.get_endpoint(node_B, node_A) == Endpoint.TAIL:
exist = True
if not exist:
continue
# Find B o-*C
intoBCircles = graph.get_nodes_into(node_B, Endpoint.CIRCLE)
for node_C in intoBCircles:
changeFlag = True
edge = graph.get_edge(node_B, node_C)
graph.remove_edge(edge)
graph.add_edge(Edge(node_B, node_C, Endpoint.TAIL, edge.get_proximal_endpoint(node_C)))
if verbose:
print("Orienting edge by rule 6): " + graph.get_edge(node_B, node_C).__str__())

return changeFlag


def ruleR7(graph: Graph, changeFlag: bool,
verbose: bool = False) -> bool:
nodes = graph.get_nodes()

for node_B in nodes:
# Find A -o B
intoBCircles = graph.get_nodes_into(node_B, Endpoint.CIRCLE)
node_A_list = [node for node in intoBCircles if graph.get_endpoint(node_B, node) == Endpoint.TAIL]

# Find B o-*C
for node_C in intoBCircles:
# pdb.set_trace()
for node_A in node_A_list:
# pdb.set_trace()
if not graph.is_adjacent_to(node_A, node_C):
changeFlag = True
edge = graph.get_edge(node_B, node_C)
graph.remove_edge(edge)
graph.add_edge(Edge(node_B, node_C, Endpoint.TAIL, edge.get_proximal_endpoint(node_C)))
if verbose:
print("Orienting edge by rule 7): " + graph.get_edge(node_B, node_C).__str__())
return changeFlag

def getPath(node_c: Node, previous) -> List[Node]:
l = []
Expand Down Expand Up @@ -544,9 +727,8 @@ def ruleR4B(graph: Graph, maxPathLength: int, data: ndarray, independence_test_m



def rule8(graph: Graph, nodes: List[Node]):
nodes = graph.get_nodes()
changeFlag = False
def rule8(graph: Graph, nodes: List[Node], changeFlag):
nodes = graph.get_nodes() if nodes is None else nodes
for node_B in nodes:
adj = graph.get_adjacent_nodes(node_B)
if len(adj) < 2:
Expand Down Expand Up @@ -601,9 +783,9 @@ def find_possible_children(graph: Graph, parent_node, en_nodes=None):

return potential_child_nodes

def rule9(graph: Graph, nodes: List[Node]):
changeFlag = False
nodes = graph.get_nodes()
def rule9(graph: Graph, nodes: List[Node], changeFlag):
# changeFlag = False
nodes = graph.get_nodes() if nodes is None else nodes
for node_C in nodes:
intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW)
for node_A in intoCArrows:
Expand All @@ -629,8 +811,8 @@ def rule9(graph: Graph, nodes: List[Node]):
return changeFlag


def rule10(graph: Graph):
changeFlag = False
def rule10(graph: Graph, changeFlag):
# changeFlag = False
nodes = graph.get_nodes()
for node_C in nodes:
intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW)
Expand Down Expand Up @@ -895,6 +1077,7 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
graph, sep_sets, test_results = fas(dataset, nodes, independence_test_method=independence_test_method, alpha=alpha,
knowledge=background_knowledge, depth=depth, verbose=verbose, show_progress=show_progress)

# pdb.set_trace()
reorientAllWith(graph, Endpoint.CIRCLE)

rule0(graph, nodes, sep_sets, background_knowledge, verbose)
Expand Down Expand Up @@ -925,12 +1108,22 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
if verbose:
print("Epoch")

# rule 5
change_flag = ruleR5(graph, change_flag, verbose)

# rule 6
change_flag = ruleR6(graph, change_flag, verbose)

# rule 7
change_flag = ruleR7(graph, change_flag, verbose)

# rule 8
change_flag = rule8(graph,nodes)
change_flag = rule8(graph,nodes, change_flag)

# rule 9
change_flag = rule9(graph, nodes)
change_flag = rule9(graph, nodes, change_flag)
# rule 10
change_flag = rule10(graph)
change_flag = rule10(graph, change_flag)

graph.set_pag(True)

Expand Down
18 changes: 13 additions & 5 deletions causallearn/utils/DAG2PAG.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from causallearn.graph.Endpoint import Endpoint
from causallearn.graph.GeneralGraph import GeneralGraph
from causallearn.graph.Node import Node
from causallearn.search.ConstraintBased.FCI import rule0, rulesR1R2cycle, ruleR3, ruleR4B
from causallearn.search.ConstraintBased.FCI import rule0, rulesR1R2cycle, ruleR3, ruleR4B, ruleR5, ruleR6, ruleR7, rule8, rule9, rule10
from causallearn.utils.cit import CIT, d_separation

def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:

def dag2pag(dag: Dag, islatent: List[Node], isselection: List[Node] = []) -> GeneralGraph:
"""
Convert a DAG to its corresponding PAG
Parameters
Expand All @@ -27,8 +28,8 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
dg = nx.DiGraph()
true_dag = nx.DiGraph()
nodes = dag.get_nodes()
observed_nodes = list(set(nodes) - set(islatent))
mod_nodes = observed_nodes + islatent
observed_nodes = list(set(nodes) - set(islatent) - set(isselection))
mod_nodes = observed_nodes + islatent + isselection
nodes = dag.get_nodes()
nodes_ids = {node: i for i, node in enumerate(nodes)}
mod_nodeids = {node: i for i, node in enumerate(mod_nodes)}
Expand Down Expand Up @@ -65,7 +66,7 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
for Z in combinations(observed_nodes, l):
if nodex in Z or nodey in Z:
continue
if d_separated(dg, {nodes_ids[nodex]}, {nodes_ids[nodey]}, set(nodes_ids[z] for z in Z)):
if d_separated(dg, {nodes_ids[nodex]}, {nodes_ids[nodey]}, set(nodes_ids[z] for z in Z) | set([nodes_ids[s] for s in isselection])):
if edge:
PAG.remove_edge(edge)
sepset[(nodes_ids[nodex], nodes_ids[nodey])] |= set(Z)
Expand Down Expand Up @@ -105,6 +106,13 @@ def dag2pag(dag: Dag, islatent: List[Node]) -> GeneralGraph:
change_flag = ruleR4B(PAG, -1, data, independence_test_method, 0.05, sep_sets=sepset_reindexed,
change_flag=change_flag,
bk=None, verbose=False)
change_flag = ruleR5(PAG, changeFlag=change_flag, verbose=True)
change_flag = ruleR6(PAG, changeFlag=change_flag)
change_flag = ruleR7(PAG, changeFlag=change_flag)
change_flag = rule8(PAG, nodes=observed_nodes, changeFlag=change_flag)
change_flag = rule9(PAG, nodes=observed_nodes, changeFlag=change_flag)
change_flag = rule10(PAG, changeFlag=change_flag)

return PAG


Expand Down
14 changes: 14 additions & 0 deletions tests/TestDAG2PAG.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,17 @@ def test_case3(self):
print(pag)
graphviz_pag = GraphUtils.to_pgv(pag)
graphviz_pag.draw("pag.png", prog='dot', format='png')

def test_case_selection(self):
nodes = []
for i in range(5):
nodes.append(GraphNode(str(i)))
dag = Dag(nodes)
dag.add_directed_edge(nodes[0], nodes[1])
dag.add_directed_edge(nodes[1], nodes[2])
dag.add_directed_edge(nodes[2], nodes[3])
# Selection nodes
dag.add_directed_edge(nodes[3], nodes[4])
dag.add_directed_edge(nodes[0], nodes[4])
pag = dag2pag(dag, islatent=[], isselection=[nodes[4]])
print(pag)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 2 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 -1 -1 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
2 0 0 0 2 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
2 0 0 0 -1 0 -1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -1 0
0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0
Expand All @@ -12,17 +12,17 @@
0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 2
0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -1 0 0 -1
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 -1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 -1 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -1 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -1 0 0 0 0 0 2 0 2 0 0 0 0 2 -1 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 -1 0 0 0 0 0 -1 0 -1 0 0 0 0 2 -1 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 -1 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0
Expand Down
Loading

0 comments on commit 7924253

Please sign in to comment.