Skip to content

Commit

Permalink
Copy slice along arbitrary axis (apache#3259)
Browse files Browse the repository at this point in the history
* rnn-cell demo (push to server for testing)

* a running example with cuDNN RNN cell

* add copyslice along arbitrary axis for NDArray

* copy_slice_to as an ndarray operator

* Python interface to the _copy_slice_to operator

* fix lint error
  • Loading branch information
pluskid authored Sep 9, 2016
1 parent 873b928 commit d67964c
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 3 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,6 @@ dmlc-core
ps-lite
nnvm
lib

# Visual Studio Code
.vscode
6 changes: 3 additions & 3 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ MXNET_DLL int MXNDArrayWaitAll();
MXNET_DLL int MXNDArrayFree(NDArrayHandle handle);
/*!
* \brief Slice the NDArray along axis 0.
* \param handle the handle to the narraya
* \param handle the handle to the NDArray
* \param slice_begin The beginning index of slice
* \param slice_end The ending index of slice
* \param out The NDArrayHandle of sliced NDArray
Expand All @@ -322,9 +322,9 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle,
NDArrayHandle *out);
/*!
* \brief Index the NDArray along axis 0.
* \param handle the handle to the narraya
* \param handle the handle to the NDArray
* \param idx the index
* \param out The NDArrayHandle of sliced NDArray
* \param out The NDArrayHandle of output NDArray
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayAt(NDArrayHandle handle,
Expand Down
25 changes: 25 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,19 @@ class NDArray {
});
return res;
}
/*!
* \return a chunk of raw data in TBlob
*/
inline TBlob raw_data(index_t offset, index_t length) const {
TBlob res;
TShape raw_shape(1);
raw_shape[0] = length;
MSHADOW_TYPE_SWITCH(dtype_, DType, {
res = TBlob(static_cast<DType*>(ptr_->shandle.dptr)
+ offset_ + offset, raw_shape, ptr_->shandle.ctx.dev_mask());
});
return res;
}
/*!
* \return the context of NDArray, this function is only valid when the NDArray is not empty
*/
Expand Down Expand Up @@ -368,6 +381,18 @@ class NDArray {
*/
void CopyFromTo(const NDArray &from, NDArray *to, int priority = 0);

/*!
* \brief copy a slice along any axis.
* \param from the NDArray we want to slice from
* \param slice_dim the axis we want to perform slice in
* \param start the beginning of the slice
* \param end the ending of the slice
* \param to the pre-allocated NDArray to copy the slice to
* \param priority the priority of the task
*/
void CopySliceTo(const NDArray &from, int slice_dim, index_t start, index_t end,
NDArray *to, int priority = 0);

/*!
* \brief Perform elementwise sum over each data from source, store result into out.
* \param source the ndarray we want to sum
Expand Down
27 changes: 27 additions & 0 deletions python/mxnet/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,33 @@ def _slice(self, start, stop):
self.handle, start, stop, ctypes.byref(handle)))
return NDArray(handle=handle, writable=self.writable)

def _copy_slice_to(self, axis, start, stop, target):
"""Copy a slice along an axis.
Parameters
----------
axis : int
The axis along which to do slicing.
start : int
The starting index of the slice.
stop : int
The finishing index of the slice.
target : NDArray or Context
If an NDArray, must be pre-allocated with compatible shape.
If a Context, a new NDArray will be created.
Returns
-------
The sliced copy of the NDArray.
"""
if isinstance(target, Context):
shape = list(self.shape)
shape[axis] = stop - start
target = NDArray(_new_alloc_handle(shape, target, True, self.dtype))

assert isinstance(target, NDArray)
return _internal._copy_slice_to(self, axis, start, stop, out=target)

def _at(self, idx):
"""Return a sub NDArray that shares memory with current one.
Expand Down
73 changes: 73 additions & 0 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,66 @@ void ScalarOp(const NDArray &lhs,
}
}

void CopySliceTo(const NDArray &from, int slice_dim, index_t start, index_t end,
NDArray *to, int priority) {
CHECK(from.shape().ndim() == to->shape().ndim())
<< "from and to must have the same number of dimensions";
CHECK(slice_dim < from.shape().ndim())
<< "slice dimension out of bounds";
CHECK(start < end)
<< "slice is empty";
CHECK(end < from.shape()[slice_dim])
<< "slice out of bounds";

mshadow::Shape<3> from_shape = from.shape().FlatTo3D(slice_dim);
mshadow::Shape<3> to_shape = to->shape().FlatTo3D(slice_dim);
CHECK(from_shape[0] == to_shape[0] && from_shape[2] == to_shape[2])
<< "shape incompatible";
CHECK(end - start == to_shape[1])
<< "shape incompatible";

int a = from.ctx().dev_mask();
int b = to->ctx().dev_mask();

std::vector<Engine::VarHandle> const_vars{from.var()};
NDArray ret = *to;

#define MXNET_COPYSLICETO_IMPL(xpu1, xpu2) \
Engine::Get()->PushSync([from, ret, from_shape, start, end](RunContext ctx) { \
ret.CheckAndAlloc(); \
for (index_t i = 0; i < from_shape[0]; ++i) { \
index_t src_idx = i * (from_shape[1] * from_shape[2]) + \
start * from_shape[2]; \
index_t length = from_shape[2] * (end - start); \
index_t dst_idx = i * length; \
\
TBlob blob_from = from.raw_data(src_idx, length); \
TBlob blob_to = ret.raw_data(dst_idx, length); \
ndarray::Copy<xpu1, xpu2>(blob_from, &blob_to, \
from.ctx(), ret.ctx(), ctx); \
} \
}, from.ctx(), const_vars, {ret.var()}, \
FnProperty::kNormal, priority)

if (a == cpu::kDevMask && b == cpu::kDevMask) {
MXNET_COPYSLICETO_IMPL(cpu, cpu);
} else {
#if MXNET_USE_CUDA
if (a == cpu::kDevMask && b == gpu::kDevMask) {
MXNET_COPYSLICETO_IMPL(cpu, gpu);
} else if (a == gpu::kDevMask && b == cpu::kDevMask) {
MXNET_COPYSLICETO_IMPL(gpu, cpu);
} else if (a == gpu::kDevMask && b == gpu::kDevMask) {
MXNET_COPYSLICETO_IMPL(gpu, gpu);
} else {
LOG(FATAL) << "unknown device mask";
}
#else
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
#endif
}
}

void CopyFromTo(const NDArray &from, NDArray *to, int priority) {
CHECK(from.shape() == to->shape())
<< "operands shape mismatch";
Expand Down Expand Up @@ -743,6 +803,19 @@ MXNET_REGISTER_NDARRAY_FUN(_copyto)
.set_function(CopyFromToSimple)
.set_type_mask(kNDArrayArgBeforeScalar);

MXNET_REGISTER_NDARRAY_FUN(_copy_slice_to)
.set_body([](NDArray **u, real_t *s, NDArray **out,
int num_params, char **param_keys, char **param_vals) {
CopySliceTo(*u[0],
static_cast<index_t>(s[0]),
static_cast<index_t>(s[1]),
static_cast<index_t>(s[2]), out[0]);
})
.set_num_use_vars(1)
.set_num_scalars(3)
.set_num_mutate_vars(1)
.set_type_mask(kNDArrayArgBeforeScalar);

// register random number generators
MXNET_REGISTER_NDARRAY_FUN(_random_uniform)
.set_body([](NDArray **u, real_t *s, NDArray **out,
Expand Down
15 changes: 15 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,20 @@ def test_ndarray_slice():
A[3:8] = A2[3:8]
assert same(A[3:8].asnumpy(), A2[3:8])


def test_ndarray_slice_along_axis():
arr = mx.nd.array(np.random.uniform(-10, 10, (3, 4, 2, 3)))
sub_arr = mx.nd.zeros((3, 2, 2, 3))
arr._copy_slice_to(1, 1, 3, sub_arr)

# test we sliced correctly
assert same(arr.asnumpy()[:, 1:3, :, :], sub_arr.asnumpy())

# test that slice is copy, instead of shared memory
sub_arr[:] = 0
assert not same(arr.asnumpy()[:, 1:3, :, :], sub_arr.asnumpy())


def test_clip():
shape = (10,)
A = mx.random.uniform(-10, 10, shape)
Expand Down Expand Up @@ -261,6 +275,7 @@ def test_broadcast_to():
test_broadcast_to()

if __name__ == '__main__':
test_ndarray_slice_along_axis()
test_ndarray_slice()
test_ndarray_pickle()
test_ndarray_saveload()
Expand Down

0 comments on commit d67964c

Please sign in to comment.