Skip to content

Commit

Permalink
MNT: Improving the set_download_dir to handle parallel access (#667)
Browse files Browse the repository at this point in the history
* FIX: paying old bills

* FIX: fixing unrelated test
  • Loading branch information
bruAristimunha authored Oct 18, 2024
1 parent 98d041a commit ef22069
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 24 deletions.
1 change: 1 addition & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Bugs
- Fix Stieger2021 dataset bugs (:gh:`651` by `Martin Wimpff`_)
- Unpinning major version Scikit-learn and numpy (:gh:`652` by `Bruno Aristimunha`_)
- Replacing the func:`numpy.string_` to func:`numpy.bytes_` (:gh:`665` by `Bruno Aristimunha`_)
- Fixing the set_download_dir that was not working when we tried to set the dir more than 10 times at the same time (:gh:`668` by `Bruno Aristimunha`_)

API changes
~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion moabb/tests/util_braindecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_type_create_from_X_y_vs_transfomer(self, data):
transformer = BraindecodeDatasetLoader()
dataset_trans = transformer.fit(X=X_train, y=y_train).transform(X_train)
assert isinstance(dataset_trans, BaseConcatDataset)
assert isinstance(type(dataset_trans), type(dataset))
assert isinstance(dataset, BaseConcatDataset)

def test_wrong_input(self):
"""Test that an invalid input raises a ValueError."""
Expand Down
91 changes: 90 additions & 1 deletion moabb/tests/util_tests.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os.path as osp
import tempfile
import unittest
from unittest.mock import MagicMock, patch

import pytest
from mne import get_config
from joblib import Parallel, delayed
from mne import get_config, set_config

from moabb.datasets import utils
from moabb.utils import aliases_list, depreciated_alias, set_download_dir, setup_seed
Expand Down Expand Up @@ -209,5 +211,92 @@ def dummy_a(a, b=1):
self.assertEqual(dummy_b.__name__, "dummy_b") # noqa: F821


@pytest.fixture(autouse=True)
def reset_mne_config():
"""Fixture to reset MNE_DATA config before and after each test."""
original_config = get_config("MNE_DATA")
yield
if original_config is not None:
set_config("MNE_DATA", original_config, set_env=True)
else:
# Remove the config if it was not set originally
set_config("MNE_DATA", None, set_env=True)


def test_set_download_dir_none_not_set(capsys):
"""Test setting download directory to None when MNE_DATA is not set."""
# Ensure MNE_DATA is not set
set_config("MNE_DATA", None)

set_download_dir(None)

captured = capsys.readouterr()
expected_path = osp.join(osp.expanduser("~"), "mne_data")
assert "MNE_DATA is not already configured" in captured.out
assert "default location in the home directory" in captured.out
assert "mne_data" in captured.out

assert get_config("MNE_DATA") == expected_path


def test_set_download_dir_none_already_set(capsys):
"""Test setting download directory to None when MNE_DATA is already set."""
predefined_path = "/existing/mne_data_path"
set_config("MNE_DATA", predefined_path)

set_download_dir(None)

captured = capsys.readouterr()
# No print should occur since MNE_DATA is already set
assert captured.out == ""
assert get_config("MNE_DATA") == predefined_path


def test_set_download_dir_existing_path(capsys):
"""Test setting download directory to an existing path."""
with tempfile.TemporaryDirectory() as tmpdir:
set_download_dir(tmpdir)
captured = capsys.readouterr()
# No print should occur since the directory exists
assert captured.out == ""
assert get_config("MNE_DATA") == tmpdir


def test_set_download_dir_nonexistent_path(capsys):
"""Test setting download directory to a non-existent path."""
with tempfile.TemporaryDirectory() as tmpdir:
non_existent_path = osp.join(tmpdir, "new_mne_data")

# Ensure the path does not exist
assert not osp.exists(non_existent_path)

set_download_dir(non_existent_path)

captured = capsys.readouterr()
assert "The path given does not exist, creating it.." in captured.out
assert osp.isdir(non_existent_path)
assert get_config("MNE_DATA") == non_existent_path


@pytest.mark.parametrize("path_exists", [True, False])
def test_set_download_dir_parallel(path_exists, tmp_path, capsys):
"""Test setting download directory in parallel with joblib."""
if path_exists:
path = tmp_path / "existing_dir"
path.mkdir()
else:
path = tmp_path / "non_existing_dir"

def worker(p):
set_download_dir(p)
mne_data_value = get_config("MNE_DATA")
return mne_data_value

results = Parallel(n_jobs=10)(delayed(worker)(path) for _ in range(100))

for mne_data_value in results:
assert mne_data_value == str(path)


if __name__ == "__main__":
unittest.main()
86 changes: 64 additions & 22 deletions moabb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
from mne import get_config, set_config
from mne import set_log_level as sll
from mne.utils import get_config_path


if TYPE_CHECKING:
Expand Down Expand Up @@ -131,34 +132,75 @@ def set_log_level(level="INFO"):
)


# Cross-platform file-locking
if sys.platform.startswith("win"):
import msvcrt

def lock_file(f):
msvcrt.locking(f.fileno(), msvcrt.LK_LOCK, 1)

def unlock_file(f):
msvcrt.locking(f.fileno(), msvcrt.LK_UNLCK, 1)

else:
import fcntl

def lock_file(f):
fcntl.flock(f, fcntl.LOCK_EX)

def unlock_file(f):
fcntl.flock(f, fcntl.LOCK_UN)


def set_download_dir(path):
"""Set the download directory if required to change from default mne path.
"""Set the download directory if required to change from default MNE path.
Parameters
----------
path : None | str
The new storage location, if it does not exist, a warning is raised and the
path is created
If None, and MNE_DATA config does not exist, a warning is raised and the
storage location is set to the MNE default directory
The new storage location. If it does not exist, a warning is raised and the
path is created.
If None, and MNE_DATA config does not exist, a warning is raised and the
storage location is set to the MNE default directory.
"""
if path is None:
if get_config("MNE_DATA") is None:
print(
"MNE_DATA is not already configured. It will be set to "
"default location in the home directory - "
+ osp.join(osp.expanduser("~"), "mne_data")
+ "All datasets will be downloaded to this location, if anything is "
"already downloaded, please move manually to this location"
)

set_config("MNE_DATA", osp.join(osp.expanduser("~"), "mne_data"))
else:
# Check if the path exists, if not, create it
if not osp.isdir(path):
print("The path given does not exist, creating it..")
os.makedirs(path)
set_config("MNE_DATA", path)
config_path = get_config_path()
# Use the config file itself as the lock file
lock_file_path = config_path + ".lock"

# Ensure the config directory exists
config_dir = osp.dirname(config_path)
if not osp.exists(config_dir):
os.makedirs(config_dir)

# Open the lock file
with open(lock_file_path, "w") as lock_file_obj:
# Acquire the lock
lock_file(lock_file_obj)
try:
# Critical section: read and write config
if path is None:
if get_config("MNE_DATA") is None:
print(
"MNE_DATA is not already configured. It will be set to "
"default location in the home directory - "
+ osp.join(osp.expanduser("~"), "mne_data")
+ ". All datasets will be downloaded to this location. "
"If anything is already downloaded, please move manually to this location."
)
default_path = osp.join(osp.expanduser("~"), "mne_data")
set_config("MNE_DATA", default_path)
else:
# Create the directory if it doesn't exist
if not osp.isdir(path):
print("The path given does not exist, creating it..")
os.makedirs(path, exist_ok=True)
# Only set the config if it's different
current_mne_data = get_config("MNE_DATA")
if current_mne_data != path:
set_config("MNE_DATA", path)
finally:
# Release the lock
unlock_file(lock_file_obj)


def make_process_pipelines(
Expand Down

0 comments on commit ef22069

Please sign in to comment.