This is an automated email from the ASF dual-hosted git repository.
reminisce pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 87425d2 fix boolean_mask for 0-size output (#15731)
87425d2 is described below
commit 87425d2adf22ca4bb7bbd19016ee0a8d9bafdcb2
Author: Hao Jin <[email protected]>
AuthorDate: Thu Aug 1 22:12:00 2019 -0700
fix boolean_mask for 0-size output (#15731)
---
include/mxnet/ndarray.h | 6 +-----
src/imperative/imperative.cc | 4 +++-
src/ndarray/ndarray.cc | 7 +++++++
src/operator/contrib/boolean_mask.cc | 1 +
src/operator/contrib/boolean_mask.cu | 15 +++++++++------
tests/python/unittest/test_operator.py | 15 +++++++++++++++
6 files changed, 36 insertions(+), 12 deletions(-)
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index 428245b..176aa0a 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -190,11 +190,7 @@ class NDArray {
/*!
* \brief set the correct shape of NDArray directly from the storage_shape
of its own chunk.
*/
- void SetShapeFromChunk() {
- if (!(ptr_->storage_shape.ndim() == 1 && ptr_->storage_shape[0] == 0)) {
- shape_ = ptr_->storage_shape;
- }
- }
+ void SetShapeFromChunk();
/*
* This indicates whether an array is a view of another array (created by
* reshape or slice). If an array is a view and the data is stored in
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index e2c0c9d..c00021c 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -313,7 +313,9 @@ std::vector<NDArray*> Imperative::Backward(
} else {
info.outputs.emplace_back(outputs[i]->shape(), outputs[i]->ctx(),
true, outputs[i]->dtype());
- info.outputs.back() = static_cast<real_t>(1.0);
+ if (info.outputs.back().shape().Size() != 0) {
+ info.outputs.back() = static_cast<real_t>(1.0);
+ }
}
}
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index bee8bef..37c32c0 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -96,6 +96,13 @@ NDArray::NDArray(const NDArrayStorageType stype, const
mxnet::TShape &shape, Con
dtype, aux_types, aux_shapes);
}
+void NDArray::SetShapeFromChunk() {
+ if (Imperative::Get()->is_np_shape() ||
+ !(ptr_->storage_shape.ndim() == 1 && ptr_->storage_shape[0] == 0)) {
+ shape_ = ptr_->storage_shape;
+ }
+}
+
struct ChunkMem {
Storage::Handle h;
std::vector<Storage::Handle> aux_h;
diff --git a/src/operator/contrib/boolean_mask.cc
b/src/operator/contrib/boolean_mask.cc
index 4d66e1e..f431d77 100644
--- a/src/operator/contrib/boolean_mask.cc
+++ b/src/operator/contrib/boolean_mask.cc
@@ -143,6 +143,7 @@ inline void BooleanMaskForward<cpu>(const nnvm::NodeAttrs&
attrs,
// set the output shape forcefully
mxnet::TShape s = data.shape();
s[axis] = valid_num;
+
const_cast<NDArray &>(out).Init(s);
// do the copy
MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
diff --git a/src/operator/contrib/boolean_mask.cu
b/src/operator/contrib/boolean_mask.cu
index 47335bf..c4a06d2 100644
--- a/src/operator/contrib/boolean_mask.cu
+++ b/src/operator/contrib/boolean_mask.cu
@@ -79,7 +79,6 @@ inline void BooleanMaskForward<gpu>(const nnvm::NodeAttrs&
attrs,
Stream<gpu>::GetStream(s));
CUDA_CALL(cudaMemcpy(&valid_num, &prefix_sum[idx_size - 1], sizeof(int32_t),
cudaMemcpyDeviceToHost));
- CHECK(valid_num > 0) << "boolean_mask behavior not defined when all masks
are 0";
// Set the output shape forcefully
mxnet::TShape data_shape = data.shape();
data_shape[axis] = valid_num;
@@ -88,8 +87,10 @@ inline void BooleanMaskForward<gpu>(const nnvm::NodeAttrs&
attrs,
size_t col_size = input_size / idx.shape()[0];
// Do the copy
MSHADOW_TYPE_SWITCH(out.dtype(), DType, {
- mxnet_op::Kernel<BooleanMaskForwardKernel, gpu>::Launch(
- s, input_size, out.data().dptr<DType>(), data.data().dptr<DType>(),
prefix_sum, col_size);
+ if (valid_num > 0) {
+ mxnet_op::Kernel<BooleanMaskForwardKernel, gpu>::Launch(
+ s, input_size, out.data().dptr<DType>(), data.data().dptr<DType>(),
prefix_sum, col_size);
+ }
});
}
@@ -143,9 +144,11 @@ inline void BooleanMaskBackward<gpu>(const
nnvm::NodeAttrs& attrs,
size_t col_size = input_size / idx_size;
// Backward pass
MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, {
- mxnet_op::Kernel<BooleanMaskBackwardKernel, gpu>::Launch(
- s, input_size, igrad_data.data().dptr<DType>(), req[0],
ograd.data().dptr<DType>(),
- prefix_sum, col_size);
+ if (input_size > 0) {
+ mxnet_op::Kernel<BooleanMaskBackwardKernel, gpu>::Launch(
+ s, input_size, igrad_data.data().dptr<DType>(), req[0],
ograd.data().dptr<DType>(),
+ prefix_sum, col_size);
+ }
});
}
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index 8f1c253..72bf586 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -5589,6 +5589,21 @@ def test_boolean_mask():
assert same(out.asnumpy(), expected)
assert same(data.grad.asnumpy(), expected_grad)
+ # test 0-size output
+ mx.set_np_shape(True)
+ data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
+ index = mx.nd.array([0, 0, 0])
+ data.attach_grad()
+ with mx.autograd.record():
+ out = mx.nd.contrib.boolean_mask(data, index)
+ out.backward()
+ data.grad.wait_to_read()
+ expected = np.zeros((0, 3))
+ expected_grad = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
+ assert same(out.asnumpy(), expected)
+ assert same(data.grad.asnumpy(), expected_grad)
+ mx.set_np_shape(False)
+
# test gradient
shape = (100, 30)
a = mx.nd.random.randint(0, 100, shape=shape)