diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 9155cdc..0000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "kbest-assignment-enumeration-rust"] - path = kbest-assignment-enumeration-rust - url = https://github.com/studentofkyoto/kbest-assignment-enumeration-rust.git diff --git a/Cargo.toml b/Cargo.toml index 5381ec9..0611ba5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,12 @@ numpy = { git = "https://github.com/PyO3/rust-numpy", branch = "main" } ndarray-linalg = { git = "https://github.com/rust-ndarray/ndarray-linalg", features = ["openblas-static"] } [dependencies.kbest-lap] -path = "kbest-assignment-enumeration-rust" +git = "https://github.com/studentofkyoto/kbest-assignment-enumeration-rust" +tag = "v0.1.1" + +[dependencies.all-lap-rust] +git = "https://github.com/studentofkyoto/all-lap-rust" +tag = "v0.1.1" [dependencies.pyo3] version = "0.14.0" diff --git a/kbest-assignment-enumeration-rust b/kbest-assignment-enumeration-rust deleted file mode 160000 index f7e9522..0000000 --- a/kbest-assignment-enumeration-rust +++ /dev/null @@ -1 +0,0 @@ -Subproject commit f7e9522216d15fe38077c107956088876fff2b6d diff --git a/kbest_lap/__init__.py b/kbest_lap/__init__.py index 7e5d9ca..a0837c1 100644 --- a/kbest_lap/__init__.py +++ b/kbest_lap/__init__.py @@ -1,8 +1,9 @@ """K-th best matching enumeration""" -from ._wrapper import CostMatrix, Edge, Matching, enumerate_kbest -from .rust_ext import Iter, State +from .rust_ext import Node, NodeSet, SortedMatchingIterator + +LEFT = True +RIGHT = False __all__ = [ - "CostMatrix", "Edge", "Matching", "enumerate_kbest", - "Iter", "State" + "SortedMatchingIterator", "NodeSet", "Node" ] diff --git a/kbest_lap/_wrapper.py b/kbest_lap/_wrapper.py deleted file mode 100644 index 43f2649..0000000 --- a/kbest_lap/_wrapper.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Type aliases for semantic naming""" -from typing import Callable, Iterator, List, Optional, Tuple, TypeVar - -import numpy as np - -from .rust_ext import Iter, get_costs_reduced - -T = TypeVar("T") -Matching = List[Tuple[T, T]] -MatchingIndices = Tuple[np.ndarray, np.ndarray] -Edge = Tuple[T, T, float] - -EnumerationAlgorithm = Callable[[List[Edge[T]]], Iterator[Matching[T]]] -CostMatrix = np.ndarray - -def enumerate_kbest( - cost_matrix: CostMatrix, *, - yield_iter: Optional[Callable[[List[Tuple[int, int]], MatchingIndices], MatchingIndices]] = None -) -> Iterator[MatchingIndices]: - """ - When `ignore_same_value` is set to True, yield only one matching for each cost. - Otherwise, return all possible matchings, even if some of them has the same value. - - Iterate through triplets of (matching cost, row indices of a solution, column indices of a solution) - """ - for state in Iter(cost_matrix): - a_solution = np.arange(len(state.a_solution)), state.a_solution - if yield_iter is None: - yield a_solution - else: - costs = get_costs_reduced(state) - rows, cols = np.nonzero(np.isclose(costs, 0.)) - yield from yield_iter(list(zip(rows, cols)), a_solution) diff --git a/src/lib.rs b/src/lib.rs index 5495d34..a3afd2e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,69 +1,124 @@ -// use numpy::{PyReadonlyArray2}; -use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArray2}; +use all_lap_rust::bipartite as bp; +use all_lap_rust::contains::Contains; +use kbest_lap as kl; +use numpy::PyReadonlyArray2; use pyo3::prelude::*; use pyo3::PyIterProtocol; +use std::ops::DerefMut; -type Matrix = ndarray::Array2; +#[pyclass] +#[derive(Clone)] +struct Node { + inner: bp::Node, +} + +#[pymethods] +impl Node { + #[new] + fn __new__(lr: bool, index: usize) -> Self { + let nodegroup = match lr { + false => bp::NodeGroup::Left, + true => bp::NodeGroup::Right, + }; + let inner = bp::Node::new(nodegroup, index); + Self { inner } + } +} #[pyclass] -struct State { - #[pyo3(get)] - cost_solution: f64, - costs_reduced: Matrix, - #[pyo3(get)] - a_solution: Vec, +#[derive(Clone)] +struct Matching { + inner: kl::Matching, } -impl From> for State { - fn from(inner: kbest_lap::State) -> Self { - State { - cost_solution: *inner.cost_solution, - costs_reduced: inner.costs_reduced, - a_solution: inner.a_solution.0, - } +#[pymethods] +impl Matching { + #[new] + fn new(v: Vec>) -> Self { + kl::Matching::new(v).into() + } + fn as_l2r(&self) -> PyResult>> { + Ok(self.inner.l2r.clone()) + } + + fn as_sparse(&self) -> PyResult> { + Ok(self.inner.iter_pairs().collect()) + } +} + +impl From for Matching { + fn from(val: kl::Matching) -> Self { + Self { inner: val } } } #[pyclass] -struct Iter { - inner: kbest_lap::KBestEnumeration, +#[derive(Clone)] +struct NodeSet { + inner: bp::NodeSet, } -#[pyproto] -impl PyIterProtocol for Iter { - fn __iter__(slf: PyRef) -> PyRef { - slf +#[pymethods] +impl NodeSet { + #[new] + fn __new__(nodes: Vec, lsize: usize) -> Self { + let hashset = nodes.into_iter().map(|x| x.inner).collect(); + Self { + inner: bp::NodeSet::new(hashset, lsize), + } + } +} + +impl Contains for NodeSet { + fn contains_node(&self, item: &bp::Node) -> bool { + self.inner.contains_node(item) } +} - fn __next__(mut slf: PyRefMut) -> Option { - let s = slf.inner.next()?; - Some(s.into()) +impl Contains for NodeSet { + fn contains_node(&self, item: &usize) -> bool { + self.inner.contains_node(item) } } +#[pyclass] +struct SortedMatchingIterator { + inner: kl::SortedMatchingCalculator, + allowed_start_nodes: NodeSet, +} + #[pymethods] -impl Iter { +impl SortedMatchingIterator { #[new] - fn new(m: PyReadonlyArray2) -> Iter { - let arr = m.as_array().to_owned(); - Iter { - inner: kbest_lap::KBestEnumeration::::new(arr).unwrap(), + fn new(m: PyReadonlyArray2, allowed_start_nodes: NodeSet) -> Self { + let costs = m.as_array().to_owned(); + let inner = kl::SortedMatchingCalculator::from_costs(costs); + Self { + inner, + allowed_start_nodes, } } } +#[pyproto] +impl PyIterProtocol for SortedMatchingIterator { + fn __iter__(slf: PyRef) -> PyRef { + slf + } + + fn __next__(mut slf: PyRefMut) -> Option { + let _self = slf.deref_mut(); + let m = _self.inner.next_item(&_self.allowed_start_nodes)?; + Some(m.into()) + } +} + #[pymodule] fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> { - // wrapper of `Iter` - m.add_class::()?; - m.add_class::()?; - - #[pyfn(m)] - #[pyo3(name = "get_costs_reduced")] - fn costs_reduced<'py>(py: Python<'py>, state: &State) -> &'py PyArrayDyn { - let dynmat = state.costs_reduced.clone().into_dyn(); - dynmat.into_pyarray(py) - } + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/tests/naive.py b/tests/naive.py index 53a4fd3..03c29a3 100644 --- a/tests/naive.py +++ b/tests/naive.py @@ -2,9 +2,7 @@ Naive approach; find by enumeration of perfect matching """ import heapq -from typing import Callable, Iterator, List, TypeVar - -from kbest_lap import Edge, Matching +from typing import Callable, Iterator, List, Tuple, TypeVar try: import networkx as nx @@ -12,16 +10,20 @@ except ImportError: print("Networkx not found. Regression test is not available") - T = TypeVar("T") -def enumerate_naive(edges: List[Edge[T]]) -> Iterator[Matching[T]]: + +Matching = List[Tuple[T, T]] + +MAX_HEAPSIZE = 100 + +def enumerate_naive(edges: List[Tuple[T, T, float]]) -> Iterator[Matching +]: """Enumerate best matchings""" - MAX_HEAPSIZE = 100 graph = nx.Graph() - for n1, n2, w in edges: - graph.add_node(n1) - graph.add_node(n2) - graph.add_edge(n1, n2, weight=w) + for node1, node2, weight in edges: + graph.add_node(node1) + graph.add_node(node2) + graph.add_edge(node1, node2, weight=weight) if not bprt.is_bipartite(graph): raise RuntimeError("Not bipartite") @@ -30,16 +32,16 @@ def enumerate_naive(edges: List[Edge[T]]) -> Iterator[Matching[T]]: lefts = sorted(left_set) rights = sorted(right_set) heap = [] - for m in _naive(lefts, rights, is_valid=lambda x, y: (x, y) in graph.edges): - score = sum(graph.edges[(n1, n2)]['weight'] for n1, n2 in m) + for matching in _naive(lefts, rights, is_valid=lambda x, y: (x, y) in graph.edges): + score = sum(graph.edges[(n1, n2)]['weight'] for n1, n2 in matching) if len(heap) <= MAX_HEAPSIZE: - heapq.heappush(heap, (score, m)) + heapq.heappush(heap, (score, matching)) else: - heapq.heappushpop(heap, (score, m)) + heapq.heappushpop(heap, (score, matching)) while heap != []: - score, m = heapq.heappop(heap) - yield m + score, matching = heapq.heappop(heap) + yield matching def _naive(lefts: List[T], rights: List[T], is_valid: Callable[[T, T], bool]) -> Iterator[Matching[T]]: @@ -59,5 +61,5 @@ def _naive(lefts: List[T], rights: List[T], is_valid: Callable[[T, T], bool]) -> if not is_valid(*pair): continue - for m in _naive(tail, rights[:i] + rights[i + 1:], is_valid=is_valid): - yield [pair] + m + for matching in _naive(tail, rights[:i] + rights[i + 1:], is_valid=is_valid): + yield [pair] + matching diff --git a/tests/test_enumerate.py b/tests/test_enumerate.py index c843abb..b116f15 100644 --- a/tests/test_enumerate.py +++ b/tests/test_enumerate.py @@ -1,24 +1,32 @@ +"""A little bit smarter way to enumerate""" import itertools as it import numpy as np import pytest -from kbest_lap import enumerate_kbest + +import kbest_lap @pytest.mark.parametrize('size', [5,] * 10) def test_linear_sum_assignment(size: int) -> None: + """Enumerate everything!""" cost_matrix = np.random.random((size, size)) - solutions = list(enumerate_kbest(cost_matrix, yield_iter=None)) + nodes = [kbest_lap.Node(bool(b), i) for b in (0, 1) for i in range(size)] + nodes = kbest_lap.NodeSet(nodes, size) + solutions = list(kbest_lap.SortedMatchingIterator(cost_matrix, nodes)) # solution count assert len(solutions) == np.math.factorial(size) + get_cost = lambda s: cost_matrix[tuple(zip(*s.as_sparse()))].sum() + # solution sort - solutions_sorted = sorted(solutions, key=lambda x: cost_matrix[x].sum()) + solution_costs = list(map(get_cost, solutions)) + solutions_sorted = [x[1] for x in sorted(enumerate(solutions), key=lambda i_s: solution_costs[i_s[0]])] assert solutions == solutions_sorted # regression with brute-force bf_costs = sorted(cost_matrix[range(size), js].sum() for js in it.permutations(range(size))) - solution_costs = [cost_matrix[s].sum() for s in solutions] + solution_costs = [get_cost(s) for s in solutions] np.testing.assert_allclose(bf_costs, solution_costs)