From a8114292880cf96fa2a52b11b22636911e4aec58 Mon Sep 17 00:00:00 2001 From: Ernest Ibarolle <63222761+eibarolle@users.noreply.github.com> Date: Sat, 25 Jan 2025 06:48:43 -0800 Subject: [PATCH] 1/25 Updates --- .../test_latent_information_gain.py | 48 ++++--------------- 1 file changed, 10 insertions(+), 38 deletions(-) diff --git a/test_community/acquisition/test_latent_information_gain.py b/test_community/acquisition/test_latent_information_gain.py index ff135bd2fa..856ef65efd 100644 --- a/test_community/acquisition/test_latent_information_gain.py +++ b/test_community/acquisition/test_latent_information_gain.py @@ -1,7 +1,5 @@ import unittest import torch -from torch import nn -from torch.distributions import Normal from botorch_community.acquisition.latent_information_gain import LatentInformationGain from botorch_community.models.np_regression import NeuralProcessModel @@ -11,6 +9,8 @@ def setUp(self): self.y_dim = 1 self.r_dim = 8 self.z_dim = 3 + self.context_x = torch.rand(10, self.x_dim) + self.context_y = torch.rand(10, self.y_dim) self.r_hidden_dims = [16, 16] self.z_hidden_dims = [32, 32] self.decoder_hidden_dims = [16, 16] @@ -25,11 +25,11 @@ def setUp(self): z_dim=self.z_dim, ) self.acquisition_function = LatentInformationGain( + context_x=self.context_x, + context_y=self.context_y, model=self.model, num_samples=self.num_samples, ) - self.context_x = torch.rand(10, self.x_dim) - self.context_y = torch.rand(10, self.y_dim) self.candidate_x = torch.rand(5, self.x_dim) def test_initialization(self): @@ -37,54 +37,26 @@ def test_initialization(self): self.assertEqual(self.acquisition_function.model, self.model) def test_acquisition_shape(self): - lig_score = self.acquisition_function.acquisition( - candidate_x=self.candidate_x, - context_x=self.context_x, - context_y=self.context_y, + lig_score = self.acquisition_function.forward( + candidate_x=self.candidate_x ) self.assertTrue(torch.is_tensor(lig_score)) self.assertEqual(lig_score.shape, ()) def test_acquisition_kl(self): - lig_score = self.acquisition_function.acquisition( - candidate_x=self.candidate_x, - context_x=self.context_x, - context_y=self.context_y, + lig_score = self.acquisition_function.forward( + candidate_x=self.candidate_x ) self.assertGreaterEqual(lig_score.item(), 0) def test_acquisition_samples(self): - lig_1 = self.acquisition_function.acquisition( - candidate_x=self.candidate_x, - context_x=self.context_x, - context_y=self.context_y, - ) + lig_1 = self.acquisition_function.forward(candidate_x=self.candidate_x) self.acquisition_function.num_samples = 20 - lig_2 = self.acquisition_function.acquisition( - candidate_x=self.candidate_x, - context_x=self.context_x, - context_y=self.context_y, - ) + lig_2 = self.acquisition_function.forward(candidate_x=self.candidate_x) self.assertTrue(lig_2.item() < lig_1.item()) self.assertTrue(abs(lig_2.item() - lig_1.item()) < 0.2) - def test_acquisition_invalid_inputs(self): - invalid_context_x = torch.rand(10, self.x_dim + 5) - with self.assertRaises(Exception): - self.acquisition_function.acquisition( - candidate_x=self.candidate_x, - context_x=invalid_context_x, - context_y=self.context_y, - ) - - invalid_candidate_x = torch.rand(5, self.x_dim + 5) - with self.assertRaises(Exception): - self.acquisition_function.acquisition( - candidate_x=invalid_candidate_x, - context_x=self.context_x, - context_y=self.context_y, - ) if __name__ == "__main__":