sxjscience commented on a change in pull request #18319: URL: https://github.com/apache/incubator-mxnet/pull/18319#discussion_r440633446
########## File path: src/operator/numpy/np_indexing_op.cu ########## @@ -0,0 +1,574 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file np_indexing_op.cu +*/ + +#include "./np_indexing_op.h" +#include <cub/cub.cuh> + +namespace mxnet { +namespace op { + +/*! \brief If there are out-of-bound indices, out will be assigned to 1. + */ +struct is_valid_check { + template<typename DType> + MSHADOW_XINLINE static void Map(int i, char* out, const DType* data, + const DType min, const DType max) { + if (data[i] < min || data[i] > max) *out = 1; + } +}; + +template<typename DType> +bool CheckIndexOutOfBound(mshadow::Stream<gpu> *s, const DType* data_ptr, size_t data_size, + const DType min, const DType max, char* is_valid_ptr) { + using namespace mxnet_op; + int32_t is_valid = 0; + Kernel<set_zero, gpu>::Launch(s, 1, is_valid_ptr); + Kernel<is_valid_check, gpu>::Launch(s, data_size, is_valid_ptr, data_ptr, min, max); + CUDA_CALL(cudaMemcpyAsync(&is_valid, is_valid_ptr, sizeof(char), + cudaMemcpyDeviceToHost, mshadow::Stream<gpu>::GetStream(s))); + CUDA_CALL(cudaStreamSynchronize(mshadow::Stream<gpu>::GetStream(s))); + return is_valid == 0; +} + +struct AdvancedIndexingTakeGPU { + // assume that idx have been flattened to a 1-D tensor (N,) + // assume that out_data and in_data have been flattened to 2-D tensors, (N, M) and (K, M) + // M is the number of columns of in_data and out_data + // K is the number of rows of in_data + // i is the index of out_data + template<typename DType, typename IType> + MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, + const IType* idx, const int64_t M, const int64_t K) { + int64_t j = static_cast<int64_t>(idx[i]); + j = j % K; + j += (j < 0) ? K : 0; + + for (int64_t k = 0; k < M; k++){ + out_data[i * M + k] = in_data[j * M + k]; + } + } +}; + +struct AdvancedIndexingTakeMultiDimensionGPU { + // assume that idx have been flattened to a 1-D tensor (N,) + // assume that out_data and in_data have been flattened to 2-D tensors, (N, M) and (K, M) + // M is the number of columns of in_data and out_data + // K is the number of rows of in_data + // i is the index of out_data + template<typename DType, typename IType> + MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, + const IType* idx, const int64_t M, const int64_t K) { + int64_t j = static_cast<int64_t>(idx[i]); + j = j % K; + j += (j < 0) ? K : 0; + + for (int64_t k = 0; k < M; k++){ + out_data[i * M + k] = in_data[(i * k + j) * M + k]; + } + } +}; + +template<> +inline void AdvancedIndexingOpForward<gpu>(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector<NDArray> &inputs, + const std::vector<OpReqType> &req, + const std::vector<NDArray> &outputs) { + using namespace mshadow; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + + if (inputs[np_indexing_::kIdx].dtype() == mshadow::kBool) { + CHECK(req[0] == kWriteTo || req[0] == kWriteInplace); + const int axis = 0; + const NDArray &data = inputs[0]; + const NDArray &idx = inputs[1]; + const NDArray &out = outputs[0]; + CHECK_EQ(axis, 0) << "Not supported yet"; + CHECK_EQ(data.shape()[axis], idx.shape()[0]); + CHECK_EQ(idx.shape().ndim(), 1U); + Stream<gpu>* s = ctx.get_stream<gpu>(); + cudaStream_t stream = Stream<gpu>::GetStream(s); + // count the number of 1s in `idx`, so that we could know the output dimension + size_t idx_size = idx.shape()[0]; + int32_t valid_num = 0; + int32_t* prefix_sum = nullptr; + void* d_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + // Calculate total temporary memory size + cub::DeviceScan::InclusiveSum(d_temp_storage, + temp_storage_bytes, + prefix_sum, + prefix_sum, + idx_size, + stream); + size_t buffer_size = idx_size * sizeof(int32_t); + temp_storage_bytes += buffer_size; + // Allocate memory on GPU and allocate pointer + Tensor<gpu, 1, char> workspace = + ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(temp_storage_bytes), s); + prefix_sum = reinterpret_cast<int32_t*>(workspace.dptr_); + d_temp_storage = workspace.dptr_ + buffer_size; + MSHADOW_TYPE_SWITCH_WITH_BOOL(idx.dtype(), IType, { + mxnet_op::Kernel<mshadow_op::identity_with_cast, gpu>::Launch( + s, idx.shape()[0], prefix_sum, idx.data().dptr<IType>()); + }); + // Calculate prefix sum + cub::DeviceScan::InclusiveSum(d_temp_storage, + temp_storage_bytes, + prefix_sum, + prefix_sum, + idx_size, + stream); + CUDA_CALL(cudaMemcpyAsync(&valid_num, &prefix_sum[idx_size - 1], sizeof(int32_t), + cudaMemcpyDeviceToHost, stream)); + CUDA_CALL(cudaStreamSynchronize(stream)); + + // Set the output shape forcefully + mxnet::TShape data_shape = data.shape(); + data_shape[axis] = valid_num; + const_cast<NDArray &>(out).Init(data_shape); + size_t input_size = data.shape().Size(); + size_t col_size = input_size / idx.shape()[0]; + // Do the copy + MSHADOW_TYPE_SWITCH_WITH_BOOL(out.dtype(), DType, { + if (valid_num > 0) { + mxnet_op::Kernel<BooleanMaskForwardKernel, gpu>::Launch( + s, input_size, out.data().dptr<DType>(), + data.data().dptr<DType>(), prefix_sum, col_size); + } + }); +} else if (inputs[np_indexing_::kIdx].dtype() == mshadow::kInt8 || + inputs[np_indexing_::kIdx].dtype() == mshadow::kInt16 || + inputs[np_indexing_::kIdx].dtype() == mshadow::kInt32 || + inputs[np_indexing_::kIdx].dtype() == mshadow::kInt64) { + using namespace mxnet_op; + const mxnet::TShape& idxshape = inputs[np_indexing_::kIdx].shape(); + const mxnet::TShape& arrshape = inputs[np_indexing_::kArr].shape(); + + if (idxshape.Size() == 0) { + return; + } + + mxnet::TShape oshape(idxshape.ndim() + arrshape.ndim() - 1, -1); + for (index_t i = 0; i < idxshape.ndim(); ++i) { + oshape[i] = idxshape[i]; + } + for (index_t i = 0; i < arrshape.ndim(); i++) { + if (i < 0) { + oshape[i] = arrshape[i]; + } else if (i > 0) { + oshape[i + idxshape.ndim() - 1] = arrshape[i]; + } + } + + const NDArray &out = outputs[0]; + const_cast<NDArray &>(out).Init(oshape); + + Stream<gpu> *s = ctx.get_stream<gpu>(); + + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[np_indexing_::kOut].dtype(), DType, { // output data type Review comment: I think we may use the `MSHADOW_TYPE_SWITCH` version without bool. Because the bool type won't be used. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
