Skip to content

Commit

Permalink
[models/kernel_regression](feat) Implement MemoryManager using NNDesc…
Browse files Browse the repository at this point in the history
…ent algorithm
  • Loading branch information
kzajac97 committed Dec 10, 2023
1 parent beb159d commit 2747389
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pynndescent==0.5.11
72 changes: 67 additions & 5 deletions pydentification/models/kernel_regression/memory.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
from abc import ABC, abstractmethod
from functools import wraps
from typing import Callable
from typing import Callable, Any

import torch
from torch import Tensor

try:
from pynndescent import NNDescent
except ImportError as ex:
message = (
"Missing optional dependency, to install all optionals from experiment module run:\n"
"`pip install -r pydentification/models/kernel_regression/extra-requirements.txt`"
)

def needs_tensor_with_dims(*expected_dims: tuple[int | None]) -> Callable:
raise ImportError(message) from ex


def needs_tensor_with_dims(*expected_dims: tuple[int | None, ...]) -> Callable:
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(tensor: Tensor, *args, **kwargs):
def wrapper(self: Any, tensor: Tensor, *args, **kwargs):
if not isinstance(tensor, Tensor):
raise TypeError("Input should be a PyTorch Tensor")
raise TypeError(f"Input should be a PyTorch Tensor, got {type(tensor)}")

if tensor.dim() != len(expected_dims):
raise ValueError(f"Expected tensor with {len(expected_dims)} dimensions, got {tensor.dim()}")
Expand All @@ -19,7 +30,7 @@ def wrapper(tensor: Tensor, *args, **kwargs):
if dim is not None and tensor.size(index) != dim:
raise ValueError(f"Expected size {dim} at dimension {index}, got {tensor.size(index)}")

return func(tensor, *args, **kwargs)
return func(self, tensor, *args, **kwargs)

return wrapper

Expand Down Expand Up @@ -58,3 +69,54 @@ def query_radius(self, points: Tensor, r: float) -> [tuple[Tensor, Tensor]]:
:param r: radius of the neighbourhood
"""
...


class NNDescentMemoryManager(MemoryManager):
def __init__(self, memory: Tensor, targets: Tensor | tuple[Tensor, ...], **nn_descent_params):
"""
:param memory: tensor of indexed data points to search for nearest neighbors
:param targets: tensor of target values corresponding to the memory points, can be any number of tensors
:param nn_descent_params: parameters for NNDescent algorithm,
see: https://pynndescent.readthedocs.io/en/latest/api.html
"""
super().__init__()

self.memory = memory
self.targets = targets if isinstance(targets, tuple) else (targets,) # store targets as tuple

self.index = None # build deferred until first query, it takes significant amount of time
self.nn_descent_params = nn_descent_params

def _build_index(self) -> None:
"""Build index for nearest neighbors search over memory using `NNDescent`"""
self.index = NNDescent(self.memory, **self.nn_descent_params)
self.index.prepare()

@property
def neighbor_graph(self) -> tuple[Tensor, Tensor]:
"""Returns the neighbor graph of the memory points"""
if self.index is None:
self._build_index()

return self.index.neighbor_graph

def query_nearest(self, points: Tensor, k: int, epsilon: float = 0.1) -> [tuple[Tensor, ...]]:
"""
Query for K-nearest neighbors in memory given input points.
:param points: input points for which to find nearest neighbours
:param k: number of nearest neighbours to return
:param epsilon: search parameter for NNDescent, see: https://pynndescent.readthedocs.io/en/latest/api.html
"""
if self.index is None:
self._build_index()

indexed, _ = self.index.query(points, k=k, epsilon=epsilon)
# memory manager returns flat memory for all query points
# duplicates are removed and the dimensionality is reduced to 1
indexed = torch.unique(torch.from_numpy(indexed.flatten()))
# return found nearest points from memory and collect from all target tensors corresponding to them
return self.memory[indexed, :], *(target[indexed, :] for target in self.targets)

def query_radius(self, points: Tensor, r: float) -> [tuple[Tensor, Tensor]]:
raise NotImplementedError("Radius query is not implemented for NNDescentMemoryManager!")
44 changes: 39 additions & 5 deletions tests/test_models/test_kernel_regression/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from torch import Tensor

from pydentification.models.kernel_regression.memory import needs_tensor_with_dims
from pydentification.models.kernel_regression.memory import needs_tensor_with_dims, NNDescentMemoryManager


@pytest.mark.parametrize(
Expand All @@ -18,10 +18,10 @@
)
def test_needs_tensor_with_dims_ok(tensor: Tensor, dims: tuple[int | None]):
@needs_tensor_with_dims(*dims)
def func(_: Tensor) -> bool:
def func(_, t: Tensor) -> bool: # first argument is self of MemoryManager, which is not used
return True # dummy function to test decorator

assert func(tensor) # check if error was not raised
assert func(None, tensor) # pass None as self and check if error was not raised


@pytest.mark.parametrize(
Expand All @@ -36,8 +36,42 @@ def func(_: Tensor) -> bool:
)
def test_needs_tensor_with_dims_not_ok(tensor: Tensor, dims: tuple[int | None]):
@needs_tensor_with_dims(*dims)
def func(_: Tensor) -> bool:
def func(_, t: Tensor) -> bool: # first argument is self of MemoryManager, which is not used
return True # dummy function to test decorator

with pytest.raises(ValueError):
func(tensor)
func(None, tensor) # pass None as self


@pytest.fixture(scope="module")
def linspace_nn_descent_memory_manager():
memory = torch.linspace(0, 1, 101).unsqueeze(-1) # 101 points in [0, 1] range spaced by 0.01 and shape [101, 1]
targets = 2 * memory # dummy targets
return NNDescentMemoryManager(memory, targets, metric="euclidean")


@pytest.mark.parametrize(
"points, k, expected",
(
# # query for single point
# (torch.tensor([[0.5]]), 1, torch.tensor([[0.5]])),
# # query for multiple points
# (torch.tensor([[0.5]]), 3, torch.tensor([[0.49], [0.5], [0.51]])),
# # query on the edge of range
# (torch.tensor([[0.0]]), 1, torch.tensor([[0.0]])),
# (torch.tensor([[1.0]]), 1, torch.tensor([[1.0]])),
# # query for multiple points on the edge of range
# (torch.tensor([[0.0]]), 5, torch.tensor([[0.0], [0.01], [0.02], [0.03], [0.04]])),
# (torch.tensor([[1.0]]), 5, torch.tensor([[0.96], [0.97], [0.98], [0.99], [1.0]])),
# # batch query
# (torch.tensor([[0.5], [0.25]]), 1, torch.tensor([[0.25], [0.5]])),
# # batch query with multiple points
# (torch.tensor([[0.5], [0.25]]), 3, torch.tensor([[0.24], [0.25], [0.26], [0.49], [0.5], [0.51]])),
# batch query with multiple points without exact matches to memory
(torch.tensor([[0.501], [0.2501]]), 3, torch.tensor([[0.24], [0.25], [0.26], [0.49], [0.5], [0.51]])),
)
)
def test_nn_descent_memory_manager(points: Tensor, k: int, expected: Tensor, linspace_nn_descent_memory_manager):
# query with high epsilon to get certain results
memory, _ = linspace_nn_descent_memory_manager.query_nearest(points, k, epsilon=1.0) # ignore targets
torch.testing.assert_close(memory, expected)

0 comments on commit 2747389

Please sign in to comment.