yzhliu commented on a change in pull request #15938: Tvm broadcast backward
URL: https://github.com/apache/incubator-mxnet/pull/15938#discussion_r315585709
##########
File path: src/operator/contrib/tvmop/ufunc.cc
##########
@@ -37,29 +38,88 @@ namespace op {
static constexpr char func_vadd_cpu[] = "vadd";
static constexpr char func_vadd_gpu[] = "cuda_vadd";
+static constexpr char func_bakcward_vadd_cpu[] = "backward_vadd";
+static constexpr char func_bakcward_vadd_gpu[] = "cuda_backward_vadd";
template<const char* func>
-void TVMBroadcastCompute(const nnvm::NodeAttrs& attrs,
- const mxnet::OpContext& ctx,
- const std::vector<TBlob>& inputs,
- const std::vector<OpReqType>& req,
- const std::vector<TBlob>& outputs) {
+void TVMBinaryCompute(const nnvm::NodeAttrs& attrs,
+ const mxnet::OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
tvm::runtime::TVMOpModule::Get()->Call(func, ctx, {inputs[0], inputs[1],
outputs[0]});
}
+template<const char* func>
+void TVMBinaryBackwardComputeUseNone(const nnvm::NodeAttrs& attrs,
+ const mxnet::OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ CHECK_EQ(inputs.size(), 1U);
+ CHECK_EQ(outputs.size(), 2U);
+ int ndim = inputs[0].shape_.ndim();
+ for (int k = 0; k < 2; ++k) {
+ // dispatch by backward
+ std::vector<int> ov, iv;
+ const TBlob& ograd = inputs[0], igrad = outputs[k];
+ bool flag = ograd.size(0) != igrad.size(0);
+ for (int i = 0; i < ndim; ++i) {
+ if (i == 0 || (ograd.size(i) != igrad.size(i)) != (ograd.size(i - 1) !=
igrad.size(i - 1))) {
+ ov.push_back(ograd.size(i));
+ } else {
+ ov.back() *= ograd.size(i);
+ }
+ }
+ for (int i = flag; i < ov.size(); i += 2) {
+ iv.push_back(ov[i]);
+ }
+ TShape oshape(ov.begin(), ov.end()), ishape(iv.begin(), iv.end());
+ TBlob ograd_tvm(ograd.reshape(oshape).dltensor());
+ TBlob igrad_tvm(igrad.reshape(ishape).dltensor());
Review comment:
please add some comments to elaborate the ideas.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services