From 59961a2837897cefcd9ad728239a6686c75adc67 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Mon, 27 Jan 2025 07:04:55 -0800 Subject: [PATCH] FullyBayesianSingleTaskGP.train should not return None Summary: This is for consistency with the signature of `Module.train`. Differential Revision: D68710923 --- botorch/models/fully_bayesian.py | 12 ++++++++++-- test/models/test_fully_bayesian.py | 3 ++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/botorch/models/fully_bayesian.py b/botorch/models/fully_bayesian.py index 52361a9e41..91a09e2875 100644 --- a/botorch/models/fully_bayesian.py +++ b/botorch/models/fully_bayesian.py @@ -33,7 +33,7 @@ import math from abc import ABC, abstractmethod from collections.abc import Mapping -from typing import Any +from typing import Any, TypeVar import pyro import torch @@ -67,6 +67,11 @@ from pyro.ops.integrator import register_exception_handler from torch import Tensor +# Can replace with Self type once 3.11 is the minimum version +TFullyBayesianSingleTaskGP = TypeVar( + "TFullyBayesianSingleTaskGP", bound="FullyBayesianSingleTaskGP" +) + _sqrt5 = math.sqrt(5) @@ -623,13 +628,16 @@ def _aug_batch_shape(self) -> torch.Size: aug_batch_shape += torch.Size([self.num_outputs]) return aug_batch_shape - def train(self, mode: bool = True) -> None: + def train( + self: TFullyBayesianSingleTaskGP, mode: bool = True + ) -> TFullyBayesianSingleTaskGP: r"""Puts the model in `train` mode.""" super().train(mode=mode) if mode: self.mean_module = None self.covar_module = None self.likelihood = None + return self def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None: r"""Load the MCMC hyperparameter samples into the model. diff --git a/test/models/test_fully_bayesian.py b/test/models/test_fully_bayesian.py index 43689428de..caffbd5a3f 100644 --- a/test/models/test_fully_bayesian.py +++ b/test/models/test_fully_bayesian.py @@ -434,7 +434,8 @@ def test_fit_model(self): # Make sure the model shapes are set correctly self.assertEqual(model.pyro_model.train_X.shape, torch.Size([n, d])) self.assertAllClose(model.pyro_model.train_X, train_X) - model.train() # Put the model in train mode + trained_model = model.train() # Put the model in train mode + self.assertIs(trained_model, model) self.assertAllClose(train_X, model.pyro_model.train_X) self.assertIsNone(model.mean_module) self.assertIsNone(model.covar_module)