sxjscience commented on a change in pull request #18545:
URL: https://github.com/apache/incubator-mxnet/pull/18545#discussion_r440405417



##########
File path: src/operator/tensor/index_update.cu
##########
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file index_update.cu
+ * \brief GPU implementation of index_update operator
+ */
+
+#include <cub/cub.cuh>
+#include "./index_update-inl.h"
+#include "../tensor/util/tensor_util-inl.cuh"
+#include "../tensor/util/tensor_util-inl.h"
+
+
+namespace mxnet {
+namespace op {
+
+template<typename DType>
+struct IndexUpdateForwardGPUKernel {
+  MSHADOW_XINLINE static void Map(size_t i, DType* out,
+                                  const DType* val,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_tail_shape,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_pre_stride,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_stride,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_shape,
+                                  const int a_tail_size, const int ind_num,
+                                  const int ind_ndim, const int* ind,
+                                  const int a_ndim, const int seg) {
+    index_t id = 0;
+    for (int dim = 0; dim < ind_ndim; ++dim) {
+      id += a_pre_stride[seg + dim] * ind[dim * ind_num + i];
+    }
+    id *= a_tail_size;
+    for (int _i = 0; _i < a_tail_size; ++_i) {
+      mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_tail_id = mxnet_op::unravel(_i, 
a_tail_shape);
+      mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_id;
+      for (int _j = 0; _j < seg; ++_j) {
+        val_id[_j] = 0;
+      }
+      for (int _j = seg; _j < seg + a_ndim; ++_j) {
+        val_id[_j] = (val_shape[_j] == 1) ? 0 : a_tail_id[_j];
+      }
+      val_id[seg + ind_ndim - 1] = (val_shape[seg + ind_ndim - 1] == 1) ? 0 : 
i;
+      index_t val_dest = mxnet_op::dot(val_id, val_stride);
+      out[id + _i] = val[val_dest];
+    }
+  }
+};
+
+template<typename xpu, typename DType>
+void IndexUpdateForwardCalc(mshadow::Stream<xpu> *s,
+                            const int ind_num, DType* out,
+                            const DType* val,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_tail_shape,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_pre_stride,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_stride,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_shape,
+                            const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
a_shape,
+                            const int a_tail_size,
+                            const int ind_ndim, const int* ind,
+                            const int a_ndim) {
+  using namespace mxnet_op;
+  using namespace mshadow;
+  int seg = MXNET_SPECIAL_MAX_NDIM - a_ndim;
+  Kernel<IndexUpdateForwardGPUKernel<DType>, xpu>::Launch(
+    s, ind_num, out, val, a_tail_shape, a_pre_stride,
+    val_stride, val_shape, a_tail_size, ind_num,
+    ind_ndim, ind, a_ndim, seg);
+}
+
+
+struct IndexUpdateBackwardValGPUKernel {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(size_t i, DType* grad_val,
+                                  const DType* ograd, const int* ind_vec,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
ograd_tail_shape,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
ograd_pre_stride,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_stride,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_shape,
+                                  const int ograd_tail_size, const int ind_num,
+                                  const int ind_ndim, const int out_ndim, 
const int seg) {
+    index_t id = 0;
+    for (int dim = 0; dim < ind_ndim; ++dim) {
+      id += ograd_pre_stride[seg + dim] * ind_vec[dim * ind_num + i];
+    }
+    id *= ograd_tail_size;
+    for (int _i = 0; _i < ograd_tail_size; ++_i) {
+      mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_id =
+        mxnet_op::unravel(_i, ograd_tail_shape);
+      mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_id;
+      for (int _j = 0; _j < seg; ++_j) {
+        val_id[_j] = 0;
+      }
+      for (int _j = seg; _j < seg + out_ndim; ++_j) {
+        val_id[_j] = (val_shape[_j] == 1) ? 0 : ograd_tail_id[_j];
+      }
+      val_id[seg + ind_ndim - 1] = (val_shape[seg + ind_ndim - 1] == 1) ? 0 : 
i;
+      index_t val_dest = mxnet_op::dot(val_id, val_stride);
+      atomicAdd(&grad_val[val_dest], ograd[id + _i]);
+    }
+  }
+};
+
+template<>
+void IndexUpdateOpBackwardValImpl<gpu>(const OpContext& ctx,
+                                const TBlob& grad_val,
+                                const TBlob& ograd,
+                                const TBlob& t_ind,
+                                const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
ograd_tail_shape,
+                                const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
ograd_pre_stride,
+                                const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_stride,
+                                const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
val_shape,
+                                const int tail_size, const int ind_num, const 
int ind_ndim,
+                                const int ndim) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
+  int seg = MXNET_SPECIAL_MAX_NDIM - ndim;
+  MSHADOW_TYPE_SWITCH(grad_val.type_flag_, DType, {
+    Kernel<IndexUpdateBackwardValGPUKernel, gpu>::Launch(
+        s, ind_num, grad_val.dptr<DType>(), ograd.dptr<DType>(), 
t_ind.dptr<int>(),
+        ograd_tail_shape, ograd_pre_stride, val_stride, val_shape, tail_size,
+        ind_num, ind_ndim, ndim, seg);
+  });
+}
+
+template<typename DType>
+struct IndexUpdateBackwardAGPUKernel {
+  MSHADOW_XINLINE static void Map(size_t i, DType* out_grad,
+                                  const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> 
grada_pre_stride,
+                                  const int tail_size, const int ind_num, 
const int ind_ndim,
+                                  const int32_t* ind, const int seg,
+                                  const int req) {
+    index_t id = 0;
+    for (int dim = 0; dim < ind_ndim; ++dim) {
+      id += grada_pre_stride[seg + dim] * ind[dim * ind_num + i];
+    }
+    id *= tail_size;
+    for (int _i = 0; _i < tail_size; ++_i) {
+      out_grad[id + _i] = 0;
+    }
+  }
+};
+
+template<>
+void IndexUpdateOpBackwardAImpl<gpu>(const OpContext& ctx,
+                                     const TBlob& grad_a,
+                                     const TBlob& ograd,
+                                     const TBlob& ind,
+                                     const 
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> grada_pre_stride,
+                                     const int tail_size, const int ind_num, 
const int ind_ndim,
+                                     const int seg, const int req) {
+  using namespace mxnet_op;
+  using namespace mshadow;
+  mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
+  MSHADOW_TYPE_SWITCH(grad_a.type_flag_, DType, {
+    size_t temp_mem_size = ind.shape_.Size() * sizeof(int) +
+                           ograd.shape_.Size() * sizeof(DType);

Review comment:
       Here, we mix `int` type and `DType`. The second pointer may not be 
appropriately aligned with `DType`. You may consider to revise the code and add 
the `PadBytes`: 
https://github.com/apache/incubator-mxnet/blob/1b02225fefd8ccc93bc73223f0d3cde103fad661/src/operator/tensor/ordering_op-inl.h#L350-L351
   
   The reason to do this is that CUDA forces memory alignment, i.e., we will 
always have the following assertion:
   
   ```c++
   DType* ptr;
   ASSERT static_cast<size_t>(ptr) % sizeof(DType) == 0.
   ```
   
   By default, calling cudaMalloc will give you a pointer with at least 
256-bytes aligned address. However, after shifting the pointer to `ptr += 
ind.shape_.Size() * sizeof(int)`, the new address may not be aligned with 
`sizeof(DType)`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to