samskalicky commented on a change in pull request #17270: Dynamic custom
operator GPU support
URL: https://github.com/apache/incubator-mxnet/pull/17270#discussion_r369388560
##########
File path: src/c_api/c_api.cc
##########
@@ -563,101 +623,30 @@ int MXLoadLib(const char *path) {
}
// create a pointer to hold custom op state object
+ // only create one stateful op depending on passing context
+ // user can add new supported context and call to custom library
void* state_op_inst = nullptr;
- CHECK(callCreateOpState(create_opstate_fp, attr_keys.data(),
attr_vals.data(),
- attr_keys.size(), &state_op_inst))
- << "Error calling CreateOpState for custom operator '" << name_str <<
"'";
-
+ if (ctx.dev_mask() == Context::kCPU) {
+ CHECK(createop_map.count("cpu") > 0)
+ << "CPU CreateOpState not implemented for '" << name_str << "'";
+ CHECK(callCreateOpState(createop_map.at("cpu"), attr_keys.data(),
attr_vals.data(),
+ attr_keys.size(), &state_op_inst))
+ << "Error calling CreateOpState CPU for custom operator '" << name_str
<< "'";
+ } else if (ctx.dev_mask() == Context::kGPU) {
+ CHECK(createop_map.count("gpu") > 0)
+ << "GPU CreateOpState not implemented for '" << name_str << "'";
+ CHECK(callCreateOpState(createop_map.at("gpu"), attr_keys.data(),
attr_vals.data(),
+ attr_keys.size(), &state_op_inst))
+ << "Error calling CreateOpState GPU for custom operator '" << name_str
<< "'";
+ }
CHECK(state_op_inst != nullptr)
<< "Error custom library failed to create stateful operator '" <<
name_str << "'";
CustomStatefulOp* state_op =
reinterpret_cast<CustomStatefulOp*>(state_op_inst);
return OpStatePtr::Create<CustomStatefulOpWrapper>(state_op);
};
- // stateful forward and backward
- auto fstateful_lambda = [=](bool is_forward,
- const OpStatePtr& state_ptr,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
- std::vector<void*> in_data, out_data;
- std::vector<const int64_t *> in_shapes, out_shapes;
- std::vector<int> in_dims, out_dims;
- std::vector<int> in_types, out_types;
- std::vector<size_t> in_verIDs, out_verIDs;
-
- // convert input tensors to constituent parts
- for (size_t i = 0; i < inputs.size(); i++) {
- in_data.push_back(inputs[i].data().dptr_);
- in_shapes.push_back(inputs[i].shape().data());
- in_dims.push_back(inputs[i].shape().ndim());
- in_types.push_back(inputs[i].dtype());
- in_verIDs.push_back(inputs[i].version());
- }
-
- // convert output tensors to constituent parts
- for (size_t i = 0; i < outputs.size(); i++) {
- out_data.push_back(outputs[i].data().dptr_);
- out_shapes.push_back(outputs[i].shape().data());
- out_dims.push_back(outputs[i].shape().ndim());
- out_types.push_back(outputs[i].dtype());
- out_verIDs.push_back(outputs[i].version());
- }
-
- // get memory resource
- const Resource &resource = ctx.requested[0];
- mshadow::Stream<mxnet::cpu> *cpu_stream = ctx.get_stream<mxnet::cpu>();
-
- // create lambda that captures stream & resource objects
- // this temp workspace holds memory allocated by custom library via
OpResource
- auto cpu_alloc = [&](int size) {
- mshadow::Tensor<mxnet::cpu, 1, char> data =
- resource.get_space_typed<mxnet::cpu, 1, char>(mshadow::Shape1(size),
cpu_stream);
- return data.dptr_;
- };
-
- // create lambda without captures so that we can cast it to function
pointer
- // this needs to be a lambda function so that we can do the decltype cast
- typedef decltype(cpu_alloc) alloc_type;
- auto cpu_malloc = [](void* _cpu_alloc, int size) {
- // cast the void* argument to the type for the cpu_alloc lambda
function
- alloc_type* cpualloc = static_cast<alloc_type*>(_cpu_alloc);
- // call cpu_alloc to actually allocate memory and get the pointer
- void* ptr = (*cpualloc)(size);
- return ptr;
- };
-
- // retrieve op state object created from CreateOpState
- CustomStatefulOpWrapper& op =
state_ptr.get_state<CustomStatefulOpWrapper>();
- CustomStatefulOp* state_op_inst = op.get_instance();
- CHECK(state_op_inst != nullptr)
- << "Error MXNet cannot load custom stateful operator'" << name_str <<
"'";
-
- // call fcompute function
- CHECK(callFStatefulComp(is_forward, state_op_inst, in_shapes.data(),
in_dims.data(),
- in_data.data(), in_types.data(),
in_verIDs.data(), in_data.size(),
- out_shapes.data(), out_dims.data(),
out_data.data(), out_types.data(),
- out_verIDs.data(), out_data.size(), cpu_malloc,
&cpu_alloc))
- << "Error calling FStatefulCompute for custom operator '" << name_str <<
"'";
- };
-
- auto fstateful_forward = [=](const OpStatePtr& state_ptr,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
- fstateful_lambda(true, state_ptr, ctx, inputs, req, outputs);
- };
-
- auto fstateful_backward = [=](const OpStatePtr& state_ptr,
- const OpContext& ctx,
- const std::vector<NDArray>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<NDArray>& outputs) {
- fstateful_lambda(false, state_ptr, ctx, inputs, req, outputs);
- };
+ /* -------------- BELOW ARE CUSTOM OPERATOR REGISTRATION --------------- */
Review comment:
I think you mean "BELOW IS THE REGISTRATION FOR CUSTOM OPERATORS"
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services