Skip to content

Commit

Permalink
Merge pull request #84 from astrofrog/path-args
Browse files Browse the repository at this point in the history
  • Loading branch information
astrofrog authored Sep 5, 2024
2 parents f01a852 + c5c51ff commit 6fa2009
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 24 deletions.
26 changes: 14 additions & 12 deletions extension_helpers/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
from importlib import machinery as import_machinery
from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path

__all__ = ["write_if_different", "import_file"]

Expand Down Expand Up @@ -86,23 +87,23 @@ def write_if_different(filename, data):
Parameters
----------
filename : str
filename : str or `pathlib.Path`
The file name to be written to.
data : bytes
The data to be written to ``filename``.
"""

filepath = Path(filename)

assert isinstance(data, bytes)

if os.path.exists(filename):
with open(filename, "rb") as fd:
original_data = fd.read()
if filepath.exists():
original_data = filepath.read_bytes()
else:
original_data = None

if original_data != data:
with open(filename, "wb") as fd:
fd.write(data)
filepath.write_bytes(data)


def import_file(filename, name=None):
Expand All @@ -123,15 +124,16 @@ def import_file(filename, name=None):
# be unique, and it doesn't really matter because the name isn't
# used directly here anyway.

filepath = Path(filename)

if name is None:
basename = os.path.splitext(filename)[0]
name = "_".join(os.path.abspath(basename).split(os.sep)[1:])
name = "_".join(filepath.resolve().with_suffix("").parts[1:])

if not os.path.exists(filename):
raise ImportError(f"Could not import file {filename}")
if not filepath.exists():
raise ImportError(f"Could not import file {filepath}")

loader = import_machinery.SourceFileLoader(name, filename)
spec = spec_from_file_location(name, filename)
loader = import_machinery.SourceFileLoader(name, str(filepath))
spec = spec_from_file_location(name, str(filepath))
mod = module_from_spec(spec)
loader.exec_module(mod)

Expand Down
32 changes: 20 additions & 12 deletions extension_helpers/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
import os
import time

import pytest

from .._utils import import_file, write_if_different


def test_import_file(tmp_path):
filename = str(tmp_path / "spam.py")
with open(filename, "w") as f:
@pytest.mark.parametrize("path_type", ("str", "path"))
def test_import_file(tmp_path, path_type):
filepath = tmp_path / "spam.py"
if path_type == "str":
filepath = str(filepath)
with open(filepath, "w") as f:
f.write("magic = 12345")
module = import_file(filename)
module = import_file(filepath)
assert module.magic == 12345


def test_write_if_different(tmp_path):
filename = str(tmp_path / "test.txt")
write_if_different(filename, b"abc")
time1 = os.path.getmtime(filename)
@pytest.mark.parametrize("path_type", ("str", "path"))
def test_write_if_different(tmp_path, path_type):
filepath = tmp_path / "test.txt"
if path_type == "str":
filepath = str(filepath)
write_if_different(filepath, b"abc")
time1 = os.path.getmtime(filepath)
time.sleep(0.01)
write_if_different(filename, b"abc")
time2 = os.path.getmtime(filename)
write_if_different(filepath, b"abc")
time2 = os.path.getmtime(filepath)
assert time2 == time1
time.sleep(0.01)
write_if_different(filename, b"abcd")
time3 = os.path.getmtime(filename)
write_if_different(filepath, b"abcd")
time3 = os.path.getmtime(filepath)
assert time3 > time1

0 comments on commit 6fa2009

Please sign in to comment.