-
-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Equinox Module methods not recognized by beartype #584
Comments
The final note seems to break a bunch of stuff in It doesn't actually work -- a function's import equinox as eqx
class MyModule(eqx.Module):
a: int
def f(self, b):
return self.a + b
x = MyModule(1)
print(x.f)
f = MyModule.__dict__['f']
print(f.__get__(x, MyModule)) The code outputs:
|
Hmm, interesting! class Foo(eqx,Module):
@beartype
def bar(self): ... I'm aware that this isn't what beartype recommends, but in reality this handles 99% of all cases. (And decorating the method like this is actually what I always do as well.) Tagging @leycec, as this has got me thinking. I'd like to propose that beartype recommend decorating methods instead. On the two points raised in the docs above:
The benefits would be that:
WDYT? |
Hah! I'll show you, @EtaoinWu. Friendship is magic. @beartype is definitely prepared to support Big Boss @patrick-kidger and all of his associated JAX madness in whatever way he needs – including rolling out internal support for Equinox-specific wrappers. After all, @beartype already internally supports non-standard third-party NumPy and Pandera type hints. Equinox is just yet another non-standard third-party thing that @beartype would support. Ideally, of course, @beartype would provide some sort of plugin API for this sort of thing. But... ain't nobody got that kind of time. I'll just hack it instead. 😄
In response, @beartype would like to propose this Berserk meme.
Hah! @beartype showed you all. Indeed, @beartype type-checks Moreover, doing so enables @beartype to additionally type-check @beartype
@dataclass
class BadExampleIsBad(object):
myself: Self # <-- Yup. @beartype can type-check this.
@static
def make_myself() -> Self: # <-- Yup. @beartype can type-check this, too.
return BadExampleIsBad()
@classmethod
def ignore_myself(cls: type[Self]) -> Self: # <-- Still fine.
return make_myself() But there are actually many, many reasons apart from PEP 673 (i.e.,
Relevant commentary in the # If the decorated callable is nested (rather than global) and thus
# *MAY* have a non-empty local nested scope...
if bear_call.func_wrappee_is_nested:
# Attempt to...
try:
# Local scope of the decorated callable, localized to improve
# readability and negligible efficiency when accessed below.
func_locals = get_func_locals(
func=func,
# Ignore all lexical scopes in the fully-qualified name of
# the decorated callable corresponding to parent classes
# lexically nesting the current decorated class containing
# that callable (including that class). Why? Because these
# classes are *ALL* currently being decorated and thus have
# yet to be encapsulated by new stack frames on the call
# stack. If these lexical scopes are *NOT* ignored, this
# call to get_func_locals() will fail to find the parent
# lexical scope of the decorated callable and then raise an
# unexpected exception.
#
# Consider, for example, this nested class decoration of a
# fully-qualified "muh_package.Outer" class:
# @beartype
# class Outer(object):
# class Middle(object):
# class Inner(object):
# def muh_method(self) -> str:
# return 'Painful API is painful.'
#
# When @beartype finally recurses into decorating the nested
# muh_package.Outer.Middle.Inner.muh_method() method, this
# call to get_func_locals() if *NOT* passed this parameter
# would naively assume that the parent lexical scope of the
# current muh_method() method on the call stack is named
# "Inner". Instead, the parent lexical scope of that method
# on the call stack is named "muh_package" -- the first
# lexical scope enclosing that method that exists on the
# call stack. The non-existent "Outer", "Middle", and
# "Inner" lexical scopes must *ALL* be silently ignored.
func_scope_names_ignore=(
0 if cls_stack is None else len(cls_stack)),
#FIXME: Consider dynamically calculating exactly how many
#additional @beartype-specific frames are ignorable on the first
#call to this function, caching that number, and then reusing
#that cached number on all subsequent calls to this function.
#The current approach employed below of naively hard-coding a
#number of frames to ignore was incredibly fragile and had to be
#effectively disabled, which hampers runtime efficiency.
# Ignore additional frames on the call stack embodying:
# * The current call to this function.
#
# Note that, for safety, we currently avoid ignoring
# additional frames that we could technically ignore. These
# include:
# * The call to the parent
# beartype._check.checkcall.BeartypeCall.reinit() method.
# * The call to the parent @beartype.beartype() decorator.
#
# Why? Because the @beartype codebase has been sufficiently
# refactored so as to render any such attempts non-trivial,
# fragile, and frankly dangerous.
func_stack_frames_ignore=1,
exception_cls=exception_cls,
)
# If this local scope cannot be found (i.e., if this getter found
# the lexical scope of the module declaring the decorated callable
# *before* that of the parent callable or class declaring that
# callable), then this resolve_hint() function was called *AFTER*
# rather than *DURING* the declaration of the decorated callable.
# This implies that that callable is not, in fact, currently being
# decorated. Instead, that callable was *NEVER* decorated by
# @beartype but has instead subsequently been passed to this
# resolve_hint() function after its initial declaration -- typically
# due to an external caller passing that callable to our public
# beartype.peps.resolve_pep563() function.
#
# In this case, the call stack frame providing this local scope has
# (almost certainly) already been deleted and is no longer
# accessible. We have no recourse but to default this local scope to
# the empty dictionary -- which might be subsequently modified and
# *CANNOT* thus default to the singleton empty dictionary
# "DICT_EMPTY" (unlike below).
except _BeartypeUtilCallableScopeNotFoundException:
func_locals = {}
# If the decorated callable is a method transitively defined by a
# root decorated class, add a pair of local attributes exposing:
#
# * The unqualified basename of the root decorated class. Why?
# Because this class may be recursively referenced in postponed
# type hints and *MUST* thus be exposed to *ALL* postponed type
# hints. However, this class is currently being decorated and thus
# has yet to be defined in either:
# * If this class is module-scoped, the global attribute
# dictionary of that module and thus the "func_globals"
# dictionary.
# * If this class is closure-scoped, the local attribute
# dictionary of that closure and thus the "func_locals"
# dictionary.
# * The unqualified basename of the current decorated class. Why?
# For similar reasons. Since the current decorated class may be
# lexically nested in the root decorated class, the current
# decorated class is *NOT* already accessible as either a global
# or local. Exposing the current decorated class to a stringified
# type hint referencing that class thus requires adding a local
# attribute exposing that class.
#
# Note that:
# * *ALL* intermediary classes (i.e., excluding the root decorated
# class) lexically nesting the current decorated class are
# irrelevant. Intermediary classes are neither module-scoped nor
# closure-scoped and thus inaccessible as either globals or locals
# in the nested lexical scope of the current decorated class:
# e.g.,
# # This raises a parser error and is thus *NOT* fine:
# # NameError: name 'muh_type' is not defined
# class Outer(object):
# class Middle(object):
# muh_type = str
#
# class Inner(object):
# def muh_method(self) -> muh_type:
# return 'Dumpster fires are all I see.'
# * This implicitly overrides any previously declared locals of the
# same name. Although non-ideal, this constitutes syntactically
# valid Python and is thus *NOT* worth emitting even a non-fatal
# warning over: e.g.,
# # This is fine... technically.
# from beartype import beartype
# def muh_closure() -> None:
# MuhClass = 'This is horrible, yet fine.'
#
# @beartype
# class MuhClass(object):
# def muh_method(self) -> str:
# return 'Look away and cringe, everyone!'
if cls_stack:
# Root and current decorated classes.
cls_root = cls_stack[0]
cls_curr = cls_stack[-1] The Emoji cat cries for Equinox! 😿
Gah! Curse you, @dataclass
class C:
# This is a lie. "mylist" is an instance of "dataclasses.field" rather than "list".
mylist: list[int] = field(default_factory=list) Since
I like the way you think, Dr. Kidger. Yes! Let's do that! Let's do the class-level sneakery thing. Just let me know somewhere how you'd like @beartype to eventually:
We'll make this despicable magic happen yet, boys. 💪 🐻 |
Ech, that sounds complicated! Okay, I think we can make this work with just some small tweaks. IIUC, beartype is morally doing something like this: for key, value in cls.__dict__.items():
if inspect.isfunction(value):
setattr(cls, key, beartype(value)) I think it should be enough to change things to: for key in cls.__dict__.keys():
value = getattr(cls, key) # call __get__ to get an actual function, not a function-wrapper
if inspect.isfunction(value):
setattr(cls, key, beartype(value)) ...that is, once I've merged #587, which adds support for such monkey-patching, by adding a |
The more you know, the more you know you don't wanna know. @beartype: it's like quantum mechanics that way.
🤣
This is both clever and obscene. I feel approval. I also feel trepidation. @beartype does already handle C-based builtin method descriptors wrapping pure-Python unbound methods. This includes I kinda intuit that your proposed resolution will satisfy the specific use case of Equinox while yet failing the general use case of Python's standard descriptors. Do I really know what I am talking about? The answer may shock you. But probably it won't. |
Hmm. Good point. Perhaps: for key, value in cls.__dict__.items():
if not is_classmethod_or_whatever(value):
value = getattr(cls, key)
if inspect.isfunction(value):
setattr(cls, key, beartype(value)) ? That's obviously kind of a hack that happens to work for the builtins and happens to work for Equinox. But I think the above should hit not just |
...heh. Brilliant minds GitHub alike. Coincidentally, that's exactly what I concocted in my bald head while hiking through the frigid wastes of Canada this morning:
The really nice thing about the Equinox-specific thing is that it should also generalize to arbitrary other third-party packages resembling Equinox. @beartype will then "just work" out-of-the-box without @beartype ...pretty sure this means me needing to explicitly support
This is win. Thankfully, you were even kind enough to promptly merge #587. Therefore, this is @beartype's roadmap to payback:
Thanks for being so supportive, Dr. Kidger. The bear will howl on the equinox! 🐻 🌚 |
Possibly resolved by beartype/beartype@58219ba02be8a. In theory, this now works. In practice, nothing is tested. More importantly... OMFG!!!!!! This was so shockingly insane. It turns out the brute-force approach outlined above fails to suffice. Reality is a harsh mistress and so is the moon. Doing this uncarefully induces INFINITE FRIGGIN' RECURSION on standard types, including:
Needless to say, I remain both shocked and appalled that @beartype has been uselessly attempting to decorate the For additional protection against madness, I've also prohibited dunder attributes like Nonetheless, we're shipping this. If @beartype 0.17.0 blows up PyTorch yet again, I can only hang my head and blame @patrick-kidger. |
This commit is the first in a commit chain officially adding support for @patrick-kidger (Patrick Kidger)'s third-party Equinox JAX-driven ML framework, en-route to resolving issue patrick-kidger/equinox#584 kindly submitted by friendly magical ML unicorn @EtaoinWu (Yue Wu). Specifically, this commit *very* carefully crafts general-purpose support for dynamically unwrapping non-standard function wrappers implemented by third-party packages -- including Equinox. Naturally, nothing is tested; everything is suspect. Trust no one, @beartype! (*Some piano etude on a window pane is no winsome pain!*)
:D
Maybe go the other way, prohibit all magic methods except those on an explicit allow-list? |
This commit is the next in a commit chain officially adding support for @patrick-kidger (Patrick Kidger)'s third-party Equinox JAX-driven ML framework, en-route to resolving issue patrick-kidger/equinox#584 kindly submitted by friendly magical ML unicorn @EtaoinWu (Yue Wu). This commit dramatically improves the general-purpose support for dynamically unwrapping non-standard function wrappers implemented by third-party packages – including Equinox – implemented by the prior commit. Specifically, this commit: * Dramatically streamlines this support with robust protection against unexpected recursion. Previously, @beartype recursively decorated class variables whose values are types (e.g., `class_var = type`). Although doing so was typically safe, common edge cases like user-defined `enum.Enum` subclasses exposed substantial weaknesses in this assumption. "Typically safe" is just another way of saying "Guaranteed to explode all over your lap, bro." * Exhaustively unit tests this support. Although Equinox itself has yet to be tested, @beartype itself now tests that it: * Implicitly decorates nested classes as expected. * No longer decorates class variables whose values are types. (*Telegraph the paragraph: bolo gun or parabola lobbed by nuns!?*)
😇 Rejoice! Microsoft just hired Sam Altman, restoring balance to the ML Force... wait. What were we talking about again? Is this Reddit? Where even am I? Oh, right. GitHub. It's still happening. So. Gentlemen and scuba divers alike, I have solved all the recursion complaints. @beartype is now robust yet again against pernicious edge cases that are too shameful for me to publicly exhibit here. Although I have yet to explicitly test against Equinox, everything should now work as expected. If anyone with more free time than is healthy would like to disprove these lies I am telling you, please test this for me: pip install git+https://github.com/beartype/beartype.git@9faf1ecfc0fc3f26ab0de9eab710354476990cb4 To further harden this feature against "Surprise. It's Johnnnny!", I'll also be adding Equinox-specific unit tests to @beartype over the next several days. We inch closer to officially solving everything. 🥳 |
Santa Bear Claws has come to town. In other words, this long-standing issue was resolved by beartype/beartype@9faf1ecfc0fc3 a month and a half ago. This is the first spare moment I've had to genuinely test that commit against Equinox. After all, Xenoblade Chronicles: Definitive Edition doesn't play itself. Thankfully, all is now full of worky. Behold! 🪄 import equinox as eqx
from beartype import beartype
from jax import numpy as jnp
from jaxtyping import (
Array,
Float,
)
@beartype
class MyClass(eqx.Module):
x: Float[Array, ""]
def fn(self, y: bool) -> Float[Array, ""]:
return self.x + 1
MyClass(jnp.array(1.)).fn('not bool') ...which now raises the expected type-checking violation: Traceback (most recent call last):
File "/home/leycec/tmp/mopy.py", line 18, in <module>
MyClass(jnp.array(1.)).fn('not bool')
File "/home/leycec/py/conda/envs/ionyou_dev/lib/python3.11/site-packages/equinox/_module.py", line 875, in __call__
return self.__func__(self.__self__, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<@beartype(__main__.MyClass.fn) at 0x7fa9bcdfb420>", line 22, in fn
beartype.roar.BeartypeCallHintParamViolation: Method __main__.MyClass.fn() parameter y='not bool' violates
type hint <class 'bool'>, as str 'not bool' not instance of bool. As the traceback suggests, @beartype now defers to Equinox's expert opinion on the matter. For everyone's safety, I've added an integration test to @beartype's test suite exercising this against regressions. This is the way the AI was won. @patrick-kidger: Thanks again for all the 💘 from the Google Open Source Programs Office. This issue can now be safely closed. Happy New Year from the frigid wastelands of |
This commit is the last in a commit chain officially adding support for @patrick-kidger (Patrick Kidger)'s third-party Equinox JAX-driven ML framework, resolving issue patrick-kidger/equinox#584 kindly submitted by friendly magical ML unicorn @EtaoinWu (Yue Wu). This commit dramatically improves the general-purpose support for dynamically unwrapping non-standard function wrappers implemented by third-party packages – including Equinox. Specifically, this commit exhaustively tests this support with an integration test shamelessly lifted from @EtaoinWu's initial exhibition of this issue. Praise be to the Unicorn Lord in 2024. (*Simply pimply!*)
In the following code:
You can actually call
MyClass(jnp.array(1.)).fn('not bool')
without getting a roar from the bear.The reason is that beartype, when decorating a class, iterates through each attribute of its
__dict__
. In our case,MyClass.__dict__['fn']
(different fromMyClass.fn
!) is aequinox._module._wrap_method
, and is not beartype-able.Current workaround
I use a dirty hack to add
beartype
toeqx.Module
s.Potential fix
I cannot think of a perfect way to fix this. Here are some thoughts.
beartype
would bother to add support for Equinox._wrap_method
to lie about its__class__
to foolisinstance
checks, but we still need to pretend as a function (with.__code__
!) to foolbeartype
._wrap_method
could return an actual function:Not sure if this would break other parts of Equinox.
The text was updated successfully, but these errors were encountered: