Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ptxas version issue #427

Closed
zohimchandani opened this issue May 9, 2023 · 2 comments
Closed

ptxas version issue #427

zohimchandani opened this issue May 9, 2023 · 2 comments

Comments

@zohimchandani
Copy link

zohimchandani commented May 9, 2023

Running the following code snippet and getting an error:

import jax
from jaxopt import GradientDescent

jax.devices('cpu')

def f(x): 
    return x**2


opt = GradientDescent(fun=f, stepsize=0.1, maxiter = 100, verbose = True, value_and_grad=False)

opt.run([3])

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[1], line 12
      7     return x**2
     10 opt = GradientDescent(fun=f, stepsize=0.1, maxiter = 100, verbose = True, value_and_grad=False)
---> 12 opt.run([3])

File ~/.local/lib/python3.10/site-packages/jaxopt/_src/base.py:255, in IterativeSolver.run(self, init_params, *args, **kwargs)
    248   decorator = idf.custom_root(
    249       self.optimality_fun,
    250       has_aux=True,
    251       solve=self.implicit_diff_solve,
    252       reference_signature=reference_signature)
    253   run = decorator(run)
--> 255 return run(init_params, *args, **kwargs)

File ~/.local/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:251, in _custom_root.<locals>.wrapped_solver_fun(*args, **kwargs)
    249 args, kwargs = _signature_bind(solver_fun_signature, *args, **kwargs)
    250 keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251 return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)

    [... skipping hidden 5 frame]

File ~/.local/lib/python3.10/site-packages/jaxopt/_src/implicit_diff.py:207, in _custom_root.<locals>.make_custom_vjp_solver_fun.<locals>.solver_fun_flat(*flat_args)
    204 @jax.custom_vjp
...
    469 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    470 # to take in `host_callbacks`
--> 471 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: FAILED_PRECONDITION: Couldn't get ptxas/nvlink version string: INTERNAL: Couldn't invoke ptxas --version
@nouiz
Copy link

nouiz commented May 11, 2023

How did you install JAX? pip command line?
What is your environment/container?
XLA need ptxas and it doesn't find it. So maybe it isn't installed, or it is installed at a place that XLA doesn't find.

@zohimchandani
Copy link
Author

Fix below

pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants