diff --git a/doc/changes/devel/13044.newfeature.rst b/doc/changes/devel/13044.newfeature.rst new file mode 100644 index 00000000000..9633aba66b9 --- /dev/null +++ b/doc/changes/devel/13044.newfeature.rst @@ -0,0 +1 @@ +Add :meth:`mne.Evoked.interpolate_to` to allow interpolating EEG data to other montages, by :newcontrib:`Antoine Collas`. diff --git a/doc/changes/names.inc b/doc/changes/names.inc index 18753a0c872..282fa8341a0 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -24,6 +24,7 @@ .. _Anna Padee: https://github.com/apadee/ .. _Annalisa Pascarella: https://www.iac.cnr.it/personale/annalisa-pascarella .. _Anne-Sophie Dubarry: https://github.com/annesodub +.. _Antoine Collas: https://www.antoinecollas.fr .. _Antoine Gauthier: https://github.com/Okamille .. _Antti Rantala: https://github.com/Odingod .. _Apoorva Karekal: https://github.com/apoorva6262 diff --git a/doc/references.bib b/doc/references.bib index e2578ed18f2..f0addb5f3b2 100644 --- a/doc/references.bib +++ b/doc/references.bib @@ -2514,3 +2514,11 @@ @article{OyamaEtAl2015 year = {2015}, pages = {24--36}, } + +@inproceedings{MellotEtAl2024, + title = {Physics-informed and Unsupervised Riemannian Domain Adaptation for Machine Learning on Heterogeneous EEG Datasets}, + author = {Mellot, Apolline and Collas, Antoine and Chevallier, Sylvain and Engemann, Denis and Gramfort, Alexandre}, + booktitle = {Proceedings of the 32nd European Signal Processing Conference (EUSIPCO)}, + year = {2024}, + address = {Lyon, France} +} diff --git a/examples/preprocessing/interpolate_to.py b/examples/preprocessing/interpolate_to.py new file mode 100644 index 00000000000..b97a7251cbb --- /dev/null +++ b/examples/preprocessing/interpolate_to.py @@ -0,0 +1,81 @@ +""" +.. _ex-interpolate-to-any-montage: + +====================================================== +Interpolate EEG data to any montage +====================================================== + +This example demonstrates how to interpolate EEG channels to match a given montage. +This can be useful for standardizing +EEG channel layouts across different datasets (see :footcite:`MellotEtAl2024`). + +- Using the field interpolation for EEG data. +- Using the target montage "biosemi16". + +In this example, the data from the original EEG channels will be +interpolated onto the positions defined by the "biosemi16" montage. +""" + +# Authors: Antoine Collas +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import matplotlib.pyplot as plt + +import mne +from mne.channels import make_standard_montage +from mne.datasets import sample + +print(__doc__) +ylim = (-10, 10) + +# %% +# Load EEG data +data_path = sample.data_path() +eeg_file_path = data_path / "MEG" / "sample" / "sample_audvis-ave.fif" +evoked = mne.read_evokeds(eeg_file_path, condition="Left Auditory", baseline=(None, 0)) + +# Select only EEG channels +evoked.pick("eeg") + +# Plot the original EEG layout +evoked.plot(exclude=[], picks="eeg", ylim=dict(eeg=ylim)) + +# %% +# Define the target montage +standard_montage = make_standard_montage("biosemi16") + +# %% +# Use interpolate_to to project EEG data to the standard montage +evoked_interpolated_spline = evoked.copy().interpolate_to( + standard_montage, method="spline" +) + +# Plot the interpolated EEG layout +evoked_interpolated_spline.plot(exclude=[], picks="eeg", ylim=dict(eeg=ylim)) + +# %% +# Use interpolate_to to project EEG data to the standard montage +evoked_interpolated_mne = evoked.copy().interpolate_to(standard_montage, method="MNE") + +# Plot the interpolated EEG layout +evoked_interpolated_mne.plot(exclude=[], picks="eeg", ylim=dict(eeg=ylim)) + +# %% +# Comparing before and after interpolation +fig, axs = plt.subplots(3, 1, figsize=(8, 6), constrained_layout=True) +evoked.plot(exclude=[], picks="eeg", axes=axs[0], show=False, ylim=dict(eeg=ylim)) +axs[0].set_title("Original EEG Layout") +evoked_interpolated_spline.plot( + exclude=[], picks="eeg", axes=axs[1], show=False, ylim=dict(eeg=ylim) +) +axs[1].set_title("Interpolated to Standard 1020 Montage using spline interpolation") +evoked_interpolated_mne.plot( + exclude=[], picks="eeg", axes=axs[2], show=False, ylim=dict(eeg=ylim) +) +axs[2].set_title("Interpolated to Standard 1020 Montage using MNE interpolation") + +# %% +# References +# ---------- +# .. footbibliography:: diff --git a/mne/channels/channels.py b/mne/channels/channels.py index bf9e58f2819..d0e57eecb5f 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -41,7 +41,7 @@ pick_info, pick_types, ) -from .._fiff.proj import setup_proj +from .._fiff.proj import _has_eeg_average_ref_proj, setup_proj from .._fiff.reference import add_reference_channels, set_eeg_reference from .._fiff.tag import _rename_list from ..bem import _check_origin @@ -960,6 +960,162 @@ def interpolate_bads( return self + def interpolate_to(self, sensors, origin="auto", method="spline", reg=0.0): + """Interpolate EEG data onto a new montage. + + .. warning:: + Be careful, only EEG channels are interpolated. Other channel types are + not interpolated. + + Parameters + ---------- + sensors : DigMontage + The target montage containing channel positions to interpolate onto. + origin : array-like, shape (3,) | str + Origin of the sphere in the head coordinate frame and in meters. + Can be ``'auto'`` (default), which means a head-digitization-based + origin fit. + method : str + Method to use for EEG channels. + Supported methods are 'spline' (default) and 'MNE'. + reg : float + The regularization parameter for the interpolation method + (only used when the method is 'spline'). + + Returns + ------- + inst : instance of Raw, Epochs, or Evoked + The instance with updated channel locations and data. + + Notes + ----- + This method is useful for standardizing EEG layouts across datasets. + However, some attributes may be lost after interpolation. + + .. versionadded:: 1.10.0 + """ + from ..epochs import BaseEpochs, EpochsArray + from ..evoked import Evoked, EvokedArray + from ..forward._field_interpolation import _map_meg_or_eeg_channels + from ..io import RawArray + from ..io.base import BaseRaw + from .interpolation import _make_interpolation_matrix + from .montage import DigMontage + + # Check that the method option is valid. + _check_option("method", method, ["spline", "MNE"]) + _validate_type(sensors, DigMontage, "sensors") + + # Get target positions from the montage + ch_pos = sensors.get_positions().get("ch_pos", {}) + target_ch_names = list(ch_pos.keys()) + if not target_ch_names: + raise ValueError( + "The provided sensors configuration has no channel positions." + ) + + # Get original channel order + orig_names = self.info["ch_names"] + + # Identify EEG channel + picks_good_eeg = pick_types(self.info, meg=False, eeg=True, exclude="bads") + if len(picks_good_eeg) == 0: + raise ValueError("No good EEG channels available for interpolation.") + # Also get the full list of EEG channel indices (including bad channels) + picks_remove_eeg = pick_types(self.info, meg=False, eeg=True, exclude=[]) + eeg_names_orig = [orig_names[i] for i in picks_remove_eeg] + + # Identify non-EEG channels in original order + non_eeg_names_ordered = [ch for ch in orig_names if ch not in eeg_names_orig] + + # Create destination info for new EEG channels + sfreq = self.info["sfreq"] + info_interp = create_info( + ch_names=target_ch_names, + sfreq=sfreq, + ch_types=["eeg"] * len(target_ch_names), + ) + info_interp.set_montage(sensors) + info_interp["bads"] = [ch for ch in self.info["bads"] if ch in target_ch_names] + # Do not assign "projs" directly. + + # Compute the interpolation mapping + if method == "spline": + origin_val = _check_origin(origin, self.info) + pos_from = self.info._get_channel_positions(picks_good_eeg) - origin_val + pos_to = np.stack(list(ch_pos.values()), axis=0) + + def _check_pos_sphere(pos): + d = np.linalg.norm(pos, axis=-1) + d_norm = np.mean(d / np.mean(d)) + if np.abs(1.0 - d_norm) > 0.1: + warn("Your spherical fit is poor; interpolation may be inaccurate.") + + _check_pos_sphere(pos_from) + _check_pos_sphere(pos_to) + mapping = _make_interpolation_matrix(pos_from, pos_to, alpha=reg) + + else: + assert method == "MNE" + info_eeg = pick_info(self.info, picks_good_eeg) + # If the original info has an average EEG reference projector but + # the destination info does not, + # update info_interp via a temporary RawArray. + if _has_eeg_average_ref_proj(self.info) and not _has_eeg_average_ref_proj( + info_interp + ): + # Create dummy data: shape (n_channels, 1) + temp_data = np.zeros((len(info_interp["ch_names"]), 1)) + temp_raw = RawArray(temp_data, info_interp, first_samp=0) + # Using the public API, add an average reference projector. + temp_raw.set_eeg_reference( + ref_channels="average", projection=True, verbose=False + ) + # Extract the updated info. + info_interp = temp_raw.info + mapping = _map_meg_or_eeg_channels( + info_eeg, info_interp, mode="accurate", origin=origin + ) + + # Interpolate EEG data + data_good = self.get_data(picks=picks_good_eeg) + data_interp = mapping @ data_good + + # Create a new instance for the interpolated EEG channels + # TODO: Creating a new instance leads to a loss of information. + # We should consider updating the existing instance in the future + # by 1) drop channels, 2) add channels, 3) re-order channels. + if isinstance(self, BaseRaw): + inst_interp = RawArray(data_interp, info_interp, first_samp=self.first_samp) + elif isinstance(self, BaseEpochs): + inst_interp = EpochsArray(data_interp, info_interp) + else: + assert isinstance(self, Evoked) + inst_interp = EvokedArray(data_interp, info_interp) + + # Merge only if non-EEG channels exist + if not non_eeg_names_ordered: + return inst_interp + + inst_non_eeg = self.copy().pick(non_eeg_names_ordered).load_data() + inst_out = inst_non_eeg.add_channels([inst_interp], force_update_info=True) + + # Reorder channels + # Insert the entire new EEG block at the position of the first EEG channel. + orig_names_arr = np.array(orig_names) + mask_eeg = np.isin(orig_names_arr, eeg_names_orig) + if mask_eeg.any(): + first_eeg_index = np.where(mask_eeg)[0][0] + pre = orig_names_arr[:first_eeg_index] + new_eeg = np.array(info_interp["ch_names"]) + post = orig_names_arr[first_eeg_index:] + post = post[~np.isin(orig_names_arr[first_eeg_index:], eeg_names_orig)] + new_order = np.concatenate((pre, new_eeg, post)).tolist() + else: + new_order = orig_names + inst_out.reorder_channels(new_order) + return inst_out + @verbose def rename_channels(info, mapping, allow_duplicates=False, *, verbose=None): diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index a881a41edcc..62c7d79e3eb 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -12,7 +12,7 @@ from mne import Epochs, pick_channels, pick_types, read_events from mne._fiff.constants import FIFF from mne._fiff.proj import _has_eeg_average_ref_proj -from mne.channels import make_dig_montage +from mne.channels import make_dig_montage, make_standard_montage from mne.channels.interpolation import _make_interpolation_matrix from mne.datasets import testing from mne.io import RawArray, read_raw_ctf, read_raw_fif, read_raw_nirx @@ -439,3 +439,85 @@ def test_method_str(): raw.interpolate_bads(method="spline") raw.pick("eeg", exclude=()) raw.interpolate_bads(method="spline") + + +@pytest.mark.parametrize("montage_name", ["biosemi16", "standard_1020"]) +@pytest.mark.parametrize("method", ["spline", "MNE"]) +@pytest.mark.parametrize("data_type", ["raw", "epochs", "evoked"]) +def test_interpolate_to_eeg(montage_name, method, data_type): + """Test the interpolate_to method for EEG for raw, epochs, and evoked.""" + # Load EEG data + raw, epochs_eeg = _load_data("eeg") + epochs_eeg = epochs_eeg.copy() + + # Load data for raw + raw.load_data() + + # Create a target montage + montage = make_standard_montage(montage_name) + + # Prepare data to interpolate to + if data_type == "raw": + inst = raw.copy() + elif data_type == "epochs": + inst = epochs_eeg.copy() + elif data_type == "evoked": + inst = epochs_eeg.average() + shape = list(inst._data.shape) + orig_total = len(inst.info["ch_names"]) + n_eeg_orig = len(pick_types(inst.info, eeg=True)) + + # Assert first and last channels are not EEG + if data_type == "raw": + ch_types = inst.get_channel_types() + assert ch_types[0] != "eeg" + assert ch_types[-1] != "eeg" + + # Record the names and data of the first and last channels. + if data_type == "raw": + first_name = inst.info["ch_names"][0] + last_name = inst.info["ch_names"][-1] + data_first = inst._data[..., 0, :].copy() + data_last = inst._data[..., -1, :].copy() + + # Interpolate the EEG channels. + inst_interp = inst.copy().interpolate_to(montage, method=method) + + # Check that the new channel names include the montage channels. + assert set(montage.ch_names).issubset(set(inst_interp.info["ch_names"])) + # Check that the overall channel order is changed. + assert inst.info["ch_names"] != inst_interp.info["ch_names"] + + # Check that the data shape is as expected. + new_nchan_expected = orig_total - n_eeg_orig + len(montage.ch_names) + expected_shape = (new_nchan_expected, shape[-1]) + if len(shape) == 3: + expected_shape = (shape[0],) + expected_shape + assert inst_interp._data.shape == expected_shape + + # Verify that the first and last channels retain their positions. + if data_type == "raw": + assert inst_interp.info["ch_names"][0] == first_name + assert inst_interp.info["ch_names"][-1] == last_name + + # Verify that the data for the first and last channels is unchanged. + if data_type == "raw": + np.testing.assert_allclose( + inst_interp._data[..., 0, :], + data_first, + err_msg="Data for the first non-EEG channel has changed.", + ) + np.testing.assert_allclose( + inst_interp._data[..., -1, :], + data_last, + err_msg="Data for the last non-EEG channel has changed.", + ) + + # Validate that bad channels are carried over. + # Mark the first non eeg channel as bad + all_ch = inst_interp.info["ch_names"] + eeg_ch = [all_ch[i] for i in pick_types(inst_interp.info, eeg=True)] + bads = [ch for ch in all_ch if ch not in eeg_ch][:1] + inst.info["bads"] = bads + inst_interp = inst.copy().interpolate_to(montage, method=method) + assert inst_interp.info["bads"] == bads