This is an automated email from the ASF dual-hosted git repository.
sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new bb7e0cb Fix binary broadcast shape large tensor (#19070)
bb7e0cb is described below
commit bb7e0cb29da846bf2110693d76073cec5a76862c
Author: Zhaoqi Zhu <[email protected]>
AuthorDate: Thu Sep 3 01:04:26 2020 -0700
Fix binary broadcast shape large tensor (#19070)
* fix binary broadcast shape
* tweak
* Revert "tweak"
This reverts commit f40c844a27390dbba0d716f3dc36f451fbcc528d.
---
src/operator/tensor/elemwise_binary_broadcast_op.h | 2 +-
tests/nightly/test_np_large_array.py | 6 ++----
2 files changed, 3 insertions(+), 5 deletions(-)
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h
b/src/operator/tensor/elemwise_binary_broadcast_op.h
index e3ba92d..a5bfdd7 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -58,7 +58,7 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
const int bl = out.ndim() - lhs.ndim();
const int br = out.ndim() - rhs.ndim();
for (int i = 0; i < out.ndim(); ++i) {
- int l = 1, r = 1;
+ dim_t l = 1, r = 1;
if (i >= bl) l = lhs[i-bl];
if (i >= br) r = rhs[i-br];
if (!mxnet::dim_size_is_known(l) || !mxnet::dim_size_is_known(r)) continue;
diff --git a/tests/nightly/test_np_large_array.py
b/tests/nightly/test_np_large_array.py
index 606b4b5..e0ec0df 100644
--- a/tests/nightly/test_np_large_array.py
+++ b/tests/nightly/test_np_large_array.py
@@ -142,11 +142,8 @@ def test_add():
assert A.grad.shape == (INT_OVERFLOW, 2)
assert A.grad[0][0] == 1
-# this will fail; broadcast needs to be fixed
-# TODO add backward test after forward is fixed
@use_np
[email protected](reason='Does not support large tensor; to be fixed')
-def test_add_broadcast():
+def test_binary_broadcast():
A = np.ones((INT_OVERFLOW, 2))
B = np.ones((INT_OVERFLOW, 1))
C = np.add(A, B)
@@ -571,6 +568,7 @@ def test_slice_assign():
B[-1] = 2
assert B[-1, 0] == 2 and B[-1, 1] == 2
+
'''
_ _
_ _ _ _ _ __ _ __ _ _ _____ _| |_ ___ _ _ __(_)___ _ _