Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

weak_type inconsistency causes jit recompilation #451

Closed
mblondel opened this issue Jun 28, 2023 · 4 comments
Closed

weak_type inconsistency causes jit recompilation #451

mblondel opened this issue Jun 28, 2023 · 4 comments
Labels
bug Something isn't working

Comments

@mblondel
Copy link
Collaborator

Each solver in JAXopt maintains a state, which contains a bunch of attributes, some of them are scalar-valued (e.g.,stepsize). In some solvers, the state returned by init_state and the state returned by update have inconsistent weak_type for some of these attributes, which triggers a JIT recompilation of update. This can be seen with the code example below (code by @fllinares):

import jax
import jax.numpy as jnp
import jaxopt
import sklearn.datasets
import functools
import time


def make_fun_with_aux(fun, dtype=None, has_aux=False):
  @functools.wraps(fun)
  def wrapper(*args, **kwargs):
    value = fun(*args, **kwargs)
    if dtype is not None:
      value = value.astype(dtype)
    aux = {'plus_one': value + 1.0, 'times_two': 2.0 * value}
    return (value, aux) if has_aux else value
  return wrapper


N_SAMPLES = 100
N_FEATURES = 20
N_CLASSES = 3
N_INFORMATIVE = 5

PARAMS_DTYPE = jnp.bfloat16
FUN_DTYPE = jnp.float32
HAS_AUX = True

data = sklearn.datasets.make_classification(n_samples=N_SAMPLES,
                                            n_features=N_FEATURES,
                                            n_classes=N_CLASSES,
                                            n_informative=N_INFORMATIVE,
                                            random_state=0)

init_params = jnp.zeros([N_FEATURES, N_CLASSES], dtype=PARAMS_DTYPE)
fun = make_fun_with_aux(
    jaxopt.objective.multiclass_logreg, dtype=FUN_DTYPE, has_aux=HAS_AUX)

solver = jaxopt.LBFGS(fun=fun, stepsize=1e-2, has_aux=HAS_AUX)
#solver = jaxopt.GradientDescent(fun=fun, stepsize=1e-2, has_aux=HAS_AUX)

update = jax.jit(solver.update)

data = jax.tree_map(jax.device_put, data)
state0 = jax.jit(solver.init_state)(init_params, data=data)
params0 = init_params

print("state0.stepsize.weak_type", state0.stepsize.weak_type)
print()

tic = time.time()
params1, state1 = update(params0, state0, data)
print('First call:', time.time() - tic)
print("state1.stepsize.weak_type", state1.stepsize.weak_type)
print()

tic = time.time()
params2, state2 = update(params1, state1, data)
print('Second call:', time.time() - tic)
print("state2.stepsize.weak_type", state2.stepsize.weak_type)
print()

tic = time.time()
params3, state3 = update(params2, state2, data)
print('Third call:', time.time() - tic)
print("state3.stepsize.weak_type", state3.stepsize.weak_type)
print()

tic = time.time()
params4, state4 = update(params3, state3, data)
print("state4.stepsize.weak_type", state4.stepsize.weak_type)
print('Fourth call:', time.time() - tic)

Output:

$ python recompilation_issue.py
state0.stepsize.weak_type False

First call: 0.5013151168823242
state1.stepsize.weak_type True

Second call: 0.463458776473999
state2.stepsize.weak_type True

Third call: 0.00026679039001464844
state3.stepsize.weak_type True

state4.stepsize.weak_type True
Fourth call: 0.0002238750457763672

What's happening: state0 is obtained from theinit_state call and is given as input to update. A first JIT compilation happens, withstate0.stepsize.weak_type = False.
Then update outputs a new state state1 with state1.stepsize.weak_type = True. When we use that state as input to update, a JIT recompilation happens since weak_type has changed. For the following calls to update, no recompilation occurs, since weak_type remains True.

Similarly to the dtype and aux consistency checks in common_test.py, we need to check weak_type consistency for each solver in a systematic manner, and make fixes if necessary.

@froystig Your opinion on the best way to fix would be welcome.

@mblondel mblondel added the bug Something isn't working label Jun 28, 2023
@jakevdp
Copy link

jakevdp commented Jun 28, 2023

Similarly to the dtype and aux consistency checks in common_test.py, we need to check weak_type consistency for each solver in a systematic manner, and make fixes if necessary.

This sounds like the right fix to me. A change in weak type for a function input is like changing the dtype: it can change the function's behavior, and must trigger a re-compilation. If you want to avoid recompilation, you need to make sure the inputs of the second function call match the inputs of the first function call.

@mblondel
Copy link
Collaborator Author

Thanks @jakevdp! I think I tracked down the inconsistency to this behavior:

>>> a = jnp.asarray(0.0, dtype=jnp.float32)
>>> a.weak_type
False
>>> a.dtype
dtype('float32')
>>> b = jnp.asarray(0.0)
>>> b.weak_type
True
>>> b.dtype
dtype('float32')

That is, when we explicitly specify dtype, weak_type is False, while weak_type is True if we don't explicitly specify dtype...

@jakevdp
Copy link

jakevdp commented Jun 28, 2023

Yes, that's expected. Roughly, the mental model of "weak type" is that it's a value whose dtype has not been specified by the user. It's the mechanism that allows (x + 1).dtype == x.dtype to hold true within JAX code.

@mblondel
Copy link
Collaborator Author

fixed by #458

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants