Skip to content

Commit

Permalink
patched edge-case in LinearInterpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
packquickly committed Sep 1, 2023
1 parent 737bf39 commit e050614
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions diffrax/global_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,11 @@ def _index(_ys):
prev_t = self.ts[index]
next_t = self.ts[index + 1]
diff_t = next_t - prev_t

return (
prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / diff_t)
).ω
return jnp.where(
diff_t >= jnp.finfo(diff_t.dtype).eps,
(prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / diff_t)).ω,
prev_ys
)

@eqx.filter_jit
def derivative(self, t: Scalar, left: bool = True) -> PyTree:
Expand Down

0 comments on commit e050614

Please sign in to comment.