samskalicky commented on a change in pull request #17885:
URL: https://github.com/apache/incubator-mxnet/pull/17885#discussion_r411800538
##########
File path: include/mxnet/lib_api.h
##########
@@ -507,16 +553,13 @@ class OpResource {
void *rand_cpu_states, *rand_gpu_states;
};
-/*!
- * \brief Json utility to parse serialized subgraph symbol
- */
/*! \brief Macro to help passing serialized subgraph through attribute dict */
#define MX_STR_SUBGRAPH_SYM_JSON "subgraph_sym_json"
-#define MX_STR_DTYPE "__dtype__"
-#define MX_STR_SHAPE "__shape__"
+#define MX_STR_DTYPE "__ext_dtype__"
+#define MX_STR_SHAPE "__ext_shape__"
/* \brief get shape value from list of shapes string
- * format: [[1]] or [[1],[2]]
+ * format: [[1]] or [[1],[2,3]], returns "[1]" or "[2,3]"
Review comment:
done
##########
File path: example/extensions/lib_custom_op/relu_lib.cu
##########
@@ -29,93 +29,93 @@
#define NumThreadPerBlock 256 // mxnet recommended cuda thread number per block
__global__ void relu_gpu_forward(float *out, float *in, int64_t N) {
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
- if (tid < N)
- out[tid] = in[tid] > 0 ? in[tid] : 0;
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
+ if (tid < N)
+ out[tid] = in[tid] > 0 ? in[tid] : 0;
}
__global__ void relu_gpu_backward(float *ingrad, float *outgrad, float
*indata, int64_t N) {
- int tid = blockIdx.x * blockDim.x + threadIdx.x;
- if (tid < N)
- ingrad[tid] = indata[tid] > 0 ? 1 * outgrad[tid] : 0;
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
+ if (tid < N)
+ ingrad[tid] = indata[tid] > 0 ? 1 * outgrad[tid] : 0;
}
-MXReturnValue forwardCPU(std::map<std::string, std::string> attrs,
- std::vector<MXTensor> inputs,
- std::vector<MXTensor> outputs,
- OpResource res) {
- float* in_data = inputs[0].data<float>();
- float* out_data = outputs[0].data<float>();
- for (int i=0; i<inputs[0].size(); i++) {
- out_data[i] = in_data[i] > 0 ? in_data[i] : 0;
- }
- return MX_SUCCESS;
+MXReturnValue forwardCPU(const std::unordered_map<std::string, std::string>&
attrs,
+ std::vector<MXTensor>* inputs,
+ std::vector<MXTensor>* outputs,
+ const OpResource& res) {
+ float* in_data = inputs->at(0).data<float>();
+ float* out_data = outputs->at(0).data<float>();
+ for (int i=0; i<inputs->at(0).size(); i++) {
+ out_data[i] = in_data[i] > 0 ? in_data[i] : 0;
+ }
+ return MX_SUCCESS;
}
-MXReturnValue backwardCPU(std::map<std::string, std::string> attrs,
- std::vector<MXTensor> inputs,
- std::vector<MXTensor> outputs,
- OpResource res) {
- float* out_grad = inputs[0].data<float>();
- float* in_data = inputs[1].data<float>();
- float* in_grad = outputs[0].data<float>();
- for (int i=0; i<inputs[1].size(); i++) {
- in_grad[i] = in_data[i] > 0 ? 1 * out_grad[i] : 0;
- }
- return MX_SUCCESS;
+MXReturnValue backwardCPU(const std::unordered_map<std::string, std::string>&
attrs,
+ std::vector<MXTensor>* inputs,
+ std::vector<MXTensor>* outputs,
+ const OpResource& res) {
+ float* out_grad = inputs->at(0).data<float>();
+ float* in_data = inputs->at(1).data<float>();
+ float* in_grad = outputs->at(0).data<float>();
+ for (int i=0; i<inputs->at(1).size(); i++) {
+ in_grad[i] = in_data[i] > 0 ? 1 * out_grad[i] : 0;
+ }
+ return MX_SUCCESS;
}
-MXReturnValue forwardGPU(std::map<std::string, std::string> attrs,
- std::vector<MXTensor> inputs,
- std::vector<MXTensor> outputs,
- OpResource res) {
- float* in_data = inputs[0].data<float>();
- float* out_data = outputs[0].data<float>();
+MXReturnValue forwardGPU(const std::unordered_map<std::string, std::string>&
attrs,
+ std::vector<MXTensor>* inputs,
+ std::vector<MXTensor>* outputs,
+ const OpResource& res) {
+ float* in_data = inputs->at(0).data<float>();
+ float* out_data = outputs->at(0).data<float>();
- mx_stream_t cuda_stream = res.get_cuda_stream();
- int64_t N = inputs[0].size();
- int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock;
+ mx_stream_t cuda_stream = res.get_cuda_stream();
+ int64_t N = inputs->at(0).size();
+ int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock;
- relu_gpu_forward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(out_data,
in_data, N);
+ relu_gpu_forward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(out_data,
in_data, N);
- return MX_SUCCESS;
+ return MX_SUCCESS;
}
-MXReturnValue backwardGPU(std::map<std::string, std::string> attrs,
- std::vector<MXTensor> inputs,
- std::vector<MXTensor> outputs,
- OpResource res) {
- float* out_grad = inputs[0].data<float>();
- float* in_data = inputs[1].data<float>();
- float* in_grad = outputs[0].data<float>();
-
- mx_stream_t cuda_stream = res.get_cuda_stream();
- int64_t N = inputs[0].size();
- int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock;
+MXReturnValue backwardGPU(const std::unordered_map<std::string, std::string>&
attrs,
+ std::vector<MXTensor>* inputs,
+ std::vector<MXTensor>* outputs,
+ const OpResource& res) {
+ float* out_grad = inputs->at(0).data<float>();
+ float* in_data = inputs->at(1).data<float>();
+ float* in_grad = outputs->at(0).data<float>();
- relu_gpu_backward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(in_grad,
out_grad, in_data, N);
+ mx_stream_t cuda_stream = res.get_cuda_stream();
+ int64_t N = inputs->at(0).size();
+ int num_block = (N + NumThreadPerBlock - 1) / NumThreadPerBlock;
+ relu_gpu_backward<<<num_block,NumThreadPerBlock,0,cuda_stream>>>(in_grad,
out_grad, in_data, N);
- return MX_SUCCESS;
+ return MX_SUCCESS;
}
-MXReturnValue parseAttrs(std::map<std::string, std::string> attrs, int*
num_in, int* num_out) {
- *num_in = 1;
- *num_out = 1;
- return MX_SUCCESS;
+MXReturnValue parseAttrs(const std::unordered_map<std::string, std::string>&
attrs,
+ int* num_in, int* num_out) {
+ *num_in = 1;
+ *num_out = 1;
+ return MX_SUCCESS;
}
-MXReturnValue inferType(std::map<std::string, std::string> attrs,
- std::vector<int> &intypes,
- std::vector<int> &outtypes) {
- outtypes[0] = intypes[0];
- return MX_SUCCESS;
+MXReturnValue inferType(const std::unordered_map<std::string, std::string>&
attrs,
+ const std::vector<int>& intypes,
Review comment:
done
##########
File path: example/extensions/lib_custom_op/gemm_lib.cc
##########
@@ -136,13 +137,13 @@ MXReturnValue inferType(std::map<std::string,
std::string> attrs,
}
}
- outtypes[0] = intypes[0];
+ outtypes->at(0) = intypes[0];
return MX_SUCCESS;
}
-MXReturnValue inferShape(std::map<std::string, std::string> attrs,
- std::vector<std::vector<unsigned int>> &inshapes,
- std::vector<std::vector<unsigned int>> &outshapes) {
+MXReturnValue inferShape(const std::unordered_map<std::string, std::string>&
attrs,
+ const std::vector<std::vector<unsigned int>>&
inshapes,
Review comment:
done
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]