From 19b99ca3c9d56e479be8cfb161e9e31182e10d58 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 12 Feb 2024 00:03:51 +0000 Subject: [PATCH] Moved print_bindings into storage.py --- jaxtyping/__init__.py | 3 ++- jaxtyping/_decorator.py | 60 +++-------------------------------------- jaxtyping/_storage.py | 56 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 57 deletions(-) diff --git a/jaxtyping/__init__.py b/jaxtyping/__init__.py index 1bce0ff..0c9574d 100644 --- a/jaxtyping/__init__.py +++ b/jaxtyping/__init__.py @@ -30,13 +30,14 @@ set_array_name_format as set_array_name_format, ) from ._config import config as config -from ._decorator import jaxtyped as jaxtyped, print_bindings as print_bindings +from ._decorator import jaxtyped as jaxtyped from ._errors import ( AnnotationError as AnnotationError, TypeCheckError as TypeCheckError, ) from ._import_hook import install_import_hook as install_import_hook from ._ipython_extension import load_ipython_extension as load_ipython_extension +from ._storage import print_bindings as print_bindings # Now import Array and ArrayLike diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py index f4ef9ac..9a09861 100644 --- a/jaxtyping/_decorator.py +++ b/jaxtyping/_decorator.py @@ -36,7 +36,7 @@ from ._config import config from ._errors import AnnotationError, TypeCheckError -from ._storage import get_shape_memo, pop_shape_memo, push_shape_memo +from ._storage import pop_shape_memo, push_shape_memo, shape_str class _Sentinel: @@ -319,7 +319,7 @@ def wrapped_fn(*args, **kwargs): # pyright: ignore return fn(*args, **kwargs) except Exception as e: if sys.version_info >= (3, 11) and _no_jaxtyping_note(e): - shape_info = _exc_shape_info(memos) + shape_info = shape_str(memos) if shape_info != "": msg = ( "The preceding error occurred within the scope of a " @@ -411,7 +411,7 @@ def wrapped_fn(*args, **kwargs): "----------------------\n" f"Called with parameters: {param_values}\n" f"Parameter annotations: {param_hints}.\n" - + _exc_shape_info(memos) + + shape_str(memos) ) if config.jaxtyping_remove_typechecker_stack: raise TypeCheckError(msg) from None @@ -464,7 +464,7 @@ def wrapped_fn(*args, **kwargs): "----------------------\n" f"Called with parameters: {param_values}\n" f"Parameter annotations: {param_hints}.\n" - + _exc_shape_info(memos) + + shape_str(memos) ) if config.jaxtyping_remove_typechecker_stack: raise TypeCheckError(msg) from None @@ -756,40 +756,6 @@ def _pformat(x, short_self: bool): return pformat(x) -def _exc_shape_info(memos) -> str: - """Gives debug information on the current state of jaxtyping's internal memos. - Used in type-checking error messages. - """ - single_memo, variadic_memo, pytree_memo, _ = memos - single_memo = { - name: size - for name, size in single_memo.items() - if not name.startswith("~~delete~~") - } - variadic_memo = { - name: shape - for name, (_, shape) in variadic_memo.items() - if not name.startswith("~~delete~~") - } - pieces = [] - if len(single_memo) > 0 or len(variadic_memo) > 0: - pieces.append( - "The current values for each jaxtyping axis annotation are as follows." - ) - for name, size in single_memo.items(): - pieces.append(f"{name}={size}") - for name, shape in variadic_memo.items(): - pieces.append(f"{name}={shape}") - if len(pytree_memo) > 0: - pieces.append( - "The current values for each jaxtyping PyTree structure annotation are as " - "follows." - ) - for name, structure in pytree_memo.items(): - pieces.append(f"{name}={structure}") - return "\n".join(pieces) - - class _jaxtyping_note_str(str): """Used with `_no_jaxtyping_note` to flag that a note came from jaxtyping.""" @@ -808,21 +774,3 @@ def _no_jaxtyping_note(e: Exception) -> bool: _spacer = "--------------------\n" - - -def print_bindings(): - """Prints the values of the current jaxtyping axis bindings. Intended for debugging. - - That is, whilst doing runtime type checking, so that e.g. the `foo` and `bar` of - `Float[Array, "foo bar"]` are assigned values -- this function will print out those - values. - - **Arguments:** - - Nothing. - - **Returns:** - - Nothing. - """ - print(_exc_shape_info(get_shape_memo())) diff --git a/jaxtyping/_storage.py b/jaxtyping/_storage.py index 7cd1a30..706de99 100644 --- a/jaxtyping/_storage.py +++ b/jaxtyping/_storage.py @@ -71,6 +71,62 @@ def pop_shape_memo() -> None: _shape_storage.memo_stack.pop() +def shape_str(memos) -> str: + """Gives debug information on the current state of jaxtyping's internal memos. + Used in type-checking error messages. + + **Arguments:** + + - `memos`: as returned by `get_shape_memo` or `push_shape_memo`. + """ + single_memo, variadic_memo, pytree_memo, _ = memos + single_memo = { + name: size + for name, size in single_memo.items() + if not name.startswith("~~delete~~") + } + variadic_memo = { + name: shape + for name, (_, shape) in variadic_memo.items() + if not name.startswith("~~delete~~") + } + pieces = [] + if len(single_memo) > 0 or len(variadic_memo) > 0: + pieces.append( + "The current values for each jaxtyping axis annotation are as follows." + ) + for name, size in single_memo.items(): + pieces.append(f"{name}={size}") + for name, shape in variadic_memo.items(): + pieces.append(f"{name}={shape}") + if len(pytree_memo) > 0: + pieces.append( + "The current values for each jaxtyping PyTree structure annotation are as " + "follows." + ) + for name, structure in pytree_memo.items(): + pieces.append(f"{name}={structure}") + return "\n".join(pieces) + + +def print_bindings(): + """Prints the values of the current jaxtyping axis bindings. Intended for debugging. + + That is, whilst doing runtime type checking, so that e.g. the `foo` and `bar` of + `Float[Array, "foo bar"]` are assigned values -- this function will print out those + values. + + **Arguments:** + + Nothing. + + **Returns:** + + Nothing. + """ + print(shape_str(get_shape_memo())) + + _treepath_storage = threading.local()