You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Each solver in JAXopt maintains a state, which contains a bunch of attributes, some of them are scalar-valued (e.g.,stepsize). In some solvers, the state returned by init_state and the state returned by update have inconsistent weak_type for some of these attributes, which triggers a JIT recompilation of update. This can be seen with the code example below (code by @fllinares):
$ python recompilation_issue.py
state0.stepsize.weak_type False
First call: 0.5013151168823242
state1.stepsize.weak_type True
Second call: 0.463458776473999
state2.stepsize.weak_type True
Third call: 0.00026679039001464844
state3.stepsize.weak_type True
state4.stepsize.weak_type True
Fourth call: 0.0002238750457763672
What's happening: state0 is obtained from theinit_state call and is given as input to update. A first JIT compilation happens, withstate0.stepsize.weak_type = False.
Then update outputs a new state state1 with state1.stepsize.weak_type = True. When we use that state as input to update, a JIT recompilation happens since weak_type has changed. For the following calls to update, no recompilation occurs, since weak_type remains True.
Similarly to the dtype and aux consistency checks in common_test.py, we need to check weak_type consistency for each solver in a systematic manner, and make fixes if necessary.
@froystig Your opinion on the best way to fix would be welcome.
The text was updated successfully, but these errors were encountered:
Similarly to the dtype and aux consistency checks in common_test.py, we need to check weak_type consistency for each solver in a systematic manner, and make fixes if necessary.
This sounds like the right fix to me. A change in weak type for a function input is like changing the dtype: it can change the function's behavior, and must trigger a re-compilation. If you want to avoid recompilation, you need to make sure the inputs of the second function call match the inputs of the first function call.
Yes, that's expected. Roughly, the mental model of "weak type" is that it's a value whose dtype has not been specified by the user. It's the mechanism that allows (x + 1).dtype == x.dtype to hold true within JAX code.
Each solver in JAXopt maintains a state, which contains a bunch of attributes, some of them are scalar-valued (e.g.,
stepsize
). In some solvers, the state returned byinit_state
and the state returned byupdate
have inconsistentweak_type
for some of these attributes, which triggers a JIT recompilation ofupdate
. This can be seen with the code example below (code by @fllinares):Output:
What's happening:
state0
is obtained from theinit_state
call and is given as input toupdate
. A first JIT compilation happens, withstate0.stepsize.weak_type = False
.Then
update
outputs a new statestate1
withstate1.stepsize.weak_type = True
. When we use that state as input toupdate
, a JIT recompilation happens sinceweak_type
has changed. For the following calls toupdate
, no recompilation occurs, sinceweak_type
remainsTrue
.Similarly to the
dtype
andaux
consistency checks incommon_test.py
, we need to checkweak_type
consistency for each solver in a systematic manner, and make fixes if necessary.@froystig Your opinion on the best way to fix would be welcome.
The text was updated successfully, but these errors were encountered: