diff --git a/README.md b/README.md index 5da380f..65feb39 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ def accepts_pytree_of_arrays(x: PyTree[Float[Array, "batch c1 c2"]]): pip install jaxtyping ``` -Requires Python 3.9+. +Requires Python 3.10+. JAX is an optional dependency, required for a few JAX-specific types. If JAX is not installed then these will not be available, but you may still use jaxtyping to provide shape/dtype annotations for PyTorch/NumPy/TensorFlow/etc. diff --git a/docs/index.md b/docs/index.md index 677ef7e..854518b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -13,7 +13,7 @@ jaxtyping is a library providing type annotations **and runtime type-checking** pip install jaxtyping ``` -Requires Python 3.9+. +Requires Python 3.10+. JAX is an optional dependency, required for a few JAX-specific types. If JAX is not installed then these will not be available, but you may still use jaxtyping to provide shape/dtype annotations for PyTorch/NumPy/TensorFlow/etc. diff --git a/jaxtyping/__init__.py b/jaxtyping/__init__.py index e601154..3848c14 100644 --- a/jaxtyping/__init__.py +++ b/jaxtyping/__init__.py @@ -22,7 +22,7 @@ import importlib.util import typing import warnings -from typing import Union +from typing import TypeAlias, Union from ._array_types import ( AbstractArray as AbstractArray, @@ -43,8 +43,6 @@ if typing.TYPE_CHECKING: - import typing_extensions - from jax import Array as Array from jax.tree_util import PyTreeDef as PyTreeDef from jax.typing import ArrayLike as ArrayLike, DTypeLike as DTypeLike @@ -90,7 +88,7 @@ ) # Set up to deliberately confuse a static type checker. - PyTree: typing_extensions.TypeAlias = getattr(typing, "foo" + "bar") + PyTree: TypeAlias = getattr(typing, "foo" + "bar") # What's going on with this madness? # # At static-type-checking-time, we want `PyTree` to be a type for which both diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py index 308f4b8..2d2cbaa 100644 --- a/jaxtyping/_decorator.py +++ b/jaxtyping/_decorator.py @@ -33,18 +33,11 @@ get_type_hints, NoReturn, overload, + ParamSpec, TypeVar, Union, ) - -try: - from typing import ParamSpec -except ImportError: - # Python < 3.10 - from typing_extensions import ParamSpec - - from jaxtyping import AbstractArray from ._config import config @@ -847,12 +840,12 @@ def _pformat(x, short_self: bool): # No performance concerns from delayed imports -- this is only used when we're about # to raise an error anyway. try: - # TODO(kidger): this is pretty ugly. We have a circular dependency - # equinox->jaxtyping->equinox. We could consider moving all the pretty-printing - # code from equinox into jaxtyping maybe? Or into some shared dependency? + # If we can, use `eqx.tree_pformat`, which wraps `wadler_lindig.pformat` with + # understanding of a few other JAX-specific things. import equinox as eqx pformat = eqx.tree_pformat + if short_self: try: self = x["self"] @@ -862,9 +855,23 @@ def _pformat(x, short_self: bool): is_self = lambda y: y is self pformat = ft.partial(pformat, truncate_leaf=is_self) except Exception: - import pprint + # Failing that fall back to `wadler_lindig.pformat` directly. + import wadler_lindig + + pformat = wadler_lindig.pformat + + if short_self: + try: + self = x["self"] + except KeyError: + pass + else: + + def custom(obj): + if obj is self: + return wadler_lindig.TextDoc(f"{type(obj).__name__}(...)") - pformat = ft.partial(pprint.pformat, indent=2, compact=True) + pformat = ft.partial(pformat, custom=custom) try: return pformat(x) except Exception: diff --git a/pyproject.toml b/pyproject.toml index a0fa8c6..a935e86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "jaxtyping" version = "0.2.36" description = "Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays." readme = "README.md" -requires-python =">=3.9" +requires-python =">=3.10" license = {file = "LICENSE"} authors = [ {name = "Patrick Kidger", email = "contact@kidger.site"}, @@ -23,9 +23,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Mathematics", ] urls = {repository = "https://github.com/google/jaxtyping" } -dependencies = [ - "typing_extensions; python_version < '3.10'" -] +dependencies = ["wadler_lindig>=0.1.0"] entry-points = {pytest11 = {jaxtyping = "jaxtyping._pytest_plugin"}} [build-system]