diff --git a/tests/models/test_torch_model.py b/tests/models/test_torch_model.py index 0d4b4b5..b932f9e 100644 --- a/tests/models/test_torch_model.py +++ b/tests/models/test_torch_model.py @@ -81,6 +81,8 @@ def test_precision(self, california_model): assert california_model.dtype == torch.double california_model.precision = "single" assert california_model.dtype == torch.float + # set back to double + california_model.precision = "double" def test_model_evaluate_single_sample(self, california_test_input_dict: dict, california_model): results = california_model.evaluate(california_test_input_dict)