Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/slaclab/lume-model into val…
Browse files Browse the repository at this point in the history
…idation
  • Loading branch information
pluflou committed Dec 18, 2024
2 parents e6c7b6a + fd6896f commit 45dfde8
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 8 deletions.
156 changes: 156 additions & 0 deletions examples/transformer_conversion.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "235c92cd-cc05-42b8-a516-1185eeac5f0c",
"metadata": {},
"source": [
"# Transformer Conversion"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "56725817-2b21-4bea-98b0-151dea959f77",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import torch\n",
"from botorch.models.transforms.input import AffineInputTransform\n",
"\n",
"sys.path.append(\"../\")\n",
"from lume_model.models import TorchModel, TorchModule"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9feaf8a2-f533-4787-a588-22aba0844e53",
"metadata": {},
"outputs": [],
"source": [
"# load exemplary model\n",
"torch_model = TorchModel(\"../tests/test_files/california_regression/torch_model.yml\")\n",
"torch_module = TorchModule(model=torch_model)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9ab4f3bf-cfb6-43f8-beaa-3847d7caf1bf",
"metadata": {},
"outputs": [],
"source": [
"# conversion\n",
"def convert_torch_transformer(t: torch.nn.Linear) -> AffineInputTransform:\n",
" \"\"\"Creates an AffineInputTransform module which mirrors the behavior of the given torch.nn.Linear module.\n",
"\n",
" Args:\n",
" t: The torch transformer to convert.\n",
"\n",
" Returns:\n",
" AffineInputTransform module which mirrors the behavior of the given torch.nn.Linear module.\n",
" \"\"\"\n",
" m = AffineInputTransform(\n",
" d=t.bias.size(-1),\n",
" coefficient=1 / t.weight.diagonal(),\n",
" offset=-t.bias / t.weight.diagonal(),\n",
" ).to(t.bias.dtype)\n",
" m.offset.requires_grad = t.bias.requires_grad\n",
" m.coefficient.requires_grad = t.weight.requires_grad\n",
" if not t.training:\n",
" m.eval()\n",
" return m\n",
"\n",
"\n",
"def convert_botorch_transformer(t: AffineInputTransform) -> torch.nn.Linear:\n",
" \"\"\"Creates a torch.nn.Linear module which mirrors the behavior of the given AffineInputTransform module.\n",
"\n",
" Args:\n",
" t: The botorch transformer to convert.\n",
"\n",
" Returns:\n",
" torch.nn.Linear module which mirrors the behavior of the given AffineInputTransform module.\n",
" \"\"\"\n",
" d = t.offset.size(-1)\n",
" m = torch.nn.Linear(in_features=d, out_features=d).to(t.offset.dtype)\n",
" m.bias = torch.nn.Parameter(-t.offset / t.coefficient)\n",
" weight_matrix = torch.zeros((d, d))\n",
" weight_matrix = weight_matrix.fill_diagonal_(1.0) / t.coefficient\n",
" m.weight = torch.nn.Parameter(weight_matrix)\n",
" m.bias.requires_grad = t.offset.requires_grad\n",
" m.weight.requires_grad = t.coefficient.requires_grad\n",
" if not t.training:\n",
" m.eval()\n",
" return m"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ff3bfd02-dbc1-4236-9ff6-77c4f8a7dcb2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n"
]
}
],
"source": [
"# test on exemplary input\n",
"input_dict = torch_model.random_input(n_samples=1)\n",
"x = torch.tensor([input_dict[k] for k in torch_module.input_order]).unsqueeze(0)\n",
"\n",
"torch_input_transformers = [\n",
" convert_botorch_transformer(t) for t in torch_model.input_transformers\n",
"]\n",
"torch_output_transformers = [\n",
" convert_botorch_transformer(t) for t in torch_model.output_transformers\n",
"]\n",
"new_torch_model = TorchModel(\n",
" input_variables=torch_model.input_variables,\n",
" output_variables=torch_model.output_variables,\n",
" model=torch_model.model,\n",
" input_transformers=torch_input_transformers,\n",
" output_transformers=torch_output_transformers,\n",
")\n",
"new_torch_module = TorchModule(model=new_torch_model)\n",
"\n",
"print(torch.isclose(torch_module(x), new_torch_module(x)).item())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a45608c5-dae7-48f7-b602-b2bcd6e9d453",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:lume-model-dev]",
"language": "python",
"name": "conda-env-lume-model-dev-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.20"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
37 changes: 29 additions & 8 deletions lume_model/models/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class TorchModel(LUMEBaseModel):
computation is deactivated.
"""
model: torch.nn.Module
input_transformers: list[ReversibleInputTransform] = []
output_transformers: list[ReversibleInputTransform] = []
input_transformers: list[Union[ReversibleInputTransform, torch.nn.Linear]] = None
output_transformers: list[Union[ReversibleInputTransform, torch.nn.Linear]] = None
output_format: str = "tensor"
device: Union[torch.device, str] = "cpu"
fixed_model: bool = True
Expand All @@ -47,6 +47,8 @@ def __init__(self, *args, **kwargs):
**kwargs: See class attributes.
"""
super().__init__(*args, **kwargs)
self.input_transformers = [] if self.input_transformers is None else self.input_transformers
self.output_transformers = [] if self.output_transformers is None else self.output_transformers

# dtype property sets precision across model and transformers
self.dtype;
Expand Down Expand Up @@ -89,7 +91,7 @@ def validate_torch_model(cls, v):
return v

@field_validator("input_transformers", "output_transformers", mode="before")
def validate_botorch_transformers(cls, v):
def validate_transformers(cls, v):
if not isinstance(v, list):
raise ValueError("Transformers must be passed as list.")
loaded_transformers = []
Expand Down Expand Up @@ -261,12 +263,24 @@ def update_input_variables_to_transformer(self, transformer_loc: int) -> list[Sc
x = x_old[key]
# compute previous limits at transformer location
for i in range(transformer_loc):
x = self.input_transformers[i].transform(x)
if isinstance(self.input_transformers[i], ReversibleInputTransform):
x = self.input_transformers[i].transform(x)
else:
x = self.input_transformers[i](x)
# untransform of transformer to adjust for
x = self.input_transformers[transformer_loc].untransform(x)
if isinstance(self.input_transformers[transformer_loc], ReversibleInputTransform):
x = self.input_transformers[transformer_loc].untransform(x)
else:
w = self.input_transformers[transformer_loc].weight
b = self.input_transformers[transformer_loc].bias
x = torch.matmul((x - b), torch.linalg.inv(w.T))
# backtrack through transformers
for transformer in self.input_transformers[:transformer_loc][::-1]:
x = transformer.untransform(x)
if isinstance(self.input_transformers[transformer_loc], ReversibleInputTransform):
x = transformer.untransform(x)
else:
w, b = transformer.weight, transformer.bias
x = torch.matmul((x - b), torch.linalg.inv(w.T))
x_new[key] = x
updated_variables = deepcopy(self.input_variables)
for i, var in enumerate(updated_variables):
Expand Down Expand Up @@ -350,7 +364,10 @@ def _transform_inputs(self, input_tensor: torch.Tensor) -> torch.Tensor:
Tensor of transformed inputs to be passed to the model.
"""
for transformer in self.input_transformers:
input_tensor = transformer.transform(input_tensor)
if isinstance(transformer, ReversibleInputTransform):
input_tensor = transformer.transform(input_tensor)
else:
input_tensor = transformer(input_tensor)
return input_tensor

def _transform_outputs(self, output_tensor: torch.Tensor) -> torch.Tensor:
Expand All @@ -363,7 +380,11 @@ def _transform_outputs(self, output_tensor: torch.Tensor) -> torch.Tensor:
(Un-)Transformed output tensor.
"""
for transformer in self.output_transformers:
output_tensor = transformer.untransform(output_tensor)
if isinstance(transformer, ReversibleInputTransform):
output_tensor = transformer.untransform(output_tensor)
else:
w, b = transformer.weight, transformer.bias
output_tensor = torch.matmul((output_tensor - b), torch.linalg.inv(w.T))
return output_tensor

def _parse_outputs(self, output_tensor: torch.Tensor) -> dict[str, torch.Tensor]:
Expand Down

0 comments on commit 45dfde8

Please sign in to comment.