-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[models/kernel_regression](feat) Add abstract memory manager
- Loading branch information
Showing
2 changed files
with
103 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from abc import ABC, abstractmethod | ||
from functools import wraps | ||
from typing import Callable | ||
|
||
from torch import Tensor | ||
|
||
|
||
def needs_tensor_with_dims(*expected_dims: tuple[int | None]) -> Callable: | ||
def decorator(func: Callable) -> Callable: | ||
@wraps(func) | ||
def wrapper(tensor: Tensor, *args, **kwargs): | ||
if not isinstance(tensor, Tensor): | ||
raise TypeError("Input should be a PyTorch Tensor") | ||
|
||
if tensor.dim() != len(expected_dims): | ||
raise ValueError(f"Expected tensor with {len(expected_dims)} dimensions, got {tensor.dim()}") | ||
|
||
for index, dim in enumerate(expected_dims): | ||
if dim is not None and tensor.size(index) != dim: | ||
raise ValueError(f"Expected size {dim} at dimension {index}, got {tensor.size(index)}") | ||
|
||
return func(tensor, *args, **kwargs) | ||
|
||
return wrapper | ||
|
||
return decorator | ||
|
||
|
||
class MemoryManager(ABC): | ||
""" | ||
Interface for memory manager for non-parametric models (such as kernel regression). | ||
This class is used to enable non-parametric models to be used with large datasets, since they need to store all | ||
training data. Memory managers can be used to apply selection or search algorithms to reduce the number of samples | ||
return for each prediction, for example for kernel model returning only samples near the input points in given | ||
prediction batch. | ||
""" | ||
|
||
def __init__(self): | ||
... | ||
|
||
@abstractmethod | ||
def query_nearest(self, points: Tensor, k: int) -> [tuple[Tensor, Tensor]]: | ||
""" | ||
Query for K-nearest neighbors in memory given input points. | ||
:param points: input points for which to find nearest neighbours | ||
:param k: number of nearest neighbours to return | ||
""" | ||
... | ||
|
||
@abstractmethod | ||
def query_radius(self, points: Tensor, r: float) -> [tuple[Tensor, Tensor]]: | ||
""" | ||
Query for all points in memory within given radius of input points. | ||
:param points: input points for which to find neighbours | ||
:param r: radius of the neighbourhood | ||
""" | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import pytest | ||
import torch | ||
from torch import Tensor | ||
|
||
from pydentification.models.kernel_regression.memory import needs_tensor_with_dims | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"tensor, dims", | ||
[ | ||
(torch.zeros([2, 2]), (2, 2)), | ||
(torch.zeros([5, 5]), (5, None)), | ||
(torch.zeros([1, 1]), (1, None)), | ||
(torch.zeros([5, 1]), (None, 1)), | ||
(torch.zeros([2, 2]), (None, None)), | ||
(torch.zeros([2, 2, 2]), (2, 2, 2)), | ||
], | ||
) | ||
def test_needs_tensor_with_dims_ok(tensor: Tensor, dims: tuple[int | None]): | ||
@needs_tensor_with_dims(*dims) | ||
def func(_: Tensor) -> bool: | ||
return True # dummy function to test decorator | ||
|
||
assert func(tensor) # check if error was not raised | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"tensor, dims", | ||
[ | ||
(torch.zeros([10, 1, 1]), (None, None)), | ||
(torch.zeros([1, 10, 1]), (None, None)), | ||
(torch.zeros([10, 1]), (10, 2)), | ||
(torch.zeros([10, 1]), (2, 10)), | ||
(torch.zeros([10, 1]), (None, 2)), | ||
], | ||
) | ||
def test_needs_tensor_with_dims_not_ok(tensor: Tensor, dims: tuple[int | None]): | ||
@needs_tensor_with_dims(*dims) | ||
def func(_: Tensor) -> bool: | ||
return True # dummy function to test decorator | ||
|
||
with pytest.raises(ValueError): | ||
func(tensor) |