Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add reproducibility_tolerance field in weight descriptions #659

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions bioimageio/spec/model/v0_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1883,6 +1883,21 @@ def _convert(
)


class Tolerance(Node):
relative: Annotated[float, Interval(gt=0, le=0.01)] = 1e-4
"""Relative tolerance when comparing a regenerated model output (including pre- and
postprocessing) to its corresponding known output test tensor. See **rtol** argument to
[numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html#numpy.testing.assert_allclose).
Specify a single floating point value for all tensors or individual values for
(a subset of) individual output tensors."""

absolute: Annotated[float, Gt(0)] = 1.5e-4
"""Absolute tolerance when comparing a regenerated model output (including pre- and
postprocessing) to its corresponding known output test tensor. See **atol** argument to
[numpy.testing.assert_allclose](https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html#numpy.testing.assert_allclose).
"""


class WeightsEntryDescrBase(FileDescr):
type: ClassVar[WeightsFormat]
weights_format_name: ClassVar[str] # human readable
Expand All @@ -1907,6 +1922,14 @@ class WeightsEntryDescrBase(FileDescr):
All weight entries except one (the initial set of weights resulting from training the model),
need to have this field."""

reproducibility_tolerance: Union[Tolerance, Mapping[TensorId, Tolerance]] = Field(
default_factory=Tolerance
)
"""Even with seeding and selecting deterministic algorithms, model output may vary
across DL framework versions, OS and hardware. Models suseptable to these
variations may specify `reproducibility_tolerance` to compare regenerated outputs
less strictly to the known output test tensors."""

@model_validator(mode="after")
def check_parent_is_not_self(self) -> Self:
if self.type == self.parent:
Expand Down
Loading