Skip to content

Commit

Permalink
1/25 Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
eibarolle authored Jan 25, 2025
1 parent 280776d commit a811429
Showing 1 changed file with 10 additions and 38 deletions.
48 changes: 10 additions & 38 deletions test_community/acquisition/test_latent_information_gain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import unittest
import torch
from torch import nn
from torch.distributions import Normal
from botorch_community.acquisition.latent_information_gain import LatentInformationGain
from botorch_community.models.np_regression import NeuralProcessModel

Expand All @@ -11,6 +9,8 @@ def setUp(self):
self.y_dim = 1
self.r_dim = 8
self.z_dim = 3
self.context_x = torch.rand(10, self.x_dim)
self.context_y = torch.rand(10, self.y_dim)
self.r_hidden_dims = [16, 16]
self.z_hidden_dims = [32, 32]
self.decoder_hidden_dims = [16, 16]
Expand All @@ -25,66 +25,38 @@ def setUp(self):
z_dim=self.z_dim,
)
self.acquisition_function = LatentInformationGain(
context_x=self.context_x,
context_y=self.context_y,
model=self.model,
num_samples=self.num_samples,
)
self.context_x = torch.rand(10, self.x_dim)
self.context_y = torch.rand(10, self.y_dim)
self.candidate_x = torch.rand(5, self.x_dim)

def test_initialization(self):
self.assertEqual(self.acquisition_function.num_samples, self.num_samples)
self.assertEqual(self.acquisition_function.model, self.model)

def test_acquisition_shape(self):
lig_score = self.acquisition_function.acquisition(
candidate_x=self.candidate_x,
context_x=self.context_x,
context_y=self.context_y,
lig_score = self.acquisition_function.forward(
candidate_x=self.candidate_x
)
self.assertTrue(torch.is_tensor(lig_score))
self.assertEqual(lig_score.shape, ())

def test_acquisition_kl(self):
lig_score = self.acquisition_function.acquisition(
candidate_x=self.candidate_x,
context_x=self.context_x,
context_y=self.context_y,
lig_score = self.acquisition_function.forward(
candidate_x=self.candidate_x
)
self.assertGreaterEqual(lig_score.item(), 0)

def test_acquisition_samples(self):
lig_1 = self.acquisition_function.acquisition(
candidate_x=self.candidate_x,
context_x=self.context_x,
context_y=self.context_y,
)
lig_1 = self.acquisition_function.forward(candidate_x=self.candidate_x)

self.acquisition_function.num_samples = 20
lig_2 = self.acquisition_function.acquisition(
candidate_x=self.candidate_x,
context_x=self.context_x,
context_y=self.context_y,
)
lig_2 = self.acquisition_function.forward(candidate_x=self.candidate_x)
self.assertTrue(lig_2.item() < lig_1.item())
self.assertTrue(abs(lig_2.item() - lig_1.item()) < 0.2)

def test_acquisition_invalid_inputs(self):
invalid_context_x = torch.rand(10, self.x_dim + 5)
with self.assertRaises(Exception):
self.acquisition_function.acquisition(
candidate_x=self.candidate_x,
context_x=invalid_context_x,
context_y=self.context_y,
)

invalid_candidate_x = torch.rand(5, self.x_dim + 5)
with self.assertRaises(Exception):
self.acquisition_function.acquisition(
candidate_x=invalid_candidate_x,
context_x=self.context_x,
context_y=self.context_y,
)


if __name__ == "__main__":
Expand Down

0 comments on commit a811429

Please sign in to comment.