Skip to content

Commit

Permalink
[MNT] Solving moabb and braindecode compatibility (#669)
Browse files Browse the repository at this point in the history
* first try

* done

* done

* whats new file

* whats new file

* whats new file

* Easy fix with Sosulski2019 ;)

* [FIX] Applying suggestions from the revision

* [FIX] fixing Ofner2017 too!

* [FIX] fixing GrosseWentrup2009 too!

* [FIX] Solving Liu dataset!
  • Loading branch information
bruAristimunha authored Oct 24, 2024
1 parent ef22069 commit 2f49da8
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 17 deletions.
4 changes: 3 additions & 1 deletion docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ Bugs
- Fix Stieger2021 dataset bugs (:gh:`651` by `Martin Wimpff`_)
- Unpinning major version Scikit-learn and numpy (:gh:`652` by `Bruno Aristimunha`_)
- Replacing the func:`numpy.string_` to func:`numpy.bytes_` (:gh:`665` by `Bruno Aristimunha`_)
- Fixing the set_download_dir that was not working when we tried to set the dir more than 10 times at the same time (:gh:`668` by `Bruno Aristimunha`_)
- Fixing the set_download_dir that was not working when we tried to set the dir more than 10 times at the same time (:gh:`668` by `Bruno Aristimunha`_)
- Creating stimulus channels in :class:`moabb.datasets.Zhou2016` and :class:`moabb.datasets.PhysionetMI` to allow braindecode compatibility (:gh:`669` by `Bruno Aristimunha`_)


API changes
~~~~~~~~~~~
Expand Down
6 changes: 5 additions & 1 deletion moabb/datasets/Zhou2016.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .base import BaseDataset
from .download import get_dataset_path
from .utils import stim_channels_with_selected_ids


DATA_PATH = "https://ndownloader.figshare.com/files/3662952"
Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(self):
paradigm="imagery",
doi="10.1371/journal.pone.0162657",
)
self.events = dict(left_hand=1, right_hand=2, feet=3)

def _get_single_subject_data(self, subject):
"""Return data for a single subject."""
Expand All @@ -105,7 +107,9 @@ def _get_single_subject_data(self, subject):
stim[stim == "2"] = "right_hand"
stim[stim == "3"] = "feet"
raw.annotations.description = stim
out[sess_key][run_key] = raw
out[sess_key][run_key] = stim_channels_with_selected_ids(
raw, desired_event_id=self.events
)
out[sess_key][run_key].set_montage(make_standard_montage("standard_1005"))
return out

Expand Down
11 changes: 6 additions & 5 deletions moabb/datasets/liu2024.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from moabb.datasets import download as dl
from moabb.datasets.base import BaseDataset
from moabb.datasets.utils import stim_channels_with_selected_ids


# Link to the raw data
Expand Down Expand Up @@ -77,15 +78,15 @@ class Liu2024(BaseDataset):
def __init__(self, break_events=False, instr_events=False):
self.break_events = break_events
self.instr_events = instr_events
events = {"left_hand": 1, "right_hand": 2}
self.events = {"left_hand": 1, "right_hand": 2}
if break_events:
events["instr"] = 3
self.events["instr"] = 3
if instr_events:
events["break"] = 4
self.events["break"] = 4
super().__init__(
subjects=list(range(1, 50 + 1)),
sessions_per_subject=1,
events=events,
events=self.events,
code="Liu2024",
interval=(2, 6),
paradigm="imagery",
Expand Down Expand Up @@ -277,7 +278,7 @@ def _get_single_subject_data(self, subject):
# Loading dataset
raw = raw.load_data(verbose=False)
# There is only one session
sessions = {"0": {"0": raw}}
sessions = {"0": {"0": stim_channels_with_selected_ids(raw, self.event_id)}}

return sessions

Expand Down
6 changes: 4 additions & 2 deletions moabb/datasets/mpi_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from moabb.datasets import download as dl
from moabb.datasets.base import BaseDataset
from moabb.datasets.utils import stim_channels_with_selected_ids
from moabb.utils import depreciated_alias


Expand Down Expand Up @@ -56,10 +57,11 @@ class GrosseWentrup2009(BaseDataset):
"""

def __init__(self):
self.events_id = dict(right_hand=2, left_hand=1)
super().__init__(
subjects=list(range(1, 11)),
sessions_per_subject=1,
events=dict(right_hand=2, left_hand=1),
events=self.events_id,
code="GrosseWentrup2009",
interval=[0, 7],
paradigm="imagery",
Expand All @@ -76,7 +78,7 @@ def _get_single_subject_data(self, subject):
stim[stim == "20"] = "right_hand"
stim[stim == "10"] = "left_hand"
raw.annotations.description = stim
return {"0": {"0": raw}}
return {"0": {"0": stim_channels_with_selected_ids(raw, self.event_id)}}

def data_path(
self, subject, path=None, force_update=False, update_path=None, verbose=None
Expand Down
11 changes: 8 additions & 3 deletions moabb/datasets/physionet_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from moabb.datasets.base import BaseDataset
from moabb.datasets.download import data_dl, get_dataset_path
from moabb.datasets.utils import stim_channels_with_selected_ids


BASE_URL = "https://physionet.org/files/eegmmidb/1.0.0/"
Expand Down Expand Up @@ -79,7 +80,7 @@ def __init__(self, imagined=True, executed=False):
paradigm="imagery",
doi="10.1109/TBME.2004.827072",
)

self.events = dict(left_hand=2, right_hand=3, feet=5, hands=4, rest=1)
self.imagined = imagined
self.executed = executed
self.feet_runs = []
Expand Down Expand Up @@ -123,7 +124,9 @@ def _get_single_subject_data(self, subject):
stim[stim == "T1"] = "left_hand"
stim[stim == "T2"] = "right_hand"
raw.annotations.description = stim
data[str(idx)] = raw
data[str(idx)] = stim_channels_with_selected_ids(
raw, desired_event_id=self.events
)
idx += 1

# feet runs
Expand All @@ -136,7 +139,9 @@ def _get_single_subject_data(self, subject):
stim[stim == "T1"] = "hands"
stim[stim == "T2"] = "feet"
raw.annotations.description = stim
data[str(idx)] = raw
data[str(idx)] = stim_channels_with_selected_ids(
raw, desired_event_id=self.events
)
idx += 1

return {"0": data}
Expand Down
6 changes: 4 additions & 2 deletions moabb/datasets/sosulski2019.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from moabb.datasets import download as dl
from moabb.datasets.base import BaseDataset
from moabb.datasets.utils import stim_channels_with_selected_ids


SPOT_PILOT_P300_URL = (
Expand Down Expand Up @@ -95,12 +96,13 @@ def __init__(
self.n_channels = 31
self.use_soas_as_sessions = use_soas_as_sessions
self.description_map = {"Stimulus/S 21": "Target", "Stimulus/S 1": "NonTarget"}
self.events = dict(Target=21, NonTarget=1)
code = "Sosulski2019"
interval = [-0.2, 1] if interval is None else interval
super().__init__(
subjects=list(range(1, 13 + 1)),
sessions_per_subject=1,
events=dict(Target=21, NonTarget=1),
events=self.events,
code=code,
interval=interval,
paradigm="p300",
Expand Down Expand Up @@ -133,7 +135,7 @@ def _get_single_run_data(self, file_path):
if self.reject_non_iid:
raw.set_annotations(raw.annotations[7:85]) # non-iid rejection
raw.annotations.rename(self.description_map)
return raw
return stim_channels_with_selected_ids(raw, self.events)

def _get_single_subject_data(self, subject):
"""Return data for a single subject."""
Expand Down
7 changes: 4 additions & 3 deletions moabb/datasets/upper_limb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from mne.io import read_raw_gdf

from moabb.datasets.base import BaseDataset
from moabb.datasets.utils import stim_channels_with_selected_ids

from . import download as dl

Expand Down Expand Up @@ -58,7 +59,7 @@ class Ofner2017(BaseDataset):
def __init__(self, imagined=True, executed=False):
self.imagined = imagined
self.executed = executed
event_id = {
self.event_id = {
"right_elbow_flexion": 1536,
"right_elbow_extension": 1537,
"right_supination": 1538,
Expand All @@ -72,7 +73,7 @@ def __init__(self, imagined=True, executed=False):
super().__init__(
subjects=list(range(1, 16)),
sessions_per_subject=n_sessions,
events=event_id,
events=self.event_id,
code="Ofner2017",
interval=[0, 3], # according to paper 2-5
paradigm="imagery",
Expand Down Expand Up @@ -114,7 +115,7 @@ def _get_single_subject_data(self, subject):
stim[stim == "1541"] = "right_hand_open"
stim[stim == "1542"] = "rest"
raw.annotations.description = stim
data[str(ii)] = raw
data[str(ii)] = stim_channels_with_selected_ids(raw, self.event_id)

out[session_name] = data
return out
Expand Down
50 changes: 50 additions & 0 deletions moabb/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import inspect

import mne
import numpy as np
from mne import create_info
from mne.io import RawArray
Expand Down Expand Up @@ -273,3 +274,52 @@ def add_stim_channel_epoch(
)
raw = raw.add_channels([RawArray(data=stim_chan, info=info, verbose=False)])
return raw


def stim_channels_with_selected_ids(
raw: mne.io.BaseRaw, desired_event_id: dict, stim_channel_name="STIM"
):
"""
Add a stimulus channel with filtering and renaming based on events_ids.
Parameters
----------
raw: mne.Raw
The raw object to add the stimulus channel to.
desired_event_id: dict
Dictionary with events
"""

# Get events using the consistent event_id mapping
events, _ = mne.events_from_annotations(raw, event_id=desired_event_id)

# Filter the events array to include only desired events
desired_event_ids = list(desired_event_id.values())
filtered_events = events[np.isin(events[:, 2], desired_event_ids)]

# Create annotations from filtered events using the inverted mapping
event_desc = {v: k for k, v in desired_event_id.items()}
annot_from_events = mne.annotations_from_events(
events=filtered_events,
event_desc=event_desc,
sfreq=raw.info["sfreq"],
orig_time=raw.info["meas_date"],
)
raw.set_annotations(annot_from_events)

# Create the stim channel data array
stim_channs = np.zeros((1, raw.n_times))
for event in filtered_events:
sample_index = event[0]
event_code = event[2] # Consistent event IDs
stim_channs[0, sample_index] = event_code

# Create the stim channel and add it to raw

stim_info = mne.create_info(
[stim_channel_name], sfreq=raw.info["sfreq"], ch_types=["stim"]
)
stim_raw = mne.io.RawArray(stim_channs, stim_info, verbose=False)
raw_with_stim = raw.copy().add_channels([stim_raw], force_update_info=True)

return raw_with_stim

0 comments on commit 2f49da8

Please sign in to comment.