Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically use the correct device in xp.clip with passed Python number literal as bounds #177

Closed
ogrisel opened this issue Aug 8, 2024 · 7 comments

Comments

@ogrisel
Copy link

ogrisel commented Aug 8, 2024

I would like the following not to fail with PyTorch:

>>> import array_api_compat.torch  as xp
>>> data = xp.linspace(0, 1, num=5, device="mps")
>>> xp.clip(data, 0.1, 0.9)
Traceback (most recent call last):
  Cell In[4], line 1
    xp.clip(data, 0.1, 0.9)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/array_api_compat/_internal.py:28 in wrapped_f
    return f(*args, xp=xp, **kwargs)
  File ~/miniforge3/envs/dev/lib/python3.11/site-packages/array_api_compat/common/_aliases.py:317 in clip
    ia = (out < a) | xp.isnan(a)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, mps:0 and cpu!

At the moment, we need to be overly verbose to use xp.clip with pytorch on non-cpu tensors:

>>> from array_api_compat import device
>>> device_ = device(data)
>>> xp.clip(data, xp.asarray(0.1, device=device_), xp.asarray(0.9, device=device_))
tensor([0.1000, 0.2500, 0.5000, 0.7500, 0.9000], device='mps:0')
@ogrisel
Copy link
Author

ogrisel commented Aug 8, 2024

Note that I have not investigated if other array API namespaces suffer from the same problem.

@asmeurer
Copy link
Member

asmeurer commented Aug 8, 2024

Ah, I completely forgot about devices when I wrote this wrapper.

@ogrisel
Copy link
Author

ogrisel commented Aug 9, 2024

BTW, is this something that should be made explicit in the spec itself? Or would that just make the spec unnecessarily verbose?

Maybe it could just be tested in array-api-tests.

@rgommers
Copy link
Member

rgommers commented Aug 9, 2024

I think it's covered by the general design principles, e.g., https://data-apis.org/array-api/latest/design_topics/device_support.html, "This standard chooses to add support for method 3 (local control), with the convention that execution takes place on the same device where all argument arrays are allocated"

We could add a bullet point to that bullet point list that Python scalars and other such non-array-library objects should not influence device assignment.

And +1 for a test.

@asmeurer
Copy link
Member

asmeurer commented Aug 9, 2024

Yes, unfortunately, device support is not tested at all in the test suite right now.

@ogrisel
Copy link
Author

ogrisel commented Sep 2, 2024

Note that the dtype should similarly be induced from the first argument to avoid unwanted up-casting. That is, I would like the following to hold automatically:

a = xp.linspace(-1, 1, 10, dtype=xp.float32)
assert xp.clip(a, 0, 1).dtype == xp.float32

EDIT: it seems that dtype handling is part of #166.

@asmeurer
Copy link
Member

asmeurer commented Sep 3, 2024

The clip wrapper has been a little annoying to get right, primarily because of the "no promotion" behavior plus the fact that it accepts scalars. But I've hopefully ironed out all the wrinkles in #166 (except for a minor known issue that dask won't handle some cases with uint64 arrays correctly because of the way NumPy upcasts to float64).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants