This is an automated email from the ASF dual-hosted git repository. roywei pushed a commit to branch revert-16790-no_memcpy in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 7e561a3800827121f3d367c3a8252bb228ce5a4b Author: Lai Wei <[email protected]> AuthorDate: Wed Dec 4 10:07:09 2019 -0800 Revert "migrate cudaMemcpy to cudaMemcpyAsync+cudaStreamSynchronize (#16790)" This reverts commit 42d3182e5abd2ebbacb45027a08c793d30d46a50. --- src/kvstore/kvstore_utils.cu | 9 +- src/ndarray/ndarray_function.cu | 10 +-- src/operator/contrib/adamw-inl.h | 11 ++- src/operator/contrib/adamw.cc | 2 +- src/operator/contrib/adamw.cu | 8 +- src/operator/contrib/boolean_mask.cu | 16 ++-- src/operator/contrib/index_array.cu | 19 ++-- src/operator/contrib/multi_proposal.cu | 110 +++++++++++------------- src/operator/contrib/proposal.cu | 33 +++---- src/operator/numpy/np_boolean_mask_assign.cu | 11 +-- src/operator/numpy/np_nonzero_op.cu | 14 ++- src/operator/numpy/np_unique_op.cu | 17 ++-- src/operator/numpy/random/dist_common.cc | 4 +- src/operator/numpy/random/dist_common.cu | 14 +-- src/operator/numpy/random/dist_common.h | 4 +- src/operator/numpy/random/np_bernoulli_op.h | 2 +- src/operator/numpy/random/np_multinomial_op.cu | 8 +- src/operator/numpy/random/np_multinomial_op.h | 4 +- src/operator/numpy/random/np_normal_op.h | 4 +- src/operator/tensor/cast_storage-inl.cuh | 8 +- src/operator/tensor/dot-inl.cuh | 8 +- src/operator/tensor/elemwise_binary_op_basic.cu | 5 +- src/operator/tensor/indexing_op.cu | 15 ++-- src/operator/tensor/matrix_op.cu | 5 +- src/operator/tensor/square_sum.cu | 4 +- 25 files changed, 145 insertions(+), 200 deletions(-) diff --git a/src/kvstore/kvstore_utils.cu b/src/kvstore/kvstore_utils.cu index 92b203c..2dab5bc 100644 --- a/src/kvstore/kvstore_utils.cu +++ b/src/kvstore/kvstore_utils.cu @@ -82,17 +82,16 @@ size_t UniqueImplGPU(NDArray *workspace, mshadow::Stream<gpu> *s, #else thrust::sort(thrust::cuda::par.on(stream), dptr, dptr + size, thrust::greater<IType>()); - CUDA_CALL(cudaMemcpyAsync(sort_output_ptr, dptr, sort_output_bytes, - cudaMemcpyDeviceToDevice, stream)); + CUDA_CALL(cudaMemcpy(sort_output_ptr, dptr, sort_output_bytes, + cudaMemcpyDeviceToDevice)); #endif // execute unique kernel cub::DeviceSelect::Unique(temp_storage, unique_temp_bytes, sort_output_ptr, dptr, num_selected_ptr, size, stream); // retrieve num selected unique values size_t num_selected_out = 0; - CUDA_CALL(cudaMemcpyAsync(&num_selected_out, num_selected_ptr, num_selected_bytes, - cudaMemcpyDeviceToHost, stream)); - CUDA_CALL(cudaStreamSynchronize(stream)); + CUDA_CALL(cudaMemcpy(&num_selected_out, num_selected_ptr, num_selected_bytes, + cudaMemcpyDeviceToHost)); return num_selected_out; } diff --git a/src/ndarray/ndarray_function.cu b/src/ndarray/ndarray_function.cu index 79bc345..da7b60d 100644 --- a/src/ndarray/ndarray_function.cu +++ b/src/ndarray/ndarray_function.cu @@ -129,13 +129,12 @@ void ElementwiseSumRspImpl(mshadow::Stream<gpu>* s, IType* row_flg = NULL; void* d_temp_storage = NULL; size_t temp_storage_bytes = 0; - cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s); cub::DeviceScan::InclusiveSum(d_temp_storage, temp_storage_bytes, row_flg, row_flg, num_rows, - stream); + mshadow::Stream<gpu>::GetStream(s)); mshadow::Tensor<gpu, 1, char> workspace = rsc .get_space_typed<gpu, 1, char>(mshadow::Shape1(num_rows * sizeof(IType) + temp_storage_bytes), s); @@ -159,12 +158,11 @@ void ElementwiseSumRspImpl(mshadow::Stream<gpu>* s, row_flg, row_flg, num_rows, - stream); + mshadow::Stream<gpu>::GetStream(s)); // Get total number of output non-zero rows from GPU and allocate out data and row_idx dim_t nnr_out = 0; - CUDA_CALL(cudaMemcpyAsync(&nnr_out, &row_flg[num_rows-1], sizeof(dim_t), - cudaMemcpyDeviceToHost, stream)); - CUDA_CALL(cudaStreamSynchronize(stream)); + CUDA_CALL(cudaMemcpy(&nnr_out, &row_flg[num_rows-1], sizeof(dim_t), + cudaMemcpyDeviceToHost)); out->CheckAndAlloc({mshadow::Shape1(nnr_out)}); IType* out_row_idx = out->aux_data(kIdx).dptr<IType>(); DType* out_data = out->data().dptr<DType>(); diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adamw-inl.h index 6f48333..fd139de 100644 --- a/src/operator/contrib/adamw-inl.h +++ b/src/operator/contrib/adamw-inl.h @@ -442,15 +442,14 @@ static inline void MultiAdamWUpdate(const nnvm::NodeAttrs& attrs, } template<typename xpu> -void GetScaleFloat(mshadow::Stream<xpu> *s, const TBlob &scale_blob, float *pScalef); +void GetScaleFloat(const TBlob &scale_blob, float *pScalef); template<typename xpu> -bool PrepareInputBlobs(const OpContext &ctx, - const std::vector<TBlob> &inputs, +bool PrepareInputBlobs(const std::vector<TBlob> &inputs, std::vector<TBlob> *inputs_wo_scale, float *pScalef) { const size_t num_in = inputs.size() - 1; - GetScaleFloat<xpu>(ctx.get_stream<xpu>(), inputs[num_in], pScalef); + GetScaleFloat<xpu>(inputs[num_in], pScalef); if (!std::isfinite(*pScalef) || *pScalef == 0) return false; @@ -469,7 +468,7 @@ inline void MPUpdate(const nnvm::NodeAttrs& attrs, const std::vector<TBlob> &outputs) { std::vector<TBlob> inputs_wo_scale; float scalef; - if (!PrepareInputBlobs<xpu>(ctx, inputs, &inputs_wo_scale, &scalef)) + if (!PrepareInputBlobs<xpu>(inputs, &inputs_wo_scale, &scalef)) return; F::Forward(attrs, ctx, inputs_wo_scale, req, outputs, scalef); @@ -483,7 +482,7 @@ inline void multiMPUpdate(const nnvm::NodeAttrs& attrs, const std::vector<TBlob> &outputs) { std::vector<TBlob> inputs_wo_scale; float scalef; - if (!PrepareInputBlobs<xpu>(ctx, inputs, &inputs_wo_scale, &scalef)) + if (!PrepareInputBlobs<xpu>(inputs, &inputs_wo_scale, &scalef)) return; if (!MP) diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc index effae5c..2c730f0 100644 --- a/src/operator/contrib/adamw.cc +++ b/src/operator/contrib/adamw.cc @@ -119,7 +119,7 @@ the update is skipped. .add_arguments(AdamWParam::__FIELDS__()); template<> -void GetScaleFloat<cpu>(mshadow::Stream<cpu> *s, const TBlob &scale_blob, float *pScalef) { +void GetScaleFloat<cpu>(const TBlob &scale_blob, float *pScalef) { MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, *pScalef = static_cast<float>(*scale_blob.dptr<DType>()); ) diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adamw.cu index 2b0040e..81b13c9 100644 --- a/src/operator/contrib/adamw.cu +++ b/src/operator/contrib/adamw.cu @@ -29,13 +29,11 @@ namespace mxnet { namespace op { template<> -void GetScaleFloat<gpu>(mshadow::Stream<gpu> *s, const TBlob &scale_blob, float *pScalef) { +void GetScaleFloat<gpu>(const TBlob &scale_blob, float *pScalef) { MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, { DType scale = 0; - cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s); - CUDA_CALL(cudaMemcpyAsync(&scale, scale_blob.dptr<DType>(), sizeof(DType), - cudaMemcpyDeviceToHost, stream)); - CUDA_CALL(cudaStreamSynchronize(stream)); + CUDA_CALL(cudaMemcpy(&scale, scale_blob.dptr<DType>(), sizeof(DType), + cudaMemcpyDeviceToHost)); *pScalef = static_cast<float>(scale); }) } diff --git a/src/operator/contrib/boolean_mask.cu b/src/operator/contrib/boolean_mask.cu index 95f5614..a5ef4a7 100644 --- a/src/operator/contrib/boolean_mask.cu +++ b/src/operator/contrib/boolean_mask.cu @@ -46,7 +46,6 @@ inline void BooleanMaskForward<gpu>(const nnvm::NodeAttrs& attrs, CHECK_EQ(data.shape()[axis], idx.shape()[0]); CHECK_EQ(idx.shape().ndim(), 1U); Stream<gpu>* s = ctx.get_stream<gpu>(); - cudaStream_t stream = Stream<gpu>::GetStream(s); // count the number of 1s in `idx`, so that we could know the output dimension size_t idx_size = idx.shape()[0]; int32_t valid_num = 0; @@ -59,7 +58,7 @@ inline void BooleanMaskForward<gpu>(const nnvm::NodeAttrs& attrs, prefix_sum, prefix_sum, idx_size, - stream); + Stream<gpu>::GetStream(s)); size_t buffer_size = idx_size * sizeof(int32_t); temp_storage_bytes += buffer_size; // Allocate memory on GPU and allocate pointer @@ -77,11 +76,9 @@ inline void BooleanMaskForward<gpu>(const nnvm::NodeAttrs& attrs, prefix_sum, prefix_sum, idx_size, - stream); - CUDA_CALL(cudaMemcpyAsync(&valid_num, &prefix_sum[idx_size - 1], sizeof(int32_t), - cudaMemcpyDeviceToHost, stream)); - CUDA_CALL(cudaStreamSynchronize(stream)); - + Stream<gpu>::GetStream(s)); + CUDA_CALL(cudaMemcpy(&valid_num, &prefix_sum[idx_size - 1], sizeof(int32_t), + cudaMemcpyDeviceToHost)); // Set the output shape forcefully mxnet::TShape data_shape = data.shape(); data_shape[axis] = valid_num; @@ -113,7 +110,6 @@ inline void BooleanMaskBackward<gpu>(const nnvm::NodeAttrs& attrs, const NDArray& idx = inputs[2]; const NDArray& igrad_data = outputs[0]; Stream<gpu>* s = ctx.get_stream<gpu>(); - cudaStream_t stream = Stream<gpu>::GetStream(s); // Count the number of 1s in `idx`, so that we could know the output dimension size_t idx_size = idx.shape()[0]; int32_t* prefix_sum = nullptr; @@ -125,7 +121,7 @@ inline void BooleanMaskBackward<gpu>(const nnvm::NodeAttrs& attrs, prefix_sum, prefix_sum, idx_size, - stream); + Stream<gpu>::GetStream(s)); size_t buffer_size = idx_size * sizeof(int32_t); temp_storage_bytes += buffer_size; // Allocate memory on GPU and allocate pointer @@ -143,7 +139,7 @@ inline void BooleanMaskBackward<gpu>(const nnvm::NodeAttrs& attrs, prefix_sum, prefix_sum, idx_size, - stream); + Stream<gpu>::GetStream(s)); size_t input_size = igrad_data.shape().Size(); size_t col_size = input_size / idx_size; // Backward pass diff --git a/src/operator/contrib/index_array.cu b/src/operator/contrib/index_array.cu index dae61ca..ddba6a8 100644 --- a/src/operator/contrib/index_array.cu +++ b/src/operator/contrib/index_array.cu @@ -41,8 +41,7 @@ void IndexArrayForwardGPU(const nnvm::NodeAttrs &attrs, const TShape inshape = in_data.shape_; const int ndim = inshape.ndim(); - Stream<gpu> *s = ctx.get_stream<gpu>(); - cudaStream_t stream = Stream<gpu>::GetStream(s); + Stream<gpu> *stream = ctx.get_stream<gpu>(); using namespace mxnet_op; @@ -56,24 +55,24 @@ void IndexArrayForwardGPU(const nnvm::NodeAttrs &attrs, IndexArrayBuildSelectedAxesWorkspace(axes, index_products, cpu_workspace.data(), ndim); Tensor<gpu, 1, int64_t> workspace = - ctx.requested[0].get_space_typed<gpu, 1, int64_t>(Shape1(2 * naxes), s); + ctx.requested[0].get_space_typed<gpu, 1, int64_t>(Shape1(2 * naxes), stream); - CUDA_CALL(cudaMemcpyAsync(workspace.dptr_, cpu_workspace.data(), sizeof(int64_t) * (2 * naxes), - cudaMemcpyHostToDevice, stream)); + CUDA_CALL(cudaMemcpy(workspace.dptr_, cpu_workspace.data(), sizeof(int64_t) * (2 * naxes), + cudaMemcpyHostToDevice)); MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - Kernel<IndexArrayKernel<req_type>, gpu>::Launch(s, in_data.Size(), + Kernel<IndexArrayKernel<req_type>, gpu>::Launch(stream, in_data.Size(), out_data.dptr<int64_t>(), naxes, workspace.dptr_); }); } else { Tensor<gpu, 1, dim_t> workspace = - ctx.requested[0].get_space_typed<gpu, 1, dim_t>(Shape1(ndim), s); + ctx.requested[0].get_space_typed<gpu, 1, dim_t>(Shape1(ndim), stream); - CUDA_CALL(cudaMemcpyAsync(workspace.dptr_, inshape.data(), sizeof(dim_t) * ndim, - cudaMemcpyHostToDevice, stream)); + CUDA_CALL(cudaMemcpy(workspace.dptr_, inshape.data(), sizeof(dim_t) * ndim, + cudaMemcpyHostToDevice)); MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - Kernel<IndexArrayDefaultKernel<req_type>, gpu>::Launch(s, in_data.Size(), + Kernel<IndexArrayDefaultKernel<req_type>, gpu>::Launch(stream, in_data.Size(), out_data.dptr<int64_t>(), ndim, workspace.dptr_); }); } diff --git a/src/operator/contrib/multi_proposal.cu b/src/operator/contrib/multi_proposal.cu index 1aa852a..4552ae4 100644 --- a/src/operator/contrib/multi_proposal.cu +++ b/src/operator/contrib/multi_proposal.cu @@ -324,8 +324,7 @@ __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, } } -void _nms(mshadow::Stream<gpu> *s, - const mshadow::Tensor<gpu, 2>& boxes, +void _nms(const mshadow::Tensor<gpu, 2>& boxes, const float nms_overlap_thresh, const int rpn_post_nms_top_n, int *keep, @@ -350,13 +349,10 @@ void _nms(mshadow::Stream<gpu> *s, mask_dev); FRCNN_CUDA_CHECK(cudaPeekAtLastError()); std::vector<uint64_t> mask_host(boxes_num * col_blocks); - - cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s); - FRCNN_CUDA_CHECK(cudaMemcpyAsync(&mask_host[0], - mask_dev, - sizeof(uint64_t) * boxes_num * col_blocks, - cudaMemcpyDeviceToHost, stream)); - FRCNN_CUDA_CHECK(cudaStreamSynchronize(stream)); + FRCNN_CUDA_CHECK(cudaMemcpy(&mask_host[0], + mask_dev, + sizeof(uint64_t) * boxes_num * col_blocks, + cudaMemcpyDeviceToHost)); std::vector<uint64_t> remv(col_blocks); memset(&remv[0], 0, sizeof(uint64_t) * col_blocks); @@ -480,12 +476,8 @@ class MultiProposalGPUOp : public Operator{ sizeof(float) * num_images * count_anchors * 5)); Tensor<xpu, 3> workspace_proposals(workspace_proposals_ptr, Shape3(num_images, count_anchors, 5)); - - cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s); - - FRCNN_CUDA_CHECK(cudaMemcpyAsync(workspace_proposals.dptr_, &anchors[0], - sizeof(float) * anchors.size(), - cudaMemcpyHostToDevice, stream)); + FRCNN_CUDA_CHECK(cudaMemcpy(workspace_proposals.dptr_, &anchors[0], + sizeof(float) * anchors.size(), cudaMemcpyHostToDevice)); // Copy proposals to a mesh grid dim3 dimGrid((count + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock); @@ -537,50 +529,50 @@ class MultiProposalGPUOp : public Operator{ FRCNN_CUDA_CHECK(cudaMalloc(&keep, sizeof(int) * rpn_pre_nms_top_n)); for (int b = 0; b < num_images; b++) { - CheckLaunchParam(dimGrid, dimBlock, "CopyScore"); - CopyScoreKernel << <dimGrid, dimBlock >> >( - count_anchors, workspace_proposals.dptr_ + b * count_anchors * 5, - score.dptr_, order.dptr_); - FRCNN_CUDA_CHECK(cudaPeekAtLastError()); - - // argsort score, save order - thrust::stable_sort_by_key(thrust::device, - score.dptr_, - score.dptr_ + score.size(0), - order.dptr_, - thrust::greater<real_t>()); - FRCNN_CUDA_CHECK(cudaPeekAtLastError()); - - // Reorder proposals according to order - - dimGrid.x = (rpn_pre_nms_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; - CheckLaunchParam(dimGrid, dimBlock, "ReorderProposals"); - ReorderProposalsKernel << <dimGrid, dimBlock >> >( - rpn_pre_nms_top_n, workspace_proposals.dptr_ + b * count_anchors * 5, - order.dptr_, workspace_ordered_proposals.dptr_); - FRCNN_CUDA_CHECK(cudaPeekAtLastError()); - - // perform nms - std::vector<int> _keep(workspace_ordered_proposals.size(0)); - int out_size = 0; - _nms(s, workspace_ordered_proposals, - param_.threshold, - rpn_post_nms_top_n, - &_keep[0], - &out_size); - - // copy nms result to gpu - FRCNN_CUDA_CHECK(cudaMemcpyAsync(keep, &_keep[0], sizeof(int) * _keep.size(), - cudaMemcpyHostToDevice, stream)); - - // copy results after nms - dimGrid.x = (param_.rpn_post_nms_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; - CheckLaunchParam(dimGrid, dimBlock, "PrepareOutput"); - PrepareOutput << <dimGrid, dimBlock >> >( - param_.rpn_post_nms_top_n, workspace_ordered_proposals.dptr_, keep, out_size, b, - out.dptr_ + b * param_.rpn_post_nms_top_n * 5, - out_score.dptr_ + b * param_.rpn_post_nms_top_n); - FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + CheckLaunchParam(dimGrid, dimBlock, "CopyScore"); + CopyScoreKernel << <dimGrid, dimBlock >> >( + count_anchors, workspace_proposals.dptr_ + b * count_anchors * 5, + score.dptr_, order.dptr_); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + // argsort score, save order + thrust::stable_sort_by_key(thrust::device, + score.dptr_, + score.dptr_ + score.size(0), + order.dptr_, + thrust::greater<real_t>()); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + // Reorder proposals according to order + + dimGrid.x = (rpn_pre_nms_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; + CheckLaunchParam(dimGrid, dimBlock, "ReorderProposals"); + ReorderProposalsKernel << <dimGrid, dimBlock >> >( + rpn_pre_nms_top_n, workspace_proposals.dptr_ + b * count_anchors * 5, + order.dptr_, workspace_ordered_proposals.dptr_); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); + + // perform nms + std::vector<int> _keep(workspace_ordered_proposals.size(0)); + int out_size = 0; + _nms(workspace_ordered_proposals, + param_.threshold, + rpn_post_nms_top_n, + &_keep[0], + &out_size); + + // copy nms result to gpu + FRCNN_CUDA_CHECK(cudaMemcpy(keep, &_keep[0], sizeof(int) * _keep.size(), + cudaMemcpyHostToDevice)); + + // copy results after nms + dimGrid.x = (param_.rpn_post_nms_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; + CheckLaunchParam(dimGrid, dimBlock, "PrepareOutput"); + PrepareOutput << <dimGrid, dimBlock >> >( + param_.rpn_post_nms_top_n, workspace_ordered_proposals.dptr_, keep, out_size, b, + out.dptr_ + b * param_.rpn_post_nms_top_n * 5, + out_score.dptr_ + b * param_.rpn_post_nms_top_n); + FRCNN_CUDA_CHECK(cudaPeekAtLastError()); } // free temporary memory FRCNN_CUDA_CHECK(cudaFree(keep)); diff --git a/src/operator/contrib/proposal.cu b/src/operator/contrib/proposal.cu index b107dfa..446c92b 100644 --- a/src/operator/contrib/proposal.cu +++ b/src/operator/contrib/proposal.cu @@ -305,8 +305,7 @@ __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, } } -void _nms(mshadow::Stream<gpu> *s, - const mshadow::Tensor<gpu, 2>& boxes, +void _nms(const mshadow::Tensor<gpu, 2>& boxes, const float nms_overlap_thresh, const int rpn_post_nms_top_n, int *keep, @@ -331,12 +330,10 @@ void _nms(mshadow::Stream<gpu> *s, mask_dev); FRCNN_CUDA_CHECK(cudaPeekAtLastError()); std::vector<uint64_t> mask_host(boxes_num * col_blocks); - cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s); - FRCNN_CUDA_CHECK(cudaMemcpyAsync(&mask_host[0], - mask_dev, - sizeof(uint64_t) * boxes_num * col_blocks, - cudaMemcpyDeviceToHost, stream)); - FRCNN_CUDA_CHECK(cudaStreamSynchronize(stream)); + FRCNN_CUDA_CHECK(cudaMemcpy(&mask_host[0], + mask_dev, + sizeof(uint64_t) * boxes_num * col_blocks, + cudaMemcpyDeviceToHost)); std::vector<uint64_t> remv(col_blocks); memset(&remv[0], 0, sizeof(uint64_t) * col_blocks); @@ -459,10 +456,9 @@ class ProposalGPUOp : public Operator{ float* workspace_proposals_ptr = NULL; FRCNN_CUDA_CHECK(cudaMalloc(&workspace_proposals_ptr, sizeof(float) * count * 5)); Tensor<xpu, 2> workspace_proposals(workspace_proposals_ptr, Shape2(count, 5)); - cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s); - FRCNN_CUDA_CHECK(cudaMemcpyAsync(workspace_proposals.dptr_, - &anchors[0], sizeof(float) * anchors.size(), - cudaMemcpyHostToDevice, stream)); + FRCNN_CUDA_CHECK(cudaMemcpy(workspace_proposals.dptr_, + &anchors[0], sizeof(float) * anchors.size(), + cudaMemcpyHostToDevice)); // Copy proposals to a mesh grid dim3 dimGrid((count + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock); @@ -475,10 +471,9 @@ class ProposalGPUOp : public Operator{ // im_info is small, we want to copy them to cpu std::vector<float> cpu_im_info(3); - FRCNN_CUDA_CHECK(cudaMemcpyAsync(&cpu_im_info[0], im_info.dptr_, - sizeof(float) * cpu_im_info.size(), - cudaMemcpyDeviceToHost, stream)); - FRCNN_CUDA_CHECK(cudaStreamSynchronize(stream)); + FRCNN_CUDA_CHECK(cudaMemcpy(&cpu_im_info[0], im_info.dptr_, + sizeof(float) * cpu_im_info.size(), + cudaMemcpyDeviceToHost)); // prevent padded predictions int real_height = static_cast<int>(cpu_im_info[0] / param_.feature_stride); @@ -548,7 +543,7 @@ class ProposalGPUOp : public Operator{ // perform nms std::vector<int> _keep(workspace_ordered_proposals.size(0)); int out_size = 0; - _nms(s, workspace_ordered_proposals, + _nms(workspace_ordered_proposals, param_.threshold, rpn_post_nms_top_n, &_keep[0], @@ -557,8 +552,8 @@ class ProposalGPUOp : public Operator{ // copy nms result to gpu int* keep; FRCNN_CUDA_CHECK(cudaMalloc(&keep, sizeof(int) * _keep.size())); - FRCNN_CUDA_CHECK(cudaMemcpyAsync(keep, &_keep[0], sizeof(int) * _keep.size(), - cudaMemcpyHostToDevice, stream)); + FRCNN_CUDA_CHECK(cudaMemcpy(keep, &_keep[0], sizeof(int) * _keep.size(), + cudaMemcpyHostToDevice)); // copy results after nms dimGrid.x = (param_.rpn_post_nms_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; diff --git a/src/operator/numpy/np_boolean_mask_assign.cu b/src/operator/numpy/np_boolean_mask_assign.cu index 2ccc4ff..e3b0330 100644 --- a/src/operator/numpy/np_boolean_mask_assign.cu +++ b/src/operator/numpy/np_boolean_mask_assign.cu @@ -113,7 +113,6 @@ size_t* GetValidNumGPU(const OpContext &ctx, const DType *idx, const size_t idx_ void* d_temp_storage = nullptr; size_t temp_storage_bytes = 0; Stream<gpu>* s = ctx.get_stream<gpu>(); - cudaStream_t stream = Stream<gpu>::GetStream(s); // Calculate total temporary memory size cub::DeviceScan::ExclusiveSum(d_temp_storage, @@ -121,7 +120,7 @@ size_t* GetValidNumGPU(const OpContext &ctx, const DType *idx, const size_t idx_ prefix_sum, prefix_sum, idx_size + 1, - stream); + Stream<gpu>::GetStream(s)); size_t buffer_size = (idx_size + 1) * sizeof(size_t); temp_storage_bytes += buffer_size; // Allocate memory on GPU and allocate pointer @@ -145,7 +144,7 @@ size_t* GetValidNumGPU(const OpContext &ctx, const DType *idx, const size_t idx_ prefix_sum, prefix_sum, idx_size + 1, - stream); + Stream<gpu>::GetStream(s)); return prefix_sum; } @@ -175,10 +174,8 @@ void NumpyBooleanAssignForwardGPU(const nnvm::NodeAttrs& attrs, MSHADOW_TYPE_SWITCH(mask.type_flag_, MType, { prefix_sum = GetValidNumGPU<MType>(ctx, mask.dptr<MType>(), mask_size); }); - cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s); - CUDA_CALL(cudaMemcpyAsync(&valid_num, &prefix_sum[mask_size], sizeof(size_t), - cudaMemcpyDeviceToHost, stream)); - CUDA_CALL(cudaStreamSynchronize(stream)); + CUDA_CALL(cudaMemcpy(&valid_num, &prefix_sum[mask_size], sizeof(size_t), + cudaMemcpyDeviceToHost)); } // If there's no True in mask, return directly if (valid_num == 0) return; diff --git a/src/operator/numpy/np_nonzero_op.cu b/src/operator/numpy/np_nonzero_op.cu index a31222e..c732d2c 100644 --- a/src/operator/numpy/np_nonzero_op.cu +++ b/src/operator/numpy/np_nonzero_op.cu @@ -63,7 +63,6 @@ void NonzeroForwardGPU(const nnvm::NodeAttrs& attrs, } int32_t valid_num = 0; Stream<gpu>* stream = ctx.get_stream<gpu>(); - cudaStream_t cuda_stream = Stream<gpu>::GetStream(stream); int32_t* prefix_sum = nullptr; void* d_temp_storage = nullptr; size_t temp_storage_bytes = 0; @@ -73,7 +72,7 @@ void NonzeroForwardGPU(const nnvm::NodeAttrs& attrs, prefix_sum, prefix_sum, in_size, - cuda_stream); + Stream<gpu>::GetStream(stream)); size_t buffer_size = in_size * sizeof(int32_t); temp_storage_bytes += buffer_size; // Allocate memory on GPU and allocate pointer @@ -91,18 +90,17 @@ void NonzeroForwardGPU(const nnvm::NodeAttrs& attrs, prefix_sum, prefix_sum, in_size, - cuda_stream); - CUDA_CALL(cudaMemcpyAsync(&valid_num, &prefix_sum[in_size - 1], sizeof(int32_t), - cudaMemcpyDeviceToHost, cuda_stream)); - CUDA_CALL(cudaStreamSynchronize(cuda_stream)); + Stream<gpu>::GetStream(stream)); + CUDA_CALL(cudaMemcpy(&valid_num, &prefix_sum[in_size - 1], sizeof(int32_t), + cudaMemcpyDeviceToHost)); // 0-dim if (0 == in.shape().ndim()) { mxnet::TShape s(2, 1); if (valid_num) { const_cast<NDArray &>(out).Init(s); int64_t temp = 0; - CUDA_CALL(cudaMemcpyAsync(out.data().dptr<int64_t>(), &temp, sizeof(int64_t), - cudaMemcpyHostToDevice, cuda_stream)); + CUDA_CALL(cudaMemcpy(out.data().dptr<int64_t>(), &temp, sizeof(int64_t), + cudaMemcpyHostToDevice)); } else { s[0] = 0; const_cast<NDArray &>(out).Init(s); diff --git a/src/operator/numpy/np_unique_op.cu b/src/operator/numpy/np_unique_op.cu index 22fd1d1..4d90a45 100644 --- a/src/operator/numpy/np_unique_op.cu +++ b/src/operator/numpy/np_unique_op.cu @@ -97,7 +97,6 @@ void NumpyUniqueGPUNoneAxisImpl(const NumpyUniqueParam& param, const std::vector<NDArray> &outputs) { MXNET_NO_FLOAT16_TYPE_SWITCH(outputs[0].dtype(), DType, { mshadow::Stream<gpu> *stream = ctx.get_stream<gpu>(); - cudaStream_t cuda_stream = mshadow::Stream<gpu>::GetStream(stream); auto policy = thrust::cuda::par.on(stream->stream_); DType* input_data = inputs[0].data().dptr<DType>(); @@ -121,9 +120,8 @@ void NumpyUniqueGPUNoneAxisImpl(const NumpyUniqueParam& param, thrust::device_vector<int32_t> prefix_sum(input_size, 0); thrust::inclusive_scan(policy, mask.begin(), mask.end(), prefix_sum.begin()); int32_t valid_num = 0; - CUDA_CALL(cudaMemcpyAsync(&valid_num, thrust::raw_pointer_cast(&prefix_sum[input_size - 1]), - sizeof(int32_t), cudaMemcpyDeviceToHost, cuda_stream)); - CUDA_CALL(cudaStreamSynchronize(cuda_stream)); + CUDA_CALL(cudaMemcpy(&valid_num, thrust::raw_pointer_cast(&prefix_sum[input_size - 1]), + sizeof(int32_t), cudaMemcpyDeviceToHost)); // set the output shape forcefully mxnet::TShape s(1, valid_num); const_cast<NDArray &>(outputs[0]).Init(s); @@ -182,7 +180,6 @@ void NumpyUniqueGPUImpl(const NumpyUniqueParam& param, using namespace mshadow; using namespace mshadow::expr; Stream<gpu> *stream = ctx.get_stream<gpu>(); - cudaStream_t cuda_stream = Stream<gpu>::GetStream(stream); auto policy = thrust::cuda::par.on(stream->stream_); const index_t actual_axis = param.axis.value() + ((param.axis.value() < 0) ? inputs[0].shape().ndim() : 0); @@ -217,9 +214,8 @@ void NumpyUniqueGPUImpl(const NumpyUniqueParam& param, thrust::device_vector<int32_t> prefix_sum(temp_shape[0], 0); thrust::inclusive_scan(policy, mask.begin(), mask.end(), prefix_sum.begin()); int32_t valid_num = 0; - CUDA_CALL(cudaMemcpyAsync(&valid_num, thrust::raw_pointer_cast(&prefix_sum[temp_shape[0] - 1]), - sizeof(int32_t), cudaMemcpyDeviceToHost, cuda_stream)); - CUDA_CALL(cudaStreamSynchronize(cuda_stream)); + CUDA_CALL(cudaMemcpy(&valid_num, thrust::raw_pointer_cast(&prefix_sum[temp_shape[0] - 1]), + sizeof(int32_t), cudaMemcpyDeviceToHost)); // store the temp output data, reuse the space of 'input_tensor' Tensor<gpu, 3, DType> temp_tensor(workspace.dptr_, Shape3(valid_num, temp_shape[1], temp_shape[2]), stream); @@ -286,12 +282,11 @@ void NumpyUniqueGPUForward(const nnvm::NodeAttrs& attrs, CHECK(!param.axis.has_value() || param.axis.value() == -1 || param.axis.value() == 0) << "Axis can only be -1 or 0 for scalor tensor"; Stream<gpu> *s = ctx.get_stream<gpu>(); - cudaStream_t stream = Stream<gpu>::GetStream(s); mxnet::TShape shape_1(1, 1); const_cast<NDArray &>(outputs[0]).Init(shape_1); MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { - CUDA_CALL(cudaMemcpyAsync(outputs[0].data().dptr<DType>(), inputs[0].data().dptr<DType>(), - sizeof(DType), cudaMemcpyDeviceToDevice, stream)); + CUDA_CALL(cudaMemcpy(outputs[0].data().dptr<DType>(), inputs[0].data().dptr<DType>(), + sizeof(DType), cudaMemcpyDeviceToDevice)); }); int output_flag = 0; if (param.return_index) { diff --git a/src/operator/numpy/random/dist_common.cc b/src/operator/numpy/random/dist_common.cc index 18a2085..9255656 100644 --- a/src/operator/numpy/random/dist_common.cc +++ b/src/operator/numpy/random/dist_common.cc @@ -30,12 +30,12 @@ namespace mxnet { namespace op { template <> -void _copy<cpu>(mshadow::Stream<cpu> *s, float *dst, float *src) { +void _copy<cpu>(float *dst, float *src) { *dst = *src; } template <> -void _copy<cpu>(mshadow::Stream<cpu> *s, double *dst, double *src) { +void _copy<cpu>(double *dst, double *src) { *dst = *src; } diff --git a/src/operator/numpy/random/dist_common.cu b/src/operator/numpy/random/dist_common.cu index dbd313b..7dde012 100644 --- a/src/operator/numpy/random/dist_common.cu +++ b/src/operator/numpy/random/dist_common.cu @@ -30,19 +30,13 @@ namespace mxnet { namespace op { template <> -void _copy<gpu>(mshadow::Stream<gpu> *s, float *dst, float *src) { - cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s); - CUDA_CALL(cudaMemcpyAsync(dst, src, sizeof(float), cudaMemcpyDeviceToHost, - stream)); - CUDA_CALL(cudaStreamSynchronize(stream)); +void _copy<gpu>(float *dst, float *src) { +CUDA_CALL(cudaMemcpy(dst, src, sizeof(float), cudaMemcpyDeviceToHost)); } template <> -void _copy<gpu>(mshadow::Stream<gpu> *s, double *dst, double *src) { - cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s); - CUDA_CALL(cudaMemcpyAsync(dst, src, sizeof(double), cudaMemcpyDeviceToHost, - stream)); - CUDA_CALL(cudaStreamSynchronize(stream)); +void _copy<gpu>(double *dst, double *src) { +CUDA_CALL(cudaMemcpy(dst, src, sizeof(double), cudaMemcpyDeviceToHost)); } } // namespace op diff --git a/src/operator/numpy/random/dist_common.h b/src/operator/numpy/random/dist_common.h index e835829..aafd10e 100644 --- a/src/operator/numpy/random/dist_common.h +++ b/src/operator/numpy/random/dist_common.h @@ -41,10 +41,10 @@ namespace mxnet { namespace op { template <typename xpu> -void _copy(mshadow::Stream<xpu> *s, float *dst, float*src); +void _copy(float *dst, float*src); template <typename xpu> -void _copy(mshadow::Stream<xpu> *s, double *dst, double*src); +void _copy(double *dst, double*src); inline int FillShape(const mxnet::TShape &lshape, const mxnet::TShape &rshape, diff --git a/src/operator/numpy/random/np_bernoulli_op.h b/src/operator/numpy/random/np_bernoulli_op.h index 0df1089..aa8e344 100644 --- a/src/operator/numpy/random/np_bernoulli_op.h +++ b/src/operator/numpy/random/np_bernoulli_op.h @@ -173,7 +173,7 @@ void NumpyBernoulliForward(const nnvm::NodeAttrs &attrs, Kernel<check_legal_prob_kernel<IType>, xpu>::Launch( s, inputs[0].Size(), inputs[0].dptr<IType>(), indicator_device_ptr); }); - _copy<xpu>(s, &indicator_host, indicator_device_ptr); + _copy<xpu>(&indicator_host, indicator_device_ptr); CHECK_GE(indicator_host, 0.0) << "ValueError: expect probs >= 0 && probs <= 1"; } diff --git a/src/operator/numpy/random/np_multinomial_op.cu b/src/operator/numpy/random/np_multinomial_op.cu index 132d67b..6aa1639 100644 --- a/src/operator/numpy/random/np_multinomial_op.cu +++ b/src/operator/numpy/random/np_multinomial_op.cu @@ -28,12 +28,10 @@ namespace mxnet { namespace op { template<typename DType> -void CheckPvalGPU(const OpContext& ctx, DType* input, int prob_length) { +void CheckPvalGPU(DType* input, int prob_length) { std::vector<DType> pvals_(prob_length); - cudaStream_t stream = mshadow::Stream<gpu>::GetStream(ctx.get_stream<gpu>()); - CUDA_CALL(cudaMemcpyAsync(&pvals_[0], input, sizeof(DType) * prob_length, - cudaMemcpyDeviceToHost, stream)); - CUDA_CALL(cudaStreamSynchronize(stream)); + CUDA_CALL(cudaMemcpy(&pvals_[0], input, sizeof(DType) * prob_length, + cudaMemcpyDeviceToHost)); DType sum = DType(0.0); for (int i = 0; i < prob_length; ++i) { sum += pvals_[i]; diff --git a/src/operator/numpy/random/np_multinomial_op.h b/src/operator/numpy/random/np_multinomial_op.h index 9c5c73f..2350d20 100644 --- a/src/operator/numpy/random/np_multinomial_op.h +++ b/src/operator/numpy/random/np_multinomial_op.h @@ -100,7 +100,7 @@ inline bool NumpyMultinomialOpType(const nnvm::NodeAttrs& attrs, } template<typename DType> -void CheckPvalGPU(const OpContext& ctx, DType* input, int prob_length); +void CheckPvalGPU(DType* input, int prob_length); template<typename DType> void CheckPval(DType* input, int prob_length) { @@ -188,7 +188,7 @@ void NumpyMultinomialForward(const nnvm::NodeAttrs& attrs, if (std::is_same<xpu, cpu>::value) { CheckPval<DType>(inputs[0].dptr<DType>(), prob_length); } else { - CheckPvalGPU<DType>(ctx, inputs[0].dptr<DType>(), prob_length); + CheckPvalGPU<DType>(inputs[0].dptr<DType>(), prob_length); } Kernel<multinomial_kernel, xpu>::Launch( s, num_output, num_exp, prob_length, diff --git a/src/operator/numpy/random/np_normal_op.h b/src/operator/numpy/random/np_normal_op.h index 8cc4288..c74151f 100644 --- a/src/operator/numpy/random/np_normal_op.h +++ b/src/operator/numpy/random/np_normal_op.h @@ -181,7 +181,7 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs, Kernel<check_legal_scale_kernel<IType>, xpu>::Launch( s, inputs[0].Size(), inputs[0].dptr<IType>(), indicator_device_ptr); }); - _copy<xpu>(s, &indicator_host, indicator_device_ptr); + _copy<xpu>(&indicator_host, indicator_device_ptr); CHECK_GE(indicator_host, 0.0) << "ValueError: scale < 0"; } else { scalar_pos = 1; @@ -206,7 +206,7 @@ void NumpyNormalForward(const nnvm::NodeAttrs &attrs, Kernel<check_legal_scale_kernel<IType>, xpu>::Launch( s, inputs[1].Size(), inputs[1].dptr<IType>(), indicator_device_ptr); }); - _copy<xpu>(s, &indicator_host, indicator_device_ptr); + _copy<xpu>(&indicator_host, indicator_device_ptr); CHECK_GE(indicator_host, 0.0) << "ValueError: scale < 0"; int ndim = FillShape(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, &new_lshape, &new_hshape, &new_oshape); diff --git a/src/operator/tensor/cast_storage-inl.cuh b/src/operator/tensor/cast_storage-inl.cuh index 4c5d0d8..ee1531d 100644 --- a/src/operator/tensor/cast_storage-inl.cuh +++ b/src/operator/tensor/cast_storage-inl.cuh @@ -162,9 +162,7 @@ void CastStorageDnsRspGPUImpl_(const OpContext& ctx, // Get total number of non-zero rows from device dim_t nnr = 0; - CUDA_CALL(cudaMemcpyAsync(&nnr, &row_flg[num_rows - 1], sizeof(dim_t), - cudaMemcpyDeviceToHost, mshadow::Stream<gpu>::GetStream(s))); - CUDA_CALL(cudaStreamSynchronize(mshadow::Stream<gpu>::GetStream(s))); + CUDA_CALL(cudaMemcpy(&nnr, &row_flg[num_rows - 1], sizeof(dim_t), cudaMemcpyDeviceToHost)); // Allocate rsp tensor row index array and fill rsp->CheckAndAllocAuxData(rowsparse::kIdx, Shape1(nnr)); @@ -557,9 +555,7 @@ inline void CastStorageDnsCsrImpl(const OpContext& ctx, // Receive total number of nnz values from device IType nnz = 0; - CUDA_CALL(cudaMemcpyAsync(&nnz, &(indptr[num_rows]), sizeof(IType), cudaMemcpyDeviceToHost, - mshadow::Stream<gpu>::GetStream(s))); - CUDA_CALL(cudaStreamSynchronize(mshadow::Stream<gpu>::GetStream(s))); + CUDA_CALL(cudaMemcpy(&nnz, &(indptr[num_rows]), sizeof(IType), cudaMemcpyDeviceToHost)); // Allocate column index array and data array of the csr matrix csr->CheckAndAllocAuxData(csr::kIdx, Shape1(static_cast<dim_t>(nnz))); diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh index b8244d3..d6fed4a 100644 --- a/src/operator/tensor/dot-inl.cuh +++ b/src/operator/tensor/dot-inl.cuh @@ -702,8 +702,7 @@ inline void DotCsrDnsRspImpl(const OpContext& ctx, nnr_ptr, nnz, stream); // retrieve num non-zero rows size_t nnr = 0; - CUDA_CALL(cudaMemcpyAsync(&nnr, nnr_ptr, nnr_bytes, cudaMemcpyDeviceToHost, stream)); - CUDA_CALL(cudaStreamSynchronize(stream)); + CUDA_CALL(cudaMemcpy(&nnr, nnr_ptr, nnr_bytes, cudaMemcpyDeviceToHost)); // allocate data ret->CheckAndAllocData(mshadow::Shape2(nnz, num_cols_r)); // generate lookup table @@ -818,9 +817,8 @@ inline void DotCsrRspRspImpl(const OpContext& ctx, num_cols_l, mshadow::Stream<gpu>::GetStream(s)); dim_t nnr_out = 0; - CUDA_CALL(cudaMemcpyAsync(&nnr_out, &row_flg_out[num_cols_l-1], sizeof(dim_t), - cudaMemcpyDeviceToHost, mshadow::Stream<gpu>::GetStream(s))); - CUDA_CALL(cudaStreamSynchronize(mshadow::Stream<gpu>::GetStream(s))); + CUDA_CALL(cudaMemcpy(&nnr_out, &row_flg_out[num_cols_l-1], sizeof(dim_t), + cudaMemcpyDeviceToHost)); if (0 == nnr_out) { FillZerosRspImpl(s, *ret); return; diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu b/src/operator/tensor/elemwise_binary_op_basic.cu index e39f7e9..f88b8eb 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cu +++ b/src/operator/tensor/elemwise_binary_op_basic.cu @@ -115,9 +115,8 @@ void ElemwiseBinaryOp::RspRspOp(mshadow::Stream<gpu> *s, num_rows, mshadow::Stream<gpu>::GetStream(s)); nnvm::dim_t nnr_out = 0; - CUDA_CALL(cudaMemcpyAsync(&nnr_out, &common_row_table[num_rows-1], sizeof(nnvm::dim_t), - cudaMemcpyDeviceToHost, mshadow::Stream<gpu>::GetStream(s))); - CUDA_CALL(cudaStreamSynchronize(mshadow::Stream<gpu>::GetStream(s))) + CUDA_CALL(cudaMemcpy(&nnr_out, &common_row_table[num_rows-1], sizeof(nnvm::dim_t), + cudaMemcpyDeviceToHost)); output.CheckAndAlloc({mshadow::Shape1(nnr_out)}); Kernel<FillRspRowIdxKernel, gpu>::Launch( s, num_rows, output.aux_data(kIdx).dptr<IType>(), common_row_table, num_rows); diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 8250efb..3ccf1f3 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -154,9 +154,8 @@ bool CheckIndexOutOfBound(mshadow::Stream<gpu> *s, const DType* data_ptr, size_t int32_t is_valid = 0; Kernel<set_zero, gpu>::Launch(s, 1, is_valid_ptr); Kernel<is_valid_check, gpu>::Launch(s, data_size, is_valid_ptr, data_ptr, min, max); - CUDA_CALL(cudaMemcpyAsync(&is_valid, is_valid_ptr, sizeof(char), - cudaMemcpyDeviceToHost, mshadow::Stream<gpu>::GetStream(s))); - CUDA_CALL(cudaStreamSynchronize(mshadow::Stream<gpu>::GetStream(s))); + CUDA_CALL(cudaMemcpy(&is_valid, is_valid_ptr, sizeof(char), + cudaMemcpyDeviceToHost)); return is_valid == 0; } @@ -308,9 +307,8 @@ void SparseEmbeddingDeterministicKernelLaunch(const OpContext& ctx, grad_row_idx, grad_row_idx + data_size, data_size, Stream<gpu>::GetStream(s)); dim_t nnr = 0; - CUDA_CALL(cudaMemcpyAsync(&nnr, grad_row_idx + data_size, sizeof(RType), - cudaMemcpyDeviceToHost, mshadow::Stream<gpu>::GetStream(s))); - CUDA_CALL(cudaStreamSynchronize(mshadow::Stream<gpu>::GetStream(s))); + CUDA_CALL(cudaMemcpy(&nnr, grad_row_idx + data_size, sizeof(RType), + cudaMemcpyDeviceToHost)); CHECK_EQ(output.shape().ndim(), 2) << "Unexcepted ndim"; output.CheckAndAllocData(Shape2(nnr, output.shape()[1])); output.set_aux_shape(kIdx, Shape1(nnr)); @@ -412,9 +410,8 @@ inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const bool deterministic, num_rows, mshadow::Stream<gpu>::GetStream(s)); dim_t nnr = 0; - CUDA_CALL(cudaMemcpyAsync(&nnr, &prefix_sum[num_rows-1], sizeof(dim_t), - cudaMemcpyDeviceToHost, mshadow::Stream<gpu>::GetStream(s))); - CUDA_CALL(cudaStreamSynchronize(mshadow::Stream<gpu>::GetStream(s))); + CUDA_CALL(cudaMemcpy(&nnr, &prefix_sum[num_rows-1], sizeof(dim_t), + cudaMemcpyDeviceToHost)); if (nnr == 0) { FillZerosRspImpl(s, output); return; diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu index 239e42c..b382c55 100644 --- a/src/operator/tensor/matrix_op.cu +++ b/src/operator/tensor/matrix_op.cu @@ -114,9 +114,8 @@ void SliceDimTwoCsrImpl<gpu>(const mxnet::TShape &begin, const mxnet::TShape &en Stream<gpu>::GetStream(s)); // retrieve nnr RType nnr = 0; - CUDA_CALL(cudaMemcpyAsync(&nnr, &out_indptr[indptr_len-1], sizeof(RType), - cudaMemcpyDeviceToHost, mshadow::Stream<gpu>::GetStream(s))); - CUDA_CALL(cudaStreamSynchronize(mshadow::Stream<gpu>::GetStream(s))); + CUDA_CALL(cudaMemcpy(&nnr, &out_indptr[indptr_len-1], sizeof(RType), + cudaMemcpyDeviceToHost)); // returns zeros in csr format if nnr = 0 if (nnr == 0) { diff --git a/src/operator/tensor/square_sum.cu b/src/operator/tensor/square_sum.cu index 83287e0..0b40786 100644 --- a/src/operator/tensor/square_sum.cu +++ b/src/operator/tensor/square_sum.cu @@ -42,9 +42,7 @@ void CheckSameIdx<gpu>(const OpContext& ctx, mxnet_op::Kernel<mxnet_op::set_zero, gpu>::Launch(s, 1, is_diff_ptr); mxnet_op::Kernel<CheckSameIdxKernel, gpu>::Launch(s, idx_size, ograd_idx, in_idx, is_diff_ptr); - CUDA_CALL(cudaMemcpyAsync(&is_diff, is_diff_ptr, sizeof(int32_t), - cudaMemcpyDeviceToHost, mshadow::Stream<gpu>::GetStream(s))); - CUDA_CALL(cudaStreamSynchronize(mshadow::Stream<gpu>::GetStream(s))); + CUDA_CALL(cudaMemcpy(&is_diff, is_diff_ptr, sizeof(int32_t), cudaMemcpyDeviceToHost)); CHECK_EQ(is_diff, 0) << "SquareSumRspGradImpl only supports" " equal ograd_row_idx and input_row_idx" " when ograd and input are both"
