-
Notifications
You must be signed in to change notification settings - Fork 415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NP Regression Model w/ LIG Acquisition #2683
base: main
Are you sure you want to change the base?
Conversation
Hi @eibarolle! Thanks for the PR! I'll review it shortly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks the PR! This is looking good. I left some comments inline
import torch | ||
import torch.nn as nn | ||
import matplotlib.pyplot as plts | ||
# %matplotlib inline |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# %matplotlib inline |
from sklearn.gaussian_process import GaussianProcessRegressor | ||
from sklearn.gaussian_process.kernels import (RBF, Matern, RationalQuadratic, | ||
ExpSineSquared, DotProduct, | ||
ConstantKernel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from sklearn.gaussian_process import GaussianProcessRegressor | |
from sklearn.gaussian_process.kernels import (RBF, Matern, RationalQuadratic, | |
ExpSineSquared, DotProduct, | |
ConstantKernel) |
Let's avoid the sklearn dependency since it isn't used
ConstantKernel) | ||
from typing import Callable, List, Optional, Tuple | ||
from torch.nn import Module, ModuleDict, ModuleList | ||
from sklearn import preprocessing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from sklearn import preprocessing |
from typing import Callable, List, Optional, Tuple | ||
from torch.nn import Module, ModuleDict, ModuleList | ||
from sklearn import preprocessing | ||
from scipy.stats import multivariate_normal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from scipy.stats import multivariate_normal |
Let's remove the unused imports
from scipy.stats import multivariate_normal | ||
from gpytorch.distributions import MultivariateNormal | ||
|
||
device = torch.device("cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device = torch.device("cpu") |
Let's make the code agnostic to the device used.
if n == 1: | ||
eps = torch.autograd.Variable(logvar.data.new(self.z_dim).normal_()).to(device) | ||
else: | ||
eps = torch.autograd.Variable(logvar.data.new(n,self.z_dim).normal_()).to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if n == 1: | |
eps = torch.autograd.Variable(logvar.data.new(self.z_dim).normal_()).to(device) | |
else: | |
eps = torch.autograd.Variable(logvar.data.new(n,self.z_dim).normal_()).to(device) | |
shape = [n, self.z_dim] | |
if n == 1: | |
shape = shape[1:] | |
eps = torch.autograd.Variable(logvar.data.new(*shape).normal_()).to(device) | |
This is a bit more concise
mu: torch.Tensor, | ||
logvar: torch.Tensor, | ||
n: int = 1, | ||
min_std: float = 0.1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This default seems high, no?
def load_state_dict( | ||
self, | ||
state_dict: dict, | ||
strict: bool = True | ||
) -> None: | ||
""" | ||
Initialize the fully Bayesian model before loading the state dict. | ||
|
||
Args: | ||
state_dict (dict): A dictionary containing the parameters. | ||
strict (bool): Case matching strictness. | ||
""" | ||
super().load_state_dict(state_dict, strict=strict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed, since it just call's the parent class's method
ind = np.arange(x.shape[0]) | ||
mask = np.random.choice(ind, size=n_context, replace=False) | ||
x_c = torch.from_numpy(x[mask]) | ||
y_c = torch.from_numpy(y[mask]) | ||
x_t = torch.from_numpy(np.delete(x, mask, axis=0)) | ||
y_t = torch.from_numpy(np.delete(y, mask, axis=0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason to not do this in pytorch?
import torch | ||
#reference: https://arxiv.org/abs/2106.02770 | ||
|
||
class LatentInformationGain: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we implement this as a subclass of Acquisition
, so that we can use it more organically in botorch? Likely the context would be needed to be provided in LatentInformationGain.__init__
I applied the suggested changes to my new code. Note that for the decoder, the other functions are designed around its current state. |
@hvarfner curious if you have any thoughts on this PR re the information gain aspects |
Interesting! I'll check out the paper quickly and get back to you |
|
||
self.acquisition_function.num_samples = 20 | ||
lig_2 = self.acquisition_function.forward(candidate_x=self.candidate_x) | ||
self.assertTrue(lig_2.item() < lig_1.item()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would this unit test necessarily be a good check? I am not sure on the details, but it seems to me that this just improves the accuracy of the acquisition computation.
self.context_x = context_x.to(device) | ||
self.context_y = context_y.to(device) | ||
|
||
def forward(self, candidate_x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me like the acquisition function computes the information gain on one batch of points, with the batch being in the first dimension. Thus, the output of this forward would be one scalar.
This would run contrary to the acquisition function convention, and so it wouldn't be able to be used with optimize_acqf
etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You would want the forward to be able to handled a N x q x D
-shaped input, where you are currently only computing the q
-element (but that aspect seems correct as far as I can tell right now). This may be a bit challenging, but certainly looks doable!
I think this a pretty cool idea that could be useful generally in latent space models, so it would be nice some fairly general naming for the encoding steps if this were to be used with other encoder-decoder based architectures.
The Latent Information Gain forward function has been updated with the correct dimensions, with the test cases adjusted as needed. |
Notify me when you can check over the new Latent Information Gain function. @hvarfner |
Motivation
This pull request adds a Neural Process Regression Model with a Latent Information Gain acquisition function for BoTorch functionality.
Have you read the Contributing Guidelines on pull requests?
Yes, and I've followed all the steps and testing.
Test Plan
I wrote my own unit tests for both the model and acquisition function, and all of them passed. The test files are in the appropriate folder. In addition, I ran the pytests on my files, and all of them succeeded for those files.
Related
I made a repository holding the pushed files at https://github.com/eibarolle/np_regression, and it has the appropriate API documentation.