eric-haibin-lin commented on a change in pull request #10374: Sparse support 
for Custom Op
URL: https://github.com/apache/incubator-mxnet/pull/10374#discussion_r178686915
 
 

 ##########
 File path: src/operator/custom/custom.cc
 ##########
 @@ -266,97 +267,237 @@ OpStatePtr CreateState(const NodeAttrs& attrs, Context 
ctx,
   return OpStatePtr::Create<CustomParam>(state);
 }
 
-void Forward(const OpStatePtr& state,
-             const OpContext& ctx,
-             const std::vector<TBlob>& inputs,
+void Forward(const OpStatePtr& state, const OpContext& ctx,
+             const std::vector<NDArray>& inputs,
              const std::vector<OpReqType>& req,
-             const std::vector<TBlob>& outputs) {
+             const std::vector<NDArray>& outputs) {
   const CustomParam& params = state.get_state<CustomParam>();
   std::vector<void*> ptrs;
   std::vector<int> tags;
   std::vector<NDArray> cpys;
+  std::unordered_set<int> input_tags({0, 4});
+  std::unordered_set<int> output_tags({1});
 
   auto dev_id = ctx.run_ctx.ctx.dev_id;
 
   for (size_t i = 0; i < params.num_args; ++i) {
-    NDArray *nd = new NDArray(inputs[i], dev_id);
+    NDArray* nd;
+    allocate_ndarray_copy(&nd, inputs, i, dev_id);
     cpys.push_back(*nd);
     ptrs.push_back(reinterpret_cast<void*>(nd));
     tags.push_back(0);
   }
 
   for (size_t i = 0; i < params.num_outs; ++i) {
-    NDArray *nd = new NDArray(outputs[i], dev_id);
+    NDArray* nd;
+    allocate_ndarray_copy(&nd, outputs, i, dev_id);
     cpys.push_back(*nd);
     ptrs.push_back(reinterpret_cast<void*>(nd));
     tags.push_back(1);
   }
 
   for (size_t i = 0; i < params.num_auxs; ++i) {
-    NDArray *nd = new NDArray(inputs[i+params.num_args], dev_id);
+    size_t idx = i + params.num_args;
+    NDArray* nd;
+    allocate_ndarray_copy(&nd, inputs, idx, dev_id);
     cpys.push_back(*nd);
     ptrs.push_back(reinterpret_cast<void*>(nd));
     tags.push_back(4);
   }
 
   CustomOperator::Get()->Push(
-    [=]() {
-      
CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpForward])(
-        ptrs.size(), const_cast<void**>(ptrs.data()), 
const_cast<int*>(tags.data()),
-        reinterpret_cast<const int*>(req.data()), 
static_cast<int>(ctx.is_train),
-        params.info->contexts[kCustomOpForward]));
-    }, ctx, false, ctx.is_train, cpys);
+      [=]() {
+        CHECK(reinterpret_cast<CustomOpFBFunc>(
+            params.info->callbacks[kCustomOpForward])(
+            ptrs.size(), const_cast<void**>(ptrs.data()),
+            const_cast<int*>(tags.data()),
+            reinterpret_cast<const int*>(req.data()),
+            static_cast<int>(ctx.is_train),
+            params.info->contexts[kCustomOpForward]));
+      },
+      ctx, false, ctx.is_train, cpys, tags, input_tags, output_tags, inputs, 
outputs);
 }
 
-
-void Backward(const OpStatePtr& state,
-              const OpContext& ctx,
-              const std::vector<TBlob>& inputs,
+void Backward(const OpStatePtr& state, const OpContext& ctx,
+              const std::vector<NDArray>& inputs,
               const std::vector<OpReqType>& req,
-              const std::vector<TBlob>& outputs) {
+              const std::vector<NDArray>& outputs) {
   const CustomParam& params = state.get_state<CustomParam>();
 
-  size_t total = 2*params.num_args + 2*params.num_outs + params.num_auxs;
-  std::vector<void*> ptrs(params.num_args + 2*params.num_outs, nullptr);
+  size_t total = 2 * params.num_args + 2 * params.num_outs + params.num_auxs;
+  std::vector<void*> ptrs(params.num_args + 2 * params.num_outs, nullptr);
+
   std::vector<int> tags;
   std::vector<NDArray> cpys;
 
   ptrs.reserve(total);
   tags.reserve(total);
+  cpys.reserve(total);
+
+  std::unordered_set<int> input_tags({3, 0, 1, 4});
+  std::unordered_set<int> output_tags({2});
+
   for (size_t i = 0; i < params.num_outs; ++i) tags.push_back(3);
   for (size_t i = 0; i < params.num_args; ++i) tags.push_back(0);
   for (size_t i = 0; i < params.num_outs; ++i) tags.push_back(1);
 
   auto dev_id = ctx.run_ctx.ctx.dev_id;
 
+
   for (size_t i = 0; i < params.bwd_idx.size(); ++i) {
-    NDArray *nd = new NDArray(inputs[i], dev_id);
+    NDArray* nd;
+    allocate_ndarray_copy(&nd, inputs, i, dev_id);
     cpys.push_back(*nd);
     ptrs[params.bwd_idx[i]] = reinterpret_cast<void*>(nd);
   }
   for (size_t i = 0; i < ptrs.size(); ++i) {
-    if (ptrs[i] == nullptr) ptrs[i] = reinterpret_cast<void*>(new NDArray());
+    NDArray* nd;
+    if (ptrs[i] == nullptr) {
+        nd = new NDArray();
+        ptrs[i] = reinterpret_cast<void*>(nd);
+    }
   }
-  for (const auto& i : outputs) {
-    NDArray* nd = new NDArray(i, dev_id);
+  for (size_t i = 0; i < outputs.size(); ++i) {
+    NDArray* nd;
+    allocate_ndarray_copy(&nd, outputs, i, dev_id);
     cpys.push_back(*nd);
     ptrs.push_back(reinterpret_cast<void*>(nd));
     tags.push_back(2);
   }
+
   for (size_t i = 0; i < params.num_auxs; ++i) {
-    NDArray* nd = new NDArray(inputs[inputs.size()-params.num_auxs+i], dev_id);
+    size_t idx = inputs.size() - params.num_auxs + i;
+    NDArray* nd;
+    allocate_ndarray_copy(&nd, inputs, idx, dev_id);
     cpys.push_back(*nd);
     ptrs.push_back(reinterpret_cast<void*>(nd));
     tags.push_back(4);
   }
-
   CustomOperator::Get()->Push(
     [=]() {
       
CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpBackward])(
         ptrs.size(), const_cast<void**>(ptrs.data()), 
const_cast<int*>(tags.data()),
         reinterpret_cast<const int*>(req.data()), 
static_cast<int>(ctx.is_train),
         params.info->contexts[kCustomOpBackward]));
-    }, ctx, false, ctx.is_train, cpys);
+    }, ctx, false, ctx.is_train, cpys, tags, input_tags, output_tags, inputs, 
outputs);
+}
+
+// infer storage backward function for custom op which assigns kDefaultStorage 
for
+// all undefined stypes and dispatches on DispatchMode::kFComputeEx.
+inline bool BackwardInferStorageType(const nnvm::NodeAttrs& attrs,
+                                     const int dev_mask,
+                                     DispatchMode* dispatch_mode,
+                                     std::vector<int>* iattr,
+                                     std::vector<int>* oattr) {
+  const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
+
+  if (params.info->num_callbacks <= kCustomOpPropBackwardInferStorageType) {
+    for (size_t i = 0; i < iattr->size(); i++) {
+      STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, kDefaultStorage);
 
 Review comment:
   what if one of the input/output is sparse??? Would the check fail? Shouldn't 
it only assign stype to the undefined ones?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to