Skip to content

Commit

Permalink
test validation of inputs at eval
Browse files Browse the repository at this point in the history
  • Loading branch information
pluflou committed Nov 21, 2024
1 parent ab9c2d0 commit ea75f92
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
13 changes: 13 additions & 0 deletions lume_model/models/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
InputVariable,
OutputVariable,
ScalarInputVariable,
TorchTensor
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -239,17 +240,29 @@ def _format_inputs(
# NOTE: The input variable is only updated if a singular value is given (ambiguous otherwise)
formatted_inputs = {}
for var_name, var in input_dict.items():

if isinstance(var, InputVariable):
print("validating")
var.__pydantic_validator__.validate_assignment(var, 'value', var.value)
formatted_inputs[var_name] = torch.tensor(var.value, **self._tkwargs)
# self.input_variables[self.input_names.index(var_name)].value = var.value
elif isinstance(var, float):
print("validating")
var.__pydantic_validator__.validate_assignment(var, 'value', var.value)
formatted_inputs[var_name] = torch.tensor(var, **self._tkwargs)
# self.input_variables[self.input_names.index(var_name)].value = var
elif isinstance(var, torch.Tensor):
# print("validating")
# var.__pydantic_validator__.validate_assignment(var, 'value', var.value)
var = var.double().squeeze().to(self.device)
formatted_inputs[var_name] = var
# if var.dim() == 0:
# self.input_variables[self.input_names.index(var_name)].value = var.item()
elif isinstance(var, TorchTensor):
print("validating")
var.__pydantic_validator__.validate_assignment(var, 'value', var.value)
var = var.value.double().squeeze().to(self.device)
formatted_inputs[var_name] = var
else:
TypeError(
f"Unknown type {type(var)} passed to evaluate."
Expand Down
18 changes: 17 additions & 1 deletion lume_model/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
"""
import logging
from typing import Optional, Generic, TypeVar
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
import torch

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,6 +100,21 @@ class ScalarOutputVariable(OutputVariable[float], ScalarVariable):
pass


class TorchTensor(BaseModel):
value: torch.Tensor = Field(...)

@field_validator("value", mode="before")
def validate_tensor(cls, value):
if not isinstance(value, torch.Tensor):
raise ValueError("The field must be a torch.Tensor")
if not isinstance(value.item(), float):
raise ValueError("The torch.Tensor items must be floats")
return value

class Config:
arbitrary_types_allowed = True


# class NumpyNDArray(np.ndarray):
# """
# Custom type validator for numpy ndarray.
Expand Down

0 comments on commit ea75f92

Please sign in to comment.