Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basis refactor pr1 #273

Merged
merged 114 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
1e2932a
initial moving pieces
BalzaniEdoardo Nov 14, 2024
145ac9b
setup raised cosine
BalzaniEdoardo Nov 14, 2024
84c41d1
mreged sphinx
BalzaniEdoardo Nov 18, 2024
5a4bb55
fixed docs class orth exp
BalzaniEdoardo Nov 18, 2024
346f812
documentation fix
BalzaniEdoardo Nov 18, 2024
ef438d3
evaluate on grid orth exp
BalzaniEdoardo Nov 18, 2024
8028855
compute features orth exp
BalzaniEdoardo Nov 18, 2024
71ff34c
split by orth exp
BalzaniEdoardo Nov 18, 2024
7ce9bd0
merging sphinx
BalzaniEdoardo Nov 25, 2024
e3ad3ea
added changes
BalzaniEdoardo Nov 25, 2024
798bec0
updated basis old
BalzaniEdoardo Nov 25, 2024
dc46d93
fixed _basis.py docstrings
BalzaniEdoardo Nov 25, 2024
9d7d34f
fix minor changes
BalzaniEdoardo Nov 25, 2024
25abf3d
added back feature matrix
BalzaniEdoardo Nov 25, 2024
13bc62a
updated mixin description
BalzaniEdoardo Nov 25, 2024
7c45ab7
fixed docstrings orthexp
BalzaniEdoardo Nov 25, 2024
6dbe362
updated raised cos
BalzaniEdoardo Nov 25, 2024
113ca76
fixed spline docstrings
BalzaniEdoardo Nov 25, 2024
db188ad
moving stuff around
BalzaniEdoardo Nov 25, 2024
32bcf82
start editing docstrings
BalzaniEdoardo Nov 25, 2024
5761ccd
improved tests
BalzaniEdoardo Nov 25, 2024
649c0d6
improved tests
BalzaniEdoardo Nov 25, 2024
8a7f259
improved tests
BalzaniEdoardo Nov 25, 2024
ab56c9b
fixed eval changed testing
BalzaniEdoardo Nov 26, 2024
593d7a9
added some basic testing
BalzaniEdoardo Nov 26, 2024
5d3e746
linted
BalzaniEdoardo Nov 26, 2024
a1f6271
removed basis old
BalzaniEdoardo Nov 26, 2024
9927a2c
fixed first test
BalzaniEdoardo Nov 26, 2024
e89543b
fixed second test
BalzaniEdoardo Nov 26, 2024
b2dd18f
fixed test_set_width
BalzaniEdoardo Nov 26, 2024
0a288c6
fixed test_compute_features_axis
BalzaniEdoardo Nov 26, 2024
e70c07c
fixed test_compute_features_conv_input
BalzaniEdoardo Nov 26, 2024
417f67b
fixed test_compute_features_returns_expected_number_of_basis
BalzaniEdoardo Nov 26, 2024
77795b1
fixed test_number_of_required_inputs_compute_features
BalzaniEdoardo Nov 26, 2024
f6c7ead
fixed test_evaluate_on_grid_meshgrid_size
BalzaniEdoardo Nov 26, 2024
519c0d7
fixed test_evaluate_on_grid_input_number
BalzaniEdoardo Nov 26, 2024
1174447
fixed test_time_scaling_values
BalzaniEdoardo Nov 26, 2024
38b499b
fixed some test calls
BalzaniEdoardo Nov 26, 2024
33a9995
moved to self.cls
BalzaniEdoardo Nov 26, 2024
5bf53c8
refactored a bunch of raised cos tests
BalzaniEdoardo Nov 26, 2024
37ac508
finished test raised cos log
BalzaniEdoardo Nov 26, 2024
2624a0a
test jointly shared methods
BalzaniEdoardo Nov 26, 2024
0f9c28f
refactored all 1d basis
BalzaniEdoardo Nov 26, 2024
07cad18
fixed additive basis tests
BalzaniEdoardo Nov 27, 2024
ea8e208
fixed multiplicative basis tests
BalzaniEdoardo Nov 27, 2024
c112ab9
fixed test splitters
BalzaniEdoardo Nov 27, 2024
5f65b71
linted tests
BalzaniEdoardo Nov 27, 2024
2edb205
fixed some tests on docstrings
BalzaniEdoardo Nov 27, 2024
76a4b5c
fixed tests that assumed 1d
BalzaniEdoardo Nov 27, 2024
df1cb6b
fixed all basis tests and linted
BalzaniEdoardo Nov 27, 2024
839db69
fixed all basis tests and linted
BalzaniEdoardo Nov 27, 2024
d2440a6
fixed other tests relying on basis
BalzaniEdoardo Nov 27, 2024
3f10942
linted
BalzaniEdoardo Nov 27, 2024
17e7d06
fix basis naming
BalzaniEdoardo Nov 27, 2024
17a9b93
fixed bug
BalzaniEdoardo Nov 27, 2024
25f0814
Merge branch 'sphinx' into basis_refactor_pr1
BalzaniEdoardo Nov 27, 2024
93f1858
fix double plotting
BalzaniEdoardo Nov 27, 2024
1aaea3b
move transformer basis out
BalzaniEdoardo Nov 28, 2024
9985784
added to api refs
BalzaniEdoardo Nov 28, 2024
7cec3c0
fix auto-imports
BalzaniEdoardo Nov 28, 2024
7fd2ef7
fix ome links
BalzaniEdoardo Nov 28, 2024
6b7486a
simplified inheritance
BalzaniEdoardo Nov 28, 2024
0ed9040
fixed other links and added SplineBasis
BalzaniEdoardo Nov 28, 2024
7f927ff
fixed all relative links
BalzaniEdoardo Nov 28, 2024
92c0757
ignore timeouts
BalzaniEdoardo Nov 28, 2024
9217087
uniform caps
BalzaniEdoardo Nov 28, 2024
b95f2f9
fix text
BalzaniEdoardo Nov 28, 2024
db72fda
fix doctests
BalzaniEdoardo Nov 29, 2024
cbfc4a7
linted
BalzaniEdoardo Nov 29, 2024
462670b
fix test pipeline
BalzaniEdoardo Nov 29, 2024
6a0574d
fix warns
BalzaniEdoardo Nov 29, 2024
b7a7b60
generalized tests
BalzaniEdoardo Dec 2, 2024
053554f
added class lev docstrings for splines
BalzaniEdoardo Dec 2, 2024
04b529f
added class lev docstrings for orth exp
BalzaniEdoardo Dec 2, 2024
38c4138
fix all docstrings
BalzaniEdoardo Dec 3, 2024
7c754bf
removed unnecessary kwargs
BalzaniEdoardo Dec 3, 2024
4493196
removed unnecessary kwargs
BalzaniEdoardo Dec 3, 2024
00a80f5
pydocstyle
BalzaniEdoardo Dec 3, 2024
ff51725
linted
BalzaniEdoardo Dec 3, 2024
4a588f0
fixed naming
BalzaniEdoardo Dec 3, 2024
be5a5b8
moved docstrings from init
BalzaniEdoardo Dec 3, 2024
ed414be
removed attrs from class docstrings
BalzaniEdoardo Dec 3, 2024
2b7b123
removed args from docstrings
BalzaniEdoardo Dec 3, 2024
77009a7
Update typing.py
sjvenditto Dec 3, 2024
f96d08f
fixed tests and renamed funcs
BalzaniEdoardo Dec 3, 2024
550835c
fixed warns
BalzaniEdoardo Dec 3, 2024
f8d64ea
Merge branch 'basis_refactor_pr1' of github.com:flatironinstitute/nem…
BalzaniEdoardo Dec 4, 2024
fb3dd75
fix tutorials
BalzaniEdoardo Dec 4, 2024
892c412
fix all tutorials
BalzaniEdoardo Dec 4, 2024
743bd8d
fixed links htmlproofer
BalzaniEdoardo Dec 4, 2024
fb883ce
linted
BalzaniEdoardo Dec 4, 2024
d0af9c8
updated jax link
BalzaniEdoardo Dec 4, 2024
c0664cf
Update docs/api_reference.rst
BalzaniEdoardo Dec 9, 2024
629f8c5
Update docs/background/plot_01_1D_basis_function.md
BalzaniEdoardo Dec 9, 2024
1e59e59
Update docs/developers_notes/04-basis_module.md
BalzaniEdoardo Dec 9, 2024
5808b50
Update docs/background/plot_01_1D_basis_function.md
BalzaniEdoardo Dec 9, 2024
a683afa
Update docs/background/plot_03_1D_convolution.md
BalzaniEdoardo Dec 9, 2024
c1773bc
Update docs/developers_notes/04-basis_module.md
BalzaniEdoardo Dec 9, 2024
66cfa71
Update docs/developers_notes/04-basis_module.md
BalzaniEdoardo Dec 9, 2024
59abe28
Update docs/developers_notes/04-basis_module.md
BalzaniEdoardo Dec 9, 2024
f376bd6
Update src/nemos/basis/_basis_mixin.py
BalzaniEdoardo Dec 9, 2024
1aba9a5
Update src/nemos/basis/_basis_mixin.py
BalzaniEdoardo Dec 9, 2024
cb7e1e7
Update src/nemos/basis/_basis_mixin.py
BalzaniEdoardo Dec 9, 2024
8427d6e
Update src/nemos/basis/_transformer_basis.py
BalzaniEdoardo Dec 9, 2024
2b05994
edits to docstrings
BalzaniEdoardo Dec 9, 2024
98c2a74
Update src/nemos/basis/_basis_mixin.py
BalzaniEdoardo Dec 9, 2024
4d4c70f
fixed inheritance and removed TransformerMixin init call
BalzaniEdoardo Dec 9, 2024
ec3c4c2
linted
BalzaniEdoardo Dec 9, 2024
ef3e22b
merged conflicts
BalzaniEdoardo Dec 9, 2024
442fbfb
linted
BalzaniEdoardo Dec 9, 2024
a7c463b
fix tests
BalzaniEdoardo Dec 9, 2024
905b745
fix tests
BalzaniEdoardo Dec 9, 2024
6a6da11
fixed links
BalzaniEdoardo Dec 10, 2024
00a2437
updates some docstrings
billbrod Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ build:
- gem install html-proofer -v ">= 5.0.9" # Ensure version >= 5.0.9
post_build:
# Check everything except 403s and a jneurosci, which returns 404 but the link works when clicking.
- htmlproofer $READTHEDOCS_OUTPUT/html --checks Links,Scripts,Images --ignore-urls "https://fonts.gstatic.com,https://celltypes.brain-map.org/experiment/electrophysiology/478498617,https://www.jneurosci.org/content/25/47/11003" --assume-extension --check-external-hash --ignore-status-codes 403 --ignore-files "/.+\/_static\/.+/","/.+\/stubs\/.+/","/.+\/tutorials/plot_02_head_direction.+/"
- htmlproofer $READTHEDOCS_OUTPUT/html --checks Links,Scripts,Images --ignore-urls "https://fonts.gstatic.com,https://celltypes.brain-map.org/experiment/electrophysiology/478498617,https://www.jneurosci.org/content/25/47/11003" --assume-extension --check-external-hash --ignore-status-codes 403,0 --ignore-files "/.+\/_static\/.+/","/.+\/stubs\/.+/","/.+\/tutorials/plot_02_head_direction.+/"
# The auto-generated animation doesn't have a alt or src/srcset; I am able to ignore missing alt, but I cannot work around a missing src/srcset
# therefore for this file I am not checking the figures.
- htmlproofer $READTHEDOCS_OUTPUT/html/tutorials/plot_02_head_direction.html --checks Links,Scripts --ignore-urls "https://www.jneurosci.org/content/25/47/11003"
Expand Down
82 changes: 73 additions & 9 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,88 @@ Classes for creating Generalized Linear Models (GLMs) for both single neurons an
The ``nemos.basis`` module
--------------------------
Provides basis function classes to construct and transform features for model inputs.
Basis can be grouped according to the mode of operation into basis that performs convolution and basis that operates
as non-linear maps.

.. currentmodule:: nemos.basis

**The Abstract Classes:**

These classes are the building blocks for the concrete basis classes.

.. currentmodule:: nemos.basis._basis

.. autosummary::
:toctree: generated/basis
:toctree: generated/_basis
:recursive:
:nosignatures:

Basis

.. currentmodule:: nemos.basis._spline_basis
.. autosummary::
:toctree: generated/_basis
:recursive:
:nosignatures:

SplineBasis
BSplineBasis
CyclicBSplineBasis
MSplineBasis
OrthExponentialBasis
RaisedCosineBasisLinear
RaisedCosineBasisLog


**Bases For Convolution:**

.. currentmodule:: nemos.basis

.. autosummary::
:toctree: generated/basis
:recursive:
:nosignatures:


MSplineConv
BSplineConv
CyclicBSplineConv
RaisedCosineLinearConv
RaisedCosineLogConv
OrthExponentialConv

.. check for a config that prints only nemos.basis.Name

**Bases For Non-Linear Mapping:**

.. currentmodule:: nemos.basis

.. autosummary::
:toctree: generated/basis
:recursive:
:nosignatures:

MSplineEval
BSplineEval
CyclicBSplineEval
RaisedCosineLinearEval
RaisedCosineLogEval
OrthExponentialEval

**Composite Bases:**

.. currentmodule:: nemos.basis._basis

.. autosummary::
:toctree: generated/_basis
:recursive:
:nosignatures:

AdditiveBasis
MultiplicativeBasis

**Basis As ``scikit-learn`` Tranformers:**

.. currentmodule:: nemos.basis._transformer_basis

.. autosummary::
:toctree: generated/_transformer_basis
:recursive:
:nosignatures:

TransformerBasis

.. _observation_models:
Expand Down Expand Up @@ -130,7 +194,7 @@ These objects can be provided as input to nemos GLM methods.
.. currentmodule:: nemos.pytrees

.. autosummary::
:toctree: generated/identifiability_constraints
:toctree: generated/pytree
:recursive:
:nosignatures:

Expand Down
114 changes: 53 additions & 61 deletions docs/background/plot_01_1D_basis_function.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ warnings.filterwarnings(
),
category=RuntimeWarning,
)

```

(simple_basis_function)=
# Simple Basis Function

## Defining a 1D Basis Object

We'll start by defining a 1D basis function object of the type [`MSplineBasis`](nemos.basis.MSplineBasis).
We'll start by defining a 1D basis function object of the type [`MSplineEval`](nemos.basis.MSplineEval).
The hyperparameters required to initialize this class are:

- The number of basis functions, which should be a positive integer.
Expand All @@ -58,35 +59,26 @@ import pynapple as nap

import nemos as nmo

# configure plots some
plt.style.use(nmo.styles.plot_style)

# Initialize hyperparameters
order = 4
n_basis = 10

# Define the 1D basis function object
bspline = nmo.basis.BSplineBasis(n_basis_funcs=n_basis, order=order)
bspline = nmo.basis.BSplineEval(n_basis_funcs=n_basis, order=order)
```

## Evaluating a Basis

The [`Basis`](nemos.basis.Basis) object is callable, and can be evaluated as a function. By default, the support of the basis
is defined by the samples that we input to the [`__call__`](nemos.basis.Basis.__call__) method, and covers from the smallest to the largest value.

We provide the convenience method `evaluate_on_grid` for evaluating the basis on an equi-spaced grid of points that makes it easier to plot and visualize all basis elements.

```{code-cell} ipython3
# evaluate the basis on 100 sample points
x, y = bspline.evaluate_on_grid(100)

# Generate a time series of sample points
samples = nap.Tsd(t=np.arange(1001), d=np.linspace(0, 1,1001))

# Evaluate the basis at the sample points
eval_basis = bspline(samples)

# Output information about the evaluated basis
print(f"Evaluated B-spline of order {order} with {eval_basis.shape[1]} "
f"basis element and {eval_basis.shape[0]} samples.")

fig = plt.figure()
plt.title("B-spline basis")
plt.plot(samples, eval_basis);
fig = plt.figure(figsize=(5, 3))
plt.plot(x, y, lw=2)
plt.title("B-Spline Basis")
```

```{code-cell} ipython3
Expand All @@ -111,49 +103,18 @@ if path.exists():
fig.savefig(path / "plot_01_1D_basis_function.svg")
```

## Setting the basis support
Sometimes, it is useful to restrict the basis to a fixed range. This can help manage outliers or ensure that
your basis covers the same range across multiple experimental sessions.
You can specify a range for the support of your basis by setting the `bounds`
parameter at initialization. Evaluating the basis at any sample outside the bounds will result in a NaN.


```{code-cell} ipython3
bspline_range = nmo.basis.BSplineBasis(n_basis_funcs=n_basis, order=order, bounds=(0.2, 0.8))

print("Evaluated basis:")
# 0.5 is within the support, 0.1 is outside the support
print(np.round(bspline_range([0.5, 0.1]), 3))
```

Let's compare the default behavior of basis (estimating the range from the samples) with
the fixed range basis.

## Feature Computation
The bases in the `nemos.basis` module can be grouped into two categories:

```{code-cell} ipython3
fig, axs = plt.subplots(2,1, sharex=True)
plt.suptitle("B-spline basis ")
axs[0].plot(samples, bspline(samples), color="k")
axs[0].set_title("default")
axs[1].plot(samples, bspline_range(samples), color="tomato")
axs[1].set_title("bounds=[0.2, 0.8]")
plt.tight_layout()
```
1. **Evaluation Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to evaluate the basis directly, applying a non-linear transformation to the input. Classes in this category have names ending with "Eval," such as `BSplineEval`.

## Basis `mode`
In constructing features, [`Basis`](nemos.basis.Basis) objects can be used in two modalities: `"eval"` for evaluate or `"conv"`
for convolve. These two modalities change the behavior of the [`compute_features`](nemos.basis.Basis.compute_features) method of [`Basis`](nemos.basis.Basis), in particular,

- If a basis is in mode `"eval"`, then [`compute_features`](nemos.basis.Basis.compute_features) simply returns the evaluated basis.
- If a basis is in mode `"conv"`, then [`compute_features`](nemos.basis.Basis.compute_features) will convolve the input with a kernel of basis
with `window_size` specified by the user.
2. **Convolution Bases**: These bases use the [`compute_features`](nemos.basis._basis.Basis.compute_features) method to convolve the input with a kernel of basis elements, using a `window_size` specified by the user. Classes in this category have names ending with "Conv," such as `BSplineConv`.

Let's see how this two modalities operate.


```{code-cell} ipython3
eval_mode = nmo.basis.MSplineBasis(n_basis_funcs=n_basis, mode="eval")
conv_mode = nmo.basis.MSplineBasis(n_basis_funcs=n_basis, mode="conv", window_size=100)
eval_mode = nmo.basis.MSplineEval(n_basis_funcs=n_basis)
conv_mode = nmo.basis.MSplineConv(n_basis_funcs=n_basis, window_size=100)

# define an input
angles = np.linspace(0, np.pi*4, 201)
Expand Down Expand Up @@ -196,11 +157,10 @@ check out the tutorial on [1D convolutions](plot_03_1D_convolution).
:::



Plotting the Basis Function Elements:
--------------------------------------
We suggest visualizing the basis post-instantiation by evaluating each element on a set of equi-spaced sample points
and then plotting the result. The method [`Basis.evaluate_on_grid`](nemos.basis.Basis.evaluate_on_grid) is designed for this, as it generates and returns
and then plotting the result. The method [`Basis.evaluate_on_grid`](nemos.basis._basis.Basis.evaluate_on_grid) is designed for this, as it generates and returns
the equi-spaced samples along with the evaluated basis functions. The benefits of using Basis.evaluate_on_grid become
particularly evident when working with multidimensional basis functions. You can find more details and visual
background in the
Expand All @@ -219,6 +179,38 @@ plt.plot(equispaced_samples, eval_basis)
plt.show()
```


## Setting the basis support (Eval only)
Sometimes, it is useful to restrict the basis to a fixed range. This can help manage outliers or ensure that
your basis covers the same range across multiple experimental sessions.
You can specify a range for the support of your basis by setting the `bounds`
parameter at initialization of "Eval" type basis (it doesn't make sense for convolutions).
Evaluating the basis at any sample outside the bounds will result in a NaN.


```{code-cell} ipython3
bspline_range = nmo.basis.BSplineEval(n_basis_funcs=n_basis, order=order, bounds=(0.2, 0.8))

print("Evaluated basis:")
# 0.5 is within the support, 0.1 is outside the support
print(np.round(bspline_range.compute_features([0.5, 0.1]), 3))
```

Let's compare the default behavior of basis (estimating the range from the samples) with
the fixed range basis.


```{code-cell} ipython3
samples = np.linspace(0, 1, 200)
fig, axs = plt.subplots(2,1, sharex=True)
plt.suptitle("B-spline basis ")
axs[0].plot(samples, bspline.compute_features(samples), color="k")
axs[0].set_title("default")
axs[1].plot(samples, bspline_range.compute_features(samples), color="tomato")
axs[1].set_title("bounds=[0.2, 0.8]")
plt.tight_layout()
```

Other Basis Types
-----------------
Each basis type may necessitate specific hyperparameters for instantiation. For a comprehensive description,
Expand All @@ -228,8 +220,8 @@ evaluate a log-spaced cosine raised function basis.


```{code-cell} ipython3
# Instantiate the basis noting that the `RaisedCosineBasisLog` does not require an `order` parameter
raised_cosine_log = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=10, width=1.5, time_scaling=50)
# Instantiate the basis noting that the `RaisedCosineLog` basis does not require an `order` parameter
raised_cosine_log = nmo.basis.RaisedCosineLogEval(n_basis_funcs=10, width=1.5, time_scaling=50)

# Evaluate the raised cosine basis at the equi-spaced sample points
# (same method in all Basis elements)
Expand Down
36 changes: 17 additions & 19 deletions docs/background/plot_02_ND_basis_function.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,11 @@ Here, we simply add two basis objects, `a_basis` and `b_basis`, together to defi
```{code-cell} ipython3
import matplotlib.pyplot as plt
import numpy as np

import nemos as nmo

# Define 1D basis objects
a_basis = nmo.basis.MSplineBasis(n_basis_funcs=15, order=3)
b_basis = nmo.basis.RaisedCosineBasisLog(n_basis_funcs=14)
a_basis = nmo.basis.MSplineEval(n_basis_funcs=15, order=3)
b_basis = nmo.basis.RaisedCosineLogEval(n_basis_funcs=14)

# Define the 2D additive basis object
additive_basis = a_basis + b_basis
Expand All @@ -151,7 +150,7 @@ x_coord = np.linspace(0, 1, 1000)
y_coord = np.linspace(0, 1, 1000)

# Evaluate the basis functions for the given trajectory.
eval_basis = additive_basis(x_coord, y_coord)
eval_basis = additive_basis.compute_features(x_coord, y_coord)

print(f"Sum of two 1D splines with {eval_basis.shape[1]} "
f"basis element and {eval_basis.shape[0]} samples:\n"
Expand All @@ -170,13 +169,13 @@ basis_b_element = 1
fig, axs = plt.subplots(1, 2, figsize=(6, 3))

axs[0].set_title(f"$a_{{{basis_a_element}}}(x)$", color="b")
axs[0].plot(x_coord, a_basis(x_coord), "grey", alpha=.3)
axs[0].plot(x_coord, a_basis(x_coord)[:, basis_a_element], "b")
axs[0].plot(x_coord, a_basis.compute_features(x_coord), "grey", alpha=.3)
axs[0].plot(x_coord, a_basis.compute_features(x_coord)[:, basis_a_element], "b")
axs[0].set_xlabel("x-coord")

axs[1].set_title(f"$b_{{{basis_b_element}}}(x)$", color="b")
axs[1].plot(y_coord, b_basis(x_coord), "grey", alpha=.3)
axs[1].plot(y_coord, b_basis(x_coord)[:, basis_b_element], "b")
axs[1].plot(y_coord, b_basis.compute_features(x_coord), "grey", alpha=.3)
axs[1].plot(y_coord, b_basis.compute_features(x_coord)[:, basis_b_element], "b")
axs[1].set_xlabel("y-coord")
plt.tight_layout()
```
Expand Down Expand Up @@ -243,7 +242,7 @@ The number of elements of the product basis will be the product of the elements

```{code-cell} ipython3
# Evaluate the product basis at the x and y coordinates
eval_basis = prod_basis(x_coord, y_coord)
eval_basis = prod_basis.compute_features(x_coord, y_coord)

# Output the number of elements and samples of the evaluated basis,
# as well as the number of elements in the original 1D basis objects
Expand All @@ -269,19 +268,19 @@ fig, axs = plt.subplots(3,3,figsize=(8, 6))
cc = 0
for i, j in element_pairs:
# plot the element form a_basis
axs[cc, 0].plot(x_coord, a_basis(x_coord), "grey", alpha=.3)
axs[cc, 0].plot(x_coord, a_basis(x_coord)[:, i], "b")
axs[cc, 0].plot(x_coord, a_basis.compute_features(x_coord), "grey", alpha=.3)
axs[cc, 0].plot(x_coord, a_basis.compute_features(x_coord)[:, i], "b")
axs[cc, 0].set_title(f"$a_{{{i}}}(x)$",color='b')

# plot the element form b_basis
axs[cc, 1].plot(y_coord, b_basis(y_coord), "grey", alpha=.3)
axs[cc, 1].plot(y_coord, b_basis(y_coord)[:, j], "b")
axs[cc, 1].plot(y_coord, b_basis.compute_features(y_coord), "grey", alpha=.3)
axs[cc, 1].plot(y_coord, b_basis.compute_features(y_coord)[:, j], "b")
axs[cc, 1].set_title(f"$b_{{{j}}}(y)$",color='b')

# select & plot the corresponding product basis element
k = i * b_basis.n_basis_funcs + j
axs[cc, 2].contourf(X, Y, Z[:, :, k], cmap='Blues')
axs[cc, 2].set_title(f"$A_{{{k}}}(x,y) = a_{{{i}}}(x) \cdot b_{{{j}}}(y)$", color='b')
axs[cc, 2].set_title(fr"$A_{{{k}}}(x,y) = a_{{{i}}}(x) \cdot b_{{{j}}}(y)$", color='b')
axs[cc, 2].set_xlabel('x-coord')
axs[cc, 2].set_ylabel('y-coord')
axs[cc, 2].set_aspect("equal")
Expand Down Expand Up @@ -323,7 +322,6 @@ in a linear maze and the LFP phase angle.
:::



N-Dimensional Basis
-------------------
Sometimes it may be useful to model even higher dimensional interactions, for example between the heding direction of
Expand All @@ -341,13 +339,13 @@ will output a $K^N \times T$ matrix.
T = 10
n_basis = 8

a_basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis)
b_basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis)
c_basis = nmo.basis.RaisedCosineBasisLinear(n_basis_funcs=n_basis)
a_basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs=n_basis)
b_basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs=n_basis)
c_basis = nmo.basis.RaisedCosineLinearEval(n_basis_funcs=n_basis)

prod_basis_3 = a_basis * b_basis * c_basis
samples = np.linspace(0, 1, T)
eval_basis = prod_basis_3(samples, samples, samples)
eval_basis = prod_basis_3.compute_features(samples, samples, samples)

print(f"Product of three 1D splines results in {prod_basis_3.n_basis_funcs} "
f"basis elements.\nEvaluation output of shape {eval_basis.shape}")
Expand Down
Loading
Loading