Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Intermediate saved values are sometimes inf #488

Open
dkweiss31 opened this issue Aug 16, 2024 · 3 comments
Open

Intermediate saved values are sometimes inf #488

dkweiss31 opened this issue Aug 16, 2024 · 3 comments

Comments

@dkweiss31
Copy link
Contributor

Hi Patrick! I've run into an issue over in dynamiqs/dynamiqs#666 when t0=t1 and I try to save intermediate values. It seems to be independent of the stepsize_controller that I use (adaptive or constant). Here is a minimal example using constant steps.

import diffrax as dx
import jax.numpy as jnp

term = dx.ODETerm(lambda t, y, _: y)
y0 = jnp.array([1.0])
ts = jnp.array([0.0, 0.0])
saveat = dx.SaveAt(subs=[dx.SubSaveAt(ts=ts), dx.SubSaveAt(t1=True)])

solution = dx.diffeqsolve(
    term,
    dx.Tsit5(),
    ts[0],
    ts[-1],
    0.1,
    y0,
    saveat=saveat,
)
print(solution.ys[0])  # [[inf] [inf]]
print(solution.ys[1])  # [[1.]]
@patrick-kidger
Copy link
Owner

Ah! Good catch. It seems that we don't try to fill in SaveAt(ts=...) in the t0=t1 case. This case means that we never enter our integration loop and so we never trigger any of the code for saving SaveAt(ts=...):

interpolator = solver.interpolation_cls(
t0=state.tprev, t1=state.tnext, **dense_info
)
save_state = state.save_state
dense_ts = state.dense_ts
dense_infos = state.dense_infos
dense_save_index = state.dense_save_index
def save_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
if subsaveat.ts is not None:
save_state = save_ts_impl(subsaveat.ts, subsaveat.fn, save_state)
return save_state
def save_ts_impl(ts, fn, save_state: SaveState) -> SaveState:
def _cond_fun(_save_state):
return (
keep_step
& (ts[_save_state.saveat_ts_index] <= state.tnext)
& (_save_state.saveat_ts_index < len(ts))
)
def _body_fun(_save_state):
_t = ts[_save_state.saveat_ts_index]
_y = interpolator.evaluate(_t)
_ts = _save_state.ts.at[_save_state.save_index].set(_t)
_ys = jtu.tree_map(
lambda __y, __ys: __ys.at[_save_state.save_index].set(__y),
fn(_t, _y, args),
_save_state.ys,
)
return SaveState(
saveat_ts_index=_save_state.saveat_ts_index + 1,
ts=_ts,
ys=_ys,
save_index=_save_state.save_index + 1,
)
return inner_while_loop(
_cond_fun,
_body_fun,
save_state,
max_steps=len(ts),
buffers=_inner_buffers,
checkpoints=len(ts),
)
save_state = jtu.tree_map(
save_ts, saveat.subs, save_state, is_leaf=_is_subsaveat
)

I suppose we should add a special case in our integration loop that explicitly handles this case. (Something like ys = lax.cond(t0 == t1, lambda: jnp.full(ts, y0), ...) ?)

I'd be happy to take a PR on this.

@dkweiss31
Copy link
Contributor Author

Gotcha! Nice I'd be happy to give this a go. I've been meaning to learn about while loops and buffers so no time like the present!

@patrick-kidger
Copy link
Owner

Great!
And the good news is that I think this gets to happen before/after the loop (I mispoke a little above), so there won't be any need to change the integration loop itself.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants