From 74eeb705b205cc2cc921f54535328847bc8abfdd Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 23 Dec 2024 23:16:50 +0100 Subject: [PATCH] Now statically working with dataclasses correctly. --- jaxtyping/_decorator.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py index fadbb37..308f4b8 100644 --- a/jaxtyping/_decorator.py +++ b/jaxtyping/_decorator.py @@ -34,6 +34,7 @@ NoReturn, overload, TypeVar, + Union, ) @@ -53,6 +54,10 @@ _Params = ParamSpec("_Params") _Return = TypeVar("_Return") +_T = TypeVar("_T") +# Not `TypeVar(..., type, Callable)` as else the output type of our first overload is +# just `type`, and not the particular class that is decorated. +_TypeOrCallable = TypeVar("_TypeOrCallable", bound=Union[type, Callable]) class _Sentinel: @@ -77,7 +82,11 @@ def _apply_typechecker(typechecker, fn): def jaxtyped( *, typechecker=_sentinel, -) -> Callable[[Callable[_Params, _Return]], Callable[_Params, _Return]]: ... +) -> Callable[[_TypeOrCallable], _TypeOrCallable]: ... + + +@overload +def jaxtyped(fn: type[_T], *, typechecker=_sentinel) -> type[_T]: ... @overload