samskalicky commented on a change in pull request #17270: Dynamic custom 
operator GPU support
URL: https://github.com/apache/incubator-mxnet/pull/17270#discussion_r369386749
 
 

 ##########
 File path: src/c_api/c_api.cc
 ##########
 @@ -99,7 +99,135 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e,
 // NOTE: return value is added in API_END
 
 /*!
- * \brief Loads dynamic library and initializes it
+ * \brief Common compute function dispatcher for forward/backward and stateful 
forward/backward
+ * state_ptr will be nullptr for regular ops; fcomp_fp is nullptr for stateful 
ops
+ */
+void CustomFComputeDispatcher(const std::string op_name,
+                              const opCallFComp_t callFComp,
+                              const fcomp_t fcomp_fp,
+                              const nnvm::NodeAttrs* attrs,
+                              const opCallFStatefulComp_t callFStatefulComp,
+                              int stateful_forward_flag,
+                              const OpStatePtr* state_ptr,
+                              const OpContext& ctx,
+                              const std::vector<NDArray>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<NDArray>& outputs) {
+  std::vector<void*> in_data, out_data;
+  std::vector<const int64_t *> in_shapes, out_shapes;
+  std::vector<int> in_dims, out_dims;
+  std::vector<int> in_types, out_types;
+  std::vector<size_t> in_verIDs, out_verIDs;
+  std::vector<const char*> in_dev_type, out_dev_type;
+  std::vector<int> in_dev_id, out_dev_id;
+
+  // convert inputs/outpus NDArray to C types to be passed to lib_api.h
+  for (size_t i = 0; i < inputs.size(); i++) {
+    in_data.push_back(inputs[i].data().dptr_);
+    in_shapes.push_back(inputs[i].shape().data());
+    in_dims.push_back(inputs[i].shape().ndim());
+    in_types.push_back(inputs[i].dtype());
+    in_verIDs.push_back(inputs[i].version());
+    const char* ctx_str = inputs[i].ctx().dev_mask() == Context::kCPU ? "cpu" 
: "gpu";
+    in_dev_type.push_back(ctx_str);
+    in_dev_id.push_back(inputs[i].ctx().real_dev_id());
+  }
+
+  for (size_t i = 0; i < outputs.size(); i++) {
+    out_data.push_back(outputs[i].data().dptr_);
+    out_shapes.push_back(outputs[i].shape().data());
+    out_dims.push_back(outputs[i].shape().ndim());
+    out_types.push_back(outputs[i].dtype());
+    out_verIDs.push_back(outputs[i].version());
+    const char* ctx_str = outputs[i].ctx().dev_mask() == Context::kCPU ? "cpu" 
: "gpu";
+    out_dev_type.push_back(ctx_str);
+    out_dev_id.push_back(outputs[i].ctx().real_dev_id());
+  }
+
+  // get memory resource and mxnet backend streams
+  const Resource &resource = ctx.requested[0];
+  mshadow::Stream<mxnet::cpu> *cpu_stream = ctx.get_stream<mxnet::cpu>();
+  mshadow::Stream<mxnet::gpu> *gpu_stream = ctx.get_stream<mxnet::gpu>();
+
+  // create lambda that captures stream & resource objects
+  // this temp workspace holds memory allocated by custom library via 
OpResource
+  auto cpu_alloc = [&](int size) {
+    mshadow::Tensor<mxnet::cpu, 1, char> workspace =
+      resource.get_space_typed<mxnet::cpu, 1, char>(mshadow::Shape1(size), 
cpu_stream);
+    return workspace.dptr_;
+  };
+  auto gpu_alloc = [&](int size) {
+    mshadow::Tensor<mxnet::gpu, 1, char> workspace =
+      resource.get_space_typed<mxnet::gpu, 1, char>(mshadow::Shape1(size), 
gpu_stream);
+    return workspace.dptr_;
+  };
+
+  // create lambda without captures so that we can cast it to function pointer
+  // this needs to be a lambda function so that we can do the decltype cast
+  typedef decltype(cpu_alloc) alloc_type_cpu;
+  auto cpu_malloc = [](void* _cpu_alloc, int size) {
+    // cast the void* argument to the type for the cpu_alloc lambda function
+    alloc_type_cpu* cpualloc = static_cast<alloc_type_cpu*>(_cpu_alloc);
+    // call cpu_alloc to actually allocate memory and get the pointer
+    void* ptr = (*cpualloc)(size);
 
 Review comment:
   is there some reason we're doing this in two lines instead of just `return 
(*cpualloc)(size)`?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to