Skip to content

Commit

Permalink
tests for enable_grad=False
Browse files Browse the repository at this point in the history
  • Loading branch information
mhavasi committed Dec 17, 2024
1 parent 68f0f96 commit ee5be84
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions tests/solver/test_ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,24 @@ def test_gradients(self):
self.constant_velocity_model.a.grad, 2.0, delta=1e-4
)

def test_no_gradients(self):
x_init = torch.tensor([1.0, 0.0])
step_size = 0.001
time_grid = torch.tensor([0.0, 1.0])

method = "euler"
self.constant_velocity_model.zero_grad()
result = self.constant_velocity_solver.sample(
x_init=x_init,
step_size=step_size,
time_grid=time_grid,
method=method,
)
loss = result.sum()

with self.assertRaises(RuntimeError):
loss.backward()

def test_compute_likelihood(self):
x_1 = torch.tensor([[0.0, 0.0]])
step_size = 0.1
Expand All @@ -114,6 +132,9 @@ def dummy_log_p(x: Tensor) -> Tensor:
self.assertIsInstance(log_likelihood, Tensor)
self.assertEqual(x_1.shape[0], log_likelihood.shape[0])

with self.assertRaises(RuntimeError):
log_likelihood.backward()

def test_compute_likelihood_exact_divergence(self):
x_1 = torch.tensor([[0.0, 0.0]], requires_grad=True)
step_size = 0.1
Expand Down

0 comments on commit ee5be84

Please sign in to comment.