Skip to content

Commit

Permalink
[models/kernel_regression](feat) Add abstract memory manager
Browse files Browse the repository at this point in the history
  • Loading branch information
kzajac97 committed Dec 10, 2023
1 parent aba77ba commit dc1794b
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 0 deletions.
60 changes: 60 additions & 0 deletions pydentification/models/kernel_regression/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from abc import ABC, abstractmethod
from functools import wraps
from typing import Callable

from torch import Tensor


def needs_tensor_with_dims(*expected_dims: tuple[int | None]) -> Callable:
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(tensor: Tensor, *args, **kwargs):
if not isinstance(tensor, Tensor):
raise TypeError("Input should be a PyTorch Tensor")

if tensor.dim() != len(expected_dims):
raise ValueError(f"Expected tensor with {len(expected_dims)} dimensions, got {tensor.dim()}")

for index, dim in enumerate(expected_dims):
if dim is not None and tensor.size(index) != dim:
raise ValueError(f"Expected size {dim} at dimension {index}, got {tensor.size(index)}")

return func(tensor, *args, **kwargs)

return wrapper

return decorator


class MemoryManager(ABC):
"""
Interface for memory manager for non-parametric models (such as kernel regression).
This class is used to enable non-parametric models to be used with large datasets, since they need to store all
training data. Memory managers can be used to apply selection or search algorithms to reduce the number of samples
return for each prediction, for example for kernel model returning only samples near the input points in given
prediction batch.
"""

def __init__(self):
...

@abstractmethod
def query_nearest(self, points: Tensor, k: int) -> [tuple[Tensor, Tensor]]:
"""
Query for K-nearest neighbors in memory given input points.
:param points: input points for which to find nearest neighbours
:param k: number of nearest neighbours to return
"""
...

@abstractmethod
def query_radius(self, points: Tensor, r: float) -> [tuple[Tensor, Tensor]]:
"""
Query for all points in memory within given radius of input points.
:param points: input points for which to find neighbours
:param r: radius of the neighbourhood
"""
...
43 changes: 43 additions & 0 deletions tests/test_models/test_kernel_regression/test_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
import torch
from torch import Tensor

from pydentification.models.kernel_regression.memory import needs_tensor_with_dims


@pytest.mark.parametrize(
"tensor, dims",
[
(torch.zeros([2, 2]), (2, 2)),
(torch.zeros([5, 5]), (5, None)),
(torch.zeros([1, 1]), (1, None)),
(torch.zeros([5, 1]), (None, 1)),
(torch.zeros([2, 2]), (None, None)),
(torch.zeros([2, 2, 2]), (2, 2, 2)),
],
)
def test_needs_tensor_with_dims_ok(tensor: Tensor, dims: tuple[int | None]):
@needs_tensor_with_dims(*dims)
def func(_: Tensor) -> bool:
return True # dummy function to test decorator

assert func(tensor) # check if error was not raised


@pytest.mark.parametrize(
"tensor, dims",
[
(torch.zeros([10, 1, 1]), (None, None)),
(torch.zeros([1, 10, 1]), (None, None)),
(torch.zeros([10, 1]), (10, 2)),
(torch.zeros([10, 1]), (2, 10)),
(torch.zeros([10, 1]), (None, 2)),
],
)
def test_needs_tensor_with_dims_not_ok(tensor: Tensor, dims: tuple[int | None]):
@needs_tensor_with_dims(*dims)
def func(_: Tensor) -> bool:
return True # dummy function to test decorator

with pytest.raises(ValueError):
func(tensor)

0 comments on commit dc1794b

Please sign in to comment.