apeforest commented on a change in pull request #16218: Improving performance 
of argmax operator
URL: https://github.com/apache/incubator-mxnet/pull/16218#discussion_r328381052
 
 

 ##########
 File path: src/operator/tensor/broadcast_reduce_op.h
 ##########
 @@ -556,6 +556,227 @@ inline bool ReduceAxesOpForwardStorage(const 
nnvm::NodeAttrs& attrs,
   return dispatched;
 }
 
+
+using namespace mshadow_op::isnan_typed;
+
+// Type of memory used for indices storage
+typedef enum {
+  int64_st,
+  uint32_st,
+  uint16_st,
+  undefined_st
+} idxStorageType;
+
+struct argmax {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(index_t i, const int nWorkers,
+                                  const DType *in_data, DType *out_data,
+                                  size_t nSteps, size_t step, size_t shift,
+                                  void *pIdxStorage, const idxStorageType 
storageType) {
+    // i - index of launched thread
+    // nWorkers - number of threads, assigned to work on one row/column
+    // iw - index of current thread among workers assigned to the same vector
+    int iw = 0;
+    const DType *pCurr = in_data;
+    if (nWorkers > 1) {
+      // in - the vector number which current thread is assigned to
+      const auto in = i / nWorkers;
+      iw = i % nWorkers;
+      pCurr += in % step + shift * (in / step) + iw * step;
+      nSteps = (nSteps + nWorkers - 1 - iw) / nWorkers;
+      step *= nWorkers;
+    } else {
+      pCurr += i % step + shift * (i / step);
+    }
+
+    size_t maxIdx = 0;
+    DType maxVal = *pCurr;
+    while (IsNan(maxVal) && ++maxIdx < nSteps)
+      maxVal = *(pCurr += step);
+
+    if (maxIdx < nSteps) {
+      for (size_t j = maxIdx + 1; j < nSteps; ++j) {
+        const auto val = *(pCurr += step);
+        if (IsNan(val) || maxVal >= val)
+          continue;
+
+        maxVal = val;
+        maxIdx = j;
+      }
+    } else {
+      maxIdx = 0;
+    }
+
+    if (nWorkers > 1) {
+      // saving index of best element found by current thread
+      maxIdx = maxIdx * nWorkers + iw;
+      switch (storageType) {
+        case int64_st:
+          *(static_cast<index_t *>(pIdxStorage) + i) = maxIdx;
+          break;
+        case uint32_st:
+          *(static_cast<uint32_t *>(pIdxStorage) + i) = maxIdx;
+          break;
+        default:  // uint16_st
+          *(static_cast<uint16_t *>(pIdxStorage) + i) = maxIdx;
+      }
+    } else {
+      out_data[i] = maxIdx;    // output of argmax
+    }
+  }
+};
+
+struct argmax_reduce {
+  template<typename DType, typename IType>
+  MSHADOW_XINLINE static size_t BestIdx(const int nWorkers, const DType *pCurr,
+                                        const size_t step, const IType 
*pIdxStorage) {
+    size_t maxIdx = *pIdxStorage;
+    DType maxVal = *(pCurr + step * maxIdx);
+    int j = 0;
+    while (IsNan(maxVal) && ++j < nWorkers)
+      maxVal = *(pCurr + step * (maxIdx = pIdxStorage[j]));
+
+    if (j == nWorkers)
+      return *pIdxStorage;
+
+    while (++j < nWorkers) {
+      const auto val = *(pCurr + step * pIdxStorage[j]);
+      if (IsNan(val) || maxVal >= val)
+        continue;
+
+      maxVal = val;
+      maxIdx = pIdxStorage[j];
+    }
+
+    return maxIdx;
+  }
+
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(index_t i, const int nWorkers,
+                                  const DType *in_data, DType *out_data,
+                                  const size_t step, const size_t shift,
+                                  void *pIdx, const idxStorageType 
storageType) {
+    const DType *pCurr = in_data + i % step + shift * (i / step);
+    switch (storageType) {
+      case int64_st:
+        out_data[i] = BestIdx(nWorkers, pCurr, step, static_cast<index_t 
*>(pIdx) + i * nWorkers);
+        break;
+      case uint32_st:
+        out_data[i] = BestIdx(nWorkers, pCurr, step, static_cast<uint32_t 
*>(pIdx) + i * nWorkers);
+        break;
+      case uint16_st:
+        out_data[i] = BestIdx(nWorkers, pCurr, step, static_cast<uint16_t 
*>(pIdx) + i * nWorkers);
+        break;
+      default: break;  // That should never have happened
+    }
+  }
+};
+
+template<typename xpu, typename DType>
+DType *AllocateDTypeMemory(const OpContext& ctx, const size_t num_items) {
+  const size_t memSize = num_items * sizeof(DType);
+  mshadow::Tensor<xpu, 1, uint8_t> workspace =
+    ctx.requested[0].get_space_typed<xpu, 1, uint8_t>(
+      mshadow::Shape1(memSize), ctx.get_stream<xpu>());
+  return reinterpret_cast<DType *>(workspace.dptr_);
+}
+
+template<typename xpu, typename DType>
+void ArgMax(const OpContext& ctx, const TShape &shape, int axis, size_t step, 
size_t shift,
+            const int nWorkers, void *pIdxStorage, const idxStorageType 
storageType,
+            const TBlob& input, const TBlob& output) {
+  using namespace mxnet_op;
+  const auto pIn = input.dptr<DType>();
+  auto *pOut = output.dptr<DType>();
+  const auto nSize = shape[axis];
+  const auto num_threads = shape.Size() / nSize;
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  Kernel<argmax, xpu>::Launch(s, num_threads * nWorkers, nWorkers, pIn, pOut, 
nSize,
+                              step, shift, pIdxStorage, storageType);
+  if (nWorkers > 1) {
+    Kernel<argmax_reduce, xpu>::Launch(s, num_threads, nWorkers, pIn, pOut,
+                                       step, shift, pIdxStorage, storageType);
+  }
+}
+
+template<typename xpu>
+int DefineNumbWorkers(const TShape &shape, int axis);
+
+template<typename xpu>
+void ArgMax(const nnvm::NodeAttrs& attrs,
+            const OpContext& ctx,
+            const std::vector<TBlob>& inputs,
+            const std::vector<OpReqType>& req,
+            const std::vector<TBlob>& outputs) {
+  const ReduceAxisParam& param = nnvm::get<ReduceAxisParam>(attrs.parsed);
+  auto shape = inputs[0].shape_;
+  size_t shift = shape.Size();     // to reading data starting of each thread
+  CHECK_NE(shift, 0U) << "attempt to search an empty sequence";
+
+  size_t step = 1;  // step over the input array which will be used by each 
thread in kernel
+  int axis;
+  if (param.axis.has_value()) {
+    axis = CheckAxis(param.axis.value(), shape.ndim());
+    // cannot do argmax in an empty dimension
+    CHECK_NE(shape[axis], 0)
+      << "searching input tensor of shape " << shape
+      << " along axis = " << axis << " of zero dim-size is not allowed";
+
+    if (shape.ndim() == 1)
+      shape = AxisShapeCompact(shape, &axis, true);
+
+    // Calculate step & shift
+    auto i = shape.ndim();
+    while (--i > axis)
+      step *= shape[i];
+
+    shift = i? step*shape[i] : 1;
+  } else {
+    // If global reduction, reshape the input tensor into 2D shape (1, 
inputs[0].shape_.Size())
+    // and search on axis = 1.
+    shape = mxnet::TShape(2, 1);
+    shape[1] = shift;
+    axis = 1;
+  }
+
+  void *pIdxMemory = nullptr;
+  auto nWorkers = DefineNumbWorkers<xpu>(shape, axis);
+  idxStorageType storageType = undefined_st;
+  while (nWorkers > 1) {
+    const size_t num_items = nWorkers * shape.Size() / shape[axis];
+#if MXNET_USE_INT64_TENSOR_SIZE
+    pIdxMemory = AllocateDTypeMemory<xpu, index_t>(ctx, num_items);
+    if (pIdxMemory) {
+      storageType = int64_st;
+      break;
+    }
+#endif
+    if (shape[axis] <= UINT32_MAX) {
+      pIdxMemory = AllocateDTypeMemory<xpu, uint32_t>(ctx, num_items);
+      if (pIdxMemory) {
+        storageType = uint32_st;
+        break;
+      }
+    }
+    // Check if indexes can be stored in uint16 format
+    if (shape[axis] <= UINT16_MAX) {
+      pIdxMemory = AllocateDTypeMemory<xpu, uint16_t>(ctx, num_items);
 
 Review comment:
   Would this cause problem? I thought the index type inside tensor object is 
either int32 or int64. Would this cause overflow if uint16_t is passed in. Also 
based on Google C++ style guide, we should avoid using unsigned integer to 
represent data type: 
https://google.github.io/styleguide/cppguide.html#Integer_Types

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to