From cf97e0e7ed3021f02f1a7592aefb44a39f9ef2d1 Mon Sep 17 00:00:00 2001 From: Kathryn Baker Date: Mon, 19 Feb 2024 14:54:36 +0000 Subject: [PATCH] [CHANGE] fix for import as module & use in ensemble --- base.py | 2 +- decoupled_linear.py | 6 ++++-- utils.py | 5 ++++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/base.py b/base.py index b8a0bec..ea0aed1 100644 --- a/base.py +++ b/base.py @@ -31,7 +31,7 @@ def forward(self, x): def to(self, device: str): self.model.to(device) - super().to(device) + return super().to(device) class ParameterModule(BaseModule, ABC): diff --git a/decoupled_linear.py b/decoupled_linear.py index f0275f6..dc2f910 100644 --- a/decoupled_linear.py +++ b/decoupled_linear.py @@ -5,8 +5,10 @@ from gpytorch.priors import Prior, NormalPrior, GammaPrior from gpytorch.constraints import Interval, Positive -from base import ParameterModule - +try: + from base import ParameterModule +except (ImportError, ModuleNotFoundError): + from .base import ParameterModule class InputOffset(ParameterModule): """Adds input offset calibration to the model. diff --git a/utils.py b/utils.py index 4749aa5..13e71f8 100644 --- a/utils.py +++ b/utils.py @@ -2,7 +2,10 @@ from torch import Tensor from botorch.models.transforms.input import AffineInputTransform -from base import BaseModule +try: + from base import BaseModule +except (ImportError, ModuleNotFoundError): + from .base import BaseModule def extract_input_transformer(module: BaseModule) -> AffineInputTransform: