Skip to content

Commit

Permalink
adds label to basis repr
Browse files Browse the repository at this point in the history
  • Loading branch information
billbrod committed Jan 13, 2025
1 parent cd1f654 commit 470b77b
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/nemos/basis/_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def __pow__(self, exponent: int) -> MultiplicativeBasis:
return result

def __repr__(self):
return format_repr(self, ["label"])
return format_repr(self)

def _get_feature_slicing(
self,
Expand Down
14 changes: 12 additions & 2 deletions src/nemos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,9 @@ def format_repr(
init_params = list(inspect.signature(obj.__init__).parameters.keys())
disp_params = []

for k, v in obj.get_params(deep=False).items():
all_params = obj.get_params(deep=False)
label = all_params.pop("label", None)
for k, v in all_params.items():
repr_param = (
k not in exclude_keys and not hasattr(v, "shape") and (v or v in (0, False))
)
Expand All @@ -526,8 +528,16 @@ def format_repr(
v = repr(v)
disp_params.append(f"{k}={v}")
disp_params = sorted(disp_params, key=lambda x: init_params.index(x.split("=")[0]))
cls_name = obj.__class__.__name__
# if label doesn't exist or is the same as the class name (as is the default for
# basis), then don't use it
if (label is not None) and (label != cls_name):
# else, label should replace the class name as being outside the parentheses and
# class name should come first within the parens
disp_params.insert(0, cls_name)
cls_name = label
disp_params = ", ".join(disp_params)
return f"{obj.__class__.__name__}({disp_params})"
return f"{cls_name}({disp_params})"


# enable concatenation for pynapple objects.
Expand Down
44 changes: 44 additions & 0 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,20 @@ def test_expected_output_split_by_feature(basis_instance, super_class):
np.testing.assert_array_equal(xx[~nans], x[~nans])


@pytest.mark.parametrize("label", [None, "", "default-behavior", "CoolFeature"])
def test_repr_label(label):
if label == "default-behavior":
bas = basis.RaisedCosineLinearEval(n_basis_funcs=5)
else:
bas = basis.RaisedCosineLinearEval(n_basis_funcs=5, label=label)
if label in [None, "default-behavior"]:
expected = "RaisedCosineLinearEval(n_basis_funcs=5, width=2.0)"
else:
expected = f"{label}(RaisedCosineLinearEval, n_basis_funcs=5, width=2.0)"
out = repr(bas)
assert out == expected


@pytest.mark.parametrize(
"cls",
[
Expand Down Expand Up @@ -3334,6 +3348,21 @@ def test_repr_out(
basis_obj = basis_a_obj + basis_b_obj
assert repr(basis_obj) == expected_out[basis_a]

@pytest.mark.parametrize("label", [None, "", "default-behavior", "CoolFeature"])
def test_repr_label(self, label, basis_class_specific_params):
if label == "default-behavior":
bas = basis.RaisedCosineLinearEval(n_basis_funcs=5)
else:
bas = basis.RaisedCosineLinearEval(n_basis_funcs=5, label=label)
if label in [None, "default-behavior"]:
expected_a = "RaisedCosineLinearEval(n_basis_funcs=5, width=2.0)"
else:
expected_a = f"{label}(RaisedCosineLinearEval, n_basis_funcs=5, width=2.0)"
bas = bas + self.instantiate_basis(6, basis.MSplineEval, basis_class_specific_params)
expected = f"AdditiveBasis(\n basis1={expected_a},\n basis2=MSplineEval(n_basis_funcs=6, order=4),\n)"
out = repr(bas)
assert out == expected


class TestMultiplicativeBasis(CombinedBasis):
cls = {"eval": MultiplicativeBasis, "conv": MultiplicativeBasis}
Expand Down Expand Up @@ -3395,6 +3424,21 @@ def test_repr_out(
basis_obj = basis_a_obj * basis_b_obj
assert repr(basis_obj) == expected_out[basis_a]

@pytest.mark.parametrize("label", [None, "", "default-behavior", "CoolFeature"])
def test_repr_label(self, label, basis_class_specific_params):
if label == "default-behavior":
bas = basis.RaisedCosineLinearEval(n_basis_funcs=5)
else:
bas = basis.RaisedCosineLinearEval(n_basis_funcs=5, label=label)
if label in [None, "default-behavior"]:
expected_a = "RaisedCosineLinearEval(n_basis_funcs=5, width=2.0)"
else:
expected_a = f"{label}(RaisedCosineLinearEval, n_basis_funcs=5, width=2.0)"
bas = bas * self.instantiate_basis(6, basis.MSplineEval, basis_class_specific_params)
expected = f"MultiplicativeBasis(\n basis1={expected_a},\n basis2=MSplineEval(n_basis_funcs=6, order=4),\n)"
out = repr(bas)
assert out == expected

@pytest.mark.parametrize("n_basis_a", [5, 6])
@pytest.mark.parametrize("n_basis_b", [5, 6])
@pytest.mark.parametrize("sample_size", [10, 1000])
Expand Down

0 comments on commit 470b77b

Please sign in to comment.