From e6c7b6ae0c0b95cdfaee4b9bf4b6f449d6f0400e Mon Sep 17 00:00:00 2001 From: "Sara A. Miskovich" Date: Wed, 18 Dec 2024 13:15:47 -0800 Subject: [PATCH] simplify validation config --- examples/torch_model_extended.ipynb | 144 ++++++++++++++++------------ lume_model/base.py | 16 ++-- lume_model/variables.py | 28 ++++-- 3 files changed, 108 insertions(+), 80 deletions(-) diff --git a/examples/torch_model_extended.ipynb b/examples/torch_model_extended.ipynb index 27123ee..92e4ce8 100644 --- a/examples/torch_model_extended.ipynb +++ b/examples/torch_model_extended.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "56725817-2b21-4bea-98b0-151dea959f77", "metadata": {}, "outputs": [], @@ -35,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "f96d9863-269c-49d8-9671-cc73a783bcbc", "metadata": {}, "outputs": [], @@ -86,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "1af1f09a-f7f1-4eb3-9c26-3a8eba56a4db", "metadata": {}, "outputs": [ @@ -98,7 +98,7 @@ "), input_transformers=[AffineInputTransform(), AffineInputTransform()], output_transformers=[AffineInputTransform(), AffineInputTransform()], output_format='tensor', device='cpu', fixed_model=True, precision='double')" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -117,7 +117,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "9f975b19-ef04-4283-b2a8-2d0f499950b0", "metadata": {}, "outputs": [ @@ -183,7 +183,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "1f959ec3-438d-4a7f-9639-8d0caf09578a", "metadata": {}, "outputs": [ @@ -200,7 +200,7 @@ "{'output': tensor(6.2820)}" ] }, - "execution_count": 6, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -222,7 +222,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "26ce23c7-fdf1-42d9-bfe4-73832149f14b", "metadata": {}, "outputs": [ @@ -240,7 +240,7 @@ ")" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -252,7 +252,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "id": "39e18d1c-0836-43fa-933c-17a2f0aaaa30", "metadata": {}, "outputs": [ @@ -262,7 +262,7 @@ "tensor(6.2820)" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -281,7 +281,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "95a5c89e-41b5-44f3-b6e5-6085b04d9906", "metadata": {}, "outputs": [ @@ -292,7 +292,7 @@ " ('base_model.linear.bias', tensor([0.5000]))])" ] }, - "execution_count": 9, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -303,7 +303,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "id": "8fe440fd-a402-4444-8b0d-5839e67477e7", "metadata": {}, "outputs": [ @@ -322,7 +322,7 @@ " ('output_transformers_1._offset', tensor([1.]))])" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -334,7 +334,9 @@ { "cell_type": "markdown", "id": "b0e8db76-6a89-4b18-8f3e-3d7b60782957", - "metadata": {}, + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, "source": [ "# Saving and loading" ] @@ -659,7 +661,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 10, "id": "17032ff0-7a0a-42a8-a9e7-bb7902e9a4de", "metadata": {}, "outputs": [ @@ -669,7 +671,7 @@ "'double'" ] }, - "execution_count": 15, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -680,7 +682,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 11, "id": "f2fa804a-051c-4a3b-a18e-37197a831f57", "metadata": {}, "outputs": [ @@ -719,7 +721,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 12, "id": "722d7b5c-4ad7-427d-a473-86cae6170bea", "metadata": {}, "outputs": [ @@ -730,7 +732,7 @@ " ScalarVariable(name='input2', default_value=0.2, value_range=(0.0, 1.0), is_constant=False, unit=None)]" ] }, - "execution_count": 17, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -741,7 +743,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 13, "id": "4cef4328-cd1b-41bc-bd92-e3157ca6084f", "metadata": {}, "outputs": [ @@ -751,7 +753,7 @@ "[ScalarVariable(name='output', default_value=None, value_range=(-inf, inf), is_constant=False, unit=None)]" ] }, - "execution_count": 18, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -762,7 +764,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 14, "id": "73af0569-a614-47ee-be5f-910182307f0e", "metadata": {}, "outputs": [], @@ -775,7 +777,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 15, "id": "873d2095-4c41-4389-90de-d8ee6bdcc417", "metadata": {}, "outputs": [ @@ -826,7 +828,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 17, "id": "307bef74-6d02-44a3-9579-bdcda5b3a011", "metadata": {}, "outputs": [ @@ -834,23 +836,23 @@ "data": { "text/plain": [ "tensor([[6.1250],\n", - " [6.1050]], dtype=torch.float64)" + " [6.1050]])" ] }, - "execution_count": 21, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "input_data = torch.tensor([[0.5, 0.5], [0.1, 0.5]], dtype=torch.double)\n", + "input_data = torch.tensor([[0.5, 0.5], [0.1, 0.5]], dtype=torch.float)\n", "# PyTorch model\n", "model(input_data)" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 18, "id": "22b64d76-4c14-47ab-8aa6-5881f56dad54", "metadata": {}, "outputs": [ @@ -860,7 +862,7 @@ "tensor([6.1250, 6.1050])" ] }, - "execution_count": 22, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -872,7 +874,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 19, "id": "f8cc9ddf-8851-431c-b3a2-1a70bd0a447b", "metadata": {}, "outputs": [ @@ -882,7 +884,7 @@ "{'output': tensor([6.1250, 6.1050])}" ] }, - "execution_count": 23, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -902,17 +904,15 @@ "\n", "The range for each variable can be optionally defined when instantiating the ScalarVariable class (or in the YAML). If it is not defined, it will be set as (-inf, inf).\n", "\n", - "The LUMEBaseModel class has `input_validation_config` and `output_validation_config` attributes that the user can define. For example: \n", + "The LUMEBaseModel class has `input_validation_config` and `output_validation_config` attributes that the user can set to specify whether to check if the value is within the variable's and with what strictness. For example: \n", "```\n", "input_validation_config = {\n", - " \"input1\" : {\n", - " \"value_range\": True,\n", - " \"strict\": False\n", - " }\n", + " \"input1\" : \"error\",\n", + " \"input2: \"none\"\n", "}\n", "```\n", "\n", - "If `value_range` is set to `True`, the value will be checked at each iteration to make sure it's within the provided range. If `strict` is set to `True`, an error will be raised if it's out of range. If it's `False`, then a warning is printed to the console.\n", + "If a variable's config is set to `\"warn\"` or `\"error\"`, the value will be checked at each iteration to make sure it's within the provided range. If it's set to `\"error\"`, an error will be raised if it's out of range, and if it's set to `\"warn\"`, a warning is printed to the console. A value of `\"none\"` will skip range checking for that variable.\n", "\n", "Note that while output variables do not typically require a range, if one is provided along with a `output_validation_config` that defines strict range checking, an error will \n", "be raised when the output is outside of the provided range." @@ -920,7 +920,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 20, "id": "b49c5bbb-e6c9-4721-8444-193cbd00fd9f", "metadata": {}, "outputs": [ @@ -931,7 +931,7 @@ " ScalarVariable(name='input2', default_value=0.2, value_range=(0.0, 1.0), is_constant=False, unit=None)]" ] }, - "execution_count": 24, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -942,18 +942,17 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 21, "id": "91d6fc04-e9e7-4dc0-842d-421b9d6c49dc", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'input1': {'value_range': True, 'strict': False},\n", - " 'input2': {'value_range': True, 'strict': False}}" + "{'input1': 'warn', 'input2': 'warn'}" ] }, - "execution_count": 25, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -965,7 +964,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 22, "id": "bb957be2-f4a8-46ba-b2ab-3bded7e12a0d", "metadata": {}, "outputs": [ @@ -982,7 +981,7 @@ "{'output': tensor([6.1600, 6.1050])}" ] }, - "execution_count": 26, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -995,7 +994,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 24, "id": "27058d35-ff00-4887-9b4d-98a65afa0460", "metadata": {}, "outputs": [ @@ -1006,21 +1005,21 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[27], line 9\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# strict range checking, raises ValueError\u001b[39;00m\n\u001b[1;32m 8\u001b[0m input_dict \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124minput1\u001b[39m\u001b[38;5;124m'\u001b[39m: torch\u001b[38;5;241m.\u001b[39mtensor([\u001b[38;5;241m1.2\u001b[39m, \u001b[38;5;241m0.1\u001b[39m]), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124minput2\u001b[39m\u001b[38;5;124m'\u001b[39m: torch\u001b[38;5;241m.\u001b[39mtensor([\u001b[38;5;241m0.5\u001b[39m, \u001b[38;5;241m.5\u001b[39m])}\n\u001b[0;32m----> 9\u001b[0m \u001b[43mexample_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_dict\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[24], line 9\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# strict range checking, raises ValueError\u001b[39;00m\n\u001b[1;32m 8\u001b[0m input_dict \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124minput1\u001b[39m\u001b[38;5;124m'\u001b[39m: torch\u001b[38;5;241m.\u001b[39mtensor([\u001b[38;5;241m1.2\u001b[39m, \u001b[38;5;241m0.1\u001b[39m]), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124minput2\u001b[39m\u001b[38;5;124m'\u001b[39m: torch\u001b[38;5;241m.\u001b[39mtensor([\u001b[38;5;241m0.5\u001b[39m, \u001b[38;5;241m.5\u001b[39m])}\n\u001b[0;32m----> 9\u001b[0m \u001b[43mexample_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_dict\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Documents/SLAC/AD/ad-lume/lume-model/lume_model/base.py:308\u001b[0m, in \u001b[0;36mLUMEBaseModel.evaluate\u001b[0;34m(self, input_dict)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mevaluate\u001b[39m(\u001b[38;5;28mself\u001b[39m, input_dict: \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, Any]:\n\u001b[1;32m 307\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Main evaluation function, child classes must implement the _evaluate method.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 308\u001b[0m validated_input_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minput_validation\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_dict\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 309\u001b[0m output_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_evaluate(validated_input_dict)\n\u001b[1;32m 310\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_validation(output_dict)\n", "File \u001b[0;32m~/Documents/SLAC/AD/ad-lume/lume-model/lume_model/models/torch_model.py:166\u001b[0m, in \u001b[0;36mTorchModel.input_validation\u001b[0;34m(self, input_dict)\u001b[0m\n\u001b[1;32m 164\u001b[0m ele \u001b[38;5;241m=\u001b[39m InputDictModel(input_dict\u001b[38;5;241m=\u001b[39mele)\u001b[38;5;241m.\u001b[39minput_dict\n\u001b[1;32m 165\u001b[0m \u001b[38;5;66;03m# validate each value based on its var class and config\u001b[39;00m\n\u001b[0;32m--> 166\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minput_validation\u001b[49m\u001b[43m(\u001b[49m\u001b[43mele\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;66;03m# return the validated input dict for consistency w/ casting ints to floats\u001b[39;00m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28many\u001b[39m([\u001b[38;5;28misinstance\u001b[39m(value, torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;28;01mfor\u001b[39;00m value \u001b[38;5;129;01min\u001b[39;00m validated_input\u001b[38;5;241m.\u001b[39mvalues()]):\n", "File \u001b[0;32m~/Documents/SLAC/AD/ad-lume/lume-model/lume_model/base.py:321\u001b[0m, in \u001b[0;36mLUMEBaseModel.input_validation\u001b[0;34m(self, input_dict)\u001b[0m\n\u001b[1;32m 319\u001b[0m _config \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_validation_config \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_validation_config\u001b[38;5;241m.\u001b[39mget(name)\n\u001b[1;32m 320\u001b[0m var \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_variables[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_names\u001b[38;5;241m.\u001b[39mindex(name)]\n\u001b[0;32m--> 321\u001b[0m \u001b[43mvar\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidate_value\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_config\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 322\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m input_dict\n", - "File \u001b[0;32m~/Documents/SLAC/AD/ad-lume/lume-model/lume_model/variables.py:88\u001b[0m, in \u001b[0;36mScalarVariable.validate_value\u001b[0;34m(self, value, config)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;66;03m# optional validation\u001b[39;00m\n\u001b[1;32m 87\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _config[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalue_range\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n\u001b[0;32m---> 88\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_value_is_within_range\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_config\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Documents/SLAC/AD/ad-lume/lume-model/lume_model/variables.py:104\u001b[0m, in \u001b[0;36mScalarVariable._validate_value_is_within_range\u001b[0;34m(self, value, config)\u001b[0m\n\u001b[1;32m 102\u001b[0m error_message \u001b[38;5;241m=\u001b[39m error_message[:\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m ([\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m,\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m]).\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvalue_range)\n\u001b[1;32m 103\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstrict\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n\u001b[0;32m--> 104\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(error_message)\n\u001b[1;32m 105\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 106\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWarning: \u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m error_message)\n", + "File \u001b[0;32m~/Documents/SLAC/AD/ad-lume/lume-model/lume_model/variables.py:96\u001b[0m, in \u001b[0;36mScalarVariable.validate_value\u001b[0;34m(self, value, config)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[38;5;66;03m# optional validation\u001b[39;00m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m config \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnone\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m---> 96\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_value_is_within_range\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_config\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Documents/SLAC/AD/ad-lume/lume-model/lume_model/variables.py:114\u001b[0m, in \u001b[0;36mScalarVariable._validate_value_is_within_range\u001b[0;34m(self, value, config)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWarning: \u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m error_message)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(error_message)\n", "\u001b[0;31mValueError\u001b[0m: Value (1.2000000476837158) of 'input1' is out of valid range ([0.0,1.0])." ] } ], "source": [ - "# change input1's strict value to True\n", + "# change input1's config to error\n", "example_model.input_validation_config = {\n", - " \"input1\": {'value_range': True, 'strict': True},\n", - " \"input2\": {'value_range': True, 'strict': False}\n", + " \"input1\": \"error\",\n", + " \"input2\": \"warn\"\n", "}\n", "\n", "# strict range checking, raises ValueError\n", @@ -1028,6 +1027,35 @@ "example_model.evaluate(input_dict)" ] }, + { + "cell_type": "code", + "execution_count": 25, + "id": "7656c461-f66c-4d72-a8f4-8c2585dd7c0d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'output': tensor([6.1600, 6.1050])}" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# change input1's config value to none\n", + "example_model.input_validation_config = {\n", + " \"input1\": \"none\",\n", + " \"input2\": \"warn\"\n", + "}\n", + "\n", + "# nothing is printed/raised\n", + "input_dict = {'input1': torch.tensor([1.2, 0.1]), 'input2': torch.tensor([0.5, .5])}\n", + "example_model.evaluate(input_dict)" + ] + }, { "cell_type": "markdown", "id": "6e941ff9-0990-4532-b0e8-93d57105aef9", @@ -1244,14 +1272,6 @@ "input_dict = {'input1': torch.tensor([0.5]), 'input2': torch.tensor([0.5])}\n", "example_model.evaluate(input_dict)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "91e47420-0f96-483b-be6f-10dc3718a4b6", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/lume_model/base.py b/lume_model/base.py index d357ab7..37ff43a 100644 --- a/lume_model/base.py +++ b/lume_model/base.py @@ -10,14 +10,14 @@ import numpy as np from pydantic import BaseModel, ConfigDict, field_validator -from lume_model.variables import ScalarVariable, get_variable +from lume_model.variables import ScalarVariable, get_variable, ConfigEnum from lume_model.utils import ( try_import_module, verify_unique_variable_names, serialize_variables, deserialize_variables, variables_from_dict, - replace_relative_paths, + replace_relative_paths ) logger = logging.getLogger(__name__) @@ -225,14 +225,14 @@ class LUMEBaseModel(BaseModel, ABC): input_variables: List defining the input variables and their order. output_variables: List defining the output variables and their order. input_validation_config: Determines the behavior during input validation by specifying the validation - config for each input variable: {var_name: var_config}. + config for each input variable: {var_name: value}. Value can be "warn", "error", or None. output_validation_config: Determines the behavior during output validation by specifying the validation - config for each output variable: {var_name: var_config}. + config for each output variable: {var_name: value}. Value can be "warn", "error", or None. """ input_variables: list[ScalarVariable] output_variables: list[ScalarVariable] - input_validation_config: Optional[dict[str, dict[str, bool]]] = None - output_validation_config: Optional[dict[str, dict[str, bool]]] = None + input_validation_config: Optional[dict[str, ConfigEnum]] = None + output_validation_config: Optional[dict[str, ConfigEnum]] = None model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) @@ -294,12 +294,12 @@ def output_names(self) -> list[str]: return [var.name for var in self.output_variables] @property - def default_input_validation_config(self) -> dict[str, dict[str, bool]]: + def default_input_validation_config(self) -> dict[str, ConfigEnum]: """Determines default behavior during input validation (if input_validation_config is None).""" return {var.name: var.default_validation_config for var in self.input_variables} @property - def default_output_validation_config(self) -> dict[str, dict[str, bool]]: + def default_output_validation_config(self) -> dict[str, ConfigEnum]: """Determines default behavior during output validation (if output_validation_config is None).""" return {var.name: var.default_validation_config for var in self.output_variables} diff --git a/lume_model/variables.py b/lume_model/variables.py index cc2d559..2d98d4a 100644 --- a/lume_model/variables.py +++ b/lume_model/variables.py @@ -7,11 +7,19 @@ """ from abc import ABC, abstractmethod from typing import Any, Optional, Type +from enum import Enum import numpy as np from pydantic import BaseModel, field_validator, model_validator, ConfigDict +class ConfigEnum(str, Enum): + """Enum for configuration options during validation.""" + NULL = "none" + WARN = "warn" + ERROR = "error" + + class Variable(BaseModel, ABC): """Abstract variable base class. @@ -22,9 +30,9 @@ class Variable(BaseModel, ABC): @property @abstractmethod - def default_validation_config(self) -> dict[str, bool]: + def default_validation_config(self) -> ConfigEnum: """Determines default behavior during validation.""" - return {} + return None @abstractmethod def validate_value(self, value: Any, config: dict[str, bool] = None): @@ -73,10 +81,10 @@ def validate_default_value(self): return self @property - def default_validation_config(self) -> dict[str, bool]: - return {"value_range": True, "strict": False} + def default_validation_config(self) -> ConfigEnum: + return "warn" - def validate_value(self, value: float, config: dict[str, bool] = None): + def validate_value(self, value: float, config: ConfigEnum = None): _config = self.default_validation_config if config is None else config # mandatory validation self._validate_value_type(value) @@ -84,7 +92,7 @@ def validate_value(self, value: float, config: dict[str, bool] = None): if self.is_constant: self._validate_constant_value(value) # optional validation - if _config["value_range"]: + if config != "none": self._validate_value_is_within_range(value, config=_config) @staticmethod @@ -95,15 +103,15 @@ def _validate_value_type(value: float): f"but received {type(value)}." ) - def _validate_value_is_within_range(self, value: float, config: dict[str, bool] = None): + def _validate_value_is_within_range(self, value: float, config: ConfigEnum = None): if not self._value_is_within_range(value): error_message = "Value ({}) of '{}' is out of valid range.".format(value, self.name) if self.value_range is not None: error_message = error_message[:-1] + " ([{},{}]).".format(*self.value_range) - if config["strict"]: - raise ValueError(error_message) - else: + if config == "warn": print("Warning: " + error_message) + else: + raise ValueError(error_message) def _value_is_within_range(self, value) -> bool: self.value_range = self.value_range or (-np.inf, np.inf)