TaoLv closed pull request #13555: [MXNET-1253] fix control_flow_op
URL: https://github.com/apache/incubator-mxnet/pull/13555
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/src/operator/tensor/control_flow_op.h
b/src/operator/tensor/control_flow_op.h
index 07252963c87..9d0e8cf9081 100644
--- a/src/operator/tensor/control_flow_op.h
+++ b/src/operator/tensor/control_flow_op.h
@@ -46,7 +46,7 @@ struct where {
// DType is the output data type
// CType is condition data type
template<typename DType, typename CType>
- MSHADOW_XINLINE static void Map(int i, DType* out, const CType* cond,
+ MSHADOW_XINLINE static void Map(index_t i, DType* out, const CType* cond,
const DType* x, const DType* y) {
KERNEL_ASSIGN(out[i], req, (0 != cond[i]? x[i] : y[i]));
}
@@ -64,7 +64,7 @@ struct where_csr {
// CType is condition data type
// i is for i-th row in the output
template<typename DType, typename CType, typename IType>
- MSHADOW_XINLINE static void Map(int i, DType* out, const IType* cond_idx,
+ MSHADOW_XINLINE static void Map(index_t i, DType* out, const IType* cond_idx,
const IType* cond_indptr, const CType*
cond_data,
const nnvm::dim_t num_cols, const DType* x) {
using nnvm::dim_t;
@@ -92,8 +92,8 @@ struct where_batch {
// DType is the output data type
// CType is the condition data type
template<typename DType, typename CType>
- MSHADOW_XINLINE static void Map(int i, DType* out, const CType* cond,
- const DType* x, const DType* y, int M) {
+ MSHADOW_XINLINE static void Map(index_t i, DType* out, const CType* cond,
+ const DType* x, const DType* y, index_t M) {
KERNEL_ASSIGN(out[i], req, (0 != cond[i/M]? x[i] : y[i]));
}
};
@@ -109,7 +109,7 @@ struct where_backward {
// DType is the output data type
// CType is condition data type
template<typename DType, typename CType>
- MSHADOW_XINLINE static void Map(int i, DType* grad_out,
+ MSHADOW_XINLINE static void Map(index_t i, DType* grad_out,
const DType* grad_in,
const CType* cond) {
KERNEL_ASSIGN(grad_out[i], req,
@@ -130,7 +130,7 @@ struct where_backward_csr {
// CType is condition data type
// IType is condition aux data type
template<typename DType, typename CType, typename IType>
- MSHADOW_XINLINE static void Map(int i, DType* grad_out,
+ MSHADOW_XINLINE static void Map(index_t i, DType* grad_out,
const DType* grad_in,
const CType* cond_data,
const IType* cond_idx,
@@ -161,9 +161,9 @@ struct where_batch_backward {
// DType is the output data type
// CType is condition data type
template<typename DType, typename CType>
- MSHADOW_XINLINE static void Map(int i, DType* grad_out,
+ MSHADOW_XINLINE static void Map(index_t i, DType* grad_out,
const DType* grad_in,
- const CType* cond, int M) {
+ const CType* cond, index_t M) {
KERNEL_ASSIGN(grad_out[i], req,
((0 == cond[i/M])^negate)? grad_in[i] : static_cast<DType>(0));
}
diff --git a/tests/nightly/test_large_array.py
b/tests/nightly/test_large_array.py
index a301362f2db..696fdb1d417 100644
--- a/tests/nightly/test_large_array.py
+++ b/tests/nightly/test_large_array.py
@@ -134,6 +134,17 @@ def test_Dense(ctx=mx.cpu(0)):
res.wait_to_read()
assert res.shape == (50000000, 100)
+def test_where():
+ a = nd.ones(shape=(LARGE_X, SMALL_Y))
+ b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1)
+ b = nd.broadcast_to(b, shape=(b.shape[0], SMALL_Y))
+ res = nd.where(b > 100, a, b)
+ assert np.sum(res[-1].asnumpy() == 1) == b.shape[1]
+
+ csr_cond = nd.sparse.cast_storage(b < 10, 'csr')
+ res = nd.sparse.where(csr_cond, a, b)
+ assert np.sum(res[0].asnumpy() == 1) == b.shape[1]
+
if __name__ == '__main__':
import nose
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services