diff --git a/causallearn/graph/Dag.py b/causallearn/graph/Dag.py index 7df8cc32..1f549ff6 100644 --- a/causallearn/graph/Dag.py +++ b/causallearn/graph/Dag.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 from itertools import combinations -from typing import List, Optional, Union +from typing import List import networkx as nx import numpy as np @@ -18,23 +18,8 @@ # or latent, with at most one edge per node pair, and no edges to self. class Dag(GeneralGraph): - def __init__(self, nodes: Optional[List[Node]]=None, graph: Union[np.ndarray, nx.Graph, None]=None): - if nodes is not None: - self._init_from_nodes(nodes) - elif graph is not None: - if isinstance(graph, np.ndarray): - nodes = [Node(node_name=str(i)) for i in range(len(graph))] - self._init_from_nodes(nodes) - for i in range(len(nodes)): - for j in range(len(nodes)): - if graph[i, j] == 1: - self.add_directed_edge(nodes[i], nodes[j]) - else: - pass - else: - raise ValueError("Dag.__init__() requires argument 'nodes' or 'graph'") - - def _init_from_nodes(self, nodes: List[Node]): + def __init__(self, nodes: List[Node]): + # for node in nodes: # if not isinstance(node, type(GraphNode)): # raise TypeError("Graphs must be instantiated with a list of GraphNodes") diff --git a/causallearn/graph/Node.py b/causallearn/graph/Node.py index 798a616f..3b9b3e2a 100644 --- a/causallearn/graph/Node.py +++ b/causallearn/graph/Node.py @@ -2,34 +2,27 @@ # Represents an object with a name, node type, and position that can serve as a # node in a graph. -from typing import Optional from causallearn.graph.NodeType import NodeType from causallearn.graph.NodeVariableType import NodeVariableType class Node: - node_type: NodeType - node_name: str - def __init__(self, node_name: Optional[str] = None, node_type: Optional[NodeType] = None) -> None: - self.node_name = node_name - self.node_type = node_type - # @return the name of the variable. def get_name(self) -> str: - return self.node_name + pass # set the name of the variable def set_name(self, name: str): - self.node_name = name + pass # @return the node type of the variable def get_node_type(self) -> NodeType: - return self.node_type + pass # set the node type of the variable def set_node_type(self, node_type: NodeType): - self.node_type = node_type + pass # @return the intervention type def get_node_variable_type(self) -> NodeVariableType: @@ -42,7 +35,7 @@ def set_node_variable_type(self, var_type: NodeVariableType): # @return the name of the node as its string representation def __str__(self): - return self.node_name + pass # @return the x coordinate of the center of the node def get_center_x(self) -> int: @@ -66,7 +59,7 @@ def set_center(self, center_x: int, center_y: int): # @return a hashcode for this variable def __hash__(self): - return hash(self.node_name) + pass # @return true iff this variable is equal to the given variable def __eq__(self, other): diff --git a/causallearn/utils/DAG2CPDAG.py b/causallearn/utils/DAG2CPDAG.py index 01d2f2ee..2604b1f7 100644 --- a/causallearn/utils/DAG2CPDAG.py +++ b/causallearn/utils/DAG2CPDAG.py @@ -1,4 +1,3 @@ -from typing import Union import numpy as np from causallearn.graph.Dag import Dag @@ -7,7 +6,7 @@ from causallearn.graph.GeneralGraph import GeneralGraph -def dag2cpdag(G: Union[Dag, np.ndarray]) -> GeneralGraph: +def dag2cpdag(G: Dag) -> GeneralGraph: """ Convert a DAG to its corresponding PDAG @@ -23,13 +22,7 @@ def dag2cpdag(G: Union[Dag, np.ndarray]) -> GeneralGraph: ------- Yuequn Liu@dmirlab, Wei Chen@dmirlab, Kun Zhang@CMU """ - - if isinstance(G, np.ndarray): - # convert np array to Dag graph - G = Dag(graph=G) - elif not isinstance(G, Dag): - raise TypeError("parameter graph should be `Dag` or `np.ndarry`") - + # order the edges in G nodes_order = list( map(lambda x: G.node_map[x], G.get_causal_ordering())