Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jaxopt.Bisection not jittable, vmappable #422

Closed
deasmhumhna opened this issue Apr 10, 2023 · 2 comments
Closed

jaxopt.Bisection not jittable, vmappable #422

deasmhumhna opened this issue Apr 10, 2023 · 2 comments

Comments

@deasmhumhna
Copy link

def F(x, factor):
  return factor * x ** 3 - x - 2

def root(factor):
  bisec = jaxopt.Bisection(optimality_fun=F, lower=1, upper=2)
  return bisec.run(factor=factor).params

print(jax.vmap(root)(jnp.array([2.0, 2.0])))
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<BatchTrace(level=1/1)> with
  val = Array([False, False], dtype=bool, weak_type=True)
  batch_dim = 0
The problem arose with the `bool` function. 
This BatchTracer with object id 140682262908736 was created on line:
  <ipython-input-30-5f76524299dc>:7 (root)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

Turns out this is because root is not jittable, though I don't see any reason why it shouldn't be since it's internal loop is a lax.while_loop. Maybe there are some python ifs that could be lax.conds?

@mblondel
Copy link
Collaborator

You need to set check_bracket=False (see here).

@deasmhumhna
Copy link
Author

Ha, of course. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants