piiswrong commented on a change in pull request #9029: Update ndarray binary 
ops to use kernel launch instead of mshadow operations
URL: https://github.com/apache/incubator-mxnet/pull/9029#discussion_r172749967
 
 

 ##########
 File path: src/ndarray/ndarray.cc
 ##########
 @@ -805,15 +808,71 @@ void BinaryOp(const NDArray &lhs,
       CHECK(out->ctx() == lhs.ctx()) << "target context mismatch";
     }
     CHECK(out->shape() == OP::GetShape(lhs.shape(), rhs.shape()))
-        << "target shape mismatch";
+      << "target shape mismatch";
   }
+  std::vector<Engine::VarHandle> const_vars;
+  // prepare const variables for engine
+  if (lhs.var() != out->var()) const_vars.push_back(lhs.var());
+  if (rhs.var() != out->var()) const_vars.push_back(rhs.var());
+  return const_vars;
+}
+
+/*!
+* \brief run a binary operation using the kernel launch method
+* \param lhs left operand
+* \param rhs right operand
+* \param out the output ndarray
+* \param binary_op the real
+*/
+template<typename OP>
+void BinaryOpKernel(const NDArray &lhs,
+                    const NDArray &rhs,
+                    NDArray *out) {
+  std::vector<Engine::VarHandle> const_vars = BinaryOpPrepare<OP>(lhs, rhs, 
out);
   // important: callback must always capture by value
   NDArray ret = *out;
-  // get the const variables
-  std::vector<Engine::VarHandle> const_vars;
-  if (lhs.var() != ret.var()) const_vars.push_back(lhs.var());
-  if (rhs.var() != ret.var()) const_vars.push_back(rhs.var());
+  switch (lhs.ctx().dev_mask()) {
+    case cpu::kDevMask: {
+      Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) {
+        TBlob tmp = ret.data();
+        mshadow::Stream<cpu>* s = ctx.get_stream<cpu>();
+        ndarray::BinaryOpKernelImpl<OP>(s, lhs.data(), rhs.data(), &tmp);
+      },
+      lhs.ctx(), const_vars, {ret.var()},
+      FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
+      break;
+    }
+#if MXNET_USE_CUDA
+    case gpu::kDevMask: {
+  Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) {
 
 Review comment:
   fix indent

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to