Skip to content

Commit

Permalink
Merge branch 'main' into fix_annotations_orig_time
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns authored Feb 10, 2025
2 parents 924bebf + 7b316cb commit 7cceed6
Show file tree
Hide file tree
Showing 6 changed files with 331 additions and 2 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/13044.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add :meth:`mne.Evoked.interpolate_to` to allow interpolating EEG data to other montages, by :newcontrib:`Antoine Collas`.
1 change: 1 addition & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
.. _Anna Padee: https://github.com/apadee/
.. _Annalisa Pascarella: https://www.iac.cnr.it/personale/annalisa-pascarella
.. _Anne-Sophie Dubarry: https://github.com/annesodub
.. _Antoine Collas: https://www.antoinecollas.fr
.. _Antoine Gauthier: https://github.com/Okamille
.. _Antti Rantala: https://github.com/Odingod
.. _Apoorva Karekal: https://github.com/apoorva6262
Expand Down
8 changes: 8 additions & 0 deletions doc/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -2514,3 +2514,11 @@ @article{OyamaEtAl2015
year = {2015},
pages = {24--36},
}

@inproceedings{MellotEtAl2024,
title = {Physics-informed and Unsupervised Riemannian Domain Adaptation for Machine Learning on Heterogeneous EEG Datasets},
author = {Mellot, Apolline and Collas, Antoine and Chevallier, Sylvain and Engemann, Denis and Gramfort, Alexandre},
booktitle = {Proceedings of the 32nd European Signal Processing Conference (EUSIPCO)},
year = {2024},
address = {Lyon, France}
}
81 changes: 81 additions & 0 deletions examples/preprocessing/interpolate_to.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
.. _ex-interpolate-to-any-montage:
======================================================
Interpolate EEG data to any montage
======================================================
This example demonstrates how to interpolate EEG channels to match a given montage.
This can be useful for standardizing
EEG channel layouts across different datasets (see :footcite:`MellotEtAl2024`).
- Using the field interpolation for EEG data.
- Using the target montage "biosemi16".
In this example, the data from the original EEG channels will be
interpolated onto the positions defined by the "biosemi16" montage.
"""

# Authors: Antoine Collas <[email protected]>
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import matplotlib.pyplot as plt

import mne
from mne.channels import make_standard_montage
from mne.datasets import sample

print(__doc__)
ylim = (-10, 10)

# %%
# Load EEG data
data_path = sample.data_path()
eeg_file_path = data_path / "MEG" / "sample" / "sample_audvis-ave.fif"
evoked = mne.read_evokeds(eeg_file_path, condition="Left Auditory", baseline=(None, 0))

# Select only EEG channels
evoked.pick("eeg")

# Plot the original EEG layout
evoked.plot(exclude=[], picks="eeg", ylim=dict(eeg=ylim))

# %%
# Define the target montage
standard_montage = make_standard_montage("biosemi16")

# %%
# Use interpolate_to to project EEG data to the standard montage
evoked_interpolated_spline = evoked.copy().interpolate_to(
standard_montage, method="spline"
)

# Plot the interpolated EEG layout
evoked_interpolated_spline.plot(exclude=[], picks="eeg", ylim=dict(eeg=ylim))

# %%
# Use interpolate_to to project EEG data to the standard montage
evoked_interpolated_mne = evoked.copy().interpolate_to(standard_montage, method="MNE")

# Plot the interpolated EEG layout
evoked_interpolated_mne.plot(exclude=[], picks="eeg", ylim=dict(eeg=ylim))

# %%
# Comparing before and after interpolation
fig, axs = plt.subplots(3, 1, figsize=(8, 6), constrained_layout=True)
evoked.plot(exclude=[], picks="eeg", axes=axs[0], show=False, ylim=dict(eeg=ylim))
axs[0].set_title("Original EEG Layout")
evoked_interpolated_spline.plot(
exclude=[], picks="eeg", axes=axs[1], show=False, ylim=dict(eeg=ylim)
)
axs[1].set_title("Interpolated to Standard 1020 Montage using spline interpolation")
evoked_interpolated_mne.plot(
exclude=[], picks="eeg", axes=axs[2], show=False, ylim=dict(eeg=ylim)
)
axs[2].set_title("Interpolated to Standard 1020 Montage using MNE interpolation")

# %%
# References
# ----------
# .. footbibliography::
158 changes: 157 additions & 1 deletion mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
pick_info,
pick_types,
)
from .._fiff.proj import setup_proj
from .._fiff.proj import _has_eeg_average_ref_proj, setup_proj
from .._fiff.reference import add_reference_channels, set_eeg_reference
from .._fiff.tag import _rename_list
from ..bem import _check_origin
Expand Down Expand Up @@ -960,6 +960,162 @@ def interpolate_bads(

return self

def interpolate_to(self, sensors, origin="auto", method="spline", reg=0.0):
"""Interpolate EEG data onto a new montage.
.. warning::
Be careful, only EEG channels are interpolated. Other channel types are
not interpolated.
Parameters
----------
sensors : DigMontage
The target montage containing channel positions to interpolate onto.
origin : array-like, shape (3,) | str
Origin of the sphere in the head coordinate frame and in meters.
Can be ``'auto'`` (default), which means a head-digitization-based
origin fit.
method : str
Method to use for EEG channels.
Supported methods are 'spline' (default) and 'MNE'.
reg : float
The regularization parameter for the interpolation method
(only used when the method is 'spline').
Returns
-------
inst : instance of Raw, Epochs, or Evoked
The instance with updated channel locations and data.
Notes
-----
This method is useful for standardizing EEG layouts across datasets.
However, some attributes may be lost after interpolation.
.. versionadded:: 1.10.0
"""
from ..epochs import BaseEpochs, EpochsArray
from ..evoked import Evoked, EvokedArray
from ..forward._field_interpolation import _map_meg_or_eeg_channels
from ..io import RawArray
from ..io.base import BaseRaw
from .interpolation import _make_interpolation_matrix
from .montage import DigMontage

# Check that the method option is valid.
_check_option("method", method, ["spline", "MNE"])
_validate_type(sensors, DigMontage, "sensors")

# Get target positions from the montage
ch_pos = sensors.get_positions().get("ch_pos", {})
target_ch_names = list(ch_pos.keys())
if not target_ch_names:
raise ValueError(
"The provided sensors configuration has no channel positions."
)

# Get original channel order
orig_names = self.info["ch_names"]

# Identify EEG channel
picks_good_eeg = pick_types(self.info, meg=False, eeg=True, exclude="bads")
if len(picks_good_eeg) == 0:
raise ValueError("No good EEG channels available for interpolation.")
# Also get the full list of EEG channel indices (including bad channels)
picks_remove_eeg = pick_types(self.info, meg=False, eeg=True, exclude=[])
eeg_names_orig = [orig_names[i] for i in picks_remove_eeg]

# Identify non-EEG channels in original order
non_eeg_names_ordered = [ch for ch in orig_names if ch not in eeg_names_orig]

# Create destination info for new EEG channels
sfreq = self.info["sfreq"]
info_interp = create_info(
ch_names=target_ch_names,
sfreq=sfreq,
ch_types=["eeg"] * len(target_ch_names),
)
info_interp.set_montage(sensors)
info_interp["bads"] = [ch for ch in self.info["bads"] if ch in target_ch_names]
# Do not assign "projs" directly.

# Compute the interpolation mapping
if method == "spline":
origin_val = _check_origin(origin, self.info)
pos_from = self.info._get_channel_positions(picks_good_eeg) - origin_val
pos_to = np.stack(list(ch_pos.values()), axis=0)

def _check_pos_sphere(pos):
d = np.linalg.norm(pos, axis=-1)
d_norm = np.mean(d / np.mean(d))
if np.abs(1.0 - d_norm) > 0.1:
warn("Your spherical fit is poor; interpolation may be inaccurate.")

_check_pos_sphere(pos_from)
_check_pos_sphere(pos_to)
mapping = _make_interpolation_matrix(pos_from, pos_to, alpha=reg)

else:
assert method == "MNE"
info_eeg = pick_info(self.info, picks_good_eeg)
# If the original info has an average EEG reference projector but
# the destination info does not,
# update info_interp via a temporary RawArray.
if _has_eeg_average_ref_proj(self.info) and not _has_eeg_average_ref_proj(
info_interp
):
# Create dummy data: shape (n_channels, 1)
temp_data = np.zeros((len(info_interp["ch_names"]), 1))
temp_raw = RawArray(temp_data, info_interp, first_samp=0)
# Using the public API, add an average reference projector.
temp_raw.set_eeg_reference(
ref_channels="average", projection=True, verbose=False
)
# Extract the updated info.
info_interp = temp_raw.info
mapping = _map_meg_or_eeg_channels(
info_eeg, info_interp, mode="accurate", origin=origin
)

# Interpolate EEG data
data_good = self.get_data(picks=picks_good_eeg)
data_interp = mapping @ data_good

# Create a new instance for the interpolated EEG channels
# TODO: Creating a new instance leads to a loss of information.
# We should consider updating the existing instance in the future
# by 1) drop channels, 2) add channels, 3) re-order channels.
if isinstance(self, BaseRaw):
inst_interp = RawArray(data_interp, info_interp, first_samp=self.first_samp)
elif isinstance(self, BaseEpochs):
inst_interp = EpochsArray(data_interp, info_interp)
else:
assert isinstance(self, Evoked)
inst_interp = EvokedArray(data_interp, info_interp)

# Merge only if non-EEG channels exist
if not non_eeg_names_ordered:
return inst_interp

inst_non_eeg = self.copy().pick(non_eeg_names_ordered).load_data()
inst_out = inst_non_eeg.add_channels([inst_interp], force_update_info=True)

# Reorder channels
# Insert the entire new EEG block at the position of the first EEG channel.
orig_names_arr = np.array(orig_names)
mask_eeg = np.isin(orig_names_arr, eeg_names_orig)
if mask_eeg.any():
first_eeg_index = np.where(mask_eeg)[0][0]
pre = orig_names_arr[:first_eeg_index]
new_eeg = np.array(info_interp["ch_names"])
post = orig_names_arr[first_eeg_index:]
post = post[~np.isin(orig_names_arr[first_eeg_index:], eeg_names_orig)]
new_order = np.concatenate((pre, new_eeg, post)).tolist()
else:
new_order = orig_names
inst_out.reorder_channels(new_order)
return inst_out


@verbose
def rename_channels(info, mapping, allow_duplicates=False, *, verbose=None):
Expand Down
84 changes: 83 additions & 1 deletion mne/channels/tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mne import Epochs, pick_channels, pick_types, read_events
from mne._fiff.constants import FIFF
from mne._fiff.proj import _has_eeg_average_ref_proj
from mne.channels import make_dig_montage
from mne.channels import make_dig_montage, make_standard_montage
from mne.channels.interpolation import _make_interpolation_matrix
from mne.datasets import testing
from mne.io import RawArray, read_raw_ctf, read_raw_fif, read_raw_nirx
Expand Down Expand Up @@ -439,3 +439,85 @@ def test_method_str():
raw.interpolate_bads(method="spline")
raw.pick("eeg", exclude=())
raw.interpolate_bads(method="spline")


@pytest.mark.parametrize("montage_name", ["biosemi16", "standard_1020"])
@pytest.mark.parametrize("method", ["spline", "MNE"])
@pytest.mark.parametrize("data_type", ["raw", "epochs", "evoked"])
def test_interpolate_to_eeg(montage_name, method, data_type):
"""Test the interpolate_to method for EEG for raw, epochs, and evoked."""
# Load EEG data
raw, epochs_eeg = _load_data("eeg")
epochs_eeg = epochs_eeg.copy()

# Load data for raw
raw.load_data()

# Create a target montage
montage = make_standard_montage(montage_name)

# Prepare data to interpolate to
if data_type == "raw":
inst = raw.copy()
elif data_type == "epochs":
inst = epochs_eeg.copy()
elif data_type == "evoked":
inst = epochs_eeg.average()
shape = list(inst._data.shape)
orig_total = len(inst.info["ch_names"])
n_eeg_orig = len(pick_types(inst.info, eeg=True))

# Assert first and last channels are not EEG
if data_type == "raw":
ch_types = inst.get_channel_types()
assert ch_types[0] != "eeg"
assert ch_types[-1] != "eeg"

# Record the names and data of the first and last channels.
if data_type == "raw":
first_name = inst.info["ch_names"][0]
last_name = inst.info["ch_names"][-1]
data_first = inst._data[..., 0, :].copy()
data_last = inst._data[..., -1, :].copy()

# Interpolate the EEG channels.
inst_interp = inst.copy().interpolate_to(montage, method=method)

# Check that the new channel names include the montage channels.
assert set(montage.ch_names).issubset(set(inst_interp.info["ch_names"]))
# Check that the overall channel order is changed.
assert inst.info["ch_names"] != inst_interp.info["ch_names"]

# Check that the data shape is as expected.
new_nchan_expected = orig_total - n_eeg_orig + len(montage.ch_names)
expected_shape = (new_nchan_expected, shape[-1])
if len(shape) == 3:
expected_shape = (shape[0],) + expected_shape
assert inst_interp._data.shape == expected_shape

# Verify that the first and last channels retain their positions.
if data_type == "raw":
assert inst_interp.info["ch_names"][0] == first_name
assert inst_interp.info["ch_names"][-1] == last_name

# Verify that the data for the first and last channels is unchanged.
if data_type == "raw":
np.testing.assert_allclose(
inst_interp._data[..., 0, :],
data_first,
err_msg="Data for the first non-EEG channel has changed.",
)
np.testing.assert_allclose(
inst_interp._data[..., -1, :],
data_last,
err_msg="Data for the last non-EEG channel has changed.",
)

# Validate that bad channels are carried over.
# Mark the first non eeg channel as bad
all_ch = inst_interp.info["ch_names"]
eeg_ch = [all_ch[i] for i in pick_types(inst_interp.info, eeg=True)]
bads = [ch for ch in all_ch if ch not in eeg_ch][:1]
inst.info["bads"] = bads
inst_interp = inst.copy().interpolate_to(montage, method=method)
assert inst_interp.info["bads"] == bads

0 comments on commit 7cceed6

Please sign in to comment.