apeforest commented on a change in pull request #11825: Fix non-determinism of
dot(csr.T, dns) = dns with tests
URL: https://github.com/apache/incubator-mxnet/pull/11825#discussion_r203883317
##########
File path: src/operator/tensor/dot-inl.cuh
##########
@@ -539,86 +434,120 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
Kernel<set_zero, gpu>::Launch(s, num_threads,
data_out.dptr<DType>());
}
if (trans_lhs) {
- // Different kernel versions are optimized for different matrix
instances
- // TODO: switch between kernel versions depending on input
- // (1) 'Scalar kernel' (one thread computing one output
element )
- // (2) 'Warp kernel' (one warp computing one lhs
column for one rhs column )
- // (3) 'Thread block kernel' (one thread block computing one lhs
column for all rhs columns)
- // (4) 'Warp block kernel' (one warp computing one lhs
column for all rhs columns)
- const int kernel_version = 0;
- switch (kernel_version) {
- case 1:
- num_threads = data_out.Size();
- MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
- Kernel<DotCsrTransDnsDnsScalarKernel<ReqType>, gpu>::Launch(s,
num_threads,
- data_out.dptr<DType>(), data_l.dptr<DType>(),
indptr_l.dptr<IType>(),
- col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_rows_l,
num_cols_r);
- });
- break;
- case 2:
- num_threads = threads_per_warp * num_rows_l * num_cols_r;
- Kernel<DotCsrTransDnsDnsWarpKernel, gpu>::Launch(s, num_threads,
- data_out.dptr<DType>(), data_l.dptr<DType>(),
indptr_l.dptr<IType>(),
- col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
- break;
- case 3:
- num_threads = threads_per_block * num_rows_l;
- Kernel<DotCsrTransDnsDnsThreadBlockKernel, gpu>::Launch(s,
num_threads,
- data_out.dptr<DType>(), data_l.dptr<DType>(),
indptr_l.dptr<IType>(),
- col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
- break;
- case 4:
- num_threads = threads_per_warp * num_rows_l;
- Kernel<DotCsrTransDnsDnsWarpBlockKernel, gpu>::Launch(s,
num_threads,
+ // TODO(haojin2): Switching to deterministic algorithm for now.
+ // Further optimizations to come later.
+ const nnvm::dim_t num_csr_rows = lhs.shape()[0];
+ const nnvm::dim_t num_csr_cols = lhs.shape()[1];
+ const nnvm::dim_t num_dns_rows = rhs.shape_[0];
+ const nnvm::dim_t nnz = lhs.storage_shape().Size();
+
+ IType* original_idx_ptr = nullptr;
+ IType* csc_indices_ptr = nullptr;
+ IType* csc_cols_ptr = nullptr;
+ CType* csr_rows_ptr = nullptr;
+ CType* csc_indptr_ptr = nullptr;
+ DType* csc_data_ptr = nullptr;
+ char* temp_storage_ptr = nullptr;
+ size_t original_idx_bytes = nnz*sizeof(IType);
+ size_t csc_indices_bytes = nnz*sizeof(IType);
+ size_t csc_cols_bytes = nnz*sizeof(IType);
+ size_t csr_rows_bytes = nnz*sizeof(CType);
+ size_t csc_indptr_bytes = (num_csr_cols+1)*sizeof(CType);
+ size_t csc_data_bytes = nnz*sizeof(DType);
+ size_t scan_temp_storage_bytes = 0;
+ size_t temp_storage_bytes = SortByKeyWorkspaceSize<IType, IType,
gpu>(nnz);
+ IType* csr_indices_ptr = col_idx_l.dptr<IType>();
+ cub::DeviceScan::ExclusiveSum(temp_storage_ptr,
+ scan_temp_storage_bytes,
+ csc_indptr_ptr,
+ csc_indptr_ptr,
+ num_csr_cols+1,
+ mshadow::Stream<gpu>::GetStream(s));
+ temp_storage_bytes = std::max(temp_storage_bytes,
scan_temp_storage_bytes);
+ temp_storage_bytes += (sizeof(dim_t) - temp_storage_bytes %
sizeof(dim_t));
+ size_t total_workspace_bytes =
+ original_idx_bytes + csc_indices_bytes + csc_cols_bytes +
csr_rows_bytes +
+ csc_indptr_bytes + csc_data_bytes + temp_storage_bytes;
+ total_workspace_bytes += (sizeof(IType) - total_workspace_bytes %
sizeof(IType));
+ Tensor<gpu, 1, char> workspace = ctx.requested[0]
+ .get_space_typed<gpu, 1, char>(Shape1(total_workspace_bytes), s);
+ original_idx_ptr = reinterpret_cast<IType*>(workspace.dptr_);
+ csc_indices_ptr = reinterpret_cast<IType*>(workspace.dptr_ +
original_idx_bytes);
+ csc_cols_ptr = reinterpret_cast<IType*>(workspace.dptr_ +
original_idx_bytes +
+ csc_indices_bytes);
+ csr_rows_ptr = reinterpret_cast<CType*>(workspace.dptr_ +
original_idx_bytes +
+ csc_indices_bytes +
csc_cols_bytes);
+ csc_indptr_ptr = reinterpret_cast<CType*>(workspace.dptr_ +
original_idx_bytes +
+ csc_indices_bytes +
csc_cols_bytes +
+ csr_rows_bytes);
+ temp_storage_ptr = workspace.dptr_ + original_idx_bytes +
csc_indices_bytes +
+ csc_cols_bytes + csr_rows_bytes +
csc_indptr_bytes;
+ csc_data_ptr = reinterpret_cast<DType*>(
+ workspace.dptr_ + total_workspace_bytes -
csc_data_bytes);
+
+ // Fill original_idx
+ mxnet_op::Kernel<range_fwd, gpu>::Launch(
+ s, nnz, 1, IType(0), IType(1), kWriteTo, original_idx_ptr);
+ // Fill csc_cols with copy of csr_indices
+ mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity,
kWriteTo>, gpu>::Launch(
+ s, nnz, csc_cols_ptr, csr_indices_ptr);
+ // Allocate the tensors needed for SortByKey
+ Tensor<gpu, 1, IType> original_idx(original_idx_ptr, Shape1(nnz), s);
+ Tensor<gpu, 1, IType> csc_cols(csc_cols_ptr, Shape1(nnz), s);
+ Tensor<gpu, 1, char> temp_storage(temp_storage_ptr,
Shape1(temp_storage_bytes), s);
+
+ int num_bits = log2i(num_csr_cols - 1);
+ SortByKey(csc_cols, original_idx, true, &temp_storage, 0, num_bits);
Review comment:
Q: should the argument be num_bits -1 based on the SortByKey function
signature?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services