Skip to content

Commit

Permalink
Merge pull request #2 from studentofkyoto/adapt
Browse files Browse the repository at this point in the history
Versioning
  • Loading branch information
calcoloergosum authored Aug 5, 2021
2 parents c567298 + db60b39 commit b98ff90
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 104 deletions.
3 changes: 0 additions & 3 deletions .gitmodules

This file was deleted.

7 changes: 6 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ numpy = { git = "https://github.com/PyO3/rust-numpy", branch = "main" }
ndarray-linalg = { git = "https://github.com/rust-ndarray/ndarray-linalg", features = ["openblas-static"] }

[dependencies.kbest-lap]
path = "kbest-assignment-enumeration-rust"
git = "https://github.com/studentofkyoto/kbest-assignment-enumeration-rust"
tag = "v0.1.1"

[dependencies.all-lap-rust]
git = "https://github.com/studentofkyoto/all-lap-rust"
tag = "v0.1.1"

[dependencies.pyo3]
version = "0.14.0"
Expand Down
1 change: 0 additions & 1 deletion kbest-assignment-enumeration-rust
Submodule kbest-assignment-enumeration-rust deleted from f7e952
9 changes: 5 additions & 4 deletions kbest_lap/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""K-th best matching enumeration"""
from ._wrapper import CostMatrix, Edge, Matching, enumerate_kbest
from .rust_ext import Iter, State
from .rust_ext import Node, NodeSet, SortedMatchingIterator

LEFT = True
RIGHT = False

__all__ = [
"CostMatrix", "Edge", "Matching", "enumerate_kbest",
"Iter", "State"
"SortedMatchingIterator", "NodeSet", "Node"
]
33 changes: 0 additions & 33 deletions kbest_lap/_wrapper.py

This file was deleted.

135 changes: 95 additions & 40 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,69 +1,124 @@
// use numpy::{PyReadonlyArray2};
use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArray2};
use all_lap_rust::bipartite as bp;
use all_lap_rust::contains::Contains;
use kbest_lap as kl;
use numpy::PyReadonlyArray2;
use pyo3::prelude::*;
use pyo3::PyIterProtocol;
use std::ops::DerefMut;

type Matrix = ndarray::Array2<f64>;
#[pyclass]
#[derive(Clone)]
struct Node {
inner: bp::Node,
}

#[pymethods]
impl Node {
#[new]
fn __new__(lr: bool, index: usize) -> Self {
let nodegroup = match lr {
false => bp::NodeGroup::Left,
true => bp::NodeGroup::Right,
};
let inner = bp::Node::new(nodegroup, index);
Self { inner }
}
}

#[pyclass]
struct State {
#[pyo3(get)]
cost_solution: f64,
costs_reduced: Matrix,
#[pyo3(get)]
a_solution: Vec<usize>,
#[derive(Clone)]
struct Matching {
inner: kl::Matching,
}

impl From<kbest_lap::State<f64>> for State {
fn from(inner: kbest_lap::State<f64>) -> Self {
State {
cost_solution: *inner.cost_solution,
costs_reduced: inner.costs_reduced,
a_solution: inner.a_solution.0,
}
#[pymethods]
impl Matching {
#[new]
fn new(v: Vec<Option<usize>>) -> Self {
kl::Matching::new(v).into()
}
fn as_l2r(&self) -> PyResult<Vec<Option<usize>>> {
Ok(self.inner.l2r.clone())
}

fn as_sparse(&self) -> PyResult<Vec<(usize, usize)>> {
Ok(self.inner.iter_pairs().collect())
}
}

impl From<kl::Matching> for Matching {
fn from(val: kl::Matching) -> Self {
Self { inner: val }
}
}

#[pyclass]
struct Iter {
inner: kbest_lap::KBestEnumeration<f64>,
#[derive(Clone)]
struct NodeSet {
inner: bp::NodeSet,
}

#[pyproto]
impl PyIterProtocol for Iter {
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
slf
#[pymethods]
impl NodeSet {
#[new]
fn __new__(nodes: Vec<Node>, lsize: usize) -> Self {
let hashset = nodes.into_iter().map(|x| x.inner).collect();
Self {
inner: bp::NodeSet::new(hashset, lsize),
}
}
}

impl Contains<bp::Node> for NodeSet {
fn contains_node(&self, item: &bp::Node) -> bool {
self.inner.contains_node(item)
}
}

fn __next__(mut slf: PyRefMut<Self>) -> Option<State> {
let s = slf.inner.next()?;
Some(s.into())
impl Contains<usize> for NodeSet {
fn contains_node(&self, item: &usize) -> bool {
self.inner.contains_node(item)
}
}

#[pyclass]
struct SortedMatchingIterator {
inner: kl::SortedMatchingCalculator,
allowed_start_nodes: NodeSet,
}

#[pymethods]
impl Iter {
impl SortedMatchingIterator {
#[new]
fn new(m: PyReadonlyArray2<f64>) -> Iter {
let arr = m.as_array().to_owned();
Iter {
inner: kbest_lap::KBestEnumeration::<f64>::new(arr).unwrap(),
fn new(m: PyReadonlyArray2<f64>, allowed_start_nodes: NodeSet) -> Self {
let costs = m.as_array().to_owned();
let inner = kl::SortedMatchingCalculator::from_costs(costs);
Self {
inner,
allowed_start_nodes,
}
}
}

#[pyproto]
impl PyIterProtocol for SortedMatchingIterator {
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
slf
}

fn __next__(mut slf: PyRefMut<Self>) -> Option<Matching> {
let _self = slf.deref_mut();
let m = _self.inner.next_item(&_self.allowed_start_nodes)?;
Some(m.into())
}
}

#[pymodule]
fn rust_ext(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
// wrapper of `Iter`
m.add_class::<Iter>()?;
m.add_class::<State>()?;

#[pyfn(m)]
#[pyo3(name = "get_costs_reduced")]
fn costs_reduced<'py>(py: Python<'py>, state: &State) -> &'py PyArrayDyn<f64> {
let dynmat = state.costs_reduced.clone().into_dyn();
dynmat.into_pyarray(py)
}
m.add_class::<Matching>()?;
m.add_class::<Node>()?;
m.add_class::<NodeSet>()?;
m.add_class::<SortedMatchingIterator>()?;

Ok(())
}
38 changes: 20 additions & 18 deletions tests/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,28 @@
Naive approach; find by enumeration of perfect matching
"""
import heapq
from typing import Callable, Iterator, List, TypeVar

from kbest_lap import Edge, Matching
from typing import Callable, Iterator, List, Tuple, TypeVar

try:
import networkx as nx
import networkx.algorithms.bipartite as bprt
except ImportError:
print("Networkx not found. Regression test is not available")


T = TypeVar("T")
def enumerate_naive(edges: List[Edge[T]]) -> Iterator[Matching[T]]:

Matching = List[Tuple[T, T]]

MAX_HEAPSIZE = 100

def enumerate_naive(edges: List[Tuple[T, T, float]]) -> Iterator[Matching
]:
"""Enumerate best matchings"""
MAX_HEAPSIZE = 100
graph = nx.Graph()
for n1, n2, w in edges:
graph.add_node(n1)
graph.add_node(n2)
graph.add_edge(n1, n2, weight=w)
for node1, node2, weight in edges:
graph.add_node(node1)
graph.add_node(node2)
graph.add_edge(node1, node2, weight=weight)

if not bprt.is_bipartite(graph):
raise RuntimeError("Not bipartite")
Expand All @@ -30,16 +32,16 @@ def enumerate_naive(edges: List[Edge[T]]) -> Iterator[Matching[T]]:
lefts = sorted(left_set)
rights = sorted(right_set)
heap = []
for m in _naive(lefts, rights, is_valid=lambda x, y: (x, y) in graph.edges):
score = sum(graph.edges[(n1, n2)]['weight'] for n1, n2 in m)
for matching in _naive(lefts, rights, is_valid=lambda x, y: (x, y) in graph.edges):
score = sum(graph.edges[(n1, n2)]['weight'] for n1, n2 in matching)
if len(heap) <= MAX_HEAPSIZE:
heapq.heappush(heap, (score, m))
heapq.heappush(heap, (score, matching))
else:
heapq.heappushpop(heap, (score, m))
heapq.heappushpop(heap, (score, matching))

while heap != []:
score, m = heapq.heappop(heap)
yield m
score, matching = heapq.heappop(heap)
yield matching


def _naive(lefts: List[T], rights: List[T], is_valid: Callable[[T, T], bool]) -> Iterator[Matching[T]]:
Expand All @@ -59,5 +61,5 @@ def _naive(lefts: List[T], rights: List[T], is_valid: Callable[[T, T], bool]) ->
if not is_valid(*pair):
continue

for m in _naive(tail, rights[:i] + rights[i + 1:], is_valid=is_valid):
yield [pair] + m
for matching in _naive(tail, rights[:i] + rights[i + 1:], is_valid=is_valid):
yield [pair] + matching
16 changes: 12 additions & 4 deletions tests/test_enumerate.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
"""A little bit smarter way to enumerate"""
import itertools as it

import numpy as np
import pytest
from kbest_lap import enumerate_kbest

import kbest_lap


@pytest.mark.parametrize('size', [5,] * 10)
def test_linear_sum_assignment(size: int) -> None:
"""Enumerate everything!"""
cost_matrix = np.random.random((size, size))
solutions = list(enumerate_kbest(cost_matrix, yield_iter=None))
nodes = [kbest_lap.Node(bool(b), i) for b in (0, 1) for i in range(size)]
nodes = kbest_lap.NodeSet(nodes, size)
solutions = list(kbest_lap.SortedMatchingIterator(cost_matrix, nodes))

# solution count
assert len(solutions) == np.math.factorial(size)

get_cost = lambda s: cost_matrix[tuple(zip(*s.as_sparse()))].sum()

# solution sort
solutions_sorted = sorted(solutions, key=lambda x: cost_matrix[x].sum())
solution_costs = list(map(get_cost, solutions))
solutions_sorted = [x[1] for x in sorted(enumerate(solutions), key=lambda i_s: solution_costs[i_s[0]])]
assert solutions == solutions_sorted

# regression with brute-force
bf_costs = sorted(cost_matrix[range(size), js].sum() for js in it.permutations(range(size)))
solution_costs = [cost_matrix[s].sum() for s in solutions]
solution_costs = [get_cost(s) for s in solutions]

np.testing.assert_allclose(bf_costs, solution_costs)

0 comments on commit b98ff90

Please sign in to comment.