bgawrych commented on code in PR #21112:
URL: https://github.com/apache/incubator-mxnet/pull/21112#discussion_r935664043


##########
src/operator/numpy/np_repeat_op-inl.h:
##########
@@ -219,15 +225,86 @@ void NumpyRepeatsOpForward(const nnvm::NodeAttrs& attrs,
       stride *= ishape[i];
     }
 
-    MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, {
-      MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+    MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(inputs[0].type_flag_, IType, {
+      MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(outputs[0].type_flag_, OType, {
         mxnet_op::Kernel<repeat_axis_fwd, xpu>::Launch(
             s, out_data.Size(), out_data.dptr<OType>(), in_data.dptr<IType>(), 
ind, stride);
       });
     });
   }
 }
 
+template <typename xpu>
+void NumpyRepeatsOpForward(const nnvm::NodeAttrs& attrs,
+                           const OpContext& ctx,
+                           const std::vector<TBlob>& inputs,
+                           const std::vector<OpReqType>& req,
+                           const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  const mxnet::TShape& ishape = inputs[0].shape_;
+  int repeats                 = 0;
+  int axis                    = -1;
+  dmlc::optional<int> axisOpt;
+  const RepeatsParam& param = nnvm::get<RepeatsParam>(attrs.parsed);
+  GetRepeatsParams(param, ishape, &repeats, &axisOpt, &axis);
+
+  if (!shape_is_known(ishape) || repeats == 0)
+    return;
+
+  mxnet::Tuple<int> repts = param.repeats.value();
+  if (repts.ndim() == 1) {
+    int len = static_cast<bool>(axisOpt) ? ishape[axis] : ishape.Size();
+    std::vector<int> temp(len, repeats);
+    repts = mxnet::Tuple<int>(temp);
+  }
+
+  // If axis was specified then perform swapaxis before and after calling 
repeat function
+  if (axisOpt.has_value() && axisOpt.value() != 0) {
+    int type_size          = mshadow_sizeof(inputs[0].type_flag_);
+    size_t total_temp_size = inputs[0].shape_.Size() * type_size +
+                             outputs[0].shape_.Size() * type_size + 
repts.ndim() * sizeof(int);
+    Tensor<xpu, 1, char> temp_space = ctx.requested[0].get_space_typed<xpu, 1, 
char>(
+        Shape1(total_temp_size), ctx.get_stream<xpu>());
+    void* swap_output_tmp_dptr   = temp_space.dptr_;
+    void* repeat_output_tmp_dptr = temp_space.dptr_ + inputs[0].shape_.Size() 
* type_size;
+    int* repeat_tmp_dptr =
+        reinterpret_cast<int*>(temp_space.dptr_ + inputs[0].shape_.Size() * 
type_size +

Review Comment:
   ```suggestion
           reinterpret_cast<int*>(repeat_output_tmp_dptr +
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to