-
-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
UnderdampedLangevinDiffusionTerm inheritance #548
Comments
So
On to the nature of your question: there's really no such thing as a 'control term' -- these are not any different from regular terms. I think what you want is to have a pytree-of-states corresponding to the pytree-of-terms. And it will simply happen that e.g. |
Yea I have it such that it's a pytree of states where ODE is just None state. It was mostly during an initialization step, the ODEterm has no control member variable to do an initialization operation with (which could be checked via attributes). That is to say, when creating the pytree of states, it basically wants to call the AbstractPath's init method (which an ODETerm doesn't have, so I would return None for that state) and so my first instinct was just to parse which ones I call init for (which are the control terms). Writing a new init method for the term seemed unnecessary at first since it would be just calling the control.init anyway. I know the treatment of terms in a mostly uniform matter is one the attributes of diffrax, so I'm curious if you ever thought about having the ODETerm have a control as well (where the control is just the trivial t1 - t0 like it is now, but as a separated entity), then there would be even less difference in operating with control term vs ODETerm (not advocating for that idea, just curious). |
To be concrete, sometimes the code to the effect of (just my intermediate hack) def _path_init(term):
if isinstance(term, _AbstractControlTerm) or isinstance(
term, UnderdampedLangevinDiffusionTerm
):
if isinstance(term.control, AbstractPath):
return term.control.init(t0, end, y0, args, max_steps)
return None
elif isinstance(term, MultiTerm):
return jax.tree.map(
lambda x: _path_init(x, end),
term.terms,
is_leaf=lambda x: isinstance(x, AbstractTerm),
)
return None
path_state = jtu.tree_map(
lambda x: _path_init(x, t1),
terms,
is_leaf=lambda x: isinstance(x, AbstractTerm),
) comes up. Since ODETerm doesn't have a control, and term doesn't have an init. Either of which I could add to solve it (or refine the hack, or do the blind faith hasattr check). |
So my first reaction is indeed to add I'm not super happy with that but I'm definitely less happy with the other options!
|
Makes sense, its not super appealing as a solution, but seems like the best one |
Should
UnderdampedLangevinDiffusionTerm
inherit from_AbstractControlTerm
(in theory, not talking about whether this reduces LoC). The reason I ask is because working on the stateful control, I'm parsing terms to check if they are control terms (which is just done via inheritance check), and given that the diffusion term is a control term, seems like it should inherit (even if there's just some new "AbstractControlTerm" which doesn't have any code but unified inheritance). If this is not the case, should there just be a hard coded check for UnderdampedLangevinDiffusionTerm, or what is the best practice to parse some terms and determine which are control terms?The text was updated successfully, but these errors were encountered: