From b179813cab599b5ddc516f5390a5b0d404ee9723 Mon Sep 17 00:00:00 2001 From: Jan Sosulski Date: Thu, 20 Feb 2020 10:11:55 +0100 Subject: [PATCH] allow retrieval of epochs instead of np.ndarray in process_raw --- README.md | 2 +- moabb/datasets/gigadb.py | 2 +- moabb/paradigms/base.py | 33 +++++++++++++++++++++++---------- moabb/paradigms/p300.py | 7 +++++-- 4 files changed, 30 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 987b500e8..341bbceee 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ If you want to report a problem or suggest an enhancement we'd love for you to [ You might be interested in: -* [MOABB documentaion][link_moabb_docs] +* [MOABB documentation][link_moabb_docs] And of course, you'll want to know our: diff --git a/moabb/datasets/gigadb.py b/moabb/datasets/gigadb.py index 1234211c8..b07c5b73b 100644 --- a/moabb/datasets/gigadb.py +++ b/moabb/datasets/gigadb.py @@ -80,7 +80,7 @@ def _get_single_subject_data(self, subject): 'FC5', 'FC3', 'FC1', 'C1', 'C3', 'C5', 'T7', 'TP7', 'CP5', 'CP3', 'CP1', 'P1', 'P3', 'P5', 'P7', 'P9', 'PO7', 'PO3', 'O1', 'Iz', 'Oz', 'POz', 'Pz', 'CPz', - 'FPz', 'FP2', 'AF8', 'AF4', 'AFz', 'Fz', 'F2', 'F4', + 'Fpz', 'Fp2', 'AF8', 'AF4', 'AFz', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT8', 'FC6', 'FC4', 'FC2', 'FCz', 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 'CP4', 'CP2', 'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2'] diff --git a/moabb/paradigms/base.py b/moabb/paradigms/base.py index 9292a1360..3946682dc 100644 --- a/moabb/paradigms/base.py +++ b/moabb/paradigms/base.py @@ -62,7 +62,7 @@ def prepare_process(self, dataset): """ pass - def process_raw(self, raw, dataset): + def process_raw(self, raw, dataset, return_epochs=False): """ Process one raw data file. @@ -83,10 +83,16 @@ def process_raw(self, raw, dataset): The dataset corresponding to the raw file. mainly use to access dataset specific information. + return_epochs: boolean + This flag specifies whether to return only the data array or the + complete processed mne.Epochs + returns ------- - X : np.ndarray + X : Union[np.ndarray, mne.Epochs] the data that will be used as features for the model + Note: if return_epochs=True, this is mne.Epochs + if return_epochs=False, this is np.ndarray labels: np.ndarray the labels for training / evaluating the model @@ -141,7 +147,10 @@ def process_raw(self, raw, dataset): if self.resample is not None: epochs = epochs.resample(self.resample) # rescale to work with uV - X.append(dataset.unit_factor * epochs.get_data()) + if return_epochs: + X.append(epochs) + else: + X.append(dataset.unit_factor * epochs.get_data()) inv_events = {k: v for v, k in event_id.items()} labels = np.array([inv_events[e] for e in epochs.events[:, -1]]) @@ -155,7 +164,7 @@ def process_raw(self, raw, dataset): metadata = pd.DataFrame(index=range(len(labels))) return X, labels, metadata - def get_data(self, dataset, subjects=None): + def get_data(self, dataset, subjects=None, return_epochs=False): """ Return the data for a list of subject. @@ -197,7 +206,8 @@ def get_data(self, dataset, subjects=None): for subject, sessions in data.items(): for session, runs in sessions.items(): for run, raw in runs.items(): - proc = self.process_raw(raw, dataset) + proc = self.process_raw(raw, dataset, + return_epochs=return_epochs) if proc is None: # this mean the run did not contain any selected event @@ -211,12 +221,15 @@ def get_data(self, dataset, subjects=None): metadata.append(met) # grow X and labels in a memory efficient way. can be slow - if len(X) > 0: - X = np.append(X, x, axis=0) - labels = np.append(labels, lbs, axis=0) + if not return_epochs: + if len(X) > 0: + X = np.append(X, x, axis=0) + labels = np.append(labels, lbs, axis=0) + else: + X = x + labels = lbs else: - X = x - labels = lbs + X.append(x) metadata = pd.concat(metadata, ignore_index=True) return X, labels, metadata diff --git a/moabb/paradigms/p300.py b/moabb/paradigms/p300.py index 4e21c91e3..e6bf29551 100644 --- a/moabb/paradigms/p300.py +++ b/moabb/paradigms/p300.py @@ -79,7 +79,7 @@ def is_valid(self, dataset): def used_events(self, dataset): pass - def process_raw(self, raw, dataset): + def process_raw(self, raw, dataset, return_epochs=False): # find the events, first check stim_channels then annotations stim_channels = mne.utils._get_stim_channel( None, raw.info, raise_error=False) @@ -126,7 +126,10 @@ def process_raw(self, raw, dataset): if self.resample is not None: epochs = epochs.resample(self.resample) # rescale to work with uV - X.append(dataset.unit_factor * epochs.get_data()) + if return_epochs: + X.append(epochs) + else: + X.append(dataset.unit_factor * epochs.get_data()) inv_events = {k: v for v, k in event_id.items()} labels = np.array([inv_events[e] for e in epochs.events[:, -1]])