Skip to content

Commit

Permalink
add asarray
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Jan 18, 2021
1 parent 9b8c129 commit 4d75600
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 4 deletions.
7 changes: 7 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ The following involve no memory copies.
# arr = cuvec.zeros((1337, 42), "float32")
my_custom_lib.some_cpython_api_func(arr.cuvec)
**CPython API** to **Python**

.. code:: python
import cuvec, my_custom_lib
arr = cuvec.asarray(my_custom_lib.some_cpython_api_func())
**CPython API** to **C++**

.. code:: cpp
Expand Down
4 changes: 2 additions & 2 deletions cuvec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# classes
'CuVec',
# functions
'dev_sync', 'copy', 'zeros', 'cu_copy', 'cu_zeros',
'dev_sync', 'copy', 'zeros', 'asarray', 'cu_copy', 'cu_zeros',
# data
'typecodes', 'vec_types'] # yapf: disable

Expand All @@ -35,7 +35,7 @@
from warnings import warn
warn(str(err), UserWarning)
else:
from .helpers import CuVec, copy, zeros
from .helpers import CuVec, asarray, copy, zeros
from .pycuvec import cu_copy, cu_zeros, typecodes, vec_types

# for use in `cmake -DCMAKE_PREFIX_PATH=...`
Expand Down
16 changes: 14 additions & 2 deletions cuvec/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ class CuVec(np.ndarray):
"""
_Vector_types = tuple(vec_types.values())

def __new__(cls, arr, cuvec=None):
def __new__(cls, arr, raw=None):
"""arr: `cuvec.CuVec`, raw `cuvec.Vector_*`, or `numpy.ndarray`"""
if isinstance(arr, CuVec._Vector_types):
if isinstance(arr, CuVec._Vector_types) or raw:
log.debug("wrap raw %s", type(arr))
obj = np.asarray(arr).view(cls)
obj.cuvec = arr
Expand Down Expand Up @@ -53,3 +53,15 @@ def copy(arr):
(`cuvec` equivalent of `numpy.copy`).
"""
return CuVec(cu_copy(arr))


def asarray(cuvec):
"""
Returns `CuVec(cuvec, raw=True)`.
Intended to wrap CPython API functions returning `PyCuVec<T> *` PyObjects.
This is needed since `CuVec(cuvec, False)` won't work if
`isinstance(cuvec, CuVec) == False` due to external libraries
`#include "pycuvec.cuh"` making a distinct type object.
"""
return CuVec(cuvec, raw=True)
1 change: 1 addition & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def test_CuVec_creation(caplog):
assert not caplog.record_tuples
w = cuvec.CuVec(v)
assert [i[1:] for i in caplog.record_tuples] == [(10, "new view")]
assert cuvec.asarray(w.cuvec).cuvec == w.cuvec

caplog.clear()
assert w[0, 0, 0] == 1
Expand Down

0 comments on commit 4d75600

Please sign in to comment.