This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 6660be1  Fix infer storage type (#10507)
6660be1 is described below

commit 6660be1f6592fbfb250fdbb53567f8938e788dbe
Author: Anirudh Subramanian <anirudh2...@gmail.com>
AuthorDate: Wed Apr 11 14:57:43 2018 -0700

    Fix infer storage type (#10507)
    
    * Fix infer_storage_type
    
    * Add test
    
    * Fix lint
    
    * Trigger CI
---
 python/mxnet/operator.py               |  7 ++++---
 tests/python/unittest/test_operator.py | 36 ++++++++++++++++++++++++++++++++++
 2 files changed, 40 insertions(+), 3 deletions(-)

diff --git a/python/mxnet/operator.py b/python/mxnet/operator.py
index 60895bd..1da6628 100644
--- a/python/mxnet/operator.py
+++ b/python/mxnet/operator.py
@@ -552,9 +552,10 @@ class CustomOpProp(object):
             "Default infer_storage_type implementation doesnt allow non 
default stypes: " \
             "found non default stype '%s' for in_stype[%d]. Please implement " 
\
             "infer_storage_type and infer_storage_type_backward interface " \
-            "in your custom operator if you have non-default input stypes" % 
(stype, i)
-        return in_stype, [in_stype[0]]*len(self.list_outputs()), \
-            [in_stype[0]]*len(self.list_auxiliary_states())
+            "in your custom operator if you have non-default input/output 
stypes" % (stype, i)
+        return in_stype, \
+               
[_STORAGE_TYPE_ID_TO_STR[_STORAGE_TYPE_DEFAULT]]*len(self.list_outputs()), \
+               
[_STORAGE_TYPE_ID_TO_STR[_STORAGE_TYPE_DEFAULT]]*len(self.list_auxiliary_states())
 
     def infer_storage_type_backward(self, ograd_stype, in_stype, out_stype, 
igrad_stype, aux_stype):
         """infer_storage_type_backward interface. Used to infer storage
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index ca1c42c..67eefab 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4176,6 +4176,42 @@ def test_custom_op():
     assert_almost_equal(rhs.asnumpy(), lhs.grad.asnumpy(), rtol=rtol, 
atol=atol)
     assert_almost_equal(lhs.asnumpy(), rhs.grad.asnumpy(), rtol=rtol, 
atol=atol)
 
+    class NoInputOp(mx.operator.CustomOp):
+        def __init__(self, length, depth):
+            super(NoInputOp, self).__init__()
+            self.output = np.ones(shape=(length, depth), dtype=np.float32)
+
+        def forward(self, is_train, req, in_data, out_data, aux):
+            self.assign(out_data[0], req[0], self.output)
+
+        def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
+            pass
+
+    @mx.operator.register("no_input_op")
+    class NoInputOpProp(mx.operator.CustomOpProp):
+        def __init__(self, length, depth):
+            super(NoInputOpProp, self).__init__()
+            self.length = int(length)
+            self.depth = int(depth)
+
+        def list_arguments(self):
+            return []
+
+        def list_outputs(self):
+            return ['output']
+
+        def infer_shape(self, in_shape):
+            return [], [(self.length, self.depth)], []
+
+        def infer_type(self, in_type):
+            return [], [np.float32], []
+
+        def create_operator(self, ctx, shapes, dtypes):
+            return NoInputOp(length=self.length, depth=self.depth)
+
+    with mx.autograd.record():
+        x = mx.nd.Custom(length=10, depth=10, op_type="no_input_op")
+    assert_almost_equal(x.asnumpy(), np.ones(shape=(10, 10), dtype=np.float32))
 
 @with_seed()
 def test_psroipooling():

-- 
To stop receiving notification emails like this one, please contact
j...@apache.org.

Reply via email to