-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
will runtime type checking go beyond function parameters and return type? #153
Comments
So right now this will work if you use a manual @jaxtyped(typechecker=beartype)
def forward(self, x: Patch_embed) -> Patch_embed:
x: Mlp_mid = self.act(self.lin1(x))
assert isinstance(x, Mlp_mid)
return self.lin2(x) but I agree that inserting this automatically would be an awesome feature to have. I can see two possible ways we might add this:
Tagging @leycec -- I know you've tackled this problem before; do you have any thoughts on the matter?
If you're using JAX, then it shouldn't slow runtime at all. There will only be a negligible amount of work at compile time, just to actually check that the shapes and dtypes of the arrays are what they say there are. If you're using PyTorch/etc. then the amount of overhead hasn't really been benchmarked, so I'm less sure :) |
...heh. The age-old PEP 526-compliant Annotated Variable Assignment Runtime Type-checking Problem, huh? That one just never gets old. So, thanks for pinging me on! I love this sort of hacky Python wrangling. Sadly, you already know everything. This is like when Luke in Empire Strikes Back finally realizes that the little wizened green lizard blob thing is actually the most powerful surviving Jedi in the entire universe. You are that blob thing. Wait. That... no longer sounds complimentary. Let's awkwardly start over. Two Roads: One That Sucks and One That Sucks a Bit LessAs you astutely surmise, two sucky roads lie before you:
Relatedly, you almost certainly want to Let's Make a Deal: A Match Made in GitHubTangentially, would you like @beartype to automate your import hooks for you? @beartype would be delighted to silently piggyback Specifically, @beartype would automate away everything by:
# In our private "beartype.claw._importlib._clawimpload` submodule:
...
# Attempt to...
try:
# Defer optional dependency imports.
from jaxtyping._import_hook import JaxtypingTransformer
# AST transformer decorating typed callables and classes by "jaxtyping".
#
# Note that we intentionally pass *NO* third-party "typechecker" to avoid
# redundantly applying the @beartype.beartype decorator to the same callables
# and classes twice.
ast_jaxtyper = JaxtypingTransformer(typechecker=None)
# Abstract syntax tree (AST) modified by this transformer.
module_ast = ast_jaxtyper.visit(module_ast)
# If "jaxtyping" is currently unimportable, silently pretend everything is well.
except ImportError:
pass
# AST transformer decorating typed callables and classes by @beartype.
ast_beartyper = BeartypeNodeTransformer(
conf_beartype=self._module_conf_beartype)
# Abstract syntax tree (AST) modified by this transformer.
module_ast_beartyped = ast_beartyper.visit(module_ast) Of course, you don't need to actually depend upon or require @beartype in any way. No changes on needed on your end. Actually, one little change would improve sanity: if you wouldn't mind publicizing That's it. Super-easy, honestly. Everyone currently getting @beartype would then get Equinox: You Are Now Good to GoOh – and @beartype now officially supports Equinox. 2024: dis goin' be gud. |
Thank you @leycec! Okay, looks like we're stuck with these two options. I'll have a think about what implementing these looks like. I'm not wild about either approach, they're both pretty magic... :D The beartype implementation in particular is a very useful reference. As for having the AST transformers / import hooks work together... hmm, I think I'd need to understand this better still. I think right now what we've got aren't really directly compatible with each other. We might need some more changes elsewhere before this is doable. To explain: jaxtyping doesn't really want you to use As of fairly recently, we now do some fairly evil things under the hood, corresponding to passing @jaxtyped(typechecker=beartype)
def foo(x: Ann) -> Ret:
... # do stuff
out = foo(bar) into @beartype
def foo_args(x: Ann):
pass
try:
foo_args(bar)
except Exception as e:
raise ValueError("A helpful error message for the arguments") from e
out = foo(bar)
@beartype
def foo_ret(x: Ann) -> Ret:
return out
try:
foo_ret(bar)
except Exception as e:
raise ValueError("A helpful error message for the return value") from e where The reason for this is so that we can finally give those nice error messages about shapes and dtypes and what-not when something goes wrong. What this means is that if jaxtyping is being used, it's basically assuming that it's "in charge": it gets to decide how to report error messages, not beartype! (Although you can scroll up a bit in the traceback to see the underlying error message that was caught and attached as This is actually pretty nice from the jaxtyping point-of-view. We don't have to worry about whether you're using beartype or typeguard or anything else: if the decorator raises an exception, we can use it. So, what does this mean at this point? Frankly, I'm not 100% sure. :D It's definitely not optimal for jaxtyping to have to reimplement the magic required for checking statement annotations. Perhaps we could factor out all the import hook business into a shared third library? Possibly that leads to fire and explosions. Equinox: hurrah! That's awesome. What an excellent Christmas gift. |
Woooooooah. Indeed, I see terrifying – yet ultimately justifiable – shenanigans that sadden me. In the absence of a standardized plugin system for runtime-static type-checkers generically supported by both @beartype and @beartype and My only wish for 2024 is to stop failing Oh. Oh. But I Just Realized That...The
@beartype is no different. We're just taking a longer and more circuitous route to get to the same place. The above list will definitely grow over time. Indeed, the mere existence of the above list suggests that nobody should call type-checking decorators in 2024. They're obsolete. They're insufficient. And they'll probably be deprecated by both @beartype and Oh, Boy. It Comes to This.So. I grok the So. @beartype's own import hooks can still detect and apply your All of this would be a whole lot easier if @beartype just hurried up already and provided a public API for Ultimately, @beartype, There's much to ponder. Yet, the will to code big ol' plugin architectures is weak. 😮💨 Crazy Idea Is CrazyOne crazy idea would be to just fold portions of
Consider it! Could be fun. Or... it could be a living Hell. 😈 😮 |
Wait. I'd still love to add you as a collaborator to @beartype, @patrick-kidger – but I've realized the painfully obvious while cross-country skiing the muddy and rock-strewn trails of backwoods Ontario. With great snow comes great enlightenment. @beartype 0.17.0 will support the plugin API that When @beartype 0.17.0 does this:
In short, |
Okay, lots of interesting ideas here! Settle in, I have a wall of text of my own. ScopeI think we all agree that the ideal future would be for jaxtyping only to provide: (a) the type hints; and leave all type-checking and error-reporting to beartype/typeguard. Future changes in jaxtypingIt's great that the custom error message reporting is now available in beartype! In that case, I think I see the same easy-peasy future for us as you do. I'll implement the @jaxtyped(typechecker=None)
@beartype
def foo(...) approach to things should just work. Import hooksAs a practical matter, I expect we should be able to arrange to add both decorators using two separate import hooks. In particular neither package is in the driver's seat. We're just both adding our own decorators via import hooks. Actually, I realise beartype will be doing something slightly different to just adding a decorator -- you'll be adding Why I like this(a) This means that we're not coupling beartype and jaxtyping together in any meaningful way. The only contact point is the (b) As a practical matter, jaxtyping+v2-of-typeguard is actually a very popular combination, and this approach means that we won't be risking breaking that either. (For the sake of such approaches, I'm afraid I'll still keep the perfidious evil of (c) It fixes up #92, as you note! (d) It makes it possible to use jaxtyping with PEP 526. The original purpose of this issue, lest we forget... :D Questions for you
|
Excellence! beartype/beartype@6b3aadfff7f9e4ef1ccde is the first step on this tumultuous voyage into the unknown. Your questions are, of course, apropos and almost certainly highlight deficiencies in my worldview. To wit:
...heh. Let's pretend it is. Actually, the hope is that this will eventually metastasize into an actual PEP standard. In the meanwhile, I hope that somebody who is not me will market this to agronholm himself at the Injecting the suspiciously @beartype-specific substring
(╯°□°)╯︵ ┻━┻
you're not wrong
you're not wrong
you're not wrong
you're not wrong Wait... I'm beginning to detect a deterministically repeatable pattern here. Allow me to now try but fail to explain:
Assuming # foo/__init__.py
with jaxtyping.install_import_hook("foo"), beartype.beartyping():
from . import bar
from . import baz This definitely should work, too: # foo/__init__.py
jaxtyping.jaxtype_this_package() # <-- dis iz h0t
beartype.beartype_this_package()
from . import bar
from . import baz
Yes! So much, "Yes!" I actually implemented a Long story short: "Nobody did nuffin'." 😮💨
Given that @beartype still fails to deeply type-check most standard container types like That said, does this actually intersect with Uhm... Err...Oh. Wait. I never actually implemented support for Previously, I'd assumed that I'm not even necessarily clear what "trace time" is, frankly. My wife and I are currently sloooowly migrating our data science pipeline from Ye Ol' Mostly Single-threaded NumPy and SciPy World to Speedy Gonzalez JAX World. I must confess that I am dumb, in short. |
|
Oh, fascinating. I just love me some JIT + DAG action. It looks like @beartype is in good paws here.
...heh. I was wondering when you'd catch that. This is actually intentional, because the current implementation of from contextlib import contextmanager
@contextmanager
def beartyping(conf: BeartypeConf = BeartypeConf()) -> Iterator[None]:
try:
beartype_all(conf=conf)
yield
finally:
undo_beartype_all() That's... it. You are thinking: But I Hate
|
Recall that jaxtyping will currently generate rich error messages in precisely one scenario: about the arguments and return types when doing: ```python @jaxtyped(typechecker=beartype) def foo(...): ... ``` With this commit we add support for beartype 0.17.0's pseudo-standard `__instancecheck_str__`, which means the following: 1. For those using beartype decorators, the following will *also* generate an informative error message, and moreover it will state exactly why (shape mismatch, dtype mismatch etc): ```python @jaxtyped(typechecker=None) @beartype def foo(...): ... ``` (In practice we probably won't recommend the above combination in the docs just to keep things simple.) 2. For those using the beartype import hook together with the jaxtyping import hook, we can probably also check `assert isinstance(x, Float[Array, "foo"])` statements with rich error messages. (#153) We'll need to test + document that though. (@jeezrick interested?) 3. For those using plain `assert isinstance(...)` statements without beartype (#167, tagging @reinerp), then they can *also* get rich error messages by doing ```python tt = Float[Array, "foo"] assert isinstance(x, tt), tt.__instancecheck_str__(x) + "\n" + print_bindings() ``` which is still a bit long-winded right now but is a step in the right direction. (CC @leycec for interest.)
Recall that jaxtyping will currently generate rich error messages in precisely one scenario: about the arguments and return types when doing: ```python @jaxtyped(typechecker=beartype) def foo(...): ... ``` With this commit we add support for beartype 0.17.0's pseudo-standard `__instancecheck_str__`, which means the following: 1. For those using beartype decorators, the following will *also* generate an informative error message, and moreover it will state exactly why (shape mismatch, dtype mismatch etc): ```python @jaxtyped(typechecker=None) @beartype def foo(...): ... ``` (In practice we probably won't recommend the above combination in the docs just to keep things simple.) 2. For those using the beartype import hook together with the jaxtyping import hook, we can probably also check `assert isinstance(x, Float[Array, "foo"])` statements with rich error messages. (#153) We'll need to test + document that though. (@jeezrick interested?) 3. For those using plain `assert isinstance(...)` statements without beartype (#167, tagging @reinerp), then they can *also* get rich error messages by doing ```python tt = Float[Array, "foo"] assert isinstance(x, tt), tt.__instancecheck_str__(x) + "\n" + print_bindings() ``` which is still a bit long-winded right now but is a step in the right direction. (CC @leycec for interest.)
Recall that jaxtyping will currently generate rich error messages in precisely one scenario: about the arguments and return types when doing: ```python @jaxtyped(typechecker=beartype) def foo(...): ... ``` With this commit we add support for beartype 0.17.0's pseudo-standard `__instancecheck_str__`, which means the following: 1. For those using beartype decorators, the following will *also* generate an informative error message, and moreover it will state exactly why (shape mismatch, dtype mismatch etc): ```python @jaxtyped(typechecker=None) @beartype def foo(...): ... ``` (In practice we probably won't recommend the above combination in the docs just to keep things simple.) 2. For those using the beartype import hook together with the jaxtyping import hook, we can probably also check `assert isinstance(x, Float[Array, "foo"])` statements with rich error messages. (#153) We'll need to test + document that though. (@jeezrick interested?) 3. For those using plain `assert isinstance(...)` statements without beartype (#167, tagging @reinerp), then they can *also* get rich error messages by doing ```python tt = Float[Array, "foo"] assert isinstance(x, tt), tt.__instancecheck_str__(x) + "\n" + print_bindings() ``` which is still a bit long-winded right now but is a step in the right direction. (CC @leycec for interest.)
Recall that jaxtyping will currently generate rich error messages in precisely one scenario: about the arguments and return types when doing: ```python @jaxtyped(typechecker=beartype) def foo(...): ... ``` With this commit we add support for beartype 0.17.0's pseudo-standard `__instancecheck_str__`, which means the following: 1. For those using beartype decorators, the following will *also* generate an informative error message, and moreover it will state exactly why (shape mismatch, dtype mismatch etc): ```python @jaxtyped(typechecker=None) @beartype def foo(...): ... ``` (In practice we probably won't recommend the above combination in the docs just to keep things simple.) 2. For those using the beartype import hook together with the jaxtyping import hook, we can probably also check `assert isinstance(x, Float[Array, "foo"])` statements with rich error messages. (#153) We'll need to test + document that though. (@jeezrick interested?) 3. For those using plain `assert isinstance(...)` statements without beartype (#167, tagging @reinerp), then they can *also* get rich error messages by doing ```python tt = Float[Array, "foo"] assert isinstance(x, tt), tt.__instancecheck_str__(x) + "\n" + print_bindings() ``` which is still a bit long-winded right now but is a step in the right direction. (CC @leycec for interest.)
Recall that jaxtyping will currently generate rich error messages in precisely one scenario: about the arguments and return types when doing: ```python @jaxtyped(typechecker=beartype) def foo(...): ... ``` With this commit we add support for beartype 0.17.0's pseudo-standard `__instancecheck_str__`, which means the following: 1. For those using beartype decorators, the following will *also* generate an informative error message, and moreover it will state exactly why (shape mismatch, dtype mismatch etc): ```python @jaxtyped(typechecker=None) @beartype def foo(...): ... ``` (In practice we probably won't recommend the above combination in the docs just to keep things simple.) 2. For those using the beartype import hook together with the jaxtyping import hook, we can probably also check `assert isinstance(x, Float[Array, "foo"])` statements with rich error messages. (#153) We'll need to test + document that though. (@jeezrick interested?) 3. For those using plain `assert isinstance(...)` statements without beartype (#167, tagging @reinerp), then they can *also* get rich error messages by doing ```python tt = Float[Array, "foo"] assert isinstance(x, tt), tt.__instancecheck_str__(x) + "\n" + print_bindings() ``` which is still a bit long-winded right now but is a step in the right direction. (CC @leycec for interest.)
Hi! https://gist.github.com/avolchek/6ac328c435e2bd584c2722ce058de824 |
@avolchek: That's hilarious, horrifying, and hellish: The Three Holy H's. That's also guaranteed to break, because:
Seriously, though. That's gonna break. In fact, that probably already broke. I'm on the cusp of releasing @beartype 0.18.0. GitHub only knows where I have now excoriated you, @avolchek. Now allow me to praise you! Yes, praise! Because that's actually an ingenious reverse engineering of two excruciatingly non-trivial codebases. Thanks to your solemn sacrifice and soon-to-be-broken codebase, I can happily accept that integrating @beartype +
import jaxtyping
def __beartype_wrapper(...):
jaxtyping._storage.push_shape_memo({})
... # <-- *BEARTYPE MAGIC HAPPENS HERE*
__beartype_pith_0 = ... # <-- *MORE MAGICAL UNICORNS ERUPT*
jaxtyping._storage.pop_shape_memo()
return __beartype_pith_0 Clearly, that's trivial. Clearly, @patrick-kidger is also grinding his teeth into stubs. Don't worry! I won't do anything without your consent. For one, I'm lazy. For another, I'd have to violate privacy encapsulation. For a final one, there's no guarantee any of this dark magic will continue to behave itself in perpetuity without everyone's explicit consent and continual agreement. How do you feel about this sort of horror, @patrick-kidger? I know you. You're like me – only stronger, fitter, and more likely to survive the collapse of Canada's rickety maple syrup market. You really want to remain "in the driver's seat." But @beartype can probably solve all your problems with just a trivial amount of integration, glue, sputum, white-knuckle grips on the keyboard, and eyes-wide-shut commits guaranteed to blow up. Should @beartype make overtures towards doing something like this or should I just back away slowly from the keyboard before anybody gets hurt and feels bad? @patrick-kidger: Unrelatedly, Google now claims that the title for the |
For sure :)
I'm not sure that you need to add def foo():
kek: Float32[torch.Tensor, "a b c"] = ...
def bar(x: Float32[torch.Tensor, f"b w h c"]):
...
foo()
...
|
Haha, @avolchek, I'm impressed! Although, with the latest jaxtyping and beartype releases, then in theory this should "just work", just by applying both of our import hooks. Indeed I would like to update the documentation to reflect this, however... it seems like this doesn't work! # entry_point.py
from beartype.claw import beartype_package
from jaxtyping import install_import_hook
beartype_package("foo")
with install_import_hook("foo", typechecker=None):
import foo
# foo.py
import jax
from jaxtyping import Array, Float
def foo():
x: Float[Array, "size"] = jax.numpy.ones(3)
y: Float[Array, "size"] = jax.numpy.ones(4) # this is an error!
foo() # but this does not raise an error! I've tried both orders for installing the import hooks. I'm not sure what's going on here right now. (Notably, however, something like I think if we're to direct our efforts, it should be to be sure that this approach is usable! This would allow both packages to do what they do best, without needing to interface with each other at all. |
Hmmm, I think this is a Google bug. Searching for any of my GitHub repositories seems to grab something random off the front page of that site. I have no idea what's going on :D |
Yeah, I tried it too and it didn't work. As far as I remember jaxtyping hook actually doesn't do anything if beartype hook is used at the same time. So it doesn't add |
As far as I remember it's also because of lack of calls to memo-manipulation routines. You need to call 'push' at least once for shape variable checks to work. |
Great project, helps me understand DL code a lot.
I used it like this:
but turns out, it doesn't do runtime type check on this
x: Mlp_mid = self.act(self.lin1(x))
line. And this makes me feel insecure.So, my question is, will this feature be added in the future? Or is it in confilct with some design intention?
BTW, I mainly use it when I am trying to understand others code. But can I include it in production? How much does it slow down the training and inference?
The text was updated successfully, but these errors were encountered: