Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into spectrogram-class
Browse files Browse the repository at this point in the history
  • Loading branch information
drammock committed Mar 20, 2024
2 parents 5ea9cc7 + cf6e13c commit 714c14b
Show file tree
Hide file tree
Showing 99 changed files with 664 additions and 723 deletions.
6 changes: 2 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@ concurrency:
cancel-in-progress: true
on: # yamllint disable-line rule:truthy
push:
branches:
- '*'
branches: ["main", "maint/*"]
pull_request:
branches:
- '*'
branches: ["main", "maint/*"]

permissions:
contents: read
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
repos:
# Ruff mne
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.2
rev: v0.3.3
hooks:
- id: ruff
name: ruff lint mne
Expand Down
4 changes: 3 additions & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ stages:
variables:
MNE_LOGGING_LEVEL: 'warning'
MNE_FORCE_SERIAL: 'true'
OPENBLAS_NUM_THREADS: 2
OPENBLAS_NUM_THREADS: '1' # deal with OpenBLAS conflicts safely on Windows
OMP_DYNAMIC: 'false'
PYTHONUNBUFFERED: 1
PYTHONIOENCODING: 'utf-8'
Expand Down Expand Up @@ -274,6 +274,8 @@ stages:
displayName: 'Print config'
- script: python -c "import numpy; numpy.show_config()"
displayName: Print NumPy config
- script: python -c "import numpy; import scipy.linalg; import sklearn.neighbors; from threadpoolctl import threadpool_info; from pprint import pprint; pprint(threadpool_info())"
displayName: Print threadpoolctl info
- bash: source tools/get_testing_version.sh
displayName: 'Get testing version'
- task: Cache@2
Expand Down
2 changes: 2 additions & 0 deletions doc/changes/devel/12464.other.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Replacing percent format with f-strings format specifiers , by :newcontrib:`Hasrat Ali Arzoo`.

5 changes: 5 additions & 0 deletions doc/changes/devel/12507.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Fix bug where using ``phase="minimum"`` in filtering functions like
:meth:`mne.io.Raw.filter` constructed a filter half the desired length with
compromised attenuation. Now ``phase="minimum"`` has the same length and comparable
suppression as ``phase="zero"``, and the old (incorrect) behavior can be achieved
with ``phase="minimum-half"``, by `Eric Larson`_.
2 changes: 2 additions & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@

.. _Hari Bharadwaj: https://github.com/haribharadwaj

.. _Hasrat Ali Arzoo: https://github.com/hasrat17

.. _Henrich Kolkhorst: https://github.com/hekolk

.. _Hongjiang Ye: https://github.com/hongjiang-ye
Expand Down
4 changes: 3 additions & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,9 @@ def append_attr_meth_examples(app, what, name, obj, options, lines):
if what in ("attribute", "method"):
size = os.path.getsize(
os.path.join(
os.path.dirname(__file__), "generated", "%s.examples" % (name,)
os.path.dirname(__file__),
"generated",
f"{name}.examples",
)
)
if size > 0:
Expand Down
14 changes: 5 additions & 9 deletions doc/sphinxext/flow_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,14 @@
sensor_color = "#7bbeca"
source_color = "#ff6347"

legend = """
<<FONT POINT-SIZE="%s">
legend = f"""
<<FONT POINT-SIZE="{edge_size}">
<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="4" CELLPADDING="4">
<TR><TD BGCOLOR="%s"> </TD><TD ALIGN="left">
<TR><TD BGCOLOR="{sensor_color}"> </TD><TD ALIGN="left">
Sensor (M/EEG) space</TD></TR>
<TR><TD BGCOLOR="%s"> </TD><TD ALIGN="left">
<TR><TD BGCOLOR="{source_color}"> </TD><TD ALIGN="left">
Source (brain) space</TD></TR>
</TABLE></FONT>>""" % (
edge_size,
sensor_color,
source_color,
)
</TABLE></FONT>>"""
legend = "".join(legend.split("\n"))

nodes = dict(
Expand Down
12 changes: 4 additions & 8 deletions doc/sphinxext/mne_substitutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,14 @@ def run(self, **kwargs): # noqa: D102
):
keys.append(key)
rst = "- " + "\n- ".join(
"``%r``: **%s** (scaled by %g to plot in *%s*)"
% (
key,
DEFAULTS["titles"][key],
DEFAULTS["scalings"][key],
DEFAULTS["units"][key],
)
f"``{repr(key)}``: **{DEFAULTS['titles'][key]}** "
f"(scaled by {DEFAULTS['scalings'][key]} to "
f"plot in *{DEFAULTS['units'][key]}*)"
for key in keys
)
else:
raise self.error(
"MNE directive unknown in %s: %r"
"MNE directive unknown in %s: %r" # noqa: UP031
% (
env.doc2path(env.docname, base=None),
self.arguments[0],
Expand Down
4 changes: 2 additions & 2 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3995,8 +3995,8 @@ def _is_good(
bad_names = [ch_names[idx[i]] for i in idx_deltas]
if not has_printed:
logger.info(
" Rejecting %s epoch based on %s : "
"%s" % (t, name, bad_names)
f" Rejecting {t} epoch based on {name} : "
f"{bad_names}"
)
has_printed = True
if not full_report:
Expand Down
73 changes: 8 additions & 65 deletions mne/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
_setup_cuda_fft_resample,
_smart_pad,
)
from .fixes import minimum_phase
from .parallel import parallel_func
from .utils import (
_check_option,
Expand Down Expand Up @@ -307,39 +308,7 @@ def _overlap_add_filter(
copy=True,
pad="reflect_limited",
):
"""Filter the signal x using h with overlap-add FFTs.
Parameters
----------
x : array, shape (n_signals, n_times)
Signals to filter.
h : 1d array
Filter impulse response (FIR filter coefficients). Must be odd length
if ``phase='linear'``.
n_fft : int
Length of the FFT. If None, the best size is determined automatically.
phase : str
If ``'zero'``, the delay for the filter is compensated (and it must be
an odd-length symmetric filter). If ``'linear'``, the response is
uncompensated. If ``'zero-double'``, the filter is applied in the
forward and reverse directions. If 'minimum', a minimum-phase
filter will be used.
picks : list | None
See calling functions.
n_jobs : int | str
Number of jobs to run in parallel. Can be ``'cuda'`` if ``cupy``
is installed properly.
copy : bool
If True, a copy of x, filtered, is returned. Otherwise, it operates
on x in place.
pad : str
Padding type for ``_smart_pad``.
Returns
-------
x : array, shape (n_signals, n_times)
x filtered.
"""
"""Filter the signal x using h with overlap-add FFTs."""
# set up array for filtering, reshape to 2D, operate on last axis
x, orig_shape, picks = _prep_for_filtering(x, copy, picks)
# Extend the signal by mirroring the edges to reduce transient filter
Expand Down Expand Up @@ -526,34 +495,6 @@ def _construct_fir_filter(
(windowing is a smoothing in frequency domain).
If x is multi-dimensional, this operates along the last dimension.
Parameters
----------
sfreq : float
Sampling rate in Hz.
freq : 1d array
Frequency sampling points in Hz.
gain : 1d array
Filter gain at frequency sampling points.
Must be all 0 and 1 for fir_design=="firwin".
filter_length : int
Length of the filter to use. Must be odd length if phase == "zero".
phase : str
If 'zero', the delay for the filter is compensated (and it must be
an odd-length symmetric filter). If 'linear', the response is
uncompensated. If 'zero-double', the filter is applied in the
forward and reverse directions. If 'minimum', a minimum-phase
filter will be used.
fir_window : str
The window to use in FIR design, can be "hamming" (default),
"hann", or "blackman".
fir_design : str
Can be "firwin2" or "firwin".
Returns
-------
h : array
Filter coefficients.
"""
assert freq[0] == 0
if fir_design == "firwin2":
Expand All @@ -562,7 +503,7 @@ def _construct_fir_filter(
assert fir_design == "firwin"
fir_design = partial(_firwin_design, sfreq=sfreq)
# issue a warning if attenuation is less than this
min_att_db = 12 if phase == "minimum" else 20
min_att_db = 12 if phase == "minimum-half" else 20

# normalize frequencies
freq = np.array(freq) / (sfreq / 2.0)
Expand All @@ -575,11 +516,13 @@ def _construct_fir_filter(
# Use overlap-add filter with a fixed length
N = _check_zero_phase_length(filter_length, phase, gain[-1])
# construct symmetric (linear phase) filter
if phase == "minimum":
if phase == "minimum-half":
h = fir_design(N * 2 - 1, freq, gain, window=fir_window)
h = signal.minimum_phase(h)
h = minimum_phase(h)
else:
h = fir_design(N, freq, gain, window=fir_window)
if phase == "minimum":
h = minimum_phase(h, half=False)
assert h.size == N
att_db, att_freq = _filter_attenuation(h, freq, gain)
if phase == "zero-double":
Expand Down Expand Up @@ -2162,7 +2105,7 @@ def detrend(x, order=1, axis=-1):
"blackman": dict(name="Blackman", ripple=0.0017, attenuation=74),
}
_known_fir_windows = tuple(sorted(_fir_window_dict.keys()))
_known_phases_fir = ("linear", "zero", "zero-double", "minimum")
_known_phases_fir = ("linear", "zero", "zero-double", "minimum", "minimum-half")
_known_phases_iir = ("zero", "zero-double", "forward")
_known_fir_designs = ("firwin", "firwin2")
_fir_design_dict = {
Expand Down
71 changes: 63 additions & 8 deletions mne/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _safe_svd(A, **kwargs):
except np.linalg.LinAlgError as exp:
from .utils import warn

warn("SVD error (%s), attempting to use GESVD instead of GESDD" % (exp,))
warn(f"SVD error ({exp}), attempting to use GESVD instead of GESDD")
return linalg.svd(A, lapack_driver="gesvd", **kwargs)


Expand Down Expand Up @@ -192,8 +192,8 @@ def _get_param_names(cls):
"scikit-learn estimators should always "
"specify their parameters in the signature"
" of their __init__ (no varargs)."
" %s with constructor %s doesn't "
" follow this convention." % (cls, init_signature)
f" {cls} with constructor {init_signature} doesn't "
" follow this convention."
)
# Extract and sort argument names excluding 'self'
return sorted([p.name for p in parameters])
Expand Down Expand Up @@ -264,20 +264,20 @@ def set_params(self, **params):
name, sub_name = split
if name not in valid_params:
raise ValueError(
"Invalid parameter %s for estimator %s. "
f"Invalid parameter {name} for estimator {self}. "
"Check the list of available parameters "
"with `estimator.get_params().keys()`." % (name, self)
"with `estimator.get_params().keys()`."
)
sub_object = valid_params[name]
sub_object.set_params(**{sub_name: value})
else:
# simple objects case
if key not in valid_params:
raise ValueError(
"Invalid parameter %s for estimator %s. "
f"Invalid parameter {key} for estimator "
f"{self.__class__.__name__}. "
"Check the list of available parameters "
"with `estimator.get_params().keys()`."
% (key, self.__class__.__name__)
)
setattr(self, key, value)
return self
Expand All @@ -287,7 +287,7 @@ def __repr__(self): # noqa: D105
pprint(self.get_params(deep=False), params)
params.seek(0)
class_name = self.__class__.__name__
return "%s(%s)" % (class_name, params.read().strip())
return f"{class_name}({params.read().strip()})"

# __getstate__ and __setstate__ are omitted because they only contain
# conditionals that are not satisfied by our objects (e.g.,
Expand Down Expand Up @@ -889,3 +889,58 @@ def _numpy_h5py_dep():
"ignore", "`product` is deprecated.*", DeprecationWarning
)
yield


def minimum_phase(h, method="homomorphic", n_fft=None, *, half=True):
"""Wrap scipy.signal.minimum_phase with half option."""
# Can be removed once
from scipy.fft import fft, ifft
from scipy.signal import minimum_phase as sp_minimum_phase

assert isinstance(method, str) and method == "homomorphic"

if "half" in inspect.getfullargspec(sp_minimum_phase).kwonlyargs:
return sp_minimum_phase(h, method=method, n_fft=n_fft, half=half)
h = np.asarray(h)
if np.iscomplexobj(h):
raise ValueError("Complex filters not supported")
if h.ndim != 1 or h.size <= 2:
raise ValueError("h must be 1-D and at least 2 samples long")
n_half = len(h) // 2
if not np.allclose(h[-n_half:][::-1], h[:n_half]):
warnings.warn(
"h does not appear to by symmetric, conversion may fail",
RuntimeWarning,
stacklevel=2,
)
if n_fft is None:
n_fft = 2 ** int(np.ceil(np.log2(2 * (len(h) - 1) / 0.01)))
n_fft = int(n_fft)
if n_fft < len(h):
raise ValueError("n_fft must be at least len(h)==%s" % len(h))

# zero-pad; calculate the DFT
h_temp = np.abs(fft(h, n_fft))
# take 0.25*log(|H|**2) = 0.5*log(|H|)
h_temp += 1e-7 * h_temp[h_temp > 0].min() # don't let log blow up
np.log(h_temp, out=h_temp)
if half: # halving of magnitude spectrum optional
h_temp *= 0.5
# IDFT
h_temp = ifft(h_temp).real
# multiply pointwise by the homomorphic filter
# lmin[n] = 2u[n] - d[n]
# i.e., double the positive frequencies and zero out the negative ones;
# Oppenheim+Shafer 3rd ed p991 eq13.42b and p1004 fig13.7
win = np.zeros(n_fft)
win[0] = 1
stop = n_fft // 2
win[1:stop] = 2
if n_fft % 2:
win[stop] = 1
h_temp *= win
h_temp = ifft(np.exp(fft(h_temp)))
h_minimum = h_temp.real

n_out = (n_half + len(h) % 2) if half else len(h)
return h_minimum[:n_out]
2 changes: 1 addition & 1 deletion mne/forward/_compute_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def _magnetic_dipole_field_vec(rrs, coils, too_close="raise"):
rmags, cosmags, ws, bins = _triage_coils(coils)
fwd, min_dist = _compute_mdfv(rrs, rmags, cosmags, ws, bins, too_close)
if min_dist < _MIN_DIST_LIMIT:
msg = "Coil too close (dist = %g mm)" % (min_dist * 1000,)
msg = f"Coil too close (dist = {min_dist * 1000:g} mm)"
if too_close == "raise":
raise RuntimeError(msg)
func = warn if too_close == "warning" else logger.info
Expand Down
4 changes: 2 additions & 2 deletions mne/forward/_lead_dots.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ def _get_legen_table(
# Updated due to API change (GH 1167)
os.makedirs(fname)
if ch_type == "meg":
fname = op.join(fname, "legder_%s_%s.bin" % (n_coeff, n_interp))
fname = op.join(fname, f"legder_{n_coeff}_{n_interp}.bin")
leg_fun = _get_legen_der
extra_str = " derivative"
lut_shape = (n_interp + 1, n_coeff, 3)
else: # 'eeg'
fname = op.join(fname, "legval_%s_%s.bin" % (n_coeff, n_interp))
fname = op.join(fname, f"legval_{n_coeff}_{n_interp}.bin")
leg_fun = _get_legen
extra_str = ""
lut_shape = (n_interp + 1, n_coeff)
Expand Down
Loading

0 comments on commit 714c14b

Please sign in to comment.