piiswrong closed pull request #10507: Fix infer storage type
URL: https://github.com/apache/incubator-mxnet/pull/10507
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/operator.py b/python/mxnet/operator.py
index 60895bdca7b..1da6628e68d 100644
--- a/python/mxnet/operator.py
+++ b/python/mxnet/operator.py
@@ -552,9 +552,10 @@ def infer_storage_type(self, in_stype):
             "Default infer_storage_type implementation doesnt allow non 
default stypes: " \
             "found non default stype '%s' for in_stype[%d]. Please implement " 
\
             "infer_storage_type and infer_storage_type_backward interface " \
-            "in your custom operator if you have non-default input stypes" % 
(stype, i)
-        return in_stype, [in_stype[0]]*len(self.list_outputs()), \
-            [in_stype[0]]*len(self.list_auxiliary_states())
+            "in your custom operator if you have non-default input/output 
stypes" % (stype, i)
+        return in_stype, \
+               
[_STORAGE_TYPE_ID_TO_STR[_STORAGE_TYPE_DEFAULT]]*len(self.list_outputs()), \
+               
[_STORAGE_TYPE_ID_TO_STR[_STORAGE_TYPE_DEFAULT]]*len(self.list_auxiliary_states())
 
     def infer_storage_type_backward(self, ograd_stype, in_stype, out_stype, 
igrad_stype, aux_stype):
         """infer_storage_type_backward interface. Used to infer storage
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index ca1c42c2386..67eefabd33d 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4176,6 +4176,42 @@ def infer_storage_type_backward(self, ograd_stype, 
in_stype, out_stype, igrad_st
     assert_almost_equal(rhs.asnumpy(), lhs.grad.asnumpy(), rtol=rtol, 
atol=atol)
     assert_almost_equal(lhs.asnumpy(), rhs.grad.asnumpy(), rtol=rtol, 
atol=atol)
 
+    class NoInputOp(mx.operator.CustomOp):
+        def __init__(self, length, depth):
+            super(NoInputOp, self).__init__()
+            self.output = np.ones(shape=(length, depth), dtype=np.float32)
+
+        def forward(self, is_train, req, in_data, out_data, aux):
+            self.assign(out_data[0], req[0], self.output)
+
+        def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
+            pass
+
+    @mx.operator.register("no_input_op")
+    class NoInputOpProp(mx.operator.CustomOpProp):
+        def __init__(self, length, depth):
+            super(NoInputOpProp, self).__init__()
+            self.length = int(length)
+            self.depth = int(depth)
+
+        def list_arguments(self):
+            return []
+
+        def list_outputs(self):
+            return ['output']
+
+        def infer_shape(self, in_shape):
+            return [], [(self.length, self.depth)], []
+
+        def infer_type(self, in_type):
+            return [], [np.float32], []
+
+        def create_operator(self, ctx, shapes, dtypes):
+            return NoInputOp(length=self.length, depth=self.depth)
+
+    with mx.autograd.record():
+        x = mx.nd.Custom(length=10, depth=10, op_type="no_input_op")
+    assert_almost_equal(x.asnumpy(), np.ones(shape=(10, 10), dtype=np.float32))
 
 @with_seed()
 def test_psroipooling():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to