diff --git a/jaxtyping/_array_types.py b/jaxtyping/_array_types.py index 417dc69..9f18045 100644 --- a/jaxtyping/_array_types.py +++ b/jaxtyping/_array_types.py @@ -115,7 +115,7 @@ def _check_dims( obj_shape: tuple[int, ...], single_memo: dict[str, int], arg_memo: dict[str, Any], -) -> bool: +) -> str: assert len(cls_dims) == len(obj_shape) for cls_dim, obj_size in zip(cls_dims, obj_shape): if cls_dim is _anonymous_dim: @@ -124,7 +124,7 @@ def _check_dims( pass elif type(cls_dim) is _FixedDim: if cls_dim.size != obj_size: - return False + return f"the dimension size {obj_size} does not equal {cls_dim.size} as expected by the type hint" # noqa: E501 elif type(cls_dim) is _SymbolicDim: try: # Support f-string syntax. @@ -141,7 +141,7 @@ def _check_dims( "arguments." ) from e if eval_size != obj_size: - return False + return f"the dimension size {obj_size} does not equal the existing value of {cls_dim.elem}={eval_size}" # noqa: E501 else: assert type(cls_dim) is _NamedDim if cls_dim.treepath: @@ -154,16 +154,19 @@ def _check_dims( single_memo[name] = obj_size else: if cls_size != obj_size: - return False - return True + return f"the size of dimension {cls_dim.name} is {obj_size} which does not equal the existing value of {cls_size}" # noqa: E501 + return "" class _MetaAbstractArray(type): - def __instancecheck__(cls, obj): + def __instancecheck__(cls, obj: Any) -> bool: + return cls.__instancecheck_str__(obj) == "" + + def __instancecheck_str__(cls, obj: Any) -> str: if not isinstance(obj, cls.array_type): - return False + return f"this value is not an instance of the underlying array type {cls.array_type}" # noqa: E501 if get_treeflatten_memo(): - return True + return "" if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"): # JAX, numpy @@ -193,7 +196,10 @@ def __instancecheck__(cls, obj): if in_dtypes: break if not in_dtypes: - return False + if len(cls.dtypes) == 1: + return f"this array has dtype {dtype}, not {cls.dtypes[0]} as expected by the type hint" # noqa: E501 + else: + return f"this array has dtype {dtype}, not any of {cls.dtypes} as expected by the type hint" # noqa: E501 single_memo, variadic_memo, pytree_memo, arg_memo = get_shape_memo() single_memo_bak = single_memo.copy() @@ -207,13 +213,13 @@ def __instancecheck__(cls, obj): single_memo_bak, variadic_memo_bak, pytree_memo_bak, arg_memo_bak ) raise - if check: - return True + if check == "": + return check else: set_shape_memo( single_memo_bak, variadic_memo_bak, pytree_memo_bak, arg_memo_bak ) - return False + return check def _check_shape( cls, @@ -221,27 +227,32 @@ def _check_shape( single_memo: dict[str, int], variadic_memo: dict[str, tuple[bool, tuple[int, ...]]], arg_memo: dict[str, Any], - ): + ) -> str: if cls.index_variadic is None: if obj.ndim != len(cls.dims): - return False + return f"this array has {obj.ndim} dimensions, not the {len(cls.dims)} expected by the type hint" # noqa: E501 return _check_dims(cls.dims, obj.shape, single_memo, arg_memo) else: if obj.ndim < len(cls.dims) - 1: - return False + return f"this array has {obj.ndim} dimensions, which is fewer than {len(cls.dims - 1)} that is the minimum expected by the type hint" # noqa: E501 i = cls.index_variadic j = -(len(cls.dims) - i - 1) if j == 0: j = None - if not _check_dims(cls.dims[:i], obj.shape[:i], single_memo, arg_memo): - return False - if j is not None and not _check_dims( - cls.dims[j:], obj.shape[j:], single_memo, arg_memo - ): - return False + prefix_check = _check_dims( + cls.dims[:i], obj.shape[:i], single_memo, arg_memo + ) + if prefix_check != "": + return prefix_check + if j is not None: + suffix_check = _check_dims( + cls.dims[j:], obj.shape[j:], single_memo, arg_memo + ) + if suffix_check != "": + return suffix_check variadic_dim = cls.dims[i] if variadic_dim is _anonymous_variadic_dim: - return True + return "" else: assert type(variadic_dim) is _NamedVariadicDim if variadic_dim.treepath: @@ -253,16 +264,16 @@ def _check_shape( prev_broadcastable, prev_shape = variadic_memo[name] except KeyError: variadic_memo[name] = (broadcastable, obj.shape[i:j]) - return True + return "" else: new_shape = obj.shape[i:j] if prev_broadcastable: try: broadcast_shape = np.broadcast_shapes(new_shape, prev_shape) except ValueError: # not broadcastable e.g. (3, 4) and (5,) - return False + return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which cannot be broadcast with the existing value of {prev_shape}" # noqa: E501 if not broadcastable and broadcast_shape != new_shape: - return False + return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which the existing value of {prev_shape} cannot be broadcast to" # noqa: E501 variadic_memo[name] = (broadcastable, broadcast_shape) else: if broadcastable: @@ -271,13 +282,13 @@ def _check_shape( new_shape, prev_shape ) except ValueError: # not broadcastable e.g. (3, 4) and (5,) - return False + return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which cannot be broadcast with the existing value of {prev_shape}" # noqa: E501 if broadcast_shape != prev_shape: - return False + return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which cannot be broadcast to the existing value of {prev_shape}" # noqa: E501 else: if new_shape != prev_shape: - return False - return True + return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which does not equal the existing value of {prev_shape}" # noqa: E501 + return "" assert False