Skip to content

Commit

Permalink
Merge branch 'develop' into deprecated-deep-learning
Browse files Browse the repository at this point in the history
# Conflicts:
#	moabb/tests/util_braindecode.py
  • Loading branch information
bruAristimunha committed Oct 18, 2024
2 parents ce7d57c + 98d041a commit d473045
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 4 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ exclude: ".*svg"

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: check-yaml
- id: check-json
Expand All @@ -35,7 +35,7 @@ repos:


- repo: https://github.com/psf/black
rev: 24.4.2
rev: 24.10.0
hooks:
- id: black
language_version: python3
Expand All @@ -54,7 +54,7 @@ repos:
- id: isort

- repo: https://github.com/PyCQA/flake8
rev: 7.1.0
rev: 7.1.1
hooks:
- id: flake8
additional_dependencies: [
Expand All @@ -69,7 +69,7 @@ repos:
exclude: ^docs/ | ^setup\.py$ |

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.0
rev: v0.6.9
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix, --ignore, E501 ]
Expand Down
175 changes: 175 additions & 0 deletions moabb/tests/util_braindecode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import unittest

import numpy as np
import pytest
from mne import EpochsArray, create_info
from sklearn.preprocessing import LabelEncoder


try:
from braindecode.datasets.base import BaseConcatDataset, WindowsDataset
from braindecode.datasets.xy import create_from_X_y

from moabb.pipelines.utils_pytorch import BraindecodeDatasetLoader

no_braindecode = False
except ImportError:
no_braindecode = None


from moabb.datasets.fake import FakeDataset
from moabb.tests import SimpleMotorImagery


@pytest.fixture(scope="module")
def data():
"""Return EEG data from dataset to test transformer."""
paradigm = SimpleMotorImagery()
dataset = FakeDataset(paradigm="imagery")
X, labels, metadata = paradigm.get_data(dataset, subjects=[1], return_epochs=True)
y = LabelEncoder().fit_transform(labels)
return X, y, labels, metadata


@pytest.mark.skipif(no_braindecode is None, reason="Braindecode is not installed")
class TestTransformer:
def test_transform_input_and_output_shape(self, data):
X, y, _, info = data
transformer = BraindecodeDatasetLoader()
braindecode_dataset = transformer.fit_transform(X, y=y)
assert (
len(braindecode_dataset) == X.get_data().shape[0]
and braindecode_dataset[0][0].shape[0] == X.get_data().shape[1]
and braindecode_dataset[0][0].shape[1] == X.get_data().shape[2]
)

def test_sklearn_is_fitted(self, data):
transformer = BraindecodeDatasetLoader()
assert transformer.__sklearn_is_fitted__()

def test_transformer_fit(self, data):
"""Test whether transformer can fit to some training data."""
X_train, y_train, _, _ = data
transformer = BraindecodeDatasetLoader()
assert transformer.fit(X_train, y_train) == transformer

def test_transformer_transform_returns_dataset(self, data):
"""Test whether the output of the transform method is a
BaseConcatDataset."""
X_train, y_train, _, _ = data
transformer = BraindecodeDatasetLoader()
dataset = transformer.fit(X_train, y_train).transform(X_train, y_train)
assert isinstance(dataset, BaseConcatDataset)

def test_transformer_transform_contents(self, data):
"""Test whether the contents and metadata of a transformed dataset are
correct."""
X_train, y_train, _, _ = data
transformer = BraindecodeDatasetLoader()
dataset = transformer.fit(X_train, y_train).transform(X_train, y_train)
assert len(dataset) == len(X_train)
# test the properties of one epoch - that they match the input MNE Epoch object
sample_epoch = dataset.datasets[0][0]
# assert with approximately equal values
assert np.allclose(sample_epoch[0], X_train.get_data()[0])
assert sample_epoch[1] == y_train[0]

def test_sfreq_passed_through(self, data):
"""Test if the sfreq parameter makes it through the transformer."""
sfreq = 128.0
info = create_info(ch_names=["test"], sfreq=sfreq, ch_types=["eeg"])
data = np.random.normal(size=(2, 1, 10 * int(sfreq))) * 1e-6
# create some noise data in a 10s window
epochs = EpochsArray(data, info=info)
y_train = np.array([0])
transformer = BraindecodeDatasetLoader()
dataset = transformer.fit(epochs, y_train).transform(epochs, y_train)

if not isinstance(dataset.datasets[0], WindowsDataset):
assert dataset.datasets[0].raw.info["sfreq"] == sfreq
else:
assert dataset.datasets[0].windows.info["sfreq"] == sfreq

def test_kw_args_initialization(self):
"""Test initializing the transformer with kw_args."""
kw_args = {"sampling_rate": 128}
transformer = BraindecodeDatasetLoader(kw_args=kw_args)
assert transformer.kw_args == kw_args

def test_is_fitted_method(self):
"""Test __sklearn_is_fitted__ returns True."""
transformer = BraindecodeDatasetLoader()
is_fitter = transformer.__sklearn_is_fitted__()
assert is_fitter

def test_assert_raises_value_error(self, data):
"""Test that an invalid argument gives a ValueError."""
X_train, y_train, _, _ = data
transformer = BraindecodeDatasetLoader()
invalid_param_name = "invalid"
with pytest.raises(TypeError):
transformer.fit(X_train, y=y_train, **{invalid_param_name: None})

def test_type_create_from_X_y_vs_transfomer(self, data):
"""Test the type from create_from_X_y() and the transformer."""
X_train, y_train, _, _ = data

dataset = create_from_X_y(
X_train.get_data(),
y=y_train,
window_size_samples=X_train.get_data().shape[2],
window_stride_samples=X_train.get_data().shape[2],
drop_last_window=False,
sfreq=X_train.info["sfreq"],
)
transformer = BraindecodeDatasetLoader()
dataset_trans = transformer.fit(X=X_train, y=y_train).transform(X_train)
assert isinstance(dataset_trans, BaseConcatDataset)
assert isinstance(type(dataset_trans), type(dataset))

def test_wrong_input(self):
"""Test that an invalid input raises a ValueError."""
transformer = BraindecodeDatasetLoader()
with pytest.raises(ValueError):
transformer.fit_transform(np.random.normal(size=(2, 1, 10)), y=np.array([0]))

def test_transformer_transform_with_custom_y(self, data):
"""Test whether the provided y is used during transform."""
X_train, y_train, _, _ = data
transformer = BraindecodeDatasetLoader()

# Create test data with different y values
X_test = X_train.copy()
y_test = y_train + 1

# Fit the transformer with training data and custom y
transformer.fit(X_train, y_train)

# Transform the test data with custom y
dataset_test = transformer.transform(X_test, y=y_test)

# Verify that the transformed dataset contains the test data's x values and the custom y values
assert len(dataset_test) == len(X_test)
assert np.array_equal(dataset_test[0][1], y_test[0])
assert np.array_equal(dataset_test[1][1], y_test[1])

def test_transformer_transform_with_default_y(self, data):
"""Test whether self.y is used when y is not provided during
transform."""
X_train, y_train, _, _ = data
transformer = BraindecodeDatasetLoader()

# Fit the transformer with training data and default y
transformer.fit(X_train, y_train)

# Transform the test data without providing y
dataset_test = transformer.transform(X_train)

# Verify that the transformed dataset contains the training data's x values and the default y values
assert len(dataset_test) == len(X_train)
assert np.array_equal(dataset_test[0][1], y_train[0])
assert np.array_equal(dataset_test[1][1], y_train[1])


if __name__ == "__main__":
unittest.main()

0 comments on commit d473045

Please sign in to comment.