Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds __repr__ for our custom classes #286

Merged
merged 38 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
6cceaf9
adds basis repr
billbrod Jan 9, 2025
bdcf81b
adds repr for glm, observationmodel, regularizer
billbrod Jan 9, 2025
9e5c889
update for function kwarg name
billbrod Jan 10, 2025
7dfe801
improved format_repr
BalzaniEdoardo Jan 10, 2025
87c3c1b
Merge branch 'development' into repr
BalzaniEdoardo Jan 10, 2025
9a12c1f
added tests, numpy docstrings and repr Falsey values not None
BalzaniEdoardo Jan 10, 2025
1037ca0
add test for repr atomic
BalzaniEdoardo Jan 10, 2025
22c7662
add test composite
BalzaniEdoardo Jan 10, 2025
d56a02d
add test transformer repr
BalzaniEdoardo Jan 10, 2025
e19aae1
added edge case
BalzaniEdoardo Jan 10, 2025
f661dd2
repr string fix, added test for glm
BalzaniEdoardo Jan 10, 2025
236b57f
added test for pop glm
BalzaniEdoardo Jan 10, 2025
f2b8a75
test obs models repr
BalzaniEdoardo Jan 10, 2025
ce5e6fb
test regularizer repr
BalzaniEdoardo Jan 10, 2025
24911a3
linted
BalzaniEdoardo Jan 10, 2025
6234ad9
fixed doctests
BalzaniEdoardo Jan 10, 2025
2bc7622
fixed tests
BalzaniEdoardo Jan 10, 2025
cd1f654
Merge branch 'repr' of github.com:flatironinstitute/generalized-linea…
billbrod Jan 13, 2025
470b77b
adds label to basis repr
billbrod Jan 13, 2025
a06d271
improved tests
BalzaniEdoardo Jan 14, 2025
a14beed
Merge branch 'repr' of github.com:flatironinstitute/nemos into repr
BalzaniEdoardo Jan 14, 2025
6392abd
merged split by feature fixes
BalzaniEdoardo Jan 14, 2025
9408683
modified label repr
BalzaniEdoardo Jan 14, 2025
f8771dd
Update tests/test_observation_models.py
BalzaniEdoardo Jan 15, 2025
b83f0dd
Update src/nemos/glm.py
BalzaniEdoardo Jan 15, 2025
8befa06
Update src/nemos/utils.py
BalzaniEdoardo Jan 15, 2025
d831904
Update tests/test_observation_models.py
BalzaniEdoardo Jan 15, 2025
5a2c838
added multiline
BalzaniEdoardo Jan 15, 2025
ba41424
Merge branch 'repr' of github.com:flatironinstitute/nemos into repr
BalzaniEdoardo Jan 15, 2025
c5e4aa3
added multiline
BalzaniEdoardo Jan 15, 2025
74d9d07
linted
BalzaniEdoardo Jan 15, 2025
fa0cb5e
text fixes
BalzaniEdoardo Jan 15, 2025
90c3cf7
Update src/nemos/basis/_basis_mixin.py
BalzaniEdoardo Jan 15, 2025
2a88d14
Start improving repr
BalzaniEdoardo Jan 15, 2025
d1832ea
Merge branch 'repr' of github.com:flatironinstitute/nemos into repr
BalzaniEdoardo Jan 15, 2025
14c0ba8
added line wrapping
BalzaniEdoardo Jan 15, 2025
505bd4d
simplified test
BalzaniEdoardo Jan 15, 2025
67da7b4
use generator
BalzaniEdoardo Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions src/nemos/basis/_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..base_class import Base
from ..type_casting import support_pynapple
from ..typing import FeatureMatrix
from ..utils import row_wise_kron
from ..utils import format_repr, row_wise_kron
from ..validation import check_fraction_valid_samples
from ._basis_mixin import BasisTransformerMixin, CompositeBasisMixin

Expand Down Expand Up @@ -521,6 +521,9 @@ def __pow__(self, exponent: int) -> MultiplicativeBasis:
result = result * self
return result

def __repr__(self):
return format_repr(self)

def _get_feature_slicing(
self,
n_inputs: Optional[tuple] = None,
Expand Down Expand Up @@ -736,11 +739,23 @@ class AdditiveBasis(CompositeBasisMixin, Basis):
>>> basis_1 = nmo.basis.BSplineEval(10)
>>> basis_2 = nmo.basis.RaisedCosineLinearEval(15)
>>> additive_basis = basis_1 + basis_2

>>> additive_basis
AdditiveBasis(
basis1=BSplineEval(n_basis_funcs=10, order=4),
basis2=RaisedCosineLinearEval(n_basis_funcs=15, width=2.0),
)
>>> # can add another basis to the AdditiveBasis object
>>> X = np.random.normal(size=(30, 3))
>>> basis_3 = nmo.basis.RaisedCosineLogEval(100)
>>> additive_basis_2 = additive_basis + basis_3
>>> additive_basis_2
AdditiveBasis(
basis1=AdditiveBasis(
basis1=BSplineEval(n_basis_funcs=10, order=4),
basis2=RaisedCosineLinearEval(n_basis_funcs=15, width=2.0),
),
basis2=RaisedCosineLogEval(n_basis_funcs=100, width=2.0, time_scaling=50.0, enforce_decay_to_zero=True),
)
"""

def __init__(self, basis1: Basis, basis2: Basis) -> None:
Expand Down Expand Up @@ -1158,11 +1173,23 @@ class MultiplicativeBasis(CompositeBasisMixin, Basis):
>>> basis_1 = nmo.basis.BSplineEval(10)
>>> basis_2 = nmo.basis.RaisedCosineLinearEval(15)
>>> multiplicative_basis = basis_1 * basis_2

>>> multiplicative_basis
MultiplicativeBasis(
basis1=BSplineEval(n_basis_funcs=10, order=4),
basis2=RaisedCosineLinearEval(n_basis_funcs=15, width=2.0),
)
>>> # Can multiply or add another basis to the AdditiveBasis object
>>> # This will cause the number of output features of the result basis to grow accordingly
>>> basis_3 = nmo.basis.RaisedCosineLogEval(100)
>>> multiplicative_basis_2 = multiplicative_basis * basis_3
>>> multiplicative_basis_2
MultiplicativeBasis(
basis1=MultiplicativeBasis(
basis1=BSplineEval(n_basis_funcs=10, order=4),
basis2=RaisedCosineLinearEval(n_basis_funcs=15, width=2.0),
),
basis2=RaisedCosineLogEval(n_basis_funcs=100, width=2.0, time_scaling=50.0, enforce_decay_to_zero=True),
)
"""

def __init__(self, basis1: Basis, basis2: Basis) -> None:
Expand Down
23 changes: 23 additions & 0 deletions src/nemos/basis/_basis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pynapple import Tsd, TsdFrame, TsdTensor

from ..convolve import create_convolutional_predictor
from ..utils import _get_terminal_size
from ._transformer_basis import TransformerBasis

if TYPE_CHECKING:
Expand Down Expand Up @@ -661,3 +662,25 @@ def set_input_shape(self, *xi: int | tuple[int, ...] | NDArray) -> Basis:
)._input_shape_product,
)
return self

def __repr__(self, n=0):
BalzaniEdoardo marked this conversation as resolved.
Show resolved Hide resolved
_, rows = _get_terminal_size()
rows = rows // 4
# number of nested composite bases
n += 1
tab = " "
try:
basis1 = self.basis1.__repr__(n=n)
except TypeError:
basis1 = self.basis1
try:
basis2 = self.basis2.__repr__(n=n)
except TypeError:
basis2 = self.basis2
if n < rows:
rep = f"{self.__class__.__name__}(\n{n*tab}basis1={basis1},\n{n*tab}basis2={basis2},\n{(n-1)*tab})"
elif n == rows:
rep = f"{self.__class__.__name__}(\n{n*tab}...\n{(n-1)*tab})"
else:
rep = None
return rep
3 changes: 3 additions & 0 deletions src/nemos/basis/_transformer_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ def __sklearn_clone__(self) -> TransformerBasis:
cloned_obj = TransformerBasis(self.basis.__sklearn_clone__())
return cloned_obj

def __repr__(self):
return f"Transformer({self.basis})"

def set_params(self, **parameters) -> TransformerBasis:
"""
Set TransformerBasis parameters.
Expand Down
26 changes: 25 additions & 1 deletion src/nemos/basis/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class BSplineEval(EvalBasisMixin, BSplineBasis):
>>> n_basis_funcs = 5
>>> order = 3
>>> bspline_basis = BSplineEval(n_basis_funcs, order=order)
>>> bspline_basis
BSplineEval(n_basis_funcs=5, order=3)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = bspline_basis.compute_features(sample_points)
"""
Expand Down Expand Up @@ -226,6 +228,8 @@ class BSplineConv(ConvBasisMixin, BSplineBasis):
>>> n_basis_funcs = 5
>>> order = 3
>>> bspline_basis = BSplineConv(n_basis_funcs, order=order, window_size=10)
>>> bspline_basis
BSplineConv(n_basis_funcs=5, window_size=10, order=3)
>>> sample_points = linspace(0, 1, 100)
>>> features = bspline_basis.compute_features(sample_points)
"""
Expand Down Expand Up @@ -365,6 +369,8 @@ class CyclicBSplineEval(EvalBasisMixin, CyclicBSplineBasis):
>>> n_basis_funcs = 5
>>> order = 3
>>> cyclic_bspline_basis = CyclicBSplineEval(n_basis_funcs, order=order)
>>> cyclic_bspline_basis
CyclicBSplineEval(n_basis_funcs=5, order=3)
>>> sample_points = linspace(0, 1, 100)
>>> features = cyclic_bspline_basis.compute_features(sample_points)
"""
Expand Down Expand Up @@ -507,6 +513,8 @@ class CyclicBSplineConv(ConvBasisMixin, CyclicBSplineBasis):
>>> n_basis_funcs = 5
>>> order = 3
>>> cyclic_bspline_basis = CyclicBSplineConv(n_basis_funcs, order=order, window_size=10)
>>> cyclic_bspline_basis
CyclicBSplineConv(n_basis_funcs=5, window_size=10, order=3)
>>> sample_points = linspace(0, 1, 100)
>>> features = cyclic_bspline_basis.compute_features(sample_points)
"""
Expand Down Expand Up @@ -670,6 +678,8 @@ class MSplineEval(EvalBasisMixin, MSplineBasis):
>>> n_basis_funcs = 5
>>> order = 3
>>> mspline_basis = MSplineEval(n_basis_funcs, order=order)
>>> mspline_basis
MSplineEval(n_basis_funcs=5, order=3)
>>> sample_points = linspace(0, 1, 100)
>>> features = mspline_basis.compute_features(sample_points)
"""
Expand Down Expand Up @@ -836,6 +846,8 @@ class MSplineConv(ConvBasisMixin, MSplineBasis):
>>> n_basis_funcs = 5
>>> order = 3
>>> mspline_basis = MSplineConv(n_basis_funcs, order=order, window_size=10)
>>> mspline_basis
MSplineConv(n_basis_funcs=5, window_size=10, order=3)
>>> sample_points = linspace(0, 1, 100)
>>> features = mspline_basis.compute_features(sample_points)
"""
Expand Down Expand Up @@ -982,6 +994,8 @@ class RaisedCosineLinearEval(EvalBasisMixin, RaisedCosineBasisLinear):
>>> from nemos.basis import RaisedCosineLinearEval
>>> n_basis_funcs = 5
>>> raised_cosine_basis = RaisedCosineLinearEval(n_basis_funcs)
>>> raised_cosine_basis
RaisedCosineLinearEval(n_basis_funcs=5, width=2.0)
>>> sample_points = np.random.randn(100)
>>> # convolve the basis
>>> features = raised_cosine_basis.compute_features(sample_points)
Expand Down Expand Up @@ -1125,6 +1139,8 @@ class RaisedCosineLinearConv(ConvBasisMixin, RaisedCosineBasisLinear):
>>> from nemos.basis import RaisedCosineLinearConv
>>> n_basis_funcs = 5
>>> raised_cosine_basis = RaisedCosineLinearConv(n_basis_funcs, window_size=10)
>>> raised_cosine_basis
RaisedCosineLinearConv(n_basis_funcs=5, window_size=10, width=2.0)
>>> sample_points = np.random.randn(100)
>>> # convolve the basis
>>> features = raised_cosine_basis.compute_features(sample_points)
Expand Down Expand Up @@ -1269,6 +1285,8 @@ class RaisedCosineLogEval(EvalBasisMixin, RaisedCosineBasisLog):
>>> from nemos.basis import RaisedCosineLogEval
>>> n_basis_funcs = 5
>>> raised_cosine_basis = RaisedCosineLogEval(n_basis_funcs)
>>> raised_cosine_basis
RaisedCosineLogEval(n_basis_funcs=5, width=2.0, time_scaling=50.0, enforce_decay_to_zero=True)
>>> sample_points = np.random.randn(100)
>>> # convolve the basis
>>> features = raised_cosine_basis.compute_features(sample_points)
Expand Down Expand Up @@ -1424,6 +1442,8 @@ class RaisedCosineLogConv(ConvBasisMixin, RaisedCosineBasisLog):
>>> from nemos.basis import RaisedCosineLogConv
>>> n_basis_funcs = 5
>>> raised_cosine_basis = RaisedCosineLogConv(n_basis_funcs, window_size=10)
>>> raised_cosine_basis
RaisedCosineLogConv(n_basis_funcs=5, window_size=10, width=2.0, time_scaling=50.0, enforce_decay_to_zero=True)
>>> sample_points = np.random.randn(100)
>>> # convolve the basis
>>> features = raised_cosine_basis.compute_features(sample_points)
Expand Down Expand Up @@ -1561,6 +1581,8 @@ class OrthExponentialEval(EvalBasisMixin, OrthExponentialBasis):
>>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates
>>> window_size = 10
>>> ortho_basis = OrthExponentialEval(n_basis_funcs, decay_rates)
>>> ortho_basis
OrthExponentialEval(n_basis_funcs=5)
>>> sample_points = linspace(0, 1, 100)
>>> # evaluate the basis
>>> features = ortho_basis.compute_features(sample_points)
Expand Down Expand Up @@ -1701,6 +1723,8 @@ class OrthExponentialConv(ConvBasisMixin, OrthExponentialBasis):
>>> decay_rates = np.array([0.01, 0.02, 0.03, 0.04, 0.05]) # sample decay rates
>>> window_size = 10
>>> ortho_basis = OrthExponentialConv(n_basis_funcs, window_size, decay_rates)
>>> ortho_basis
OrthExponentialConv(n_basis_funcs=5, window_size=10)
>>> sample_points = np.random.randn(100)
>>> # convolve the basis
>>> features = ortho_basis.compute_features(sample_points)
Expand Down Expand Up @@ -1967,7 +1991,7 @@ class HistoryConv(ConvBasisMixin, HistoryBasis):
def __init__(
self,
window_size: int,
label: Optional[str] = "IdentityEval",
label: Optional[str] = "HistoryConv",
conv_kwargs: Optional[dict] = None,
):
ConvBasisMixin.__init__(self, window_size=window_size, conv_kwargs=conv_kwargs)
Expand Down
16 changes: 16 additions & 0 deletions src/nemos/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .solvers._compute_defaults import glm_compute_optimal_stepsize_configs
from .type_casting import jnp_asarray_if, support_pynapple
from .typing import DESIGN_INPUT_TYPE
from .utils import format_repr

ModelParams = Tuple[jnp.ndarray, jnp.ndarray]

Expand Down Expand Up @@ -145,6 +146,12 @@ class GLM(BaseRegressor):
>>> import nemos as nmo
>>> # define single neuron GLM model
>>> model = nmo.glm.GLM()
>>> model
GLM(
observation_model=PoissonObservations(inverse_link_function=exp),
regularizer=UnRegularized(),
solver_name='GradientDescent'
)
>>> print("Regularizer type: ", type(model.regularizer))
Regularizer type: <class 'nemos.regularizer.UnRegularized'>
>>> print("Observation model: ", type(model.observation_model))
Expand Down Expand Up @@ -1116,6 +1123,9 @@ def _get_optimal_solver_params_config(self):
"""Return the functions for computing default step and batch size for the solver."""
return glm_compute_optimal_stepsize_configs(self)

def __repr__(self):
return format_repr(self, multiline=True)


class PopulationGLM(GLM):
"""
Expand Down Expand Up @@ -1233,6 +1243,12 @@ class PopulationGLM(GLM):
[0, 1]], dtype=int32)
>>> # Create and fit the model
>>> model = PopulationGLM(feature_mask=feature_mask).fit(X, y)
>>> model
PopulationGLM(
observation_model=PoissonObservations(inverse_link_function=exp),
regularizer=UnRegularized(),
solver_name='GradientDescent'
)
>>> # Check the fitted coefficients
>>> print(model.coef_.shape)
(3, 2)
Expand Down
3 changes: 3 additions & 0 deletions src/nemos/observation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def __init__(self, inverse_link_function: Callable, **kwargs):
self.inverse_link_function = inverse_link_function
self.scale = 1.0

def __repr__(self):
return utils.format_repr(self, use_name_keys=["inverse_link_function"])

@property
def inverse_link_function(self):
"""Getter for the inverse link function for the model."""
Expand Down
4 changes: 4 additions & 0 deletions src/nemos/regularizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .base_class import Base
from .proximal_operator import prox_group_lasso
from .typing import DESIGN_INPUT_TYPE, ProximalOperator
from .utils import format_repr

__all__ = ["UnRegularized", "Ridge", "Lasso", "GroupLasso"]

Expand Down Expand Up @@ -93,6 +94,9 @@ def get_proximal_operator(
"""
pass

def __repr__(self):
return format_repr(self)


class UnRegularized(Regularizer):
"""
Expand Down
Loading
Loading