-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add input and output validation #104
Conversation
…icitly addressed in model_dump()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly looks good to me. I have one question on top of the comment I left, is there some testing for evaluating models on non-scalar tensors for the input variables?
For example, in `torch_model.ipynb' does the following work:
input_dict = torch_model.random_input(n_samples=5)
torch_model.evaluate(input_dict)
@@ -292,6 +277,14 @@ def unique_variable_names(cls, value): | |||
verify_unique_variable_names(value) | |||
return value | |||
|
|||
@field_validator("input_variables") | |||
def verify_input_default_value(cls, value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is what differentiates input variables from output variables correct? if so I want to go back on what we discussed earlier and maybe have different input/output class types such that validation errors happen during variable definition instead of model validation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's correct. We either set the attribute as optional and do it this way, or make it required and set the output variable default as a nan in the base class, but the latter was a little messy. We can revisit having separate child classes.
Yes the above code will work. It supports non-scalar tensors but not non-scalars of other types (e.g. a list of scalars). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM except for one comment
lume_model/models/torch_model.py
Outdated
def evaluate( | ||
def _set_precision(self, value: torch.dtype): | ||
"""Sets the precision of the model.""" | ||
torch.set_default_dtype(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is probably an overreach, ie. it extends beyond this class/object. Do we need to set the default type for torch here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, that is not needed since we set the model and transformers. I removed it.
LGTM, you can go ahead and merge |
To do: