From aec99a2a55b71163a02366d9570a547d8bd7fd45 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 5 Jan 2025 13:58:55 +0100 Subject: [PATCH] Removed the requirement for a particular version of Python in pre-commit I've realised that this only works if you have that particular version of Python installed locally -- pre-commit doesn't download this for you. This unacceptably raises the bar for making contributions -- folks shouldn't have to modify their global system just to offer a PR against Equinox. Unfortunately this has meant that I've had to remove a couple of dependencies from the `additional_dependencies` list, as they don't support Python 3.13, which is what is sometimes selected. (Alternatives considered: - Install the specified version of Python as part of the pre-commit hook. This would be ideal. Unfortunately `pre-commit` passes the specified `language_version` to `python -m virtualenv` under the hood, and that doesn't seem to offer a way to do this. - Based on the above: something with `uv`? Unfortunately `pre-commit` have also elected *not* to support `uv` (https://github.com/pre-commit/pre-commit/pull/3131, https://github.com/pre-commit/pre-commit/issues/3222), which would have done this automatically. Maybe we just need to wait until the folks at Astral write their own version of pre-commit as well! - Specify a range of versions for Python, so that we use whatever the system Python is, as long as it is below 3.13. Unfortunately pre-commit doesn't seem to support this. - Write our own local hook that does whatever we damn well please: `uv` to installs the right version of Python, downloads pyright, and run it. If this becomes problematic amongst the other repos then I may well do this. ) --- .pre-commit-config.yaml | 4 ---- equinox/_enum.py | 2 +- equinox/_filters.py | 4 ++-- equinox/internal/_onnx.py | 4 ++-- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bfbfb30e..6c43ed90 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,8 +11,6 @@ repos: rev: v1.1.379 hooks: - id: pyright - # must match the Python version used in CI - language_version: python3.11 additional_dependencies: [ beartype, @@ -21,7 +19,5 @@ repos: jaxtyping, optax, pytest, - tensorflow, - tf2onnx, typing_extensions, ] diff --git a/equinox/_enum.py b/equinox/_enum.py index 1716a60c..b18adbdc 100644 --- a/equinox/_enum.py +++ b/equinox/_enum.py @@ -140,7 +140,7 @@ def __instancecheck__(cls, value): class EnumerationItem(Module): - _value: Int[Union[Array, np.ndarray], ""] + _value: Int[Union[Array, np.ndarray[Any, np.dtype[np.signedinteger]]], ""] # Should have annotation `"type[Enumeration]"`, but this fails due to beartype bug # #289. _enumeration: Any = field(static=True) diff --git a/equinox/_filters.py b/equinox/_filters.py index 13ee1fb5..6a219a56 100644 --- a/equinox/_filters.py +++ b/equinox/_filters.py @@ -37,7 +37,7 @@ def is_inexact_array(element: Any) -> bool: array. """ if isinstance(element, (np.ndarray, np.generic)): - return np.issubdtype(element.dtype, np.inexact) + return bool(np.issubdtype(element.dtype, np.inexact)) elif isinstance(element, jax.Array): return jnp.issubdtype(element.dtype, jnp.inexact) else: @@ -51,7 +51,7 @@ def is_inexact_array_like(element: Any) -> bool: if hasattr(element, "__jax_array__"): element = element.__jax_array__() if isinstance(element, (np.ndarray, np.generic)): - return np.issubdtype(element.dtype, np.inexact) + return bool(np.issubdtype(element.dtype, np.inexact)) elif isinstance(element, jax.Array): return jnp.issubdtype(element.dtype, jnp.inexact) else: diff --git a/equinox/internal/_onnx.py b/equinox/internal/_onnx.py index ae31263d..7c7d76ca 100644 --- a/equinox/internal/_onnx.py +++ b/equinox/internal/_onnx.py @@ -24,8 +24,8 @@ def f(x, y): ``` """ import jax.experimental.jax2tf as jax2tf - import tensorflow as tf - import tf2onnx + import tensorflow as tf # pyright: ignore[reportMissingImports] + import tf2onnx # pyright: ignore[reportMissingImports] def _to_onnx(*args): finalised_fn = finalise_fn(fn)