-
-
Notifications
You must be signed in to change notification settings - Fork 983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type annotate pyro.nn.module
#3337
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! This must have been a challenging puzzle; that PyroModule
metaclass stuff is quite intricate.
Just a couple minor comments.
pyro/nn/module.py
Outdated
|
||
|
||
class PyroSample(namedtuple("PyroSample", ("prior",))): | ||
@dataclass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use @dataclass(slots=True, frozen=True)
here, to remain closer to namedtuple? Of course we'll then need to use object._setattribute__(self, "foo", bar)
rather than self.foo = bar
in __post_init__
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
slots
seems to be new in Python 3.10
pyro/nn/module.py
Outdated
def __setattr__( | ||
self, | ||
name: str, | ||
value: Union[torch.Tensor, torch.nn.Module, PyroParam, PyroSample], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value
should have type Any
, since we fall back to super.__setattr__(name, value)
at the last line of this method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. Is that why mypy thinks torch.nn.Module
s attributes are either Tensor
or Module
because they annotated them as such 🤔 ? That always confused me.
[mypy-pyro.nn.*] | ||
ignore_errors = True | ||
warn_unused_ignores = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎉
No description provided.