-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add draft for directed acyclic graph class
- Loading branch information
Showing
6 changed files
with
245 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from hiclass.Node import Node | ||
|
||
|
||
class DirectedAcyclicGraph: | ||
""" | ||
Manages the directed acyclic graph used in HiClass. | ||
It tries to copy networkx API as much as possible, | ||
but extends it by adding support for multiple nodes with the same name, | ||
as long as they have different predecessors. | ||
""" | ||
|
||
|
||
def __init__(self, n_rows): | ||
""" | ||
Initialize a directed acyclic graph. | ||
Parameters | ||
---------- | ||
n_rows : int | ||
The number of rows in x and y, i.e., the features and labels matrices. | ||
""" | ||
self.root = Node(n_rows, "root") | ||
self.nodes = { | ||
"root": self.root | ||
} | ||
|
||
def add_node(self, node_name): | ||
""" | ||
Add a new as successor of the root node. | ||
Parameters | ||
---------- | ||
node_name : str | ||
The name of the node. | ||
""" | ||
if node_name != "": | ||
new_node = self.root.add_successor(node_name) | ||
self.nodes[node_name] = new_node | ||
|
||
def add_path(self, nodes): | ||
""" | ||
Add new nodes from a path. | ||
Parameters | ||
---------- | ||
nodes : np.ndarray | ||
The list with the path, e.g., [a b c] = a -> b -> c | ||
""" | ||
successor = nodes[0] | ||
leaf = self.root.add_successor(successor) | ||
self.nodes[successor] = leaf | ||
index = 0 | ||
while index < len(nodes) - 1 and nodes[index] != "": | ||
successor = nodes[index + 1] | ||
if successor != "": | ||
leaf = leaf.add_successor(successor) | ||
self.nodes[successor] = leaf | ||
index = index + 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import numpy as np | ||
|
||
class Node: | ||
"""Manages data for an individual node in the hierarchy.""" | ||
|
||
def __init__(self, n_rows, name): | ||
""" | ||
Initialize an individual node. | ||
Parameters | ||
---------- | ||
n_rows : int | ||
The number of rows in x and y. | ||
""" | ||
self.n_rows = n_rows | ||
self.mask = np.full(n_rows, True) | ||
self.children = dict() | ||
self.name = name | ||
|
||
def add_successor(self, successor_name): | ||
""" | ||
Add a new successor. | ||
Parameters | ||
---------- | ||
node_name : str | ||
The name of the new successor. | ||
Returns | ||
------- | ||
successor : Node | ||
The new successor created. | ||
""" | ||
if successor_name != "": | ||
if not successor_name in self.children: | ||
new_successor = Node(self.n_rows, successor_name) | ||
self.children[successor_name] = new_successor | ||
return new_successor | ||
else: | ||
return self.children[successor_name] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import numpy as np | ||
|
||
from hiclass import DirectedAcyclicGraph | ||
|
||
|
||
def test_add_node(): | ||
n_rows = 3 | ||
dag = DirectedAcyclicGraph(n_rows) | ||
dag.add_node("node1") | ||
dag.add_node("node2") | ||
dag.add_node("node1") | ||
dag.add_node("node2") | ||
assert 3 == len(dag.nodes) | ||
assert "root" in dag.nodes | ||
assert "node1" in dag.nodes | ||
assert "node2" in dag.nodes | ||
|
||
|
||
def test_add_path(): | ||
paths = np.array([ | ||
["a", "c", "d"], | ||
["b", "c", "e"], | ||
["a", "c", "f"], | ||
["c", "", ""], | ||
["a", "c", "d"], | ||
["b", "c", "e"], | ||
["a", "c", "f"], | ||
["c", "", ""], | ||
["", "", ""], | ||
]) | ||
rows = paths.shape[0] | ||
dag = DirectedAcyclicGraph(rows) | ||
for row in range(rows): | ||
path = paths[row, :] | ||
dag.add_path(path) | ||
assert 8 == len(dag.nodes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from hiclass import Node | ||
|
||
|
||
def test_add_successor(): | ||
n_rows = 3 | ||
name = "root" | ||
node = Node(n_rows, name) | ||
assert node.name == "root" | ||
successor1 = node.add_successor("node1") | ||
successor2 = node.add_successor("node2") | ||
assert successor1 == node.add_successor("node1") | ||
assert successor2 == node.add_successor("node2") | ||
assert n_rows == node.n_rows | ||
assert 2 == len(node.children) |