Skip to content

Commit

Permalink
Make compatible with n-dims
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Feb 6, 2025
1 parent afd69af commit a395d76
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4291,19 +4291,20 @@ def _tfr_from_mt(x_mt, weights):
Parameters
----------
x_mt : array, shape (n_channels, n_tapers, n_freqs, n_times)
x_mt : array, shape (..., n_tapers, n_freqs, n_times)
The complex-valued multitaper coefficients.
weights : array, shape (n_tapers, n_freqs)
The weights to use to combine the tapered estimates.
Returns
-------
tfr : array, shape (n_channels, n_freqs, n_times)
tfr : array, shape (..., n_freqs, n_times)
The time-frequency power estimates.
"""
weights = weights[np.newaxis, :, :, np.newaxis] # add singleton channel & time dims
# add singleton dim for time and any dims preceding the tapers
weights = np.expand_dims(weights, axis=(*range(x_mt.ndim - 3), -1))
tfr = weights * x_mt
tfr *= tfr.conj()
tfr = tfr.real.sum(axis=1)
tfr *= 2 / (weights * weights.conj()).real.sum(axis=1)
tfr = tfr.real.sum(axis=-3)
tfr *= 2 / (weights * weights.conj()).real.sum(axis=-3)
return tfr

0 comments on commit a395d76

Please sign in to comment.