Skip to content

Commit

Permalink
Merge pull request #214 from py-why/revert-206-main
Browse files Browse the repository at this point in the history
Revert "Graph operations compatible with np array"
  • Loading branch information
kunwuz authored Jan 10, 2025
2 parents 67637e6 + f9caf66 commit bdfe3c8
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 40 deletions.
21 changes: 3 additions & 18 deletions causallearn/graph/Dag.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
from itertools import combinations
from typing import List, Optional, Union
from typing import List

import networkx as nx
import numpy as np
Expand All @@ -18,23 +18,8 @@
# or latent, with at most one edge per node pair, and no edges to self.
class Dag(GeneralGraph):

def __init__(self, nodes: Optional[List[Node]]=None, graph: Union[np.ndarray, nx.Graph, None]=None):
if nodes is not None:
self._init_from_nodes(nodes)
elif graph is not None:
if isinstance(graph, np.ndarray):
nodes = [Node(node_name=str(i)) for i in range(len(graph))]
self._init_from_nodes(nodes)
for i in range(len(nodes)):
for j in range(len(nodes)):
if graph[i, j] == 1:
self.add_directed_edge(nodes[i], nodes[j])
else:
pass
else:
raise ValueError("Dag.__init__() requires argument 'nodes' or 'graph'")

def _init_from_nodes(self, nodes: List[Node]):
def __init__(self, nodes: List[Node]):

# for node in nodes:
# if not isinstance(node, type(GraphNode)):
# raise TypeError("Graphs must be instantiated with a list of GraphNodes")
Expand Down
19 changes: 6 additions & 13 deletions causallearn/graph/Node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,27 @@

# Represents an object with a name, node type, and position that can serve as a
# node in a graph.
from typing import Optional
from causallearn.graph.NodeType import NodeType
from causallearn.graph.NodeVariableType import NodeVariableType


class Node:
node_type: NodeType
node_name: str

def __init__(self, node_name: Optional[str] = None, node_type: Optional[NodeType] = None) -> None:
self.node_name = node_name
self.node_type = node_type

# @return the name of the variable.
def get_name(self) -> str:
return self.node_name
pass

# set the name of the variable
def set_name(self, name: str):
self.node_name = name
pass

# @return the node type of the variable
def get_node_type(self) -> NodeType:
return self.node_type
pass

# set the node type of the variable
def set_node_type(self, node_type: NodeType):
self.node_type = node_type
pass

# @return the intervention type
def get_node_variable_type(self) -> NodeVariableType:
Expand All @@ -42,7 +35,7 @@ def set_node_variable_type(self, var_type: NodeVariableType):

# @return the name of the node as its string representation
def __str__(self):
return self.node_name
pass

# @return the x coordinate of the center of the node
def get_center_x(self) -> int:
Expand All @@ -66,7 +59,7 @@ def set_center(self, center_x: int, center_y: int):

# @return a hashcode for this variable
def __hash__(self):
return hash(self.node_name)
pass

# @return true iff this variable is equal to the given variable
def __eq__(self, other):
Expand Down
11 changes: 2 additions & 9 deletions causallearn/utils/DAG2CPDAG.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Union
import numpy as np

from causallearn.graph.Dag import Dag
Expand All @@ -7,7 +6,7 @@
from causallearn.graph.GeneralGraph import GeneralGraph


def dag2cpdag(G: Union[Dag, np.ndarray]) -> GeneralGraph:
def dag2cpdag(G: Dag) -> GeneralGraph:
"""
Convert a DAG to its corresponding PDAG
Expand All @@ -23,13 +22,7 @@ def dag2cpdag(G: Union[Dag, np.ndarray]) -> GeneralGraph:
-------
Yuequn Liu@dmirlab, Wei Chen@dmirlab, Kun Zhang@CMU
"""

if isinstance(G, np.ndarray):
# convert np array to Dag graph
G = Dag(graph=G)
elif not isinstance(G, Dag):
raise TypeError("parameter graph should be `Dag` or `np.ndarry`")


# order the edges in G
nodes_order = list(
map(lambda x: G.node_map[x], G.get_causal_ordering())
Expand Down

0 comments on commit bdfe3c8

Please sign in to comment.