Skip to content

Commit

Permalink
Added support for beartype 0.17.0's __instancecheck_str__.
Browse files Browse the repository at this point in the history
Recall that jaxtyping will currently generate rich error messages in precisely one scenario: about the arguments and return types when doing:
```python
@jaxtyped(typechecker=beartype)
def foo(...): ...
```

With this commit we add support for beartype 0.17.0's pseudo-standard `__instancecheck_str__`, which means the following:

1. For those using beartype decorators, the following will *also* generate an informative error message, and moreover it will state exactly why (shape mismatch, dtype mismatch etc):
    ```python
    @jaxtyped(typechecker=None)
    @beartype
    def foo(...): ...
    ```
    (In practice we probably won't recommend the above combination in the docs just to keep things simple.)

2. For those using the beartype import hook together with the jaxtyping import hook, we can probably also check `assert isinstance(x, Float[Array, "foo"])` statements with rich error messages. (#153) We'll need to test + document that though. (@jeezrick interested?)

3. For those using plain `assert isinstance(...)` statements without beartype (#167, tagging @reinerp), then they can *also* get rich error messages by doing
    ```python
    tt = Float[Array, "foo"]
    assert isinstance(x, tt), tt.__instancecheck_str__(x) + "\n" + print_bindings()
    ```
    which is still a bit long-winded right now but is a step in the right direction.

(CC @leycec for interest.)
  • Loading branch information
patrick-kidger committed Feb 17, 2024
1 parent 19b99ca commit 9c96f09
Showing 1 changed file with 40 additions and 29 deletions.
69 changes: 40 additions & 29 deletions jaxtyping/_array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _check_dims(
obj_shape: tuple[int, ...],
single_memo: dict[str, int],
arg_memo: dict[str, Any],
) -> bool:
) -> str:
assert len(cls_dims) == len(obj_shape)
for cls_dim, obj_size in zip(cls_dims, obj_shape):
if cls_dim is _anonymous_dim:
Expand All @@ -124,7 +124,7 @@ def _check_dims(
pass
elif type(cls_dim) is _FixedDim:
if cls_dim.size != obj_size:
return False
return f"the dimension size {obj_size} does not equal {cls_dim.size} as expected by the type hint" # noqa: E501
elif type(cls_dim) is _SymbolicDim:
try:
# Support f-string syntax.
Expand All @@ -141,7 +141,7 @@ def _check_dims(
"arguments."
) from e
if eval_size != obj_size:
return False
return f"the dimension size {obj_size} does not equal the existing value of {cls_dim.elem}={eval_size}" # noqa: E501
else:
assert type(cls_dim) is _NamedDim
if cls_dim.treepath:
Expand All @@ -154,16 +154,19 @@ def _check_dims(
single_memo[name] = obj_size
else:
if cls_size != obj_size:
return False
return True
return f"the size of dimension {cls_dim.name} is {obj_size} which does not equal the existing value of {cls_size}" # noqa: E501
return ""


class _MetaAbstractArray(type):
def __instancecheck__(cls, obj):
def __instancecheck__(cls, obj: Any) -> bool:
return cls.__instancecheck_str__(obj) == ""

def __instancecheck_str__(cls, obj: Any) -> str:
if not isinstance(obj, cls.array_type):
return False
return f"this value is not an instance of the underlying array type {cls.array_type}" # noqa: E501
if get_treeflatten_memo():
return True
return ""

if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"):
# JAX, numpy
Expand Down Expand Up @@ -193,7 +196,10 @@ def __instancecheck__(cls, obj):
if in_dtypes:
break
if not in_dtypes:
return False
if len(cls.dtypes) == 1:
return f"this array has dtype {dtype}, not {cls.dtypes[0]} as expected by the type hint" # noqa: E501
else:
return f"this array has dtype {dtype}, not any of {cls.dtypes} as expected by the type hint" # noqa: E501

single_memo, variadic_memo, pytree_memo, arg_memo = get_shape_memo()
single_memo_bak = single_memo.copy()
Expand All @@ -207,41 +213,46 @@ def __instancecheck__(cls, obj):
single_memo_bak, variadic_memo_bak, pytree_memo_bak, arg_memo_bak
)
raise
if check:
return True
if check == "":
return check
else:
set_shape_memo(
single_memo_bak, variadic_memo_bak, pytree_memo_bak, arg_memo_bak
)
return False
return check

def _check_shape(
cls,
obj,
single_memo: dict[str, int],
variadic_memo: dict[str, tuple[bool, tuple[int, ...]]],
arg_memo: dict[str, Any],
):
) -> str:
if cls.index_variadic is None:
if obj.ndim != len(cls.dims):
return False
return f"this array has {obj.ndim} dimensions, not the {len(cls.dims)} expected by the type hint" # noqa: E501
return _check_dims(cls.dims, obj.shape, single_memo, arg_memo)
else:
if obj.ndim < len(cls.dims) - 1:
return False
return f"this array has {obj.ndim} dimensions, which is fewer than {len(cls.dims - 1)} that is the minimum expected by the type hint" # noqa: E501
i = cls.index_variadic
j = -(len(cls.dims) - i - 1)
if j == 0:
j = None
if not _check_dims(cls.dims[:i], obj.shape[:i], single_memo, arg_memo):
return False
if j is not None and not _check_dims(
cls.dims[j:], obj.shape[j:], single_memo, arg_memo
):
return False
prefix_check = _check_dims(
cls.dims[:i], obj.shape[:i], single_memo, arg_memo
)
if prefix_check != "":
return prefix_check
if j is not None:
suffix_check = _check_dims(
cls.dims[j:], obj.shape[j:], single_memo, arg_memo
)
if suffix_check != "":
return suffix_check
variadic_dim = cls.dims[i]
if variadic_dim is _anonymous_variadic_dim:
return True
return ""
else:
assert type(variadic_dim) is _NamedVariadicDim
if variadic_dim.treepath:
Expand All @@ -253,16 +264,16 @@ def _check_shape(
prev_broadcastable, prev_shape = variadic_memo[name]
except KeyError:
variadic_memo[name] = (broadcastable, obj.shape[i:j])
return True
return ""
else:
new_shape = obj.shape[i:j]
if prev_broadcastable:
try:
broadcast_shape = np.broadcast_shapes(new_shape, prev_shape)
except ValueError: # not broadcastable e.g. (3, 4) and (5,)
return False
return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which cannot be broadcast with the existing value of {prev_shape}" # noqa: E501
if not broadcastable and broadcast_shape != new_shape:
return False
return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which the existing value of {prev_shape} cannot be broadcast to" # noqa: E501
variadic_memo[name] = (broadcastable, broadcast_shape)
else:
if broadcastable:
Expand All @@ -271,13 +282,13 @@ def _check_shape(
new_shape, prev_shape
)
except ValueError: # not broadcastable e.g. (3, 4) and (5,)
return False
return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which cannot be broadcast with the existing value of {prev_shape}" # noqa: E501
if broadcast_shape != prev_shape:
return False
return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which cannot be broadcast to the existing value of {prev_shape}" # noqa: E501
else:
if new_shape != prev_shape:
return False
return True
return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which does not equal the existing value of {prev_shape}" # noqa: E501
return ""
assert False


Expand Down

0 comments on commit 9c96f09

Please sign in to comment.