anirudh2290 commented on a change in pull request #10374: [MXNET-93] Sparse
support for Custom Op
URL: https://github.com/apache/incubator-mxnet/pull/10374#discussion_r180322202
##
File path: python/mxnet/operator.py
##
@@ -522,6 +526,69 @@ def infer_type(self, in_type):
return in_type, [in_type[0]]*len(self.list_outputs()), \
[in_type[0]]*len(self.list_auxiliary_states())
+def infer_storage_type(self, in_stype):
+"""infer_storage_type interface. Used to infer storage type of
+inputs and outputs in the forward pass.
+
+Parameters
+--
+in_stype : list of stypes, Valid stypes are default, row_sparse and
+csr
+
+Returns
+---
+in_stype : list
+list of argument stypes.
+out_stype : list
+list of output types calculated from in_stype,
+in the same order as declared in list_outputs.
+aux_type : Optional, list
+list of aux types calculated from in_stype,
+in the same order as declared in list_auxiliary_states.
+"""
+return in_stype, [in_stype[0]]*len(self.list_outputs()), \
+[in_stype[0]]*len(self.list_auxiliary_states())
+
+def infer_storage_type_backward(self, ograd_stype, in_stype, out_stype,
igrad_stype, aux_stype):
+"""infer_storage_type_backward interface. Used to infer storage
+type of inputs and outputs in the backward pass.
+
+Will raise an error if undefined storage type is returned.
+Returned lists have to be the same size as the input lists to
infer_storage_type_backward,
+otherwise an exception will be thrown. When this interface is not
implemented,
+all stypes will fallback to default.
Review comment:
for the default implementation only default stypes are supported. that is
why i replicated stype of in_stypes. I have added asserts now in
infer_storage_type and infer_storage_type_backward to prevent misuse.
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org
With regards,
Apache Git Services