Skip to content

Commit

Permalink
Fix error when using gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrea Ponti committed Dec 21, 2023
1 parent 545e388 commit e258563
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
2 changes: 1 addition & 1 deletion botorch/acquisition/augmented_multisource.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def forward(self, X: Tensor) -> Tensor:
A `(b1 x ... bk)`-dim tensor of Augmented Upper Confidence Bound values at
the given design points `X`.
"""
alpha = torch.zeros(X.shape[0], dtype=X.dtype)
alpha = torch.zeros(X.shape[0], dtype=X.dtype, device=X.device)
agp_mean, agp_sigma = self._mean_and_sigma(X[..., :-1])
cb = (self.best_f if self.maximize else -self.best_f) + (
(agp_mean if self.maximize else -agp_mean) + self.beta.sqrt() * agp_sigma
Expand Down
5 changes: 2 additions & 3 deletions test/models/test_gp_regression_multisource.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,13 @@


def _get_random_data_with_source(batch_shape, n, d, n_source, q=1, **tkwargs):
dtype = tkwargs.get("dtype", torch.float32)
rep_shape = batch_shape + torch.Size([1, 1])
bounds = torch.stack([torch.zeros(d), torch.ones(d)])
bounds[-1, -1] = n_source - 1
train_x = (
get_random_x_for_agp(n=n, bounds=bounds, q=q).repeat(rep_shape).type(dtype)
get_random_x_for_agp(n=n, bounds=bounds, q=q).repeat(rep_shape).to(**tkwargs)
)
train_y = torch.sin(train_x[..., :1] * (2 * math.pi)).type(dtype)
train_y = torch.sin(train_x[..., :1] * (2 * math.pi)).to(**tkwargs)
train_y = train_y + 0.2 * torch.randn(n, 1, **tkwargs).repeat(rep_shape)
return train_x, train_y

Expand Down

0 comments on commit e258563

Please sign in to comment.