samskalicky commented on a change in pull request #17762: Custom Operator
Random Number Generator Support
URL: https://github.com/apache/incubator-mxnet/pull/17762#discussion_r399043504
##########
File path: example/extensions/lib_custom_op/relu_lib.cu
##########
@@ -180,6 +182,75 @@ REGISTER_OP(my_state_relu)
.setCreateOpState(createOpStateCPU, "cpu")
.setCreateOpState(createOpStateGPU, "gpu");
+
+
+/* ------------------------ Below is noisy relu operator example
---------------------*/
+
+#include <curand_kernel.h>
+#include <random>
+
+#define NumRandomPerThread 64 // mxnet recommended random numbers generated
per thread
+
+__global__ void noisy_relu_gpu_forward(float *out, float *in, int64_t N, void
*states, int step) {
+ int tid = blockIdx.x * blockDim.x + threadIdx.x;
+ curandStatePhilox4_32_10_t* global_states =
(curandStatePhilox4_32_10_t*)states;
Review comment:
can you add a comment to explain why you're doing this? where do the states
get allocated? how can you guarantee the states pointer is pointing to a big
enough array to index with tid?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services