From e8e380b6810ec3fd876e33850810ab0656b64bc6 Mon Sep 17 00:00:00 2001 From: Pieter Eendebak Date: Fri, 19 Jan 2024 23:08:08 +0100 Subject: [PATCH] Add method to return edge indices from endpoints (#1055) * add edge_indices_from_endpoints for graph and digraph * fix tests * Add release notes * linter * black * Fix text_signature for new method --------- Co-authored-by: Matthew Treinish --- ...ints-method-to-graph-d58dc98719c4db39.yaml | 6 ++++++ rustworkx/digraph.pyi | 1 + rustworkx/graph.pyi | 1 + src/digraph.rs | 18 +++++++++++++++++ src/graph.rs | 18 +++++++++++++++++ tests/rustworkx_tests/digraph/test_edges.py | 19 ++++++++++++++++++ tests/rustworkx_tests/graph/test_edges.py | 20 +++++++++++++++++++ 7 files changed, 83 insertions(+) create mode 100644 releasenotes/notes/add-edge_indices_from_endpoints-method-to-graph-d58dc98719c4db39.yaml diff --git a/releasenotes/notes/add-edge_indices_from_endpoints-method-to-graph-d58dc98719c4db39.yaml b/releasenotes/notes/add-edge_indices_from_endpoints-method-to-graph-d58dc98719c4db39.yaml new file mode 100644 index 0000000000..2a7082b481 --- /dev/null +++ b/releasenotes/notes/add-edge_indices_from_endpoints-method-to-graph-d58dc98719c4db39.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added method :meth:`~rustworkx.PyGraph.edge_indices_from_endpoints` which returns the indices of all edges + between the specified endpoints. For :class:`~rustworkx.PyDiGraph` there is a corresponding method that returns the + directed edges. diff --git a/rustworkx/digraph.pyi b/rustworkx/digraph.pyi index ae470846f7..9648ca62e6 100644 --- a/rustworkx/digraph.pyi +++ b/rustworkx/digraph.pyi @@ -68,6 +68,7 @@ class PyDiGraph(Generic[S, T]): def copy(self) -> PyDiGraph[S, T]: ... def edge_index_map(self) -> EdgeIndexMap[T]: ... def edge_indices(self) -> EdgeIndices: ... + def edge_indices_from_endpoints(self, node_a: int, node_b: int) -> EdgeIndices: ... def edge_list(self) -> EdgeList: ... def edges(self) -> list[T]: ... def edge_subgraph(self, edge_list: Sequence[tuple[int, int]], /) -> PyDiGraph[S, T]: ... diff --git a/rustworkx/graph.pyi b/rustworkx/graph.pyi index d710f6a1d8..956c39c8c9 100644 --- a/rustworkx/graph.pyi +++ b/rustworkx/graph.pyi @@ -66,6 +66,7 @@ class PyGraph(Generic[S, T]): def degree(self, node: int, /) -> int: ... def edge_index_map(self) -> EdgeIndexMap[T]: ... def edge_indices(self) -> EdgeIndices: ... + def edge_indices_from_endpoints(self, node_a: int, node_b: int) -> EdgeIndices: ... def edge_list(self) -> EdgeList: ... def edges(self) -> list[T]: ... def edge_subgraph(self, edge_list: Sequence[tuple[int, int]], /) -> PyGraph[S, T]: ... diff --git a/src/digraph.rs b/src/digraph.rs index 2d8a97d18a..1afb1ba98b 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -547,6 +547,24 @@ impl PyDiGraph { } } + /// Return a list of indices of all directed edges between specified nodes + /// + /// :returns: A list of all the edge indices connecting the specified start and end node + /// :rtype: EdgeIndices + pub fn edge_indices_from_endpoints(&self, node_a: usize, node_b: usize) -> EdgeIndices { + let node_a_index = NodeIndex::new(node_a); + let node_b_index = NodeIndex::new(node_b); + + EdgeIndices { + edges: self + .graph + .edges_directed(node_a_index, petgraph::Direction::Outgoing) + .filter(|edge| edge.target() == node_b_index) + .map(|edge| edge.id().index()) + .collect(), + } + } + /// Return a list of all node data. /// /// :returns: A list of all the node data objects in the graph diff --git a/src/graph.rs b/src/graph.rs index 45d8902a71..0ad81c25b6 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -422,6 +422,24 @@ impl PyGraph { } } + /// Return a list of indices of all edges between specified nodes + /// + /// :returns: A list of all the edge indices connecting the specified start and end node + /// :rtype: EdgeIndices + pub fn edge_indices_from_endpoints(&self, node_a: usize, node_b: usize) -> EdgeIndices { + let node_a_index = NodeIndex::new(node_a); + let node_b_index = NodeIndex::new(node_b); + + EdgeIndices { + edges: self + .graph + .edges_directed(node_a_index, petgraph::Direction::Outgoing) + .filter(|edge| edge.target() == node_b_index) + .map(|edge| edge.id().index()) + .collect(), + } + } + /// Return a list of all node data. /// /// :returns: A list of all the node data objects in the graph diff --git a/tests/rustworkx_tests/digraph/test_edges.py b/tests/rustworkx_tests/digraph/test_edges.py index ef0af66dd6..9ec06d31a3 100644 --- a/tests/rustworkx_tests/digraph/test_edges.py +++ b/tests/rustworkx_tests/digraph/test_edges.py @@ -370,6 +370,25 @@ def test_weighted_edge_list_empty(self): dag = rustworkx.PyDiGraph() self.assertEqual([], dag.weighted_edge_list()) + def test_edge_indices_from_endpoints(self): + dag = rustworkx.PyDiGraph() + dag.add_nodes_from(list(range(4))) + edge_list = [ + (0, 1, None), + (1, 2, None), + (0, 2, None), + (2, 3, None), + (0, 3, None), + (0, 2, None), + ] + dag.add_edges_from(edge_list) + indices = dag.edge_indices_from_endpoints(0, 0) + self.assertEqual(indices, []) + indices = dag.edge_indices_from_endpoints(0, 1) + self.assertEqual(indices, [0]) + indices = dag.edge_indices_from_endpoints(0, 2) + self.assertEqual(set(indices), {2, 5}) + def test_extend_from_edge_list(self): dag = rustworkx.PyDAG() edge_list = [(0, 1), (1, 2), (0, 2), (2, 3), (0, 3)] diff --git a/tests/rustworkx_tests/graph/test_edges.py b/tests/rustworkx_tests/graph/test_edges.py index 04f24af1a3..628a9788b9 100644 --- a/tests/rustworkx_tests/graph/test_edges.py +++ b/tests/rustworkx_tests/graph/test_edges.py @@ -331,6 +331,26 @@ def test_weighted_edge_list_empty(self): graph = rustworkx.PyGraph() self.assertEqual([], graph.weighted_edge_list()) + def test_edge_indices_from_endpoints(self): + dag = rustworkx.PyGraph() + dag.add_nodes_from(list(range(4))) + edge_list = [ + (0, 1, None), + (1, 2, None), + (0, 2, None), + (2, 3, None), + (0, 3, None), + (0, 2, None), + (2, 0, None), + ] + dag.add_edges_from(edge_list) + indices = dag.edge_indices_from_endpoints(0, 0) + self.assertEqual(indices, []) + indices = dag.edge_indices_from_endpoints(0, 1) + self.assertEqual(set(indices), {0}) + indices = dag.edge_indices_from_endpoints(0, 2) + self.assertEqual(set(indices), {2, 5, 6}) + def test_extend_from_edge_list(self): graph = rustworkx.PyGraph() edge_list = [(0, 1), (1, 2), (0, 2), (2, 3), (0, 3)]