Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UnderdampedLangevinDiffusionTerm inheritance #548

Open
lockwo opened this issue Dec 31, 2024 · 5 comments
Open

UnderdampedLangevinDiffusionTerm inheritance #548

lockwo opened this issue Dec 31, 2024 · 5 comments
Labels
question User queries

Comments

@lockwo
Copy link
Contributor

lockwo commented Dec 31, 2024

Should UnderdampedLangevinDiffusionTerm inherit from _AbstractControlTerm (in theory, not talking about whether this reduces LoC). The reason I ask is because working on the stateful control, I'm parsing terms to check if they are control terms (which is just done via inheritance check), and given that the diffusion term is a control term, seems like it should inherit (even if there's just some new "AbstractControlTerm" which doesn't have any code but unified inheritance). If this is not the case, should there just be a hard coded check for UnderdampedLangevinDiffusionTerm, or what is the best practice to parse some terms and determine which are control terms?

@patrick-kidger
Copy link
Owner

So _AbstractControlTerm is a private implementation detail of ControlTerm and WeaklyDiagonalControlTerm, just to factor out shared code. Although FWIW as we can now pass lineax.DiagonalLinearOperators to ControlTerm then I'm thinking I may simplify things by:

  • folding _AbstractControlTerm into ControlTerm
  • maintaining backward compatibility by just turn WeaklyDiagonalControlTerm into a stub with a __new__ that returns a ControlTerm.

On to the nature of your question: there's really no such thing as a 'control term' -- these are not any different from regular terms. I think what you want is to have a pytree-of-states corresponding to the pytree-of-terms. And it will simply happen that e.g. ODETerm has None as its state.

@patrick-kidger patrick-kidger added the question User queries label Dec 31, 2024
@lockwo
Copy link
Contributor Author

lockwo commented Dec 31, 2024

Yea I have it such that it's a pytree of states where ODE is just None state.

It was mostly during an initialization step, the ODEterm has no control member variable to do an initialization operation with (which could be checked via attributes). That is to say, when creating the pytree of states, it basically wants to call the AbstractPath's init method (which an ODETerm doesn't have, so I would return None for that state) and so my first instinct was just to parse which ones I call init for (which are the control terms). Writing a new init method for the term seemed unnecessary at first since it would be just calling the control.init anyway.

I know the treatment of terms in a mostly uniform matter is one the attributes of diffrax, so I'm curious if you ever thought about having the ODETerm have a control as well (where the control is just the trivial t1 - t0 like it is now, but as a separated entity), then there would be even less difference in operating with control term vs ODETerm (not advocating for that idea, just curious).

@lockwo
Copy link
Contributor Author

lockwo commented Jan 1, 2025

To be concrete, sometimes the code to the effect of (just my intermediate hack)

def _path_init(term):
    if isinstance(term, _AbstractControlTerm) or isinstance(
        term, UnderdampedLangevinDiffusionTerm
    ):
        if isinstance(term.control, AbstractPath):
            return term.control.init(t0, end, y0, args, max_steps)
        return None
    elif isinstance(term, MultiTerm):
        return jax.tree.map(
            lambda x: _path_init(x, end),
            term.terms,
            is_leaf=lambda x: isinstance(x, AbstractTerm),
        )
    return None

path_state = jtu.tree_map(
    lambda x: _path_init(x, t1),
    terms,
    is_leaf=lambda x: isinstance(x, AbstractTerm),
)

comes up. Since ODETerm doesn't have a control, and term doesn't have an init. Either of which I could add to solve it (or refine the hack, or do the blind faith hasattr check).

@patrick-kidger
Copy link
Owner

So my first reaction is indeed to add init to all terms, for the ODETerm one to just return None, and for the ControlTerm one to delegate onwards to control.init.

I'm not super happy with that but I'm definitely less happy with the other options!

  • First of all, historically: I originally decided against expressing ODEs as ControlTerm(vf, Time()) because that felt a bit too complicated an API to do the 95% use case of solving an ODE. In addition the original versions of Diffrax needed to be able to express the distinction between e.g. ControlTerm and WeaklyDiagonalControlTerm, as Lineax wouldn't exist for another 2 years :)

    And now, relative to where we are today -- I think I'm still happy with the term abstraction. It allows us to express things like UnderdampedLangevin{Drift,Diffusion}Term, whilst keeping their individual components accessible as attributes, which wouldn't really be doable analogously with separate vector fields and controls.

    And, given the existence of a term abstraction -- I prefer having an ODETerm over e.g. def ode_term(vf): return ControlTerm(vf, Time()), as the latter would mean that we have a mix of classes and factory functions to create our term(s), and I think that makes for a more complicated UX.

    So overall even working from a blank slate, I think there's very little I would change about the terms and their class hierarchy!

  • I'd definitely prefer anyything over using hasattr -- I feel pretty strongly against the use of protocols/hasattr/structural typing, in favour of ABCs/isinstance/nominal typing!

@lockwo
Copy link
Contributor Author

lockwo commented Jan 3, 2025

So my first reaction is indeed to add init to all terms, for the ODETerm one to just return None, and for the ControlTerm one to delegate onwards to control.init.
I'm not super happy with that but I'm definitely less happy with the other options!

Makes sense, its not super appealing as a solution, but seems like the best one

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants