This is an automated email from the ASF dual-hosted git repository.
reminisce pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 31144c7 Fix (#17674)
31144c7 is described below
commit 31144c763bfd0fe199b7fe0f23a20555c9731e7a
Author: reminisce <[email protected]>
AuthorDate: Mon Feb 24 19:58:25 2020 -0800
Fix (#17674)
---
src/nnvm/plan_memory.cc | 25 +++++++++++++------------
tests/python/unittest/test_numpy_gluon.py | 21 +++++++++++++++++++++
2 files changed, 34 insertions(+), 12 deletions(-)
diff --git a/src/nnvm/plan_memory.cc b/src/nnvm/plan_memory.cc
index 6c6e02d..3815f23 100644
--- a/src/nnvm/plan_memory.cc
+++ b/src/nnvm/plan_memory.cc
@@ -38,21 +38,22 @@ namespace {
// Return bytes of data flag.
static int MXGetDTypeSize(int type_flag) {
switch (type_flag) {
- case kUint8:
- case kInt8:
+ case mshadow::kUint8:
+ case mshadow::kInt8:
+ case mshadow::kBool:
return 1;
- case kFloat16:
- case kBfloat16:
- case kInt16:
- case kUint16:
+ case mshadow::kFloat16:
+ case mshadow::kBfloat16:
+ case mshadow::kInt16:
+ case mshadow::kUint16:
return 2;
- case kFloat32:
- case kInt32:
- case kUint32:
+ case mshadow::kFloat32:
+ case mshadow::kInt32:
+ case mshadow::kUint32:
return 4;
- case kFloat64:
- case kInt64:
- case kUint64:
+ case mshadow::kFloat64:
+ case mshadow::kInt64:
+ case mshadow::kUint64:
return 8;
default:
LOG(FATAL) << "unknown type_flag=" << type_flag;
diff --git a/tests/python/unittest/test_numpy_gluon.py
b/tests/python/unittest/test_numpy_gluon.py
index 6ce9e18..0d1e5fe 100644
--- a/tests/python/unittest/test_numpy_gluon.py
+++ b/tests/python/unittest/test_numpy_gluon.py
@@ -400,6 +400,27 @@ def test_net_symbol_save_load():
mx.np.random.normal(0, 1, (10, 5, 8))])
+@with_seed()
+@use_np
+def test_hybridize_boolean_dtype():
+ class Foo(gluon.HybridBlock):
+ def __init__(self, prefix=None, params=None):
+ super(Foo, self).__init__(prefix=prefix, params=params)
+
+ def hybrid_forward(self, F, valid_length):
+ mask = ((F.np.ones((10,)) / 2) < valid_length)
+ return mask
+
+ valid_length = mx.np.random.uniform(size=(10,))
+ foo = Foo()
+ out1 = foo(valid_length)
+
+ foo = Foo()
+ foo.hybridize()
+ out2 = foo(valid_length)
+
+ assert mx.test_utils.same(out1.asnumpy(), out2.asnumpy())
+
if __name__ == '__main__':
import nose