Skip to content

Commit

Permalink
Add draft for directed acyclic graph class
Browse files Browse the repository at this point in the history
  • Loading branch information
mirand863 committed Feb 2, 2024
1 parent 9d40db5 commit 6c78704
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 3 deletions.
59 changes: 59 additions & 0 deletions hiclass/DirectedAcyclicGraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from hiclass.Node import Node


class DirectedAcyclicGraph:
"""
Manages the directed acyclic graph used in HiClass.
It tries to copy networkx API as much as possible,
but extends it by adding support for multiple nodes with the same name,
as long as they have different predecessors.
"""


def __init__(self, n_rows):
"""
Initialize a directed acyclic graph.
Parameters
----------
n_rows : int
The number of rows in x and y, i.e., the features and labels matrices.
"""
self.root = Node(n_rows, "root")
self.nodes = {
"root": self.root
}

def add_node(self, node_name):
"""
Add a new as successor of the root node.
Parameters
----------
node_name : str
The name of the node.
"""
if node_name != "":
new_node = self.root.add_successor(node_name)
self.nodes[node_name] = new_node

def add_path(self, nodes):
"""
Add new nodes from a path.
Parameters
----------
nodes : np.ndarray
The list with the path, e.g., [a b c] = a -> b -> c
"""
successor = nodes[0]
leaf = self.root.add_successor(successor)
self.nodes[successor] = leaf
index = 0
while index < len(nodes) - 1 and nodes[index] != "":
successor = nodes[index + 1]
if successor != "":
leaf = leaf.add_successor(successor)
self.nodes[successor] = leaf
index = index + 1
95 changes: 92 additions & 3 deletions hiclass/LocalClassifierPerParentNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
Numeric and string output labels are both handled.
"""

from copy import deepcopy

import networkx as nx
import numpy as np
from copy import deepcopy
from sklearn.base import BaseEstimator
from sklearn.utils.validation import _check_sample_weight
from sklearn.utils.validation import check_array, check_is_fitted

from hiclass.ConstantClassifier import ConstantClassifier
from hiclass.DirectedAcyclicGraph import DirectedAcyclicGraph
from hiclass.HierarchicalClassifier import HierarchicalClassifier
from hiclass.HierarchicalClassifier import make_leveled


class LocalClassifierPerParentNode(BaseEstimator, HierarchicalClassifier):
Expand Down Expand Up @@ -98,7 +100,7 @@ def fit(self, X, y, sample_weight=None):
Fitted estimator.
"""
# Execute common methods necessary before fitting
super()._pre_fit(X, y, sample_weight)
self._pre_fit(X, y, sample_weight)

# Fit local classifiers in DAG
super().fit(X, y)
Expand Down Expand Up @@ -157,6 +159,93 @@ def predict(self, X):

return y

def _pre_fit(self, X, y, sample_weight):
# Check that X and y have correct shape
# and convert them to np.ndarray if need be

if not self.bert:
self.X_, self.y_ = self._validate_data(
X, y, multi_output=True, accept_sparse="csr", allow_nd=True
)
else:
self.X_ = np.array(X)
self.y_ = np.array(y)

if sample_weight is not None:
self.sample_weight_ = _check_sample_weight(sample_weight, X)
else:
self.sample_weight_ = None

self.y_ = make_leveled(self.y_)

# Create and configure logger
self._create_logger()

# Create DAG from self.y_ and store to self.hierarchy_
self._create_digraph()

# If user passes edge_list, then export
# DAG to CSV file to visualize with Gephi
self._export_digraph()

# Assert that graph is directed acyclic
self._assert_digraph_is_dag()

# If y is 1D, convert to 2D for binary policies
self._convert_1d_y_to_2d()

# Detect root(s) and add artificial root to DAG
self._add_artificial_root()

# Initialize local classifiers in DAG
self._initialize_local_classifiers()

def _create_digraph(self):
# Create DiGraph
self.hierarchy_ = DirectedAcyclicGraph(self.X_.shape[0])

# Save dtype of y_
self.dtype_ = self.y_.dtype

self._create_digraph_1d()

self._create_digraph_2d()

if self.y_.ndim > 2:
# Unsuported dimension
self.logger_.error(f"y with {self.y_.ndim} dimensions detected")
raise ValueError(
f"Creating graph from y with {self.y_.ndim} dimensions is not supported"
)

def _create_digraph_1d(self):
# Flatten 1D disguised as 2D
if self.y_.ndim == 2 and self.y_.shape[1] == 1:
self.logger_.info("Converting y to 1D")
self.y_ = self.y_.flatten()
if self.y_.ndim == 1:
# Create max_levels_ variable
self.max_levels_ = 1
self.logger_.info(f"Creating digraph from {self.y_.size} 1D labels")
for label in self.y_:
self.hierarchy_.add_node(label)

def _create_digraph_2d(self):
if self.y_.ndim == 2:
# Create max_levels variable
self.max_levels_ = self.y_.shape[1]
rows, columns = self.y_.shape
self.logger_.info(f"Creating digraph from {rows} 2D labels")
for row in range(rows):
path = self.y_[row, :]
self.hierarchy_.add_path(path)

def _assert_digraph_is_dag(self):
# Assert that graph is directed acyclic
if not self.hierarchy_.is_acyclic_graph():
self.logger_.error("Cycle detected in graph")
raise ValueError("Graph is not directed acyclic")

def _predict_remaining_levels(self, X, y):
for level in range(1, y.shape[1]):
predecessors = set(y[:, level - 1])
Expand Down
40 changes: 40 additions & 0 deletions hiclass/Node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import numpy as np

class Node:
"""Manages data for an individual node in the hierarchy."""

def __init__(self, n_rows, name):
"""
Initialize an individual node.
Parameters
----------
n_rows : int
The number of rows in x and y.
"""
self.n_rows = n_rows
self.mask = np.full(n_rows, True)
self.children = dict()
self.name = name

def add_successor(self, successor_name):
"""
Add a new successor.
Parameters
----------
node_name : str
The name of the new successor.
Returns
-------
successor : Node
The new successor created.
"""
if successor_name != "":
if not successor_name in self.children:
new_successor = Node(self.n_rows, successor_name)
self.children[successor_name] = new_successor
return new_successor
else:
return self.children[successor_name]
4 changes: 4 additions & 0 deletions hiclass/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Init module for the library."""

from .DirectedAcyclicGraph import DirectedAcyclicGraph
from .LocalClassifierPerLevel import LocalClassifierPerLevel
from .LocalClassifierPerNode import LocalClassifierPerNode
from .LocalClassifierPerParentNode import LocalClassifierPerParentNode
from .MultiLabelLocalClassifierPerNode import MultiLabelLocalClassifierPerNode
from .MultiLabelLocalClassifierPerParentNode import (
MultiLabelLocalClassifierPerParentNode,
)
from .Node import Node
from ._version import get_versions

__version__ = get_versions()["version"]
Expand All @@ -18,4 +20,6 @@
"LocalClassifierPerLevel",
"MultiLabelLocalClassifierPerNode",
"MultiLabelLocalClassifierPerParentNode",
"Node",
"DirectedAcyclicGraph",
]
36 changes: 36 additions & 0 deletions tests/test_DirectedAcyclicGraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np

from hiclass import DirectedAcyclicGraph


def test_add_node():
n_rows = 3
dag = DirectedAcyclicGraph(n_rows)
dag.add_node("node1")
dag.add_node("node2")
dag.add_node("node1")
dag.add_node("node2")
assert 3 == len(dag.nodes)
assert "root" in dag.nodes
assert "node1" in dag.nodes
assert "node2" in dag.nodes


def test_add_path():
paths = np.array([
["a", "c", "d"],
["b", "c", "e"],
["a", "c", "f"],
["c", "", ""],
["a", "c", "d"],
["b", "c", "e"],
["a", "c", "f"],
["c", "", ""],
["", "", ""],
])
rows = paths.shape[0]
dag = DirectedAcyclicGraph(rows)
for row in range(rows):
path = paths[row, :]
dag.add_path(path)
assert 8 == len(dag.nodes)
14 changes: 14 additions & 0 deletions tests/test_Node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from hiclass import Node


def test_add_successor():
n_rows = 3
name = "root"
node = Node(n_rows, name)
assert node.name == "root"
successor1 = node.add_successor("node1")
successor2 = node.add_successor("node2")
assert successor1 == node.add_successor("node1")
assert successor2 == node.add_successor("node2")
assert n_rows == node.n_rows
assert 2 == len(node.children)

0 comments on commit 6c78704

Please sign in to comment.