Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add input and output validation #104

Merged
merged 34 commits into from
Dec 20, 2024
Merged

Add input and output validation #104

merged 34 commits into from
Dec 20, 2024

Conversation

pluflou
Copy link
Collaborator

@pluflou pluflou commented Dec 18, 2024

  • Extends @t-bz 's work in Add validation #102
  • Makes default_value required for input_variables in LUMEBaseModel
  • Adds precision attribute to TorchModel class
  • Removes value_range_tolerance attribute from Variable class (might be a temporary change until more discussion happens)
  • Extends input validation to thoroughly validate input dictionary, type casting, and values in torch tensors within dictionary, before any other method is called
  • Re-introduces is_constant attribute to Variable class
  • Adds strictness config setting for input/output validation
  • Updates example notebooks
  • Updates tests

To do:

  • Finalize documentation update
  • Add more unit and integration tests to cover any new code

@pluflou pluflou requested a review from roussel-ryan December 18, 2024 06:32
Copy link
Collaborator

@roussel-ryan roussel-ryan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks good to me. I have one question on top of the comment I left, is there some testing for evaluating models on non-scalar tensors for the input variables?
For example, in `torch_model.ipynb' does the following work:

input_dict = torch_model.random_input(n_samples=5)
torch_model.evaluate(input_dict)

@@ -292,6 +277,14 @@ def unique_variable_names(cls, value):
verify_unique_variable_names(value)
return value

@field_validator("input_variables")
def verify_input_default_value(cls, value):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is what differentiates input variables from output variables correct? if so I want to go back on what we discussed earlier and maybe have different input/output class types such that validation errors happen during variable definition instead of model validation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's correct. We either set the attribute as optional and do it this way, or make it required and set the output variable default as a nan in the base class, but the latter was a little messy. We can revisit having separate child classes.

@pluflou
Copy link
Collaborator Author

pluflou commented Dec 18, 2024

For example, in `torch_model.ipynb' does the following work:

input_dict = torch_model.random_input(n_samples=5)
torch_model.evaluate(input_dict)

Yes the above code will work. It supports non-scalar tensors but not non-scalars of other types (e.g. a list of scalars).

Copy link
Collaborator

@roussel-ryan roussel-ryan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except for one comment

def evaluate(
def _set_precision(self, value: torch.dtype):
"""Sets the precision of the model."""
torch.set_default_dtype(value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is probably an overreach, ie. it extends beyond this class/object. Do we need to set the default type for torch here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, that is not needed since we set the model and transformers. I removed it.

@roussel-ryan
Copy link
Collaborator

LGTM, you can go ahead and merge

@pluflou pluflou merged commit dfde898 into slaclab:main Dec 20, 2024
4 checks passed
@pluflou pluflou deleted the validation branch December 20, 2024 18:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants