xidulu commented on a change in pull request #18852:
URL: https://github.com/apache/incubator-mxnet/pull/18852#discussion_r468292891



##########
File path: src/operator/numpy/random/np_gamma_op.h
##########
@@ -394,6 +401,83 @@ void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const 
OpContext &ctx,
   }
 }
 
+template<typename xpu, int ndim, typename DType>
+inline void GammaReparamBackwardImpl(const OpContext& ctx,
+                                            const std::vector<TBlob>& inputs,
+                                            const std::vector<OpReqType>& req,
+                                            const std::vector<TBlob>& outputs,
+                                            const mxnet::TShape& new_ishape,
+                                            const mxnet::TShape& new_oshape,
+                                            const float scale) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace broadcast;
+  using namespace mxnet_op;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob igrad = outputs[0].reshape(new_ishape);
+  // inputs: [grad_from_samples, alpha_tensor, samples]
+  const TBlob ograd = inputs[0].reshape(new_oshape);
+  const TBlob alpha = inputs[1].reshape(new_ishape);
+  TBlob samples = inputs[2].reshape(new_oshape);
+  size_t workspace_size =
+      ReduceWorkspaceSize<ndim, DType>(s, igrad.shape_, req[0], ograd.shape_);
+  // Convert samples to standard gamma
+  // Kernel<StandarizeKernel<DType>, xpu>::Launch(
+  //       s, samples.Size(), samples.dptr<DType>(), scale);
+  Kernel<op_with_req<mshadow_op::div, kWriteTo>, xpu>::Launch(
+    s, samples.Size(), samples.dptr<DType>(), samples.dptr<DType>(), 
DType(scale));
+  Tensor<xpu, 1, char> workspace =
+      ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), 
s);
+  Reduce<red::sum, ndim, DType, op::mshadow_op::mul, 
op::mshadow_op::gamma_implicit_grad>(

Review comment:
       This stands for the multiplication between d(Gamma(x;\alpha, \beta)) and 
gradient from downstream Ops. Similar to 
https://github.com/apache/incubator-mxnet/blob/743bbcbc7c8c85661a146d94ebd3196306650677/src/operator/tensor/elemwise_binary_broadcast_op.h#L745




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to