eric-haibin-lin commented on a change in pull request #10374: Sparse support 
for Custom Op
URL: https://github.com/apache/incubator-mxnet/pull/10374#discussion_r179902293
 
 

 ##########
 File path: src/operator/custom/custom.cc
 ##########
 @@ -266,97 +292,243 @@ OpStatePtr CreateState(const NodeAttrs& attrs, Context 
ctx,
   return OpStatePtr::Create<CustomParam>(state);
 }
 
-void Forward(const OpStatePtr& state,
-             const OpContext& ctx,
-             const std::vector<TBlob>& inputs,
-             const std::vector<OpReqType>& req,
-             const std::vector<TBlob>& outputs) {
+void ForwardEx(const OpStatePtr& state, const OpContext& ctx,
+               const std::vector<NDArray>& inputs,
+               const std::vector<OpReqType>& req,
+               const std::vector<NDArray>& outputs) {
   const CustomParam& params = state.get_state<CustomParam>();
   std::vector<void*> ptrs;
+  // Tags are provided to the callback to provide the frontend
   std::vector<int> tags;
   std::vector<NDArray> cpys;
 
+  // info on what ndarray is at each position in the input and output vector
+  // 0 - Input
+  // 1 - Output
+  // 4 - aux
+  std::unordered_set<int> input_tags({0, 4});
+  std::unordered_set<int> output_tags({1});
+
   auto dev_id = ctx.run_ctx.ctx.dev_id;
 
   for (size_t i = 0; i < params.num_args; ++i) {
-    NDArray *nd = new NDArray(inputs[i], dev_id);
+    NDArray* nd;
+    AllocateNDArrayCopy(&nd, inputs, i, dev_id);
     cpys.push_back(*nd);
     ptrs.push_back(reinterpret_cast<void*>(nd));
     tags.push_back(0);
   }
 
   for (size_t i = 0; i < params.num_outs; ++i) {
-    NDArray *nd = new NDArray(outputs[i], dev_id);
+    NDArray* nd;
+    AllocateNDArrayCopy(&nd, outputs, i, dev_id);
     cpys.push_back(*nd);
     ptrs.push_back(reinterpret_cast<void*>(nd));
     tags.push_back(1);
   }
 
   for (size_t i = 0; i < params.num_auxs; ++i) {
-    NDArray *nd = new NDArray(inputs[i+params.num_args], dev_id);
+    size_t idx = i + params.num_args;
+    NDArray* nd;
+    AllocateNDArrayCopy(&nd, inputs, idx, dev_id);
     cpys.push_back(*nd);
     ptrs.push_back(reinterpret_cast<void*>(nd));
     tags.push_back(4);
   }
 
   CustomOperator::Get()->Push(
-    [=]() {
-      
CHECK(reinterpret_cast<CustomOpFBFunc>(params.info->callbacks[kCustomOpForward])(
-        ptrs.size(), const_cast<void**>(ptrs.data()), 
const_cast<int*>(tags.data()),
-        reinterpret_cast<const int*>(req.data()), 
static_cast<int>(ctx.is_train),
-        params.info->contexts[kCustomOpForward]));
-    }, ctx, false, ctx.is_train, cpys);
+      [=]() {
+        CHECK(reinterpret_cast<CustomOpFBFunc>(
+            params.info->callbacks[kCustomOpForward])(
+            ptrs.size(), const_cast<void**>(ptrs.data()),
+            const_cast<int*>(tags.data()),
+            reinterpret_cast<const int*>(req.data()),
+            static_cast<int>(ctx.is_train),
+            params.info->contexts[kCustomOpForward]));
+      },
+      ctx, false, ctx.is_train, cpys, tags, output_tags, outputs);
 }
 
-
-void Backward(const OpStatePtr& state,
-              const OpContext& ctx,
-              const std::vector<TBlob>& inputs,
-              const std::vector<OpReqType>& req,
-              const std::vector<TBlob>& outputs) {
+void BackwardEx(const OpStatePtr& state, const OpContext& ctx,
+                const std::vector<NDArray>& inputs,
+                const std::vector<OpReqType>& req,
+                const std::vector<NDArray>& outputs) {
   const CustomParam& params = state.get_state<CustomParam>();
 
-  size_t total = 2*params.num_args + 2*params.num_outs + params.num_auxs;
-  std::vector<void*> ptrs(params.num_args + 2*params.num_outs, nullptr);
+  size_t total = 2 * params.num_args + 2 * params.num_outs + params.num_auxs;
+  std::vector<void*> ptrs(params.num_args + 2 * params.num_outs, nullptr);
+
   std::vector<int> tags;
   std::vector<NDArray> cpys;
 
   ptrs.reserve(total);
   tags.reserve(total);
+  cpys.reserve(total);
+
+  std::unordered_set<int> input_tags({3, 0, 1, 4});
 
 Review comment:
   add some comment?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to