Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotate pyro.nn.module #3337

Merged
merged 4 commits into from
Mar 17, 2024
Merged

Type annotate pyro.nn.module #3337

merged 4 commits into from
Mar 17, 2024

Conversation

ordabayevy
Copy link
Member

No description provided.

@ordabayevy ordabayevy mentioned this pull request Mar 9, 2024
23 tasks
@ordabayevy ordabayevy requested a review from fritzo March 9, 2024 19:24
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! This must have been a challenging puzzle; that PyroModule metaclass stuff is quite intricate.

Just a couple minor comments.



class PyroSample(namedtuple("PyroSample", ("prior",))):
@dataclass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use @dataclass(slots=True, frozen=True) here, to remain closer to namedtuple? Of course we'll then need to use object._setattribute__(self, "foo", bar) rather than self.foo = bar in __post_init__.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slots seems to be new in Python 3.10

pyro/nn/module.py Outdated Show resolved Hide resolved
def __setattr__(
self,
name: str,
value: Union[torch.Tensor, torch.nn.Module, PyroParam, PyroSample],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value should have type Any, since we fall back to super.__setattr__(name, value) at the last line of this method.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Is that why mypy thinks torch.nn.Modules attributes are either Tensor or Module because they annotated them as such 🤔 ? That always confused me.

Comment on lines -66 to -68
[mypy-pyro.nn.*]
ignore_errors = True
warn_unused_ignores = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

@fritzo fritzo merged commit 01e340e into dev Mar 17, 2024
9 checks passed
@ordabayevy ordabayevy deleted the type-nn branch March 17, 2024 16:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants