diff --git a/extension_helpers/_utils.py b/extension_helpers/_utils.py index f618baa..42d1632 100644 --- a/extension_helpers/_utils.py +++ b/extension_helpers/_utils.py @@ -4,6 +4,7 @@ import sys from importlib import machinery as import_machinery from importlib.util import module_from_spec, spec_from_file_location +from pathlib import Path __all__ = ["write_if_different", "import_file"] @@ -86,23 +87,23 @@ def write_if_different(filename, data): Parameters ---------- - filename : str + filename : str or `pathlib.Path` The file name to be written to. data : bytes The data to be written to ``filename``. """ + filepath = Path(filename) + assert isinstance(data, bytes) - if os.path.exists(filename): - with open(filename, "rb") as fd: - original_data = fd.read() + if filepath.exists(): + original_data = filepath.read_bytes() else: original_data = None if original_data != data: - with open(filename, "wb") as fd: - fd.write(data) + filepath.write_bytes(data) def import_file(filename, name=None): @@ -123,15 +124,16 @@ def import_file(filename, name=None): # be unique, and it doesn't really matter because the name isn't # used directly here anyway. + filepath = Path(filename) + if name is None: - basename = os.path.splitext(filename)[0] - name = "_".join(os.path.abspath(basename).split(os.sep)[1:]) + name = "_".join(filepath.resolve().with_suffix("").parts[1:]) - if not os.path.exists(filename): - raise ImportError(f"Could not import file {filename}") + if not filepath.exists(): + raise ImportError(f"Could not import file {filepath}") - loader = import_machinery.SourceFileLoader(name, filename) - spec = spec_from_file_location(name, filename) + loader = import_machinery.SourceFileLoader(name, str(filepath)) + spec = spec_from_file_location(name, str(filepath)) mod = module_from_spec(spec) loader.exec_module(mod) diff --git a/extension_helpers/tests/test_utils.py b/extension_helpers/tests/test_utils.py index 983edc0..6d9aa2e 100644 --- a/extension_helpers/tests/test_utils.py +++ b/extension_helpers/tests/test_utils.py @@ -1,26 +1,34 @@ import os import time +import pytest + from .._utils import import_file, write_if_different -def test_import_file(tmp_path): - filename = str(tmp_path / "spam.py") - with open(filename, "w") as f: +@pytest.mark.parametrize("path_type", ("str", "path")) +def test_import_file(tmp_path, path_type): + filepath = tmp_path / "spam.py" + if path_type == "str": + filepath = str(filepath) + with open(filepath, "w") as f: f.write("magic = 12345") - module = import_file(filename) + module = import_file(filepath) assert module.magic == 12345 -def test_write_if_different(tmp_path): - filename = str(tmp_path / "test.txt") - write_if_different(filename, b"abc") - time1 = os.path.getmtime(filename) +@pytest.mark.parametrize("path_type", ("str", "path")) +def test_write_if_different(tmp_path, path_type): + filepath = tmp_path / "test.txt" + if path_type == "str": + filepath = str(filepath) + write_if_different(filepath, b"abc") + time1 = os.path.getmtime(filepath) time.sleep(0.01) - write_if_different(filename, b"abc") - time2 = os.path.getmtime(filename) + write_if_different(filepath, b"abc") + time2 = os.path.getmtime(filepath) assert time2 == time1 time.sleep(0.01) - write_if_different(filename, b"abcd") - time3 = os.path.getmtime(filename) + write_if_different(filepath, b"abcd") + time3 = os.path.getmtime(filepath) assert time3 > time1