From 68924b4c029f9e3c828d1133135b9f6f4477e766 Mon Sep 17 00:00:00 2001 From: Tobias Boltz Date: Wed, 4 Dec 2024 13:10:03 -0800 Subject: [PATCH 1/4] Allow transformers to be torch.nn.Linear modules --- lume_model/models/torch_model.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/lume_model/models/torch_model.py b/lume_model/models/torch_model.py index 9dabf46..e58d7b2 100644 --- a/lume_model/models/torch_model.py +++ b/lume_model/models/torch_model.py @@ -35,8 +35,8 @@ class TorchModel(LUMEBaseModel): computation is deactivated. """ model: torch.nn.Module - input_transformers: list[ReversibleInputTransform] = [] - output_transformers: list[ReversibleInputTransform] = [] + input_transformers: list[Union[ReversibleInputTransform, torch.nn.Linear]] = None + output_transformers: list[Union[ReversibleInputTransform, torch.nn.Linear]] = None output_format: str = "tensor" device: Union[torch.device, str] = "cpu" fixed_model: bool = True @@ -50,6 +50,8 @@ def __init__(self, *args, **kwargs): **kwargs: See class attributes. """ super().__init__(*args, **kwargs) + self.input_transformers = [] if self.input_transformers is None else self.input_transformers + self.output_transformers = [] if self.output_transformers is None else self.output_transformers # set precision self.model.to(dtype=self.dtype) @@ -85,7 +87,7 @@ def validate_torch_model(cls, v): return v @field_validator("input_transformers", "output_transformers", mode="before") - def validate_botorch_transformers(cls, v): + def validate_transformers(cls, v): if not isinstance(v, list): raise ValueError("Transformers must be passed as list.") loaded_transformers = [] @@ -301,7 +303,10 @@ def _transform_inputs(self, input_tensor: torch.Tensor) -> torch.Tensor: Tensor of transformed inputs to be passed to the model. """ for transformer in self.input_transformers: - input_tensor = transformer.transform(input_tensor) + if isinstance(transformer, ReversibleInputTransform): + input_tensor = transformer.transform(input_tensor) + else: + input_tensor = transformer(input_tensor) return input_tensor def _transform_outputs(self, output_tensor: torch.Tensor) -> torch.Tensor: @@ -314,7 +319,10 @@ def _transform_outputs(self, output_tensor: torch.Tensor) -> torch.Tensor: (Un-)Transformed output tensor. """ for transformer in self.output_transformers: - output_tensor = transformer.untransform(output_tensor) + if isinstance(transformer, ReversibleInputTransform): + output_tensor = transformer.untransform(output_tensor) + else: + output_tensor = transformer(output_tensor) return output_tensor def _parse_outputs(self, output_tensor: torch.Tensor) -> dict[str, torch.Tensor]: From 5d4f519143677a69502a72f76c778d5b0111c475 Mon Sep 17 00:00:00 2001 From: Tobias Boltz Date: Wed, 4 Dec 2024 14:01:42 -0800 Subject: [PATCH 2/4] Add example for conversion between torch and botorch transformers --- examples/transformer_conversion.ipynb | 146 ++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 examples/transformer_conversion.ipynb diff --git a/examples/transformer_conversion.ipynb b/examples/transformer_conversion.ipynb new file mode 100644 index 0000000..2314a35 --- /dev/null +++ b/examples/transformer_conversion.ipynb @@ -0,0 +1,146 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "235c92cd-cc05-42b8-a516-1185eeac5f0c", + "metadata": {}, + "source": [ + "# Transformer Conversion" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "56725817-2b21-4bea-98b0-151dea959f77", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import torch\n", + "from botorch.models.transforms.input import AffineInputTransform\n", + "\n", + "sys.path.append(\"../\")\n", + "from lume_model.models import TorchModel, TorchModule" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9feaf8a2-f533-4787-a588-22aba0844e53", + "metadata": {}, + "outputs": [], + "source": [ + "# load exemplary model\n", + "torch_model = TorchModel(\"../tests/test_files/california_regression/torch_model.yml\")\n", + "torch_module = TorchModule(model=torch_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9ab4f3bf-cfb6-43f8-beaa-3847d7caf1bf", + "metadata": {}, + "outputs": [], + "source": [ + "# conversion\n", + "def convert_torch_transformer(t: torch.nn.Linear) -> AffineInputTransform:\n", + " \"\"\"Creates an AffineInputTransform module which mirrors the behavior of the given torch.nn.Linear module.\n", + "\n", + " Args:\n", + " t: The torch transformer to convert.\n", + "\n", + " Returns:\n", + " AffineInputTransform module which mirrors the behavior of the given torch.nn.Linear module.\n", + " \"\"\"\n", + " m = AffineInputTransform(\n", + " d=t.bias.size(-1),\n", + " coefficient=1 / t.weight.diagonal(),\n", + " offset=-t.bias / t.weight.diagonal(),\n", + " ).to(t.bias.dtype)\n", + " m.offset.requires_grad = t.bias.requires_grad\n", + " m.coefficient.requires_grad = t.weight.requires_grad\n", + " if not t.training:\n", + " m.eval()\n", + " return m\n", + "\n", + "\n", + "def convert_botorch_transformer(t: AffineInputTransform) -> torch.nn.Linear:\n", + " \"\"\"Creates a torch.nn.Linear module which mirrors the behavior of the given AffineInputTransform module.\n", + "\n", + " Args:\n", + " t: The botorch transformer to convert.\n", + "\n", + " Returns:\n", + " torch.nn.Linear module which mirrors the behavior of the given AffineInputTransform module.\n", + " \"\"\"\n", + " d = t.offset.size(-1)\n", + " m = torch.nn.Linear(in_features=d, out_features=d).to(t.offset.dtype)\n", + " m.bias = torch.nn.Parameter(-t.offset / t.coefficient)\n", + " weight_matrix = torch.zeros((d, d))\n", + " weight_matrix = weight_matrix.fill_diagonal_(1.0) / t.coefficient\n", + " m.weight = torch.nn.Parameter(weight_matrix)\n", + " m.bias.requires_grad = t.offset.requires_grad\n", + " m.weight.requires_grad = t.coefficient.requires_grad\n", + " if not t.training:\n", + " m.eval()\n", + " return m" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ff3bfd02-dbc1-4236-9ff6-77c4f8a7dcb2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n", + "True\n" + ] + } + ], + "source": [ + "# test on exemplary input\n", + "input_dict = torch_model.random_input(n_samples=1)\n", + "x = torch.tensor([input_dict[k] for k in torch_module.input_order]).unsqueeze(0)\n", + "botorch_transformer = torch_model.input_transformers[0].to(x.dtype)\n", + "torch_transformer = convert_botorch_transformer(botorch_transformer)\n", + "converted_botorch_transformer = convert_torch_transformer(torch_transformer)\n", + "\n", + "print(torch.all(torch.isclose(botorch_transformer(x), torch_transformer(x), atol=1e-6)).item())\n", + "print(torch.all(torch.isclose(torch_transformer(x), converted_botorch_transformer(x), atol=1e-6)).item())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b96c564-7037-4a2c-8f46-2f84fc75c2e2", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:lume-model-dev]", + "language": "python", + "name": "conda-env-lume-model-dev-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.20" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 22441d003814c77227c76121a329512f60df35c5 Mon Sep 17 00:00:00 2001 From: Tobias Boltz Date: Wed, 4 Dec 2024 14:41:12 -0800 Subject: [PATCH 3/4] Mirror untransform behavior for linear torch transformers --- lume_model/models/torch_model.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/lume_model/models/torch_model.py b/lume_model/models/torch_model.py index e58d7b2..bd89b82 100644 --- a/lume_model/models/torch_model.py +++ b/lume_model/models/torch_model.py @@ -213,12 +213,24 @@ def update_input_variables_to_transformer(self, transformer_loc: int) -> list[In x = x_old[key] # compute previous limits at transformer location for i in range(transformer_loc): - x = self.input_transformers[i].transform(x) + if isinstance(self.input_transformers[i], ReversibleInputTransform): + x = self.input_transformers[i].transform(x) + else: + x = self.input_transformers[i](x) # untransform of transformer to adjust for - x = self.input_transformers[transformer_loc].untransform(x) + if isinstance(self.input_transformers[transformer_loc], ReversibleInputTransform): + x = self.input_transformers[transformer_loc].untransform(x) + else: + w = self.input_transformers[transformer_loc].weight + b = self.input_transformers[transformer_loc].bias + x = torch.matmul((x - b), torch.linalg.inv(w.T)) # backtrack through transformers for transformer in self.input_transformers[:transformer_loc][::-1]: - x = transformer.untransform(x) + if isinstance(self.input_transformers[transformer_loc], ReversibleInputTransform): + x = transformer.untransform(x) + else: + w, b = transformer.weight, transformer.bias + x = torch.matmul((x - b), torch.linalg.inv(w.T)) x_new[key] = x updated_variables = deepcopy(self.input_variables) for i, var in enumerate(updated_variables): @@ -322,7 +334,8 @@ def _transform_outputs(self, output_tensor: torch.Tensor) -> torch.Tensor: if isinstance(transformer, ReversibleInputTransform): output_tensor = transformer.untransform(output_tensor) else: - output_tensor = transformer(output_tensor) + w, b = transformer.weight, transformer.bias + output_tensor = torch.matmul((output_tensor - b), torch.linalg.inv(w.T)) return output_tensor def _parse_outputs(self, output_tensor: torch.Tensor) -> dict[str, torch.Tensor]: From 5e4aaefc27250f0d2b3d4f4d3dafeb231f768619 Mon Sep 17 00:00:00 2001 From: Tobias Boltz Date: Wed, 4 Dec 2024 14:42:33 -0800 Subject: [PATCH 4/4] Extend transformer conversion example to full model call --- examples/transformer_conversion.ipynb | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/examples/transformer_conversion.ipynb b/examples/transformer_conversion.ipynb index 2314a35..88c6f37 100644 --- a/examples/transformer_conversion.ipynb +++ b/examples/transformer_conversion.ipynb @@ -96,7 +96,6 @@ "name": "stdout", "output_type": "stream", "text": [ - "True\n", "True\n" ] } @@ -105,18 +104,29 @@ "# test on exemplary input\n", "input_dict = torch_model.random_input(n_samples=1)\n", "x = torch.tensor([input_dict[k] for k in torch_module.input_order]).unsqueeze(0)\n", - "botorch_transformer = torch_model.input_transformers[0].to(x.dtype)\n", - "torch_transformer = convert_botorch_transformer(botorch_transformer)\n", - "converted_botorch_transformer = convert_torch_transformer(torch_transformer)\n", "\n", - "print(torch.all(torch.isclose(botorch_transformer(x), torch_transformer(x), atol=1e-6)).item())\n", - "print(torch.all(torch.isclose(torch_transformer(x), converted_botorch_transformer(x), atol=1e-6)).item())" + "torch_input_transformers = [\n", + " convert_botorch_transformer(t) for t in torch_model.input_transformers\n", + "]\n", + "torch_output_transformers = [\n", + " convert_botorch_transformer(t) for t in torch_model.output_transformers\n", + "]\n", + "new_torch_model = TorchModel(\n", + " input_variables=torch_model.input_variables,\n", + " output_variables=torch_model.output_variables,\n", + " model=torch_model.model,\n", + " input_transformers=torch_input_transformers,\n", + " output_transformers=torch_output_transformers,\n", + ")\n", + "new_torch_module = TorchModule(model=new_torch_model)\n", + "\n", + "print(torch.isclose(torch_module(x), new_torch_module(x)).item())" ] }, { "cell_type": "code", "execution_count": null, - "id": "5b96c564-7037-4a2c-8f46-2f84fc75c2e2", + "id": "a45608c5-dae7-48f7-b602-b2bcd6e9d453", "metadata": {}, "outputs": [], "source": []