misc. changes and further abstraction of some cudnn codes
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/75f9a0e3 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/75f9a0e3 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/75f9a0e3 Branch: refs/heads/master Commit: 75f9a0e39520fe86f6e774f5295d65830bd274ab Parents: 26101ee Author: Vaan Ng <[email protected]> Authored: Thu May 10 18:34:44 2018 +0800 Committer: Vaan Ng <[email protected]> Committed: Thu May 10 18:34:44 2018 +0800 ---------------------------------------------------------------------- include/singa/core/tensor.h | 21 +-- src/core/tensor/tensor.cc | 12 +- src/core/tensor/tensor_math_cpp.h | 31 ++-- src/core/tensor/tensor_math_cuda.h | 309 +++++++++++++------------------- 4 files changed, 152 insertions(+), 221 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/75f9a0e3/include/singa/core/tensor.h ---------------------------------------------------------------------- diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h index 2c28e0f..b94a982 100644 --- a/include/singa/core/tensor.h +++ b/include/singa/core/tensor.h @@ -105,12 +105,13 @@ class Tensor { } /* - cudnn requires tensor dimensions to fulfill 2 requirements: - 1.) dimensions to be set to a minimum of 4 for 4d and lower dimensional tensors (cudnnOp supports up to 5d, cudnnReduce supports up to 8d) - 2.) dimensions have to be set to multiples of 8 + cudnn requires tensor dimensions to fulfill 1 requirement: + 1.) Dimensions to be set to a minimum of 4 for 4d and lower dimensional tensors + if input tensor is 5d, cudnn will take a 5d tensor as input. Beyond 5d, certain operations are not supported. + (cudnnOp supports up to 5d, cudnnReduce supports up to 8d) - for e.g. Tensor A has shape {3,3}, cudnn requires shape of {1,1,24,24} to be the input - Tensor B has shape (2,3,4), cudnn requires shape of {1,16,24,32} to be the input + for e.g. Tensor A has shape {3,3}, cudnn requires shape of {1,1,3,3} to be the input + Tensor B has shape (2,3,4), cudnn requires shape of {1,2,3,4} to be the input */ vector<int> generate_shape_cuda() const { vector<int> shape_arr; @@ -151,11 +152,11 @@ class Tensor { /* cudnn requires stride dimensions to conform to the format of the shape input as well - 1.) stride dimensions to be set to a minimum of 4 for 4d and lower dimensional tensors (cudnnOp supports up to 5d, cudnnReduce supports up to 8d) - 2.) stride dimensions have to be set to powers of 8, depending on the stride order (outer stride = higher power) + 1.) Stride dimensions to be set to a minimum of 4 for 4d and lower dimensional tensors + If input tensor is 5d, cudnn will take a 5d tensor as input. Beyond 5d, certain operations are not supported. + (cudnnOp supports up to 5d, cudnnReduce supports up to 8d) - for e.g. Tensor A has shape {3,3}, stride {3,1}, cudnn requires shape {1,1,24,24} and stride {576, 576, 24, 1} to be the inputs, - if A is transposed with stride {1,3}, then the new cudnn stride becomes {576, 576, 8, 3} + for e.g. Tensor A has shape {3,3}, stride {3,1}, cudnn requires shape {1,1,3,3} and stride {9, 9, 3, 1} or {9, 9, 1, 3} to be the inputs */ vector<int> generate_strides_cuda() const { vector<int> strides_arr; @@ -177,7 +178,7 @@ class Tensor { } return strides_arr; } else { - LOG(FATAL) << "Dimensions (strides) beyond 3 are currently not supported" ; + LOG(FATAL) << "Dimensions (strides) beyond 5 are currently not supported" ; } } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/75f9a0e3/src/core/tensor/tensor.cc ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index 48751ef..9067242 100644 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -132,10 +132,8 @@ void Tensor::ResetLike(const Tensor &in) { shape_multipliers_ = in.shape_multipliers_; } -//yisen todo //if tensor is not transposed yet i.e strides == 1, then we simply change the shape and generate new default strides //if tensor is already transposed i.e strides != 1, it should be copied to a new tensor with newly generated default strides - void Tensor::Reshape(const Shape &shape) { if(strides_.size()==0) strides_.push_back(1); @@ -144,9 +142,8 @@ void Tensor::Reshape(const Shape &shape) { if (block_ != nullptr && block_->DecRefCount() == 0) device_->FreeBlock(block_); block_ = device_->NewBlock((int)(Product(shape) * SizeOf(data_type_))); - } else if (strides_[0] != 1) { - std::cout << "Reshape Error: Tranposed tensor must return new tensor. Not implemented yet." << std::endl; - return void(); + } else if (transpose()) { + LOG(FATAL) << "Reshape Error: Reshape called on tranposed tensor. Not implemented yet." ; } shape_ = shape; Generate_Strides(); @@ -161,9 +158,8 @@ void Tensor::Reshape(Shape &&shape) { if (block_ != nullptr && block_->DecRefCount() == 0) device_->FreeBlock(block_); block_ = device_->NewBlock((int)(Product(shape) * SizeOf(data_type_))); - } else if (strides_[0] != 1) { - std::cout << "Reshape Error: Tranposed tensor must return new tensor. Not implemented yet." << std::endl; - return void(); + } else if (transpose()) { + LOG(FATAL) << "Reshape Error: Reshape called on tranposed tensor. Not implemented yet." ; } shape_ = std::move(shape); Generate_Strides(); http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/75f9a0e3/src/core/tensor/tensor_math_cpp.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h index 01d9fe3..d4cd5da 100644 --- a/src/core/tensor/tensor_math_cpp.h +++ b/src/core/tensor/tensor_math_cpp.h @@ -724,7 +724,7 @@ void Uniform<float, lang::Cpp>(const float low, // ====================Blas operations====================================== -//yisen todo, this function has block M overwritting to block M itself +//warning, this function has block M overwritting to block M itself template <> void DGMM<float, lang::Cpp>(const bool side_right, const Tensor* M, const Tensor* v, @@ -817,26 +817,26 @@ template <> void Axpy<float, lang::Cpp>(const float alpha, const Tensor *in, Tensor *out, Context *ctx) { //check input tensor for strides first - if((in->strides())[0] == 1){ + if(in->strides() != out->strides()){ const float *inPtr = static_cast<const float *>(in->block()->data()); float *outPtr = static_cast<float *>(out->block()->mutable_data()); cblas_saxpy(in->Size(), alpha, inPtr, 1, outPtr, 1); + } else { + LOG(FATAL) << "Axpy, input and output strides do not match." ; } - //yisen todo - //else throw error } template <> void Dot<float, lang::Cpp>(const Tensor *in1, const Tensor *in2, float *out, Context *ctx) { //check input tensor for strides first - if(((in1->strides())[0] == 1) && ((in2->strides())[0] == 1)){ + if(!(in1->transpose()) && !(in2->transpose())){ const float *in1Ptr = static_cast<const float *>(in1->block()->data()); const float *in2Ptr = static_cast<const float *>(in2->block()->data()); *out = cblas_sdot(in1->Size(), in1Ptr, 1, in2Ptr, 1); + } else { + LOG(FATAL) << "Dot, one of the input is tranposed. Not implemented yet." ; } - //yisen todo - //else throw error } template <> @@ -878,15 +878,14 @@ void GEMV<float, lang::Cpp>(const float alpha, const Tensor *A, const Tensor *v, const float *APtr = static_cast<const float *>(A->block()->data()); const float *vPtr = static_cast<const float *>(v->block()->data()); float *outPtr = static_cast<float *>(out->block()->mutable_data()); - auto trans = ((A->strides())[0] != 1) ? true : false; const size_t m = A->shape()[0]; const size_t n = A->shape()[1]; - if (!trans) { - cblas_sgemv(CblasRowMajor, CblasNoTrans, m, n, alpha, APtr, n, vPtr, 1, - beta, outPtr, 1); - } else { + if (A->transpose()) { cblas_sgemv(CblasRowMajor, CblasTrans, n, m, alpha, APtr, m, vPtr, 1, beta, outPtr, 1); + } else { + cblas_sgemv(CblasRowMajor, CblasNoTrans, m, n, alpha, APtr, n, vPtr, 1, + beta, outPtr, 1); } } @@ -915,9 +914,9 @@ template <> void GEMM<float, lang::Cpp>(const float alpha, const Tensor *A, const Tensor *B, const float beta, Tensor *C, Context *ctx) { - auto transA = ((A->strides())[0] != 1) ? true : false; + auto transA = A->transpose(); auto transa = transA ? CblasTrans : CblasNoTrans; - auto transB = ((B->strides())[0] != 1) ? true : false; + auto transB = B->transpose(); auto transb = transB ? CblasTrans : CblasNoTrans; const size_t nrowA = A->shape()[0]; const size_t ncolA = A->shape()[1]; @@ -1088,7 +1087,6 @@ void Scale<float, lang::Cpp>(const float x, Tensor *out, } } -//yisen todo check purpose of sum in this function template <> void Dot<float, lang::Cpp>(const Tensor *in1, const Tensor *in2, float *out, Context *ctx) { @@ -1116,7 +1114,7 @@ void GEMV<float, lang::Cpp>(const float alpha, const Tensor *A, const Tensor *v, float *outPtr = static_cast<float *>(out->block()->mutable_data()); const float *APtr = static_cast<const float *>(A->block()->data()); const float *vPtr = static_cast<const float *>(v->block()->data()); - bool trans = ((A->strides())[0] != 1) ? true : false; + bool trans = A->transpose(); const size_t m = A->shape(0); const size_t n = A->shape(1); for (size_t r = 0; r < m; r++) { @@ -1129,7 +1127,6 @@ void GEMV<float, lang::Cpp>(const float alpha, const Tensor *A, const Tensor *v, } } -//yisen todo #endif // USE_CBLAS template <> void ComputeCrossEntropy<float, lang::Cpp>(bool int_target, http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/75f9a0e3/src/core/tensor/tensor_math_cuda.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h index f4839e3..3e36877 100644 --- a/src/core/tensor/tensor_math_cuda.h +++ b/src/core/tensor/tensor_math_cuda.h @@ -32,6 +32,30 @@ namespace singa { +cudnnTensorDescriptor_t generate_tensorND_desc(const Tensor* x){ + cudnnTensorDescriptor_t x_desc; + cudnnCreateTensorDescriptor(&x_desc); + cudnnSetTensorNdDescriptor(x_desc, CUDNN_DATA_FLOAT, + x->generate_dim_cuda(), + x->generate_shape_cuda().data(), + x->generate_strides_cuda().data() + ); + + return x_desc; +} + +cudnnOpTensorDescriptor_t generate_Op_desc(cudnnOpTensorOp_t op){ + cudnnOpTensorDescriptor_t op_desc; + cudnnCreateOpTensorDescriptor(&op_desc); + cudnnSetOpTensorDescriptor(op_desc, op, + CUDNN_DATA_FLOAT, + CUDNN_PROPAGATE_NAN + ); + + return op_desc; +} + + /// out[i] = |in[i]| template <> void Abs<float, lang::Cuda>(const Tensor* in, Tensor* out, @@ -39,41 +63,25 @@ void Abs<float, lang::Cuda>(const Tensor* in, Tensor* out, const float* inPtr = static_cast<const float*>(in->block()->data()); float* outPtr = static_cast<float*>(out->block()->mutable_data()); - cudnnOpTensorOp_t op = CUDNN_OP_TENSOR_MAX; - cudnnDataType_t cudnn_dtype = CUDNN_DATA_FLOAT; - cudnnNanPropagation_t cudnn_propagation = CUDNN_PROPAGATE_NAN; - cudnnOpTensorDescriptor_t op_desc; - cudnnCreateOpTensorDescriptor(&op_desc); - cudnnSetOpTensorDescriptor(op_desc, op, cudnn_dtype, cudnn_propagation); - - float alpha1[1] = {1.0}; - float alpha2[1] = {-1.0}; - float beta[1] = {0.0}; - cudnnTensorDescriptor_t in_desc, out_desc; - cudnnCreateTensorDescriptor(&in_desc); - cudnnCreateTensorDescriptor(&out_desc); - cudnnSetTensorNdDescriptor(in_desc, cudnn_dtype, in->generate_dim_cuda(), in->generate_shape_cuda().data(), in->generate_strides_cuda().data()); - cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), out->generate_shape_cuda().data(), out->generate_strides_cuda().data()); - cudnnOpTensor(ctx->cudnn_handle, op_desc, (void*)(&alpha1), in_desc, inPtr, - (void*)(&alpha2), in_desc, inPtr, (void*)(&beta), out_desc, outPtr); - + float alpha1 = 1.0; + float alpha2 = -1.0; + float beta = 0.0; + cudnnTensorDescriptor_t in_desc = generate_tensorND_desc(in); + cudnnOpTensor(ctx->cudnn_handle, generate_Op_desc(CUDNN_OP_TENSOR_MAX), + (void*)(&alpha1), in_desc, inPtr, + (void*)(&alpha2), in_desc, inPtr, + (void*)(&beta), generate_tensorND_desc(out), outPtr + ); cudnnDestroyTensorDescriptor(in_desc); - cudnnDestroyTensorDescriptor(out_desc); } template <> void Set<float, lang::Cuda>(const float x, Tensor* out, Context* ctx) { float* outPtr = static_cast<float*>(out->block()->mutable_data()); - //float valuePtr[1] = {x}; - - cudnnDataType_t cudnn_dtype = CUDNN_DATA_FLOAT; - cudnnTensorDescriptor_t out_desc; - cudnnCreateTensorDescriptor(&out_desc); - cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), out->generate_shape_cuda().data(), out->generate_strides_cuda().data()); - cudnnSetTensor(ctx->cudnn_handle, out_desc, outPtr, (void*)(&x)); - cudnnDestroyTensorDescriptor(out_desc); + cudnnSetTensor(ctx->cudnn_handle, generate_tensorND_desc(out), + outPtr, (void*)(&x)); } template <> @@ -83,17 +91,11 @@ void Add<float, lang::Cuda>(const Tensor* in, const float x, const float* inPtr = static_cast<const float*>(in->block()->data()); float* outPtr = static_cast<float*>(out->block()->mutable_data()); - float alpha = 1.0, beta=1.0; - cudnnDataType_t cudnn_dtype = CUDNN_DATA_FLOAT; - cudnnTensorDescriptor_t in_desc, out_desc; - cudnnCreateTensorDescriptor(&in_desc); - cudnnCreateTensorDescriptor(&out_desc); - cudnnSetTensorNdDescriptor(in_desc, cudnn_dtype, in->generate_dim_cuda(), in->generate_shape_cuda().data(), in->generate_strides_cuda().data()); - cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), out->generate_shape_cuda().data(), out->generate_strides_cuda().data()); - cudnnAddTensor(ctx->cudnn_handle, (void*)(&alpha), in_desc, inPtr, (void*)(&beta), out_desc, outPtr); - - cudnnDestroyTensorDescriptor(in_desc); - cudnnDestroyTensorDescriptor(out_desc); + float alpha = 1.0, beta = 1.0; + cudnnAddTensor(ctx->cudnn_handle, + (void*)(&alpha), generate_tensorND_desc(in), inPtr, + (void*)(&beta), generate_tensorND_desc(out), outPtr + ); } /// out = in1 + in2 @@ -104,34 +106,23 @@ void Add<float, lang::Cuda>(const Tensor* in1, const float* inPtr2 = static_cast<const float*>(in2->block()->data()); float* outPtr = static_cast<float*>(out->block()->mutable_data()); - cudnnOpTensorOp_t op = CUDNN_OP_TENSOR_ADD; - cudnnDataType_t cudnn_dtype = CUDNN_DATA_FLOAT; - cudnnNanPropagation_t cudnn_propagation = CUDNN_PROPAGATE_NAN; - cudnnOpTensorDescriptor_t op_desc; - cudnnCreateOpTensorDescriptor(&op_desc); - cudnnSetOpTensorDescriptor(op_desc, op, cudnn_dtype, cudnn_propagation); - - float alpha1[1] = {1.0}; - float alpha2[1] = {1.0}; - float beta[1] = {0.0}; - cudnnTensorDescriptor_t in1_desc, in2_desc, out_desc; - cudnnCreateTensorDescriptor(&in1_desc); - cudnnCreateTensorDescriptor(&in2_desc); - cudnnCreateTensorDescriptor(&out_desc); - cudnnSetTensorNdDescriptor(in1_desc, cudnn_dtype, in1->generate_dim_cuda(), in1->generate_shape_cuda().data(), in1->generate_strides_cuda().data()); + float alpha1 = 1.0; + float alpha2 = 1.0; + float beta = 0.0; + if((in1->nDim() == in2->nDim()) || (in2->nDim() == 1)){ - cudnnSetTensorNdDescriptor(in2_desc, cudnn_dtype, in2->generate_dim_cuda(), in2->generate_shape_cuda().data(), in2->generate_strides_cuda().data()); + cudnnOpTensor(ctx->cudnn_handle, generate_Op_desc(CUDNN_OP_TENSOR_ADD), + (void*)(&alpha1), generate_tensorND_desc(in1), inPtr1, + (void*)(&alpha2), generate_tensorND_desc(in2), inPtr2, + (void*)(&beta), generate_tensorND_desc(out), outPtr + ); } else { - cudnnSetTensorNdDescriptor(in2_desc, cudnn_dtype, in1->generate_dim_cuda(), in1->generate_shape_cuda().data(), in1->generate_strides_cuda().data()); + cudnnOpTensor(ctx->cudnn_handle, generate_Op_desc(CUDNN_OP_TENSOR_ADD), + (void*)(&alpha1), generate_tensorND_desc(in1), inPtr1, + (void*)(&alpha2), generate_tensorND_desc(in1), inPtr2, + (void*)(&beta), generate_tensorND_desc(out), outPtr + ); } - - cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), out->generate_shape_cuda().data(), out->generate_strides_cuda().data()); - cudnnOpTensor(ctx->cudnn_handle, op_desc, (void*)(alpha1), in1_desc, inPtr1, - (void*)(alpha2), in2_desc, inPtr2, (void*)(beta), out_desc, outPtr); - - cudnnDestroyTensorDescriptor(in1_desc); - cudnnDestroyTensorDescriptor(in2_desc); - cudnnDestroyTensorDescriptor(out_desc); } /// out = in1 - in2 @@ -142,34 +133,23 @@ void Sub<float, lang::Cuda>(const Tensor* in1, const float* inPtr2 = static_cast<const float*>(in2->block()->data()); float* outPtr = static_cast<float*>(out->block()->mutable_data()); - cudnnOpTensorOp_t op = CUDNN_OP_TENSOR_ADD; - cudnnDataType_t cudnn_dtype = CUDNN_DATA_FLOAT; - cudnnNanPropagation_t cudnn_propagation = CUDNN_PROPAGATE_NAN; - cudnnOpTensorDescriptor_t op_desc; - cudnnCreateOpTensorDescriptor(&op_desc); - cudnnSetOpTensorDescriptor(op_desc, op, cudnn_dtype, cudnn_propagation); - - float alpha1[1] = {1.0}; - float alpha2[1] = {-1.0}; - float beta[1] = {0.0}; - cudnnTensorDescriptor_t in1_desc, in2_desc, out_desc; - cudnnCreateTensorDescriptor(&in1_desc); - cudnnCreateTensorDescriptor(&in2_desc); - cudnnCreateTensorDescriptor(&out_desc); - cudnnSetTensorNdDescriptor(in1_desc, cudnn_dtype, in1->generate_dim_cuda(), in1->generate_shape_cuda().data(), in1->generate_strides_cuda().data()); + float alpha1 = 1.0; + float alpha2 = -1.0; + float beta = 0.0; + if((in1->nDim() == in2->nDim()) || (in2->nDim() == 1)){ - cudnnSetTensorNdDescriptor(in2_desc, cudnn_dtype, in2->generate_dim_cuda(), in2->generate_shape_cuda().data(), in2->generate_strides_cuda().data()); + cudnnOpTensor(ctx->cudnn_handle, generate_Op_desc(CUDNN_OP_TENSOR_ADD), + (void*)(&alpha1), generate_tensorND_desc(in1), inPtr1, + (void*)(&alpha2), generate_tensorND_desc(in2), inPtr2, + (void*)(&beta), generate_tensorND_desc(out), outPtr + ); } else { - cudnnSetTensorNdDescriptor(in2_desc, cudnn_dtype, in1->generate_dim_cuda(), in1->generate_shape_cuda().data(), in1->generate_strides_cuda().data()); + cudnnOpTensor(ctx->cudnn_handle, generate_Op_desc(CUDNN_OP_TENSOR_ADD), + (void*)(&alpha1), generate_tensorND_desc(in1), inPtr1, + (void*)(&alpha2), generate_tensorND_desc(in1), inPtr2, + (void*)(&beta), generate_tensorND_desc(out), outPtr + ); } - - cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), out->generate_shape_cuda().data(), out->generate_strides_cuda().data()); - cudnnOpTensor(ctx->cudnn_handle, op_desc, (void*)(alpha1), in1_desc, inPtr1, - (void*)(alpha2), in2_desc, inPtr2, (void*)(beta), out_desc, outPtr); - - cudnnDestroyTensorDescriptor(in1_desc); - cudnnDestroyTensorDescriptor(in2_desc); - cudnnDestroyTensorDescriptor(out_desc); } /// Element-wise operation, clamp every element into [low, high] @@ -193,26 +173,21 @@ void Div<float, lang::Cuda>(const Tensor* in1, float* outPtr = static_cast<float*>(out->block()->mutable_data()); const size_t num = in1->Size(); - if(in1->strides() == in2->strides()){ //if both in1 and in2 strides are the same, we proceed to normal cuda::div + //if both in1 and in2 strides are the same, we proceed to normal cuda::div + if(in1->strides() == in2->strides()){ cuda::div(num, inPtr1, inPtr2, outPtr, ctx->stream); out->Set_Strides(in1->strides()); } else { //else we transform in1 to out to store first - float alpha[1] = {1.0}; - float beta[1] = {0.0}; - - cudnnDataType_t cudnn_dtype = CUDNN_DATA_FLOAT; - cudnnTensorDescriptor_t in1_desc, out_desc; - cudnnCreateTensorDescriptor(&in1_desc); - cudnnCreateTensorDescriptor(&out_desc); - cudnnSetTensorNdDescriptor(in1_desc, cudnn_dtype, in1->generate_dim_cuda(), in1->generate_shape_cuda().data(), in1->generate_strides_cuda().data()); + float alpha = 1.0; + float beta = 0.0; + out->Set_Strides(in2->strides()); - cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), out->generate_shape_cuda().data(), out->generate_strides_cuda().data()); - cudnnTransformTensor(ctx->cudnn_handle, (void*)(alpha), in1_desc, inPtr1, - (void*)(beta), out_desc, outPtr); + cudnnTransformTensor(ctx->cudnn_handle, + (void*)(&alpha), generate_tensorND_desc(in1), inPtr1, + (void*)(&beta), generate_tensorND_desc(out), outPtr + ); cuda::div(num, outPtr, inPtr2, outPtr, ctx->stream); - cudnnDestroyTensorDescriptor(in1_desc); - cudnnDestroyTensorDescriptor(out_desc); } } @@ -234,16 +209,10 @@ void EltwiseMult<float, lang::Cuda>(const Tensor* in, float* outPtr = static_cast<float*>(out->block()->mutable_data()); float alpha = x, beta = 0.0; - cudnnDataType_t cudnn_dtype = CUDNN_DATA_FLOAT; - cudnnTensorDescriptor_t in_desc, out_desc; - cudnnCreateTensorDescriptor(&in_desc); - cudnnCreateTensorDescriptor(&out_desc); - cudnnSetTensorNdDescriptor(in_desc, cudnn_dtype, in->generate_dim_cuda(), in->generate_shape_cuda().data(), in->generate_strides_cuda().data()); - cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), out->generate_shape_cuda().data(), out->generate_strides_cuda().data()); - cudnnAddTensor(ctx->cudnn_handle, (void*)(&alpha), in_desc, inPtr, (void*)(&beta), out_desc, outPtr); - - cudnnDestroyTensorDescriptor(in_desc); - cudnnDestroyTensorDescriptor(out_desc); + cudnnAddTensor(ctx->cudnn_handle, + (void*)(&alpha), generate_tensorND_desc(in), inPtr, + (void*)(&beta), generate_tensorND_desc(out), outPtr + ); } /// out = in1 * in2 @@ -256,27 +225,21 @@ void EltwiseMult<float, lang::Cuda>(const Tensor* in1, float* outPtr = static_cast<float*>(out->block()->mutable_data()); const size_t num = in1->Size(); - if(in1->strides() == in2->strides()){ //if both in1 and in2 strides are the same, we proceed to normal cuda::mult + //if both in1 and in2 strides are the same, we proceed to normal cuda::mult + if(in1->strides() == in2->strides()){ cuda::mult(num, inPtr1, inPtr2, outPtr, ctx->stream); out->Set_Strides(in1->strides()); } else { //else we transform in1 to out to store first - float alpha[1] = {1.0}; - float beta[1] = {0.0}; + float alpha = 1.0; + float beta = 0.0; - - cudnnDataType_t cudnn_dtype = CUDNN_DATA_FLOAT; - cudnnTensorDescriptor_t in1_desc, out_desc; - cudnnCreateTensorDescriptor(&in1_desc); - cudnnCreateTensorDescriptor(&out_desc); - cudnnSetTensorNdDescriptor(in1_desc, cudnn_dtype, in1->generate_dim_cuda(), in1->generate_shape_cuda().data(), in1->generate_strides_cuda().data()); out->Set_Strides(in2->strides()); - cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), out->generate_shape_cuda().data(), out->generate_strides_cuda().data()); - cudnnTransformTensor(ctx->cudnn_handle, (void*)(alpha), in1_desc, inPtr1, - (void*)(beta), out_desc, outPtr); + cudnnTransformTensor(ctx->cudnn_handle, + (void*)(&alpha), generate_tensorND_desc(in1), inPtr1, + (void*)(&beta), generate_tensorND_desc(out), outPtr + ); cuda::mult(num, outPtr, inPtr2, outPtr, ctx->stream); - cudnnDestroyTensorDescriptor(in1_desc); - cudnnDestroyTensorDescriptor(out_desc); } } @@ -404,26 +367,20 @@ void Pow<float, lang::Cuda>(const Tensor* in1, float* outPtr = static_cast<float*>(out->block()->mutable_data()); const size_t num = in1->Size(); - if(in1->strides() == in2->strides()){ //if both in1 and in2 strides are the same, we proceed to normal cuda::pow + if(in1->strides() == in2->strides()){ cuda::pow(num, inPtr1, inPtr2, outPtr, ctx->stream); out->Set_Strides(in1->strides()); } else { //else we transform in1 to out to store first - float alpha[1] = {1.0}; - float beta[1] = {0.0}; - - cudnnDataType_t cudnn_dtype = CUDNN_DATA_FLOAT; - cudnnTensorDescriptor_t in1_desc, out_desc; - cudnnCreateTensorDescriptor(&in1_desc); - cudnnCreateTensorDescriptor(&out_desc); - cudnnSetTensorNdDescriptor(in1_desc, cudnn_dtype, in1->generate_dim_cuda(), in1->generate_shape_cuda().data(), in1->generate_strides_cuda().data()); + float alpha = 1.0; + float beta = 0.0; + out->Set_Strides(in2->strides()); - cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), out->generate_shape_cuda().data(), out->generate_strides_cuda().data()); - cudnnTransformTensor(ctx->cudnn_handle, (void*)(alpha), in1_desc, inPtr1, - (void*)(beta), out_desc, outPtr); + cudnnTransformTensor(ctx->cudnn_handle, + (void*)(&alpha), generate_tensorND_desc(in1), inPtr1, + (void*)(&beta), generate_tensorND_desc(out), outPtr + ); cuda::pow(num, outPtr, inPtr2, outPtr, ctx->stream); - cudnnDestroyTensorDescriptor(in1_desc); - cudnnDestroyTensorDescriptor(out_desc); } } @@ -525,27 +482,16 @@ void Sqrt<float, lang::Cuda>(const Tensor* in, Tensor* out, Context* ctx) { const float* inPtr = static_cast<const float*>(in->block()->data()); float* outPtr = static_cast<float*>(out->block()->mutable_data()); - - cudnnOpTensorOp_t op = CUDNN_OP_TENSOR_SQRT; - cudnnDataType_t cudnn_dtype = CUDNN_DATA_FLOAT; - cudnnNanPropagation_t cudnn_propagation = CUDNN_PROPAGATE_NAN; - cudnnOpTensorDescriptor_t op_desc; - cudnnCreateOpTensorDescriptor(&op_desc); - cudnnSetOpTensorDescriptor(op_desc, op, cudnn_dtype, cudnn_propagation); - float alpha1[1] = {1.0}; - float alpha2[1] = {0.0}; - float beta[1] = {0.0}; - cudnnTensorDescriptor_t in_desc, out_desc; - cudnnCreateTensorDescriptor(&in_desc); - cudnnCreateTensorDescriptor(&out_desc); - cudnnSetTensorNdDescriptor(in_desc, cudnn_dtype, in->generate_dim_cuda(), in->generate_shape_cuda().data(), in->generate_strides_cuda().data()); - cudnnSetTensorNdDescriptor(out_desc, cudnn_dtype, out->generate_dim_cuda(), out->generate_shape_cuda().data(), out->generate_strides_cuda().data()); - cudnnOpTensor(ctx->cudnn_handle, op_desc, (void*)(&alpha1), in_desc, inPtr, - (void*)(&alpha2), in_desc, inPtr, (void*)(&beta), out_desc, outPtr); - - cudnnDestroyTensorDescriptor(in_desc); - cudnnDestroyTensorDescriptor(out_desc); + float alpha1 = 1.0; + float alpha2 = 0.0; + float beta = 0.0; + cudnnTensorDescriptor_t in_desc = generate_tensorND_desc(in); + cudnnOpTensor(ctx->cudnn_handle, generate_Op_desc(CUDNN_OP_TENSOR_SQRT), + (void*)(&alpha1), in_desc, inPtr, + (void*)(&alpha2), in_desc, inPtr, + (void*)(&beta), generate_tensorND_desc(out), outPtr + ); } /// Element-wise operation, out[i]=in[i]^2 @@ -593,30 +539,26 @@ void Sum<float, lang::Cuda>(const Tensor* in, float* out, cudnn_propagation, cudnn_indices, cudnn_indices_type); //instantiate 2 new tensors to use new blocks as memory instead of cudaMalloc - Shape reduction_size = {1000}; + size_t reduction_size_int = Product(in->shape()); + Shape reduction_size = {reduction_size_int*100}; Tensor indices(reduction_size, in->device(), in->data_type()); Tensor workspace(reduction_size, in->device(), in->data_type()); - size_t indices_bytes = indices.block()->size()*1000; - size_t workspace_bytes = workspace.block()->size()*1000; + size_t indices_bytes = indices.block()->size()*100; + size_t workspace_bytes = workspace.block()->size()*100; size_t* indicesPtr = static_cast<size_t*>(indices.block()->mutable_data()); float* workspacePtr = static_cast<float*>(workspace.block()->mutable_data()); //void* indicesPtr{nullptr}; void* workspacePtr{nullptr}; //cudaMalloc(&indicesPtr, indices_bytes); cudaMalloc(&workspacePtr, workspace_bytes); - float alpha[1] = {1.0}; - float beta[1] = {0.0}; - cudnnTensorDescriptor_t in_desc, t_desc; - cudnnCreateTensorDescriptor(&in_desc); - cudnnCreateTensorDescriptor(&t_desc); - cudnnSetTensorNdDescriptor(in_desc, cudnn_dtype, in->generate_dim_cuda(), in->generate_shape_cuda().data(), in->generate_strides_cuda().data()); - cudnnSetTensorNdDescriptor(t_desc, cudnn_dtype, t.generate_dim_cuda(), reduce_all_axes.data(), reduce_all_axes.data()); + float alpha = 1.0; + float beta = 0.0; cudnnReduceTensor(ctx->cudnn_handle, reduce_desc, indicesPtr, indices_bytes, workspacePtr, workspace_bytes, - (void*)(&alpha), in_desc, inPtr, (void*)(&beta), t_desc, tPtr); + (void*)(&alpha), generate_tensorND_desc(in), inPtr, + (void*)(&beta), generate_tensorND_desc(&t), tPtr + ); *out = tPtr[0]; - cudnnDestroyTensorDescriptor(in_desc); - cudnnDestroyTensorDescriptor(t_desc); } @@ -922,22 +864,17 @@ void RowMax<float, lang::Cuda>(const Tensor* in, Tensor* out, if(in->transpose()){ Tensor t(in->shape(), in->device(), in->data_type()); float* tPtr = static_cast<float*>(t.block()->mutable_data()); - float alpha[1] = {1.0}; - float beta[1] = {0.0}; - - cudnnDataType_t cudnn_dtype = CUDNN_DATA_FLOAT; - cudnnTensorDescriptor_t in_desc, t_desc; - cudnnCreateTensorDescriptor(&in_desc); - cudnnCreateTensorDescriptor(&t_desc); - cudnnSetTensorNdDescriptor(in_desc, cudnn_dtype, in->generate_dim_cuda(), in->generate_shape_cuda().data(), in->generate_strides_cuda().data()); - cudnnSetTensorNdDescriptor(t_desc, cudnn_dtype, t.generate_dim_cuda(), t.generate_shape_cuda().data(), t.generate_strides_cuda().data()); - cudnnTransformTensor(ctx->cudnn_handle, (void*)(alpha), in_desc, inPtr, - (void*)(beta), t_desc, tPtr); + + float alpha = 1.0; + float beta = 0.0; + + cudnnTransformTensor(ctx->cudnn_handle, + (void*)(&alpha), generate_tensorND_desc(in), inPtr, + (void*)(&beta), generate_tensorND_desc(&t), tPtr + ); const float* tPtr_const = static_cast<const float*>(t.block()->data()); cuda::RowMax(nrow, ncol, tPtr_const, outPtr, ctx->stream); - cudnnDestroyTensorDescriptor(in_desc); - cudnnDestroyTensorDescriptor(t_desc); } else { cuda::RowMax(nrow, ncol, inPtr, outPtr, ctx->stream); }
