piiswrong closed pull request #10507: Fix infer storage type URL: https://github.com/apache/incubator-mxnet/pull/10507
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/python/mxnet/operator.py b/python/mxnet/operator.py index 60895bdca7b..1da6628e68d 100644 --- a/python/mxnet/operator.py +++ b/python/mxnet/operator.py @@ -552,9 +552,10 @@ def infer_storage_type(self, in_stype): "Default infer_storage_type implementation doesnt allow non default stypes: " \ "found non default stype '%s' for in_stype[%d]. Please implement " \ "infer_storage_type and infer_storage_type_backward interface " \ - "in your custom operator if you have non-default input stypes" % (stype, i) - return in_stype, [in_stype[0]]*len(self.list_outputs()), \ - [in_stype[0]]*len(self.list_auxiliary_states()) + "in your custom operator if you have non-default input/output stypes" % (stype, i) + return in_stype, \ + [_STORAGE_TYPE_ID_TO_STR[_STORAGE_TYPE_DEFAULT]]*len(self.list_outputs()), \ + [_STORAGE_TYPE_ID_TO_STR[_STORAGE_TYPE_DEFAULT]]*len(self.list_auxiliary_states()) def infer_storage_type_backward(self, ograd_stype, in_stype, out_stype, igrad_stype, aux_stype): """infer_storage_type_backward interface. Used to infer storage diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index ca1c42c2386..67eefabd33d 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4176,6 +4176,42 @@ def infer_storage_type_backward(self, ograd_stype, in_stype, out_stype, igrad_st assert_almost_equal(rhs.asnumpy(), lhs.grad.asnumpy(), rtol=rtol, atol=atol) assert_almost_equal(lhs.asnumpy(), rhs.grad.asnumpy(), rtol=rtol, atol=atol) + class NoInputOp(mx.operator.CustomOp): + def __init__(self, length, depth): + super(NoInputOp, self).__init__() + self.output = np.ones(shape=(length, depth), dtype=np.float32) + + def forward(self, is_train, req, in_data, out_data, aux): + self.assign(out_data[0], req[0], self.output) + + def backward(self, req, out_grad, in_data, out_data, in_grad, aux): + pass + + @mx.operator.register("no_input_op") + class NoInputOpProp(mx.operator.CustomOpProp): + def __init__(self, length, depth): + super(NoInputOpProp, self).__init__() + self.length = int(length) + self.depth = int(depth) + + def list_arguments(self): + return [] + + def list_outputs(self): + return ['output'] + + def infer_shape(self, in_shape): + return [], [(self.length, self.depth)], [] + + def infer_type(self, in_type): + return [], [np.float32], [] + + def create_operator(self, ctx, shapes, dtypes): + return NoInputOp(length=self.length, depth=self.depth) + + with mx.autograd.record(): + x = mx.nd.Custom(length=10, depth=10, op_type="no_input_op") + assert_almost_equal(x.asnumpy(), np.ones(shape=(10, 10), dtype=np.float32)) @with_seed() def test_psroipooling(): ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services