Skip to content

Commit

Permalink
sketch of compute_tfr() API [ci skip]
Browse files Browse the repository at this point in the history
refactor method_kw checking

cleaner imports

refactor _get_instance_type_string

WIP

many changes; raw.compute_tfr(multitaper, freqs) works!

add verbose to tfr_array_stockwell

add get_data() method

DRY / fixes for method=stockwell

partially handle ITC; comments

WIP [ci skip]
  • Loading branch information
drammock committed Oct 31, 2023
1 parent 3e0e543 commit db5fd57
Show file tree
Hide file tree
Showing 13 changed files with 1,501 additions and 69 deletions.
58 changes: 58 additions & 0 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from .fixes import rng_uniform
from .html_templates import _get_html_template
from .parallel import parallel_func
from .time_frequency.spectrogram import EpochsTFR
from .time_frequency.spectrum import EpochsSpectrum, SpectrumMixin, _validate_method
from .utils import (
ExtendedTimeMixin,
Expand Down Expand Up @@ -2419,6 +2420,63 @@ def compute_psd(
**method_kw,
)

@verbose
def compute_tfr(
self,
method,
freqs,
*,
picks=None,
proj=False,
return_itc=False,
decim=1,
n_jobs=None,
verbose=None,
**method_kw,
):
"""Compute a time-frequency representation of epoched data.
Parameters
----------
%(method_tfr)s
%(freqs_tfr)s
%(picks_good_data_noref)s
%(proj_psd)s
return_itc : bool
Whether to return inter-trial coherence (ITC) as well as power estimates.
%(decim_tfr)s
%(n_jobs)s
%(verbose)s
%(method_kw_tfr)s
Returns
-------
tfr : instance of EpochsTFR
The time-frequency-resolved power estimates of the data.
itc : instance of EpochsTFR
The inter-trial coherence (ITC). Only returned if ``return_itc=True``.
Notes
-----
.. versionadded:: 1.6
References
----------
.. footbibliography::
"""
return EpochsTFR(
self,
method=method,
freqs=freqs,
picks=picks,
proj=proj,
return_itc=return_itc,
decim=decim,
n_jobs=n_jobs,
verbose=verbose,
**method_kw,
)

@verbose
def plot_psd(
self,
Expand Down
44 changes: 44 additions & 0 deletions mne/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,50 @@ def compute_psd(
**method_kw,
)

@verbose
def compute_tfr(
self,
method,
freqs,
*,
tmin=None,
tmax=None,
picks=None,
proj=False,
decim=1,
n_jobs=None,
verbose=None,
**method_kw,
):
"""Compute a time-frequency representation of evoked data.
Parameters
----------
%(method_tfr)s
%(freqs_tfr)s
%(tmin_tmax_psd)s
%(picks_good_data_noref)s
%(proj_psd)s
%(decim_tfr)s
%(n_jobs)s
%(verbose)s
%(method_kw_tfr)s
Returns
-------
tfr : instance of AverageTFR
The time-frequency-resolved power estimates of the data.
Notes
-----
.. versionadded:: 1.3
References
----------
.. footbibliography::
"""
pass

@verbose
def plot_psd(
self,
Expand Down
54 changes: 54 additions & 0 deletions mne/html_templates/repr/tfr.html.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
<table class="table table-hover table-striped table-sm table-responsive small">
<tr>
<th>Data type</th>
<td>{{ tfr._data_type }}</td>
</tr>
{%- for unit in units %}
<tr>
{%- if loop.index == 1 %}
<th rowspan={{ units | length }}>Units</th>
{%- endif %}
<td class="justify">{{ unit }}</td>
</tr>
{%- endfor %}
<tr>
<th>Data source</th>
<td>{{ inst_type }}</td>
</tr>
{%- if inst_type == "Epochs" %}
<tr>
<th>Number of epochs</th>
<td>{{ tfr.shape[0] }}</td>
</tr>
{% endif -%}
<tr>
<th>Dims</th>
<td>{{ tfr._dims | join(", ") }}</td>
</tr>
<tr>
<th>Estimation method</th>
<td>{{ tfr.method }}</td>
</tr>
{% if "taper" in tfr._dims %}
<tr>
<th>Number of tapers</th>
<td>{{ tfr._mt_weights.size }}</td>
</tr>
{% endif %}
<tr>
<th>Number of channels</th>
<td>{{ tfr.ch_names|length }}</td>
</tr>
<tr>
<th>Number of timepoints</th>
<td>{{ tfr.times|length }}</td>
</tr>
<tr>
<th>Number of frequency bins</th>
<td>{{ tfr.freqs|length }}</td>
</tr>
<tr>
<th>Frequency range</th>
<td>{{ '%.2f'|format(tfr.freqs[0]) }} – {{ '%.2f'|format(tfr.freqs[-1]) }} Hz</td>
</tr>
</table>
60 changes: 60 additions & 0 deletions mne/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
)
from ..html_templates import _get_html_template
from ..parallel import parallel_func
from ..time_frequency.spectrogram import RawTFR
from ..time_frequency.spectrum import Spectrum, SpectrumMixin, _validate_method
from ..utils import (
SizeMixin,
Expand Down Expand Up @@ -2195,6 +2196,65 @@ def compute_psd(
**method_kw,
)

@verbose
def compute_tfr(
self,
method,
freqs,
*,
tmin=None,
tmax=None,
picks=None,
proj=False,
reject_by_annotation=True,
decim=1,
n_jobs=None,
verbose=None,
**method_kw,
):
"""Compute a time-frequency representation of sensor data.
Parameters
----------
%(method_tfr)s
%(freqs_tfr)s
%(tmin_tmax_psd)s
%(picks_good_data_noref)s
%(proj_psd)s
%(reject_by_annotation_tfr)s
%(decim_tfr)s
%(n_jobs)s
%(verbose)s
%(method_kw_tfr)s
Returns
-------
tfr : instance of RawTFR
The time-frequency-resolved power estimates of the data.
Notes
-----
.. versionadded:: 1.3
References
----------
.. footbibliography::
"""
return RawTFR(
self,
method=method,
freqs=freqs,
tmin=tmin,
tmax=tmax,
picks=picks,
proj=proj,
reject_by_annotation=reject_by_annotation,
decim=decim,
n_jobs=n_jobs,
verbose=verbose,
**method_kw,
)

@verbose
def to_data_frame(
self,
Expand Down
34 changes: 21 additions & 13 deletions mne/time_frequency/_stockwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .._fiff.pick import _pick_data_channels, pick_info
from ..parallel import parallel_func
from ..utils import _validate_type, fill_doc, logger, verbose
from ..utils import _validate_type, logger, verbose
from .tfr import AverageTFR, _get_data


Expand Down Expand Up @@ -104,7 +104,22 @@ def _st_power_itc(x, start_f, compute_itc, zero_pad, decim, W):
return psd, itc


@fill_doc
def _compute_freqs_st(fmin, fmax, n_fft, sfreq):
from scipy.fft import fftfreq

freqs = fftfreq(n_fft, 1.0 / sfreq)
if fmin is None:
fmin = freqs[freqs > 0][0]
if fmax is None:
fmax = freqs.max()

start_f = np.abs(freqs - fmin).argmin()
stop_f = np.abs(freqs - fmax).argmin()
freqs = freqs[start_f:stop_f]
return start_f, stop_f, freqs


@verbose
def tfr_array_stockwell(
data,
sfreq,
Expand All @@ -115,6 +130,7 @@ def tfr_array_stockwell(
decim=1,
return_itc=False,
n_jobs=None,
verbose=None,
):
"""Compute power and intertrial coherence using Stockwell (S) transform.
Expand Down Expand Up @@ -147,6 +163,7 @@ def tfr_array_stockwell(
return_itc : bool
Return intertrial coherence (ITC) as well as averaged power.
%(n_jobs)s
%(verbose)s
Returns
-------
Expand Down Expand Up @@ -179,23 +196,14 @@ def tfr_array_stockwell(
n_epochs, n_channels = data.shape[:2]
n_out = data.shape[2] // decim + bool(data.shape[-1] % decim)
data, n_fft_, zero_pad = _check_input_st(data, n_fft)

freqs = fftfreq(n_fft_, 1.0 / sfreq)
if fmin is None:
fmin = freqs[freqs > 0][0]
if fmax is None:
fmax = freqs.max()

start_f = np.abs(freqs - fmin).argmin()
stop_f = np.abs(freqs - fmax).argmin()
freqs = freqs[start_f:stop_f]
start_f, stop_f, freqs = _compute_freqs_st(fmin, fmax, n_fft_, sfreq)

W = _precompute_st_windows(data.shape[-1], start_f, stop_f, sfreq, width)
n_freq = stop_f - start_f
psd = np.empty((n_channels, n_freq, n_out))
itc = np.empty((n_channels, n_freq, n_out)) if return_itc else None

parallel, my_st, n_jobs = parallel_func(_st_power_itc, n_jobs)
parallel, my_st, n_jobs = parallel_func(_st_power_itc, n_jobs, verbose=verbose)
tfrs = parallel(
my_st(data[:, c, :], start_f, return_itc, zero_pad, decim, W)
for c in range(n_channels)
Expand Down
7 changes: 7 additions & 0 deletions mne/time_frequency/multitaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ def tfr_array_multitaper(
output="complex",
n_jobs=None,
*,
return_mt_weights=False,
verbose=None,
):
"""Compute Time-Frequency Representation (TFR) using DPSS tapers.
Expand Down Expand Up @@ -506,6 +507,11 @@ def tfr_array_multitaper(
* ``'avg_power_itc'`` : average of single trial power and inter-trial
coherence across trials.
%(n_jobs)s
The number of epochs to process at the same time. The parallelization
is implemented across channels. Defaults to 1.
return_mt_weights : bool
Whether to return taper weights alongside the complex taper
coefficients. Ignored if output is not ``'complex'``.
%(verbose)s
Returns
Expand Down Expand Up @@ -553,5 +559,6 @@ def tfr_array_multitaper(
decim=decim,
output=output,
n_jobs=n_jobs,
return_mt_weights=return_mt_weights,
verbose=verbose,
)
Loading

0 comments on commit db5fd57

Please sign in to comment.