[SYSTEMML-1969] Support single-precision operations on GPU backend - Since single-precision operations are faster on most GPU, we should allow our users to perform the instructions on GPU in single precision. - The GPU backend has been refactored to support arbitrary precision. - This feature can be enabled via configuration property sysml.floating.point.precision. - The valid values for this property are double and float. We can support half/mixed precision in a separate PR.
Closes #688. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/abbffc55 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/abbffc55 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/abbffc55 Branch: refs/heads/master Commit: abbffc55ef8f47f10b6e59b0ae5e1f311f4a8f3e Parents: 881caa9 Author: Niketan Pansare <npan...@us.ibm.com> Authored: Wed Oct 25 19:25:20 2017 -0700 Committer: Niketan Pansare <npan...@us.ibm.com> Committed: Wed Oct 25 19:26:50 2017 -0700 ---------------------------------------------------------------------- conf/SystemML-config.xml.template | 3 + src/main/cpp/kernels/SystemML.cu | 1874 ++-- src/main/cpp/kernels/SystemML.ptx | 9579 ++++++++++++++---- .../java/org/apache/sysml/api/DMLScript.java | 1 + .../apache/sysml/api/ScriptExecutorUtils.java | 4 + .../java/org/apache/sysml/conf/DMLConfig.java | 4 +- .../controlprogram/caching/CacheableData.java | 4 +- .../instructions/gpu/context/CSRPointer.java | 52 +- .../instructions/gpu/context/GPUContext.java | 31 +- .../instructions/gpu/context/GPUObject.java | 91 +- .../instructions/gpu/context/JCudaKernels.java | 9 +- .../matrix/data/CudaSupportFunctions.java | 87 + .../DoublePrecisionCudaSupportFunctions.java | 175 + .../runtime/matrix/data/LibMatrixCUDA.java | 144 +- .../runtime/matrix/data/LibMatrixCuDNN.java | 38 +- .../LibMatrixCuDNNConvolutionAlgorithm.java | 5 +- .../data/LibMatrixCuDNNInputRowFetcher.java | 8 +- .../data/LibMatrixCuDNNPoolingDescriptors.java | 3 +- .../runtime/matrix/data/LibMatrixCuMatMult.java | 34 +- .../sysml/runtime/matrix/data/MatrixBlock.java | 5 +- .../SinglePrecisionCudaSupportFunctions.java | 208 + .../org/apache/sysml/test/gpu/GPUTests.java | 20 +- .../test/gpu/MatrixMultiplicationOpTest.java | 22 +- 23 files changed, 9423 insertions(+), 2978 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/abbffc55/conf/SystemML-config.xml.template ---------------------------------------------------------------------- diff --git a/conf/SystemML-config.xml.template b/conf/SystemML-config.xml.template index 511e215..8452e75 100644 --- a/conf/SystemML-config.xml.template +++ b/conf/SystemML-config.xml.template @@ -93,6 +93,9 @@ <!-- whether to perform eager CUDA free on rmvar instruction --> <sysml.gpu.eager.cudaFree>false</sysml.gpu.eager.cudaFree> + <!-- the floating point precision. supported values are double, single --> + <sysml.floating.point.precision>double</sysml.floating.point.precision> + <!-- maximum wrap length for instruction and miscellaneous timer column of statistics --> <sysml.stats.maxWrapLength>30</sysml.stats.maxWrapLength> </root> http://git-wip-us.apache.org/repos/asf/systemml/blob/abbffc55/src/main/cpp/kernels/SystemML.cu ---------------------------------------------------------------------- diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu index c243564..d176f8f 100644 --- a/src/main/cpp/kernels/SystemML.cu +++ b/src/main/cpp/kernels/SystemML.cu @@ -26,11 +26,28 @@ nvcc -ptx -arch=sm_30 SystemML.cu #include <cfloat> #include <cmath> +extern "C" __global__ void double2float_f(double *A, float *ret, int N) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < N) { + // TODO: Use __double2float_rd or __double2float_rn or __double2float_ru or __double2float_rz after + ret[tid] = (float)A[tid]; + } +} + +extern "C" __global__ void float2double_f(float *A, double *ret, int N) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < N) { + ret[tid] = (double)A[tid]; + } +} + /** - * Performs a slice operation where the input matrix is sparse and the output matrix is dense. - * This function avoids unnecessary sparse to dense conversion of the input matrix. + * Performs a slice operation where the input matrix is sparse and the output + * matrix is dense. + * This function avoids unnecessary sparse to dense conversion of the input + * matrix. * Parallelization: rows of output matrix. - * + * * @params inVal input val pointer * @params inRowPtr input row pointer * @params colInd input col index pointer @@ -41,49 +58,73 @@ nvcc -ptx -arch=sm_30 SystemML.cu * @param cu column upper * @param retClen number of columns of output matrix */ -extern "C" -__global__ void slice_sparse_dense_row(double* inVal, int* inRowPtr, int* colInd, double* ret, - int rl, int ru, int cl, int cu, int retClen) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int rowIndex = index + rl; - if (rowIndex <= ru){ - /* - * TODO: Alternative approach: use dynamic parallelism. We are skipping this for now to avoid - * the complexity of two-step separate compilation and linking process. - * - * extern "C" - * __global__ void slice_sparse_dense_row_helper(double* inVal, int* inRowPtr, int* colInd, double* ret, - * int rl, int ru, int cl, int cu, int retClen, int start, int end, int index) { - * int i = blockIdx.x * blockDim.x + threadIdx.x + start; - * // Only slice if the index falls into the given range - * if(i < end && cl <= colInd[i] && colInd[i] <= cu) { - * ret[ index*retClen + (colInd[i] - cl) ] = inVal[i]; - * } - * } - * - * int size = inRowPtr[rowIndex+1] - inRowPtr[rowIndex]; - * double numThreads = (double)min(size, MAX_NUM_THREADS_CHILD_KERNEL); - * slice_sparse_dense_row_helper<<< ceil(numThreads/ MAX_NUM_THREADS_CHILD_KERNEL), MAX_NUM_THREADS_CHILD_KERNEL>>>(inVal, inRowPtr, colInd, ret, - * rl, ru, cl, cu, retClen, inRowPtr[rowIndex], inRowPtr[rowIndex+1], index); - * - * Two-step compilation and linking process in JCudaKernels's constructor: - * cuLinkAddFile(linkState, CUjitInputType.CU_JIT_INPUT_LIBRARY, "/usr/local/cuda/lib64/libcudadevrt.a", jitOptions); - */ - // Iterate over elements of the row 'rowIndex'. - for(int i = inRowPtr[rowIndex]; i < inRowPtr[rowIndex+1]; i++) { - // Only slice if the index falls into the given range - if(cl <= colInd[i] && colInd[i] <= cu) { - ret[ index*retClen + (colInd[i] - cl) ] = inVal[i]; - } - } +template <typename T> +__device__ void slice_sparse_dense_row(T *inVal, int *inRowPtr, int *colInd, + T *ret, int rl, int ru, int cl, int cu, + int retClen) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int rowIndex = index + rl; + if (rowIndex <= ru) { + /* + * TODO: Alternative approach: use dynamic parallelism. We are skipping this +*for now to avoid + * the complexity of two-step separate compilation and linking process. + * + * extern "C" + * __global__ void slice_sparse_dense_row_helper(double* inVal, int* +*inRowPtr, int* colInd, double* ret, + * int rl, int ru, int cl, int cu, int retClen, int start, int end, int +*index) { + * int i = blockIdx.x * blockDim.x + threadIdx.x + start; + * // Only slice if the index falls into the given range + * if(i < end && cl <= colInd[i] && colInd[i] <= cu) { + * ret[ index*retClen + (colInd[i] - cl) ] = inVal[i]; + * } + * } + * + * int size = inRowPtr[rowIndex+1] - inRowPtr[rowIndex]; + * double numThreads = (double)min(size, MAX_NUM_THREADS_CHILD_KERNEL); + * slice_sparse_dense_row_helper<<< ceil(numThreads/ +*MAX_NUM_THREADS_CHILD_KERNEL), MAX_NUM_THREADS_CHILD_KERNEL>>>(inVal, inRowPtr, +*colInd, ret, +* rl, ru, cl, cu, retClen, inRowPtr[rowIndex], +*inRowPtr[rowIndex+1], index); +* +* Two-step compilation and linking process in JCudaKernels's constructor: +* cuLinkAddFile(linkState, CUjitInputType.CU_JIT_INPUT_LIBRARY, +*"/usr/local/cuda/lib64/libcudadevrt.a", jitOptions); + */ + // Iterate over elements of the row 'rowIndex'. + for (int i = inRowPtr[rowIndex]; i < inRowPtr[rowIndex + 1]; i++) { + // Only slice if the index falls into the given range + if (cl <= colInd[i] && colInd[i] <= cu) { + ret[index * retClen + (colInd[i] - cl)] = inVal[i]; + } } + } +} + +extern "C" __global__ void slice_sparse_dense_row_d(double *inVal, int *inRowPtr, + int *colInd, double *ret, + int rl, int ru, int cl, + int cu, int retClen) { + slice_sparse_dense_row(inVal, inRowPtr, colInd, ret, rl, ru, cl, cu, retClen); +} + +extern "C" __global__ void slice_sparse_dense_row_f(float *inVal, int *inRowPtr, + int *colInd, float *ret, + int rl, int ru, int cl, + int cu, int retClen) { + slice_sparse_dense_row(inVal, inRowPtr, colInd, ret, rl, ru, cl, cu, retClen); } /** - * Performs a slice operation where the input matrix is sparse and the output matrix is dense. - * This function avoids unnecessary sparse to dense conversion of the input matrix. + * Performs a slice operation where the input matrix is sparse and the output + * matrix is dense. + * This function avoids unnecessary sparse to dense conversion of the input + * matrix. * Parallelization: subset of number of non-zeroes of input matrix. - * + * * @params inVal input val pointer * @params inRowPtr input row pointer * @params colInd input col index pointer @@ -94,26 +135,42 @@ __global__ void slice_sparse_dense_row(double* inVal, int* inRowPtr, int* colInd * @param cu column upper * @param retClen number of columns of output matrix */ -extern "C" -__global__ void slice_sparse_dense_nnz(double* inVal, int* inRowPtr, int* colInd, double* ret, - int rl, int ru, int cl, int cu, int retClen) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int i = tid + inRowPtr[rl]; - - // Only slice if the index falls into the given range - if(i < inRowPtr[ru+1] && cl <= colInd[i] && colInd[i] <= cu) { - // Find the row index for corresponding non-zero value 'i'. - int rowIndex = rl; - while(inRowPtr[rowIndex+1] <= i) { - rowIndex++; - } - ret[ (rowIndex-rl)*retClen + (colInd[i] - cl) ] = inVal[i]; +template <typename T> +__device__ void slice_sparse_dense_nnz(T *inVal, int *inRowPtr, int *colInd, + T *ret, int rl, int ru, int cl, int cu, + int retClen) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int i = tid + inRowPtr[rl]; + + // Only slice if the index falls into the given range + if (i < inRowPtr[ru + 1] && cl <= colInd[i] && colInd[i] <= cu) { + // Find the row index for corresponding non-zero value 'i'. + int rowIndex = rl; + while (inRowPtr[rowIndex + 1] <= i) { + rowIndex++; } + ret[(rowIndex - rl) * retClen + (colInd[i] - cl)] = inVal[i]; + } +} + +extern "C" __global__ void slice_sparse_dense_nnz_d(double *inVal, int *inRowPtr, + int *colInd, double *ret, + int rl, int ru, int cl, + int cu, int retClen) { + slice_sparse_dense_nnz(inVal, inRowPtr, colInd, ret, rl, ru, cl, cu, retClen); +} + +extern "C" __global__ void slice_sparse_dense_nnz_f(float *inVal, int *inRowPtr, + int *colInd, float *ret, + int rl, int ru, int cl, + int cu, int retClen) { + slice_sparse_dense_nnz(inVal, inRowPtr, colInd, ret, rl, ru, cl, cu, retClen); } /** - * Performs a slice operation where the input matrix is dense and the output matrix is dense. - * + * Performs a slice operation where the input matrix is dense and the output + * matrix is dense. + * * @params in dense input pointer * @params ret dense output pointer * @param rl row lower @@ -124,17 +181,31 @@ __global__ void slice_sparse_dense_nnz(double* inVal, int* inRowPtr, int* colInd * @param retRlen number of rows of output matrix * @param retClen number of columns of output matrix */ -extern "C" -__global__ void slice_dense_dense(double* in, double* ret, int rl, int ru, int cl, int cu, int inClen, int retRlen, int retClen) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int ix = tid / retClen; - int iy = tid % retClen; - if(ix < retRlen && iy < retClen) { - int inIndex = (ix + rl)*inClen + cl + iy; - ret[tid] = in[inIndex]; - } +template <typename T> +__device__ void slice_dense_dense(T *in, T *ret, int rl, int ru, int cl, int cu, + int inClen, int retRlen, int retClen) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int ix = tid / retClen; + int iy = tid % retClen; + if (ix < retRlen && iy < retClen) { + int inIndex = (ix + rl) * inClen + cl + iy; + ret[tid] = in[inIndex]; + } +} + +extern "C" __global__ void slice_dense_dense_d(double *in, double *ret, int rl, + int ru, int cl, int cu, + int inClen, int retRlen, + int retClen) { + slice_dense_dense(in, ret, rl, ru, cl, cu, inClen, retRlen, retClen); } +extern "C" __global__ void slice_dense_dense_f(float *in, float *ret, int rl, + int ru, int cl, int cu, + int inClen, int retRlen, + int retClen) { + slice_dense_dense(in, ret, rl, ru, cl, cu, inClen, retRlen, retClen); +} /** * Does a copy of upper to lower triangle of the given matrix @@ -142,95 +213,161 @@ __global__ void slice_dense_dense(double* in, double* ret, int rl, int ru, int c * @param dim the number of rows of the square matrix ret * @param N total number of elements of the matrix */ -extern "C" -__global__ void copy_u2l_dense(double* ret, int dim, int N) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int ix = tid / dim; - int iy = tid % dim; - int id_dest = iy * dim + ix; - if(iy > ix && id_dest < N) { - // TODO: Potential to reduce the number of threads by half - int id_src = tid; - ret[id_dest] = ret[id_src]; - } -} - -extern "C" -__forceinline__ __device__ double getBoolean(int val) { - if(val == 0) - return 0.0; - else - return 1.0; +template <typename T> +__device__ void copy_u2l_dense(T *ret, int dim, int N) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int ix = tid / dim; + int iy = tid % dim; + int id_dest = iy * dim + ix; + if (iy > ix && id_dest < N) { + // TODO: Potential to reduce the number of threads by half + int id_src = tid; + ret[id_dest] = ret[id_src]; + } +} + +extern "C" __global__ void copy_u2l_dense_d(double *ret, int dim, int N) { + copy_u2l_dense(ret, dim, N); +} + +extern "C" __global__ void copy_u2l_dense_f(float *ret, int dim, int N) { + copy_u2l_dense(ret, dim, N); +} + +// Use this method in templates to fetch the maximum value for a given datatype +template <typename T> +__forceinline__ __device__ T T_MAX(T x) { + return (T)DBL_MAX; +} +template <> +__forceinline__ __device__ float T_MAX(float x) { + return FLT_MAX; +} +template <> +__forceinline__ __device__ double T_MAX(double x) { + return DBL_MAX; } // op = {0=plus, 1=minus, 2=multiply, 3=divide, 4=power, // 5=less, 6=lessequal, 7=greater, 8=greaterequal, 9=equal, 10=notequal, // 11=min, 12=max, 13=and, 14=or, 15=minus1multiply, 16=minusnz, // 17=modulus, 18=integer division} -extern "C" -__forceinline__ __device__ double binaryOp(double x, double y, int op) { - switch(op) { - case 0 : return x + y; - case 1 : return x - y; - case 2 : return x * y; - case 3 : return x / y; - case 4 : return pow(x, y); - case 5 : return getBoolean(x < y); - case 6 : return getBoolean(x <= y); - case 7 : return getBoolean(x > y); - case 8 : return getBoolean(x >= y); - case 9 : return getBoolean(x == y); - case 10 : return getBoolean(x != y); - case 11 : return min(x, y); - case 12 : return max(x, y); - case 13 : return getBoolean((int)llrint(x) & (int)llrint(y)); - case 14 : return getBoolean((int)llrint(x) | (int)llrint(y)); - case 15 : return 1 - x * y; - case 16 : return (x != 0.0 ? x - y : 0.0); - case 17 : { - if (y == 0.0 || y == -0.0){ - return nan(""); - } - double v = x / y; - // Check for v being NaN (v != v) or if it is infinity - if (isnan(v) || isinf(v)){ - return v; - } else { - v = floor(v); - } - return x - v * y; - } - case 18:{ - double v = x / y; - if (isnan(v) || isinf(v)){ - return v; - } else { - return floor(v); - } - } - default : return DBL_MAX; +template <typename T> +__forceinline__ __device__ T binaryOp(T x, T y, int op) { + switch (op) { + case 0: + return x + y; + case 1: + return x - y; + case 2: + return x * y; + case 3: + return x / y; + case 4: + return pow(x, y); + case 5: + return (x < y) == 0 ? 0.0 : 1.0; + case 6: + return (x <= y) == 0 ? 0.0 : 1.0; + case 7: + return (x > y) == 0 ? 0.0 : 1.0; + case 8: + return (x >= y) == 0 ? 0.0 : 1.0; + case 9: + return (x == y) == 0 ? 0.0 : 1.0; + case 10: + return (x != y) == 0 ? 0.0 : 1.0; + case 11: + return min(x, y); + case 12: + return max(x, y); + case 13: + return ((int)llrint(x) & (int)llrint(y)) == 0 ? 0.0 : 1.0; + case 14: + return ((int)llrint(x) | (int)llrint(y)) == 0 ? 0.0 : 1.0; + case 15: + return 1 - x * y; + case 16: + return (x != 0.0 ? x - y : 0.0); + case 17: { + if (y == 0.0 || y == -0.0) { + return nan(""); + } + T v = x / y; + // Check for v being NaN (v != v) or if it is infinity + if (isnan(v) || isinf(v)) { + return v; + } else { + v = floor(v); + } + return x - v * y; } + case 18: { + T v = x / y; + if (isnan(v) || isinf(v)) { + return v; + } else { + return floor(v); + } + } + default: + return T_MAX(x); + } } -extern "C" -__global__ void relu(double* A, double* ret, int rlen, int clen) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int ix = tid / clen; - int iy = tid % clen; - if(ix < rlen && iy < clen) { - ret[tid] = max(0.0, A[tid]); - } +/** + * Performs forward pass for relu: ret = max(A, 0) + * + * @param A input array allocated on the GPU + * @param ret output array allocated on the GPU + * @param rlen the number of rows + * @param clen the number of columns + */ +template <typename T> +__device__ void relu(T *A, T *ret, int rlen, int clen) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int ix = tid / clen; + int iy = tid % clen; + if (ix < rlen && iy < clen) { + ret[tid] = max(0.0, A[tid]); + } +} + +extern "C" __global__ void relu_d(double *A, double *ret, int rlen, int clen) { + relu(A, ret, rlen, clen); } -// This method computes the backpropagation errors for previous layer of relu operation -extern "C" -__global__ void relu_backward(double* X, double* dout, double* ret, int rlen, int clen) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int ix = tid / clen; - int iy = tid % clen; - if(ix < rlen && iy < clen) { - ret[tid] = X[tid] > 0 ? dout[tid] : 0; - } +extern "C" __global__ void relu_f(float *A, float *ret, int rlen, int clen) { + relu(A, ret, rlen, clen); +} + +/** + * This method computes the backpropagation errors for previous layer of relu operation + * + * @param X input activation array allocated on the GPU + * @param dout errors from previous layer + * @param ret output array allocated on the GPU + * @param rlen the number of rows + * @param clen the number of columns + */ +template <typename T> +__device__ void relu_backward(T *X, T *dout, T *ret, int rlen, int clen) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int ix = tid / clen; + int iy = tid % clen; + if (ix < rlen && iy < clen) { + ret[tid] = X[tid] > 0 ? dout[tid] : 0; + } +} + +extern "C" __global__ void relu_backward_d(double *X, double *dout, double *ret, + int rlen, int clen) { + relu_backward(X, dout, ret, rlen, clen); +} + +extern "C" __global__ void relu_backward_f(float *X, float *dout, float *ret, + int rlen, int clen) { + relu_backward(X, dout, ret, rlen, clen); } /** @@ -241,81 +378,113 @@ __global__ void relu_backward(double* X, double* dout, double* ret, int rlen, i * @param rlen the number of rows * @param clen the number of columns */ -extern "C" -__global__ void inplace_add(double* input, double* ret, int rlen, int clen) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int ix = tid / clen; - int iy = tid % clen; - if(ix < rlen && iy < clen) { - ret[tid] += input[tid]; - } +template <typename T> +__device__ void inplace_add(T *input, T *ret, int rlen, int clen) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int ix = tid / clen; + int iy = tid % clen; + if (ix < rlen && iy < clen) { + ret[tid] += input[tid]; + } +} + +extern "C" __global__ void inplace_add_d(double *input, double *ret, int rlen, + int clen) { + inplace_add(input, ret, rlen, clen); +} + +extern "C" __global__ void inplace_add_f(float *input, float *ret, int rlen, + int clen) { + inplace_add(input, ret, rlen, clen); } // Performs the operation corresponding to the DML script: // ones = matrix(1, rows=1, cols=Hout*Wout) // output = input + matrix(bias %*% ones, rows=1, cols=F*Hout*Wout) -// This operation is often followed by conv2d and hence we have introduced bias_add(input, bias) built-in function -extern "C" -__global__ void bias_add(double* input, double* bias, double* ret, int rlen, int clen, int PQ) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int ix = tid / clen; - int iy = tid % clen; - if(ix < rlen && iy < clen) { - int biasIndex = iy / PQ; - ret[tid] = input[tid] + bias[biasIndex]; - } +// This operation is often followed by conv2d and hence we have introduced +// bias_add(input, bias) built-in function +template <typename T> +__device__ void bias_add(T *input, T *bias, T *ret, int rlen, int clen, + int PQ) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int ix = tid / clen; + int iy = tid % clen; + if (ix < rlen && iy < clen) { + int biasIndex = iy / PQ; + ret[tid] = input[tid] + bias[biasIndex]; + } +} + +extern "C" __global__ void bias_add_d(double *input, double *bias, double *ret, + int rlen, int clen, int PQ) { + bias_add(input, bias, ret, rlen, clen, PQ); +} + +extern "C" __global__ void bias_add_f(float *input, float *bias, float *ret, + int rlen, int clen, int PQ) { + bias_add(input, bias, ret, rlen, clen, PQ); } // Performs the operation "ret <- A + alpha*B", where B is a vector -extern "C" -__global__ void daxpy_matrix_vector(double* A, double* B, double alpha, double* ret, int rlenA, int clenA, int rlenB, int clenB) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int ix = tid / clenA; - int iy = tid % clenA; - if(ix < rlenA && iy < clenA) { - int index = ix * clenA + iy; - if(rlenB == 1) { - ret[index] = A[index] + alpha*B[iy]; - } - else { - ret[index] = A[index] + alpha*B[ix]; - } - } -} - -// Performs similar operation as bias_add except elementwise multiplication instead of add -extern "C" -__global__ void bias_multiply(double* input, double* bias, double* ret, int rlen, int clen, int PQ) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int ix = tid / clen; - int iy = tid % clen; - if(ix < rlen && iy < clen) { - int biasIndex = iy / PQ; - ret[tid] = input[tid] * bias[biasIndex]; - } -} - -// Compares the value and set -extern "C" -__global__ void compare_and_set(double* A, double* ret, int rlen, int clen, double compareVal, double tol, double ifEqualsVal, double ifLessThanVal, double ifGreaterThanVal) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int ix = tid / clen; - int iy = tid % clen; - int index = ix * clen + iy; - if(ix < rlen && iy < clen) { - if(abs(A[index]-compareVal) < tol) - ret[index] = ifEqualsVal; - else if(A[index] < compareVal) - ret[index] = ifLessThanVal; - else - ret[index] = ifGreaterThanVal; - } +template <typename T> +__device__ void daxpy_matrix_vector(T *A, T *B, double alpha, T *ret, int rlenA, + int clenA, int rlenB, int clenB) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int ix = tid / clenA; + int iy = tid % clenA; + if (ix < rlenA && iy < clenA) { + int index = ix * clenA + iy; + if (rlenB == 1) { + ret[index] = A[index] + alpha * B[iy]; + } else { + ret[index] = A[index] + alpha * B[ix]; + } + } } +extern "C" __global__ void daxpy_matrix_vector_d(double *A, double *B, + double alpha, double *ret, + int rlenA, int clenA, int rlenB, + int clenB) { + daxpy_matrix_vector(A, B, alpha, ret, rlenA, clenA, rlenB, clenB); +} + +extern "C" __global__ void daxpy_matrix_vector_f(float *A, float *B, + double alpha, float *ret, + int rlenA, int clenA, int rlenB, + int clenB) { + daxpy_matrix_vector(A, B, alpha, ret, rlenA, clenA, rlenB, clenB); +} + +// Performs similar operation as bias_add except elementwise multiplication +// instead of add +template <typename T> +__device__ void bias_multiply(T *input, T *bias, T *ret, int rlen, int clen, + int PQ) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int ix = tid / clen; + int iy = tid % clen; + if (ix < rlen && iy < clen) { + int biasIndex = iy / PQ; + ret[tid] = input[tid] * bias[biasIndex]; + } +} + +extern "C" __global__ void bias_multiply_d(double *input, double *bias, + double *ret, int rlen, int clen, + int PQ) { + bias_multiply(input, bias, ret, rlen, clen, PQ); +} + +extern "C" __global__ void bias_multiply_f(float *input, float *bias, float *ret, + int rlen, int clen, int PQ) { + bias_multiply(input, bias, ret, rlen, clen, PQ); +} /** * Performs a binary cellwise arithmetic operation on 2 matrices. - * Either both matrices are of equal size or one of them is a vector or both are. + * Either both matrices are of equal size or one of them is a vector or both + * are. * @param A first input matrix allocated on GPU * @param B second input matrix allocated on GPU * @param C output allocated on GPU @@ -323,37 +492,55 @@ __global__ void compare_and_set(double* A, double* ret, int rlen, int clen, dou * @param maxClen maximum of the column lengths of A and B * @param vectorAStatus if A is a row vector, column vector or neither * @param vectorBStatus if B is a row vector, column vector or neither - * @param op the numeric code of the arithmetic operation to perform + * @param op the numeric code of the arithmetic operation to + * perform * */ -extern "C" -__global__ void matrix_matrix_cellwise_op(double* A, double* B, double* C, - int maxRlen, int maxClen, int vectorAStatus, int vectorBStatus, int op) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int ix = tid / maxClen; - int iy = tid % maxClen; - - if(ix < maxRlen && iy < maxClen) { - int outIndex = ix * maxClen + iy; - int aIndex = outIndex; - int bIndex = outIndex; - if(vectorAStatus == 1) - aIndex = ix; // clen == 1 - else if(vectorAStatus == 2) - aIndex = iy; // rlen == 1 - if(vectorBStatus == 1) - bIndex = ix; // clen == 1 - else if(vectorBStatus == 2) - bIndex = iy; // rlen == 1 - C[outIndex] = binaryOp(A[aIndex], B[bIndex], op); - //printf("C[%d] = A[%d](%f) B[%d](%f) (%d %d)\n", outIndex, aIndex, A[aIndex], bIndex, B[bIndex], (ix+1), (iy+1)); - __syncthreads(); - } +template <typename T> +__device__ void matrix_matrix_cellwise_op(T *A, T *B, T *C, int maxRlen, + int maxClen, int vectorAStatus, + int vectorBStatus, int op) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int ix = tid / maxClen; + int iy = tid % maxClen; + + if (ix < maxRlen && iy < maxClen) { + int outIndex = ix * maxClen + iy; + int aIndex = outIndex; + int bIndex = outIndex; + if (vectorAStatus == 1) + aIndex = ix; // clen == 1 + else if (vectorAStatus == 2) + aIndex = iy; // rlen == 1 + if (vectorBStatus == 1) + bIndex = ix; // clen == 1 + else if (vectorBStatus == 2) + bIndex = iy; // rlen == 1 + C[outIndex] = binaryOp(A[aIndex], B[bIndex], op); + // printf("C[%d] = A[%d](%f) B[%d](%f) (%d %d)\n", outIndex, aIndex, + // A[aIndex], bIndex, B[bIndex], (ix+1), (iy+1)); + __syncthreads(); + } +} + +extern "C" __global__ void matrix_matrix_cellwise_op_d( + double *A, double *B, double *C, int maxRlen, int maxClen, + int vectorAStatus, int vectorBStatus, int op) { + matrix_matrix_cellwise_op(A, B, C, maxRlen, maxClen, vectorAStatus, + vectorBStatus, op); +} + +extern "C" __global__ void matrix_matrix_cellwise_op_f( + float *A, float *B, float *C, int maxRlen, int maxClen, int vectorAStatus, + int vectorBStatus, int op) { + matrix_matrix_cellwise_op(A, B, C, maxRlen, maxClen, vectorAStatus, + vectorBStatus, op); } /** * Performs an arithmetic operation between a matrix and a scalar. - * C = s op A or C = A op s (where A is the matrix, s is the scalar and op is the operation) + * C = s op A or C = A op s (where A is the matrix, s is the scalar and op is + * the operation) * @param A input matrix allocated on GPU * @param scalar scalar input * @param C output matrix allocated on GPU @@ -361,32 +548,53 @@ __global__ void matrix_matrix_cellwise_op(double* A, double* B, double* C, * @param op number code of the arithmetic operation to perform * @param isLeftScalar whether the scalar is on the left side */ -extern "C" -__global__ void matrix_scalar_op(double* A, double scalar, double* C, int size, int op, int isLeftScalar) { - int index = blockIdx.x *blockDim.x + threadIdx.x; - if(index < size) { - if(isLeftScalar) { - C[index] = binaryOp(scalar, A[index], op); - } else { - C[index] = binaryOp(A[index], scalar, op); - } - } - __syncthreads(); +template <typename T> +__device__ void matrix_scalar_op(T *A, T scalar, T *C, int size, int op, + int isLeftScalar) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + if (isLeftScalar) { + C[index] = binaryOp(scalar, A[index], op); + } else { + C[index] = binaryOp(A[index], scalar, op); + } + } + __syncthreads(); } +extern "C" __global__ void matrix_scalar_op_d(double *A, double scalar, + double *C, int size, int op, + int isLeftScalar) { + matrix_scalar_op(A, scalar, C, size, op, isLeftScalar); +} + +extern "C" __global__ void matrix_scalar_op_f(float *A, double scalar, float *C, + int size, int op, + int isLeftScalar) { + matrix_scalar_op(A, (float)scalar, C, size, op, isLeftScalar); +} /** - * Sets all elements (fills) of a double array of given length with a given scalar value + * Sets all elements (fills) of a double array of given length with a given + * scalar value * @param A array to be filled * @param scalar value to fill array with * @param lenA length of array A */ -extern "C" -__global__ void fill(double* A, double scalar, int lenA) { +template <typename T> +__device__ void fill(T *A, T scalar, int lenA) { int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < lenA){ - A[index] = scalar; - } + if (index < lenA) { + A[index] = scalar; + } +} + +extern "C" __global__ void fill_d(double *A, double scalar, int lenA) { + fill(A, scalar, lenA); +} + +extern "C" __global__ void fill_f(float *A, double scalar, int lenA) { + fill(A, (float)scalar, lenA); } /** @@ -402,29 +610,39 @@ __global__ void fill(double* A, double scalar, int lenA) { * @param rowsB rows in B * @param colsB columns in B */ -extern "C" -__global__ void cbind(double *A, double *B, double *C, int rowsA, int colsA, int rowsB, int colsB) { - int maxClen = max(colsA, colsB); - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int ix = tid / maxClen; - int iy = tid % maxClen; - - int colsC = colsA + colsB; - int rowsC = rowsA; - - // Copy an element of A into C into the appropriate location - if (ix < rowsA && iy < colsA) { - double elemA = A[ix * colsA + iy]; - C[ix * colsC + iy] = elemA; - } +template <typename T> +__device__ void cbind(T *A, T *B, T *C, int rowsA, int colsA, int rowsB, + int colsB) { + int maxClen = max(colsA, colsB); + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int ix = tid / maxClen; + int iy = tid % maxClen; + + int colsC = colsA + colsB; + int rowsC = rowsA; + + // Copy an element of A into C into the appropriate location + if (ix < rowsA && iy < colsA) { + T elemA = A[ix * colsA + iy]; + C[ix * colsC + iy] = elemA; + } + + // Copy an element of B into C into the appropriate location + if (ix < rowsB && iy < colsB) { + T elemB = B[ix * colsB + iy]; + C[ix * colsC + (iy + colsA)] = elemB; + } +} - // Copy an element of B into C into the appropriate location - if (ix < rowsB && iy < colsB) { - double elemB = B[ix * colsB + iy]; - C[ix * colsC + (iy + colsA)] = elemB; - } +extern "C" __global__ void cbind_d(double *A, double *B, double *C, int rowsA, + int colsA, int rowsB, int colsB) { + cbind(A, B, C, rowsA, colsA, rowsB, colsB); } +extern "C" __global__ void cbind_f(float *A, float *B, float *C, int rowsA, + int colsA, int rowsB, int colsB) { + cbind(A, B, C, rowsA, colsA, rowsB, colsB); +} /** * Appends Matrix B to the bottom of Matrix A into a new matrix C @@ -441,176 +659,263 @@ __global__ void cbind(double *A, double *B, double *C, int rowsA, int colsA, int * @param rowsB rows in B * @param colsB columns in B */ -extern "C" -__global__ void rbind(double *A, double *B, double *C, int rowsA, int colsA, int rowsB, int colsB) { - int maxClen = max(colsA, colsB); - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int ix = tid / maxClen; - int iy = tid % maxClen; - - int rowsC = rowsA + rowsB; - int colsC = colsA; - - // Copy an element of A into C into the appropriate location - if (ix < rowsA && iy < colsA) { - double elemA = A[ix * colsA + iy]; - C[ix * colsC + iy] = elemA; - } +template <typename T> +__device__ void rbind(T *A, T *B, T *C, int rowsA, int colsA, int rowsB, + int colsB) { + int maxClen = max(colsA, colsB); + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int ix = tid / maxClen; + int iy = tid % maxClen; + + int rowsC = rowsA + rowsB; + int colsC = colsA; + + // Copy an element of A into C into the appropriate location + if (ix < rowsA && iy < colsA) { + T elemA = A[ix * colsA + iy]; + C[ix * colsC + iy] = elemA; + } + + // Copy an element of B into C into the appropriate location + if (ix < rowsB && iy < colsB) { + T elemB = B[ix * colsB + iy]; + C[(ix + rowsA) * colsC + iy] = elemB; + } +} - // Copy an element of B into C into the appropriate location - if (ix < rowsB && iy < colsB) { - double elemB = B[ix * colsB + iy]; - C[(ix + rowsA) * colsC + iy] = elemB; - } +extern "C" __global__ void rbind_d(double *A, double *B, double *C, int rowsA, + int colsA, int rowsB, int colsB) { + rbind(A, B, C, rowsA, colsA, rowsB, colsB); } +extern "C" __global__ void rbind_f(float *A, float *B, float *C, int rowsA, + int colsA, int rowsB, int colsB) { + rbind(A, B, C, rowsA, colsA, rowsB, colsB); +} /** * Does a reduce operation over all elements of the array. - * This method has been adapted from the Reduction sample in the NVIDIA CUDA Samples (v8.0) + * This method has been adapted from the Reduction sample in the NVIDIA CUDA + * Samples (v8.0) * and the Reduction example available through jcuda.org - * When invoked initially, all blocks partly compute the reduction operation over the entire array - * and writes it to the output/temporary array. A second invokation needs to happen to get the + * When invoked initially, all blocks partly compute the reduction operation + * over the entire array + * and writes it to the output/temporary array. A second invokation needs to + * happen to get the * reduced value. - * The number of threads, blocks and amount of shared memory is calculated in a specific way. - * Please refer to the NVIDIA CUDA Sample or the SystemML code that invokes this method to see + * The number of threads, blocks and amount of shared memory is calculated in a + * specific way. + * Please refer to the NVIDIA CUDA Sample or the SystemML code that invokes this + * method to see * how its done. - * The template-ized version of this function is similar to what is found in NVIDIA CUB + * The template-ized version of this function is similar to what is found in + * NVIDIA CUB * - * @param ReductionOp Type of the functor object that implements the reduction operation + * @param ReductionOp Type of the functor object that implements the + * reduction operation */ -template <typename ReductionOp> +template <typename ReductionOp, typename T> __device__ void reduce( - double *g_idata, ///< input data stored in device memory (of size n) - double *g_odata, ///< output/temporary array stored in device memory (of size n) - unsigned int n, ///< size of the input and temporary/output arrays - ReductionOp reduction_op, ///< Reduction operation to perform (functor object) - double initialValue) ///< initial value for the reduction variable + T *g_idata, ///< input data stored in device memory (of size n) + T *g_odata, ///< output/temporary array stored in device memory (of size n) + unsigned int n, ///< size of the input and temporary/output arrays + ReductionOp + reduction_op, ///< Reduction operation to perform (functor object) + T initialValue) ///< initial value for the reduction variable { - extern __shared__ double sdata[]; - - // perform first level of reduction, - // reading from global memory, writing to shared memory - unsigned int tid = threadIdx.x; - unsigned int i = blockIdx.x*blockDim.x*2 + threadIdx.x; - unsigned int gridSize = blockDim.x*2*gridDim.x; - - double v = initialValue; - - // we reduce multiple elements per thread. The number is determined by the - // number of active thread blocks (via gridDim). More blocks will result - // in a larger gridSize and therefore fewer elements per thread - while (i < n) - { - v = reduction_op(v, g_idata[i]); - // ensure we don't read out of bounds - if (i + blockDim.x < n) - v = reduction_op(v, g_idata[i+blockDim.x]); - i += gridSize; + // extern __shared__ T sdata[]; + extern __shared__ __align__(sizeof(T)) unsigned char my_sdata[]; + T *sdata = reinterpret_cast<T *>(my_sdata); + + // perform first level of reduction, + // reading from global memory, writing to shared memory + unsigned int tid = threadIdx.x; + unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x; + unsigned int gridSize = blockDim.x * 2 * gridDim.x; + + T v = initialValue; + + // we reduce multiple elements per thread. The number is determined by the + // number of active thread blocks (via gridDim). More blocks will result + // in a larger gridSize and therefore fewer elements per thread + while (i < n) { + v = reduction_op(v, g_idata[i]); + // ensure we don't read out of bounds + if (i + blockDim.x < n) v = reduction_op(v, g_idata[i + blockDim.x]); + i += gridSize; + } + + // each thread puts its local sum into shared memory + sdata[tid] = v; + __syncthreads(); + + // do reduction in shared mem + if (blockDim.x >= 1024) { + if (tid < 512) { + sdata[tid] = v = reduction_op(v, sdata[tid + 512]); } - - // each thread puts its local sum into shared memory - sdata[tid] = v; __syncthreads(); - - - // do reduction in shared mem - if (blockDim.x >= 1024){ if (tid < 512) { sdata[tid] = v = reduction_op(v, sdata[tid + 512]); } __syncthreads(); } - if (blockDim.x >= 512) { if (tid < 256) { sdata[tid] = v = reduction_op(v, sdata[tid + 256]); } __syncthreads(); } - if (blockDim.x >= 256) { if (tid < 128) { sdata[tid] = v = reduction_op(v, sdata[tid + 128]); } __syncthreads(); } - if (blockDim.x >= 128) { if (tid < 64) { sdata[tid] = v = reduction_op(v, sdata[tid + 64]); } __syncthreads(); } - - if (tid < 32) - { - // now that we are using warp-synchronous programming (below) - // we need to declare our shared memory volatile so that the compiler - // doesn't reorder stores to it and induce incorrect behavior. - volatile double* smem = sdata; - if (blockDim.x >= 64) { smem[tid] = v = reduction_op(v, smem[tid + 32]); } - if (blockDim.x >= 32) { smem[tid] = v = reduction_op(v, smem[tid + 16]); } - if (blockDim.x >= 16) { smem[tid] = v = reduction_op(v, smem[tid + 8]); } - if (blockDim.x >= 8) { smem[tid] = v = reduction_op(v, smem[tid + 4]); } - if (blockDim.x >= 4) { smem[tid] = v = reduction_op(v, smem[tid + 2]); } - if (blockDim.x >= 2) { smem[tid] = v = reduction_op(v, smem[tid + 1]); } + } + if (blockDim.x >= 512) { + if (tid < 256) { + sdata[tid] = v = reduction_op(v, sdata[tid + 256]); } + __syncthreads(); + } + if (blockDim.x >= 256) { + if (tid < 128) { + sdata[tid] = v = reduction_op(v, sdata[tid + 128]); + } + __syncthreads(); + } + if (blockDim.x >= 128) { + if (tid < 64) { + sdata[tid] = v = reduction_op(v, sdata[tid + 64]); + } + __syncthreads(); + } + + if (tid < 32) { + // now that we are using warp-synchronous programming (below) + // we need to declare our shared memory volatile so that the compiler + // doesn't reorder stores to it and induce incorrect behavior. + volatile T *smem = sdata; + if (blockDim.x >= 64) { + smem[tid] = v = reduction_op(v, smem[tid + 32]); + } + if (blockDim.x >= 32) { + smem[tid] = v = reduction_op(v, smem[tid + 16]); + } + if (blockDim.x >= 16) { + smem[tid] = v = reduction_op(v, smem[tid + 8]); + } + if (blockDim.x >= 8) { + smem[tid] = v = reduction_op(v, smem[tid + 4]); + } + if (blockDim.x >= 4) { + smem[tid] = v = reduction_op(v, smem[tid + 2]); + } + if (blockDim.x >= 2) { + smem[tid] = v = reduction_op(v, smem[tid + 1]); + } + } - // write result for this block to global mem - if (tid == 0) - g_odata[blockIdx.x] = sdata[0]; + // write result for this block to global mem + if (tid == 0) g_odata[blockIdx.x] = sdata[0]; } - - /** * Does a reduce (sum) over each row of the array. * This kernel must be launched with as many blocks as there are rows. - * The intuition for this kernel is that each block does a reduction over a single row. - * The maximum number of blocks that can launched (as of compute capability 3.0) is 2^31 - 1 - * This works out fine for SystemML, since the maximum elements in a Java array can be 2^31 - c (some small constant) - * If the matrix is "fat" and "short", i.e. there are small number of rows and a large number of columns, + * The intuition for this kernel is that each block does a reduction over a + * single row. + * The maximum number of blocks that can launched (as of compute capability 3.0) + * is 2^31 - 1 + * This works out fine for SystemML, since the maximum elements in a Java array + * can be 2^31 - c (some small constant) + * If the matrix is "fat" and "short", i.e. there are small number of rows and a + * large number of columns, * there could be under-utilization of the hardware. - * The template-ized version of this function is similar to what is found in NVIDIA CUB - * @param ReductionOp Type of the functor object that implements the reduction operation - * @param AssignmentOp Type of the functor object that is used to modify the value before writing it to its final location in global memory for each row + * The template-ized version of this function is similar to what is found in + * NVIDIA CUB + * @param ReductionOp Type of the functor object that implements the + * reduction operation + * @param AssignmentOp Type of the functor object that is used to modify + * the value before writing it to its final location in global memory for each + * row */ -template <typename ReductionOp, - typename AssignmentOp> +template <typename ReductionOp, typename AssignmentOp, typename T> __device__ void reduce_row( - double *g_idata, ///< input data stored in device memory (of size rows*cols) - double *g_odata, ///< output/temporary array store in device memory (of size rows*cols) - unsigned int rows, ///< rows in input and temporary/output arrays - unsigned int cols, ///< columns in input and temporary/output arrays - ReductionOp reduction_op, ///< Reduction operation to perform (functor object) - AssignmentOp assignment_op, ///< Operation to perform before assigning this to its final location in global memory for each row - double initialValue){ ///< initial value for the reduction variable - extern __shared__ double sdata[]; - - // one block per row - if (blockIdx.x >= rows) { - return; + T *g_idata, ///< input data stored in device memory (of size rows*cols) + T *g_odata, ///< output/temporary array store in device memory (of size + ///rows*cols) + unsigned int rows, ///< rows in input and temporary/output arrays + unsigned int cols, ///< columns in input and temporary/output arrays + ReductionOp + reduction_op, ///< Reduction operation to perform (functor object) + AssignmentOp assignment_op, ///< Operation to perform before assigning this + ///to its final location in global memory for + ///each row + T initialValue) { ///< initial value for the reduction variable + // extern __shared__ T sdata[]; + extern __shared__ __align__(sizeof(T)) unsigned char my_sdata[]; + T *sdata = reinterpret_cast<T *>(my_sdata); + + // one block per row + if (blockIdx.x >= rows) { + return; + } + + unsigned int block = blockIdx.x; + unsigned int tid = threadIdx.x; + unsigned int i = tid; + unsigned int block_offset = block * cols; + + T v = initialValue; + while (i < cols) { + v = reduction_op(v, g_idata[block_offset + i]); + i += blockDim.x; + } + + // each thread puts its local sum into shared memory + sdata[tid] = v; + __syncthreads(); + + // do reduction in shared mem + if (blockDim.x >= 1024) { + if (tid < 512) { + sdata[tid] = v = reduction_op(v, sdata[tid + 512]); } - - unsigned int block = blockIdx.x; - unsigned int tid = threadIdx.x; - unsigned int i = tid; - unsigned int block_offset = block * cols; - - double v = initialValue; - while (i < cols){ - v = reduction_op(v, g_idata[block_offset + i]); - i += blockDim.x; + __syncthreads(); + } + if (blockDim.x >= 512) { + if (tid < 256) { + sdata[tid] = v = reduction_op(v, sdata[tid + 256]); } - - // each thread puts its local sum into shared memory - sdata[tid] = v; __syncthreads(); - - // do reduction in shared mem - if (blockDim.x >= 1024){ if (tid < 512) { sdata[tid] = v = reduction_op(v, sdata[tid + 512]); } __syncthreads(); } - if (blockDim.x >= 512) { if (tid < 256) { sdata[tid] = v = reduction_op(v, sdata[tid + 256]); } __syncthreads(); } - if (blockDim.x >= 256) { if (tid < 128) { sdata[tid] = v = reduction_op(v, sdata[tid + 128]); } __syncthreads(); } - if (blockDim.x >= 128) { if (tid < 64) { sdata[tid] = v = reduction_op(v, sdata[tid + 64]); } __syncthreads(); } - - if (tid < 32) - { - // now that we are using warp-synchronous programming (below) - // we need to declare our shared memory volatile so that the compiler - // doesn't reorder stores to it and induce incorrect behavior. - volatile double* smem = sdata; - if (blockDim.x >= 64) { smem[tid] = v = reduction_op(v, smem[tid + 32]); } - if (blockDim.x >= 32) { smem[tid] = v = reduction_op(v, smem[tid + 16]); } - if (blockDim.x >= 16) { smem[tid] = v = reduction_op(v, smem[tid + 8]); } - if (blockDim.x >= 8) { smem[tid] = v = reduction_op(v, smem[tid + 4]); } - if (blockDim.x >= 4) { smem[tid] = v = reduction_op(v, smem[tid + 2]); } - if (blockDim.x >= 2) { smem[tid] = v = reduction_op(v, smem[tid + 1]); } + } + if (blockDim.x >= 256) { + if (tid < 128) { + sdata[tid] = v = reduction_op(v, sdata[tid + 128]); + } + __syncthreads(); + } + if (blockDim.x >= 128) { + if (tid < 64) { + sdata[tid] = v = reduction_op(v, sdata[tid + 64]); + } + __syncthreads(); + } + + if (tid < 32) { + // now that we are using warp-synchronous programming (below) + // we need to declare our shared memory volatile so that the compiler + // doesn't reorder stores to it and induce incorrect behavior. + volatile T *smem = sdata; + if (blockDim.x >= 64) { + smem[tid] = v = reduction_op(v, smem[tid + 32]); + } + if (blockDim.x >= 32) { + smem[tid] = v = reduction_op(v, smem[tid + 16]); + } + if (blockDim.x >= 16) { + smem[tid] = v = reduction_op(v, smem[tid + 8]); + } + if (blockDim.x >= 8) { + smem[tid] = v = reduction_op(v, smem[tid + 4]); + } + if (blockDim.x >= 4) { + smem[tid] = v = reduction_op(v, smem[tid + 2]); } + if (blockDim.x >= 2) { + smem[tid] = v = reduction_op(v, smem[tid + 1]); + } + } - // write result for this block to global mem, modify it with assignment op - if (tid == 0) - g_odata[block] = assignment_op(sdata[0]); + // write result for this block to global mem, modify it with assignment op + if (tid == 0) g_odata[block] = assignment_op(sdata[0]); } - /** * Does a column wise reduction. * The intuition is that there are as many global threads as there are columns @@ -618,57 +923,59 @@ __device__ void reduce_row( * This of course leads to a under-utilization of the GPU resources. * For cases, where the number of columns is small, there can be unused SMs * - * The template-ized version of this function is similar to what is found in NVIDIA CUB - * @param ReductionOp Type of the functor object that implements the reduction operation - * @param AssignmentOp Type of the functor object that is used to modify the value before writing it to its final location in global memory for each column + * The template-ized version of this function is similar to what is found in + * NVIDIA CUB + * @param ReductionOp Type of the functor object that implements the + * reduction operation + * @param AssignmentOp Type of the functor object that is used to modify + * the value before writing it to its final location in global memory for each + * column */ -template <typename ReductionOp, - typename AssignmentOp> +template <typename ReductionOp, typename AssignmentOp, typename T> __device__ void reduce_col( - double *g_idata, ///< input data stored in device memory (of size rows*cols) - double *g_odata, ///< output/temporary array store in device memory (of size rows*cols) - unsigned int rows, ///< rows in input and temporary/output arrays - unsigned int cols, ///< columns in input and temporary/output arrays - ReductionOp reduction_op, ///< Reduction operation to perform (functor object) - AssignmentOp assignment_op, ///< Operation to perform before assigning this to its final location in global memory for each column - double initialValue) ///< initial value for the reduction variable + T *g_idata, ///< input data stored in device memory (of size rows*cols) + T *g_odata, ///< output/temporary array store in device memory (of size + ///rows*cols) + unsigned int rows, ///< rows in input and temporary/output arrays + unsigned int cols, ///< columns in input and temporary/output arrays + ReductionOp + reduction_op, ///< Reduction operation to perform (functor object) + AssignmentOp assignment_op, ///< Operation to perform before assigning this + ///to its final location in global memory for + ///each column + T initialValue) ///< initial value for the reduction variable { - unsigned int global_tid = blockIdx.x * blockDim.x + threadIdx.x; - if (global_tid >= cols) { - return; - } - - unsigned int i = global_tid; - unsigned int grid_size = cols; - double val = initialValue; - - while (i < rows * cols) { - val = reduction_op(val, g_idata[i]); - i += grid_size; - } - g_odata[global_tid] = assignment_op(val); + unsigned int global_tid = blockIdx.x * blockDim.x + threadIdx.x; + if (global_tid >= cols) { + return; + } + + unsigned int i = global_tid; + unsigned int grid_size = cols; + T val = initialValue; + + while (i < rows * cols) { + val = reduction_op(val, g_idata[i]); + i += grid_size; + } + g_odata[global_tid] = assignment_op(val); } /** * Functor op for assignment op. This is a dummy/identity op. */ -typedef struct { - __device__ __forceinline__ - double operator()(double a) const { - return a; - } -} IdentityOp; +template <typename T> +struct IdentityOp { + __device__ __forceinline__ T operator()(T a) const { return a; } +}; /** * Functor op for summation operation */ -typedef struct { - __device__ __forceinline__ - double operator()(double a, double b) const { - return a + b; - } -} SumOp; - +template <typename T> +struct SumOp { + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } +}; /** * Do a summation over all elements of an array/matrix @@ -676,10 +983,20 @@ typedef struct { * @param g_odata output/temporary array stored in device memory (of size n) * @param n size of the input and temporary/output arrays */ -extern "C" -__global__ void reduce_sum(double *g_idata, double *g_odata, unsigned int n){ - SumOp op; - reduce<SumOp>(g_idata, g_odata, n, op, 0.0); +template <typename T> +__device__ void reduce_sum(T *g_idata, T *g_odata, unsigned int n) { + SumOp<T> op; + reduce<SumOp<T>, T>(g_idata, g_odata, n, op, (T)0.0); +} + +extern "C" __global__ void reduce_sum_d(double *g_idata, double *g_odata, + unsigned int n) { + reduce_sum(g_idata, g_odata, n); +} + +extern "C" __global__ void reduce_sum_f(float *g_idata, float *g_odata, + unsigned int n) { + reduce_sum(g_idata, g_odata, n); } /** @@ -689,11 +1006,25 @@ __global__ void reduce_sum(double *g_idata, double *g_odata, unsigned int n){ * @param rows number of rows in input matrix * @param cols number of columns in input matrix */ -extern "C" -__global__ void reduce_row_sum(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ - SumOp op; - IdentityOp aop; - reduce_row<SumOp, IdentityOp>(g_idata, g_odata, rows, cols, op, aop, 0.0); +template <typename T> +__device__ void reduce_row_sum(T *g_idata, T *g_odata, unsigned int rows, + unsigned int cols) { + SumOp<T> op; + IdentityOp<T> aop; + reduce_row<SumOp<T>, IdentityOp<T>, T>(g_idata, g_odata, rows, cols, op, aop, + 0.0); +} + +extern "C" __global__ void reduce_row_sum_d(double *g_idata, double *g_odata, + unsigned int rows, + unsigned int cols) { + reduce_row_sum(g_idata, g_odata, rows, cols); +} + +extern "C" __global__ void reduce_row_sum_f(float *g_idata, float *g_odata, + unsigned int rows, + unsigned int cols) { + reduce_row_sum(g_idata, g_odata, rows, cols); } /** @@ -703,23 +1034,39 @@ __global__ void reduce_row_sum(double *g_idata, double *g_odata, unsigned int ro * @param rows number of rows in input matrix * @param cols number of columns in input matrix */ -extern "C" -__global__ void reduce_col_sum(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ - SumOp op; - IdentityOp aop; - reduce_col<SumOp, IdentityOp>(g_idata, g_odata, rows, cols, op, aop, 0.0); +template <typename T> +__device__ void reduce_col_sum(T *g_idata, T *g_odata, unsigned int rows, + unsigned int cols) { + SumOp<T> op; + IdentityOp<T> aop; + reduce_col<SumOp<T>, IdentityOp<T>, T>(g_idata, g_odata, rows, cols, op, aop, + (T)0.0); +} + +extern "C" __global__ void reduce_col_sum_d(double *g_idata, double *g_odata, + unsigned int rows, + unsigned int cols) { + reduce_col_sum(g_idata, g_odata, rows, cols); } +extern "C" __global__ void reduce_col_sum_f(float *g_idata, float *g_odata, + unsigned int rows, + unsigned int cols) { + reduce_col_sum(g_idata, g_odata, rows, cols); +} /** * Functor op for max operation */ -typedef struct { - __device__ __forceinline__ - double operator()(double a, double b) const { - return fmax(a, b); - } -} MaxOp; +template <typename T> +struct MaxOp { + __device__ __forceinline__ T operator()(T a, T b) const { return fmax(a, b); } +}; + +template<> +struct MaxOp<float> { + __device__ __forceinline__ float operator()(float a, float b) const { return fmaxf(a, b); } +}; /** @@ -728,10 +1075,20 @@ typedef struct { * @param g_odata output/temporary array stode in device memory (of size n) * @param n size of the input and temporary/output arrays */ -extern "C" -__global__ void reduce_max(double *g_idata, double *g_odata, unsigned int n){ - MaxOp op; - reduce<MaxOp>(g_idata, g_odata, n, op, -DBL_MAX); +template <typename T> +__device__ void reduce_max(T *g_idata, T *g_odata, unsigned int n) { + MaxOp<T> op; + reduce<MaxOp<T>, T>(g_idata, g_odata, n, op, -T_MAX(g_idata[0])); +} + +extern "C" __global__ void reduce_max_d(double *g_idata, double *g_odata, + unsigned int n) { + reduce_max(g_idata, g_odata, n); +} + +extern "C" __global__ void reduce_max_f(float *g_idata, float *g_odata, + unsigned int n) { + reduce_max(g_idata, g_odata, n); } /** @@ -741,11 +1098,25 @@ __global__ void reduce_max(double *g_idata, double *g_odata, unsigned int n){ * @param rows number of rows in input matrix * @param cols number of columns in input matrix */ -extern "C" -__global__ void reduce_row_max(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ - MaxOp op; - IdentityOp aop; - reduce_row<MaxOp, IdentityOp>(g_idata, g_odata, rows, cols, op, aop, -DBL_MAX); +template <typename T> +__device__ void reduce_row_max(T *g_idata, T *g_odata, unsigned int rows, + unsigned int cols) { + MaxOp<T> op; + IdentityOp<T> aop; + reduce_row<MaxOp<T>, IdentityOp<T>, T>(g_idata, g_odata, rows, cols, op, aop, + -T_MAX(g_idata[0])); +} + +extern "C" __global__ void reduce_row_max_d(double *g_idata, double *g_odata, + unsigned int rows, + unsigned int cols) { + reduce_row_max(g_idata, g_odata, rows, cols); +} + +extern "C" __global__ void reduce_row_max_f(float *g_idata, float *g_odata, + unsigned int rows, + unsigned int cols) { + reduce_row_max(g_idata, g_odata, rows, cols); } /** @@ -755,22 +1126,34 @@ __global__ void reduce_row_max(double *g_idata, double *g_odata, unsigned int ro * @param rows number of rows in input matrix * @param cols number of columns in input matrix */ -extern "C" -__global__ void reduce_col_max(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ - MaxOp op; - IdentityOp aop; - reduce_col<MaxOp, IdentityOp>(g_idata, g_odata, rows, cols, op, aop, -DBL_MAX); +template <typename T> +__device__ void reduce_col_max(T *g_idata, T *g_odata, unsigned int rows, + unsigned int cols) { + MaxOp<T> op; + IdentityOp<T> aop; + reduce_col<MaxOp<T>, IdentityOp<T>, T>(g_idata, g_odata, rows, cols, op, aop, + (T)-T_MAX(g_idata[0])); +} + +extern "C" __global__ void reduce_col_max_d(double *g_idata, double *g_odata, + unsigned int rows, + unsigned int cols) { + reduce_col_max(g_idata, g_odata, rows, cols); +} + +extern "C" __global__ void reduce_col_max_f(float *g_idata, float *g_odata, + unsigned int rows, + unsigned int cols) { + reduce_col_max(g_idata, g_odata, rows, cols); } /** * Functor op for min operation */ -typedef struct { - __device__ __forceinline__ - double operator()(double a, double b) const { - return fmin(a, b); - } -} MinOp; +template <typename T> +struct MinOp { + __device__ __forceinline__ T operator()(T a, T b) const { return fmin(a, b); } +}; /** * Do a min over all elements of an array/matrix @@ -778,10 +1161,20 @@ typedef struct { * @param g_odata output/temporary array stode in device memory (of size n) * @param n size of the input and temporary/output arrays */ -extern "C" -__global__ void reduce_min(double *g_idata, double *g_odata, unsigned int n){ - MinOp op; - reduce<MinOp>(g_idata, g_odata, n, op, DBL_MAX); +template <typename T> +__device__ void reduce_min(T *g_idata, T *g_odata, unsigned int n) { + MinOp<T> op; + reduce<MinOp<T>, T>(g_idata, g_odata, n, op, T_MAX(g_idata[0])); +} + +extern "C" __global__ void reduce_min_d(double *g_idata, double *g_odata, + unsigned int n) { + reduce_min(g_idata, g_odata, n); +} + +extern "C" __global__ void reduce_min_f(float *g_idata, float *g_odata, + unsigned int n) { + reduce_min(g_idata, g_odata, n); } /** @@ -791,11 +1184,25 @@ __global__ void reduce_min(double *g_idata, double *g_odata, unsigned int n){ * @param rows number of rows in input matrix * @param cols number of columns in input matrix */ -extern "C" -__global__ void reduce_row_min(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ - MinOp op; - IdentityOp aop; - reduce_row<MinOp, IdentityOp>(g_idata, g_odata, rows, cols, op, aop, DBL_MAX); +template <typename T> +__device__ void reduce_row_min(T *g_idata, T *g_odata, unsigned int rows, + unsigned int cols) { + MinOp<T> op; + IdentityOp<T> aop; + reduce_row<MinOp<T>, IdentityOp<T>, T>(g_idata, g_odata, rows, cols, op, aop, + T_MAX(g_idata[0])); +} + +extern "C" __global__ void reduce_row_min_d(double *g_idata, double *g_odata, + unsigned int rows, + unsigned int cols) { + reduce_row_min(g_idata, g_odata, rows, cols); +} + +extern "C" __global__ void reduce_row_min_f(float *g_idata, float *g_odata, + unsigned int rows, + unsigned int cols) { + reduce_row_min(g_idata, g_odata, rows, cols); } /** @@ -805,22 +1212,34 @@ __global__ void reduce_row_min(double *g_idata, double *g_odata, unsigned int ro * @param rows number of rows in input matrix * @param cols number of columns in input matrix */ -extern "C" -__global__ void reduce_col_min(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ - MinOp op; - IdentityOp aop; - reduce_col<MinOp>(g_idata, g_odata, rows, cols, op, aop, DBL_MAX); +template <typename T> +__device__ void reduce_col_min(T *g_idata, T *g_odata, unsigned int rows, + unsigned int cols) { + MinOp<T> op; + IdentityOp<T> aop; + reduce_col<MinOp<T>, IdentityOp<T>, T>(g_idata, g_odata, rows, cols, op, aop, + T_MAX(g_idata[0])); +} + +extern "C" __global__ void reduce_col_min_d(double *g_idata, double *g_odata, + unsigned int rows, + unsigned int cols) { + reduce_col_min(g_idata, g_odata, rows, cols); +} + +extern "C" __global__ void reduce_col_min_f(float *g_idata, float *g_odata, + unsigned int rows, + unsigned int cols) { + reduce_col_min(g_idata, g_odata, rows, cols); } /** * Functor op for product operation */ -typedef struct { - __device__ __forceinline__ - double operator()(double a, double b) const { - return a * b; - } -} ProductOp; +template <typename T> +struct ProductOp { + __device__ __forceinline__ T operator()(T a, T b) const { return a * b; } +}; /** * Do a product over all elements of an array/matrix @@ -828,26 +1247,35 @@ typedef struct { * @param g_odata output/temporary array stode in device memory (of size n) * @param n size of the input and temporary/output arrays */ -extern "C" -__global__ void reduce_prod(double *g_idata, double *g_odata, unsigned int n){ - ProductOp op; - reduce<ProductOp>(g_idata, g_odata, n, op, 1.0); +template <typename T> +__device__ void reduce_prod(T *g_idata, T *g_odata, unsigned int n) { + ProductOp<T> op; + reduce<ProductOp<T>, T>(g_idata, g_odata, n, op, (T)1.0); +} + +extern "C" __global__ void reduce_prod_d(double *g_idata, double *g_odata, + unsigned int n) { + reduce_prod(g_idata, g_odata, n); +} + +extern "C" __global__ void reduce_prod_f(float *g_idata, float *g_odata, + unsigned int n) { + reduce_prod(g_idata, g_odata, n); } /** * Functor op for mean operation */ +template <typename T> struct MeanOp { - const long _size; ///< Number of elements by which to divide to calculate mean - __device__ __forceinline__ - MeanOp(long size): _size(size) {} - __device__ __forceinline__ - double operator()(double total) const { - return total / _size; - } + const long + _size; ///< Number of elements by which to divide to calculate mean + __device__ __forceinline__ MeanOp(long size) : _size(size) {} + __device__ __forceinline__ T operator()(T total) const { + return total / _size; + } }; - /** * Do a mean over all rows of a matrix * @param g_idata input matrix stored in device memory (of size rows * cols) @@ -855,11 +1283,25 @@ struct MeanOp { * @param rows number of rows in input matrix * @param cols number of columns in input matrix */ -extern "C" -__global__ void reduce_row_mean(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ - SumOp op; - MeanOp aop(cols); - reduce_row<SumOp, MeanOp>(g_idata, g_odata, rows, cols, op, aop, 0.0); +template <typename T> +__device__ void reduce_row_mean(T *g_idata, T *g_odata, unsigned int rows, + unsigned int cols) { + SumOp<T> op; + MeanOp<T> aop(cols); + reduce_row<SumOp<T>, MeanOp<T>, T>(g_idata, g_odata, rows, cols, op, aop, + (T)0.0); +} + +extern "C" __global__ void reduce_row_mean_d(double *g_idata, double *g_odata, + unsigned int rows, + unsigned int cols) { + reduce_row_mean(g_idata, g_odata, rows, cols); +} + +extern "C" __global__ void reduce_row_mean_f(float *g_idata, float *g_odata, + unsigned int rows, + unsigned int cols) { + reduce_row_mean(g_idata, g_odata, rows, cols); } /** @@ -869,13 +1311,26 @@ __global__ void reduce_row_mean(double *g_idata, double *g_odata, unsigned int r * @param rows number of rows in input matrix * @param cols number of columns in input matrix */ -extern "C" -__global__ void reduce_col_mean(double *g_idata, double *g_odata, unsigned int rows, unsigned int cols){ - SumOp op; - MeanOp aop(rows); - reduce_col<SumOp, MeanOp>(g_idata, g_odata, rows, cols, op, aop, 0.0); +template <typename T> +__device__ void reduce_col_mean(T *g_idata, T *g_odata, unsigned int rows, + unsigned int cols) { + SumOp<T> op; + MeanOp<T> aop(rows); + reduce_col<SumOp<T>, MeanOp<T>, T>(g_idata, g_odata, rows, cols, op, aop, + 0.0); } +extern "C" __global__ void reduce_col_mean_d(double *g_idata, double *g_odata, + unsigned int rows, + unsigned int cols) { + reduce_col_mean(g_idata, g_odata, rows, cols); +} + +extern "C" __global__ void reduce_col_mean_f(float *g_idata, float *g_odata, + unsigned int rows, + unsigned int cols) { + reduce_col_mean(g_idata, g_odata, rows, cols); +} /** * Do an exp over all the elements of a matrix @@ -883,12 +1338,21 @@ __global__ void reduce_col_mean(double *g_idata, double *g_odata, unsigned int r * @param C the pre-allocated output matrix (of length = size) * @param siz the length of the input and output matrices */ -extern "C" -__global__ void matrix_exp(double *A, double *C, unsigned int size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < size){ - C[index] = exp(A[index]); - } +template <typename T> +__device__ void matrix_exp(T *A, T *C, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = exp(A[index]); + } +} + +extern "C" __global__ void matrix_exp_d(double *A, double *C, + unsigned int size) { + matrix_exp(A, C, size); +} + +extern "C" __global__ void matrix_exp_f(float *A, float *C, unsigned int size) { + matrix_exp(A, C, size); } /** @@ -897,12 +1361,21 @@ __global__ void matrix_exp(double *A, double *C, unsigned int size) { * @param C the pre-allocated output matrix (of length = size) * @param siz the length of the input and output matrices */ -extern "C" -__global__ void matrix_sqrt(double *A, double *C, unsigned int size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < size){ - C[index] = sqrt(A[index]); - } +template <typename T> +__device__ void matrix_sqrt(T *A, T *C, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = sqrt(A[index]); + } +} + +extern "C" __global__ void matrix_sqrt_d(double *A, double *C, + unsigned int size) { + matrix_sqrt(A, C, size); +} + +extern "C" __global__ void matrix_sqrt_f(float *A, float *C, unsigned int size) { + matrix_sqrt(A, C, size); } /** @@ -911,12 +1384,22 @@ __global__ void matrix_sqrt(double *A, double *C, unsigned int size) { * @param C the pre-allocated output matrix (of length = size) * @param siz the length of the input and output matrices */ -extern "C" -__global__ void matrix_round(double *A, double *C, unsigned int size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < size){ - C[index] = (double)llround(A[index]); - } +template <typename T> +__device__ void matrix_round(T *A, T *C, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = (T)llround(A[index]); + } +} + +extern "C" __global__ void matrix_round_d(double *A, double *C, + unsigned int size) { + matrix_round(A, C, size); +} + +extern "C" __global__ void matrix_round_f(float *A, float *C, + unsigned int size) { + matrix_round(A, C, size); } /** @@ -925,12 +1408,21 @@ __global__ void matrix_round(double *A, double *C, unsigned int size) { * @param C the pre-allocated output matrix (of length = size) * @param siz the length of the input and output matrices */ -extern "C" -__global__ void matrix_abs(double *A, double *C, unsigned int size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < size){ - C[index] = (double)fabs(A[index]); - } +template <typename T> +__device__ void matrix_abs(T *A, T *C, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = (T)fabs(A[index]); + } +} + +extern "C" __global__ void matrix_abs_d(double *A, double *C, + unsigned int size) { + matrix_abs(A, C, size); +} + +extern "C" __global__ void matrix_abs_f(float *A, float *C, unsigned int size) { + matrix_abs(A, C, size); } /** @@ -939,12 +1431,21 @@ __global__ void matrix_abs(double *A, double *C, unsigned int size) { * @param C the pre-allocated output matrix (of length = size) * @param siz the length of the input and output matrices */ -extern "C" -__global__ void matrix_log(double *A, double *C, unsigned int size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < size){ - C[index] = log(A[index]); - } +template <typename T> +__device__ void matrix_log(T *A, T *C, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = log(A[index]); + } +} + +extern "C" __global__ void matrix_log_d(double *A, double *C, + unsigned int size) { + matrix_log(A, C, size); +} + +extern "C" __global__ void matrix_log_f(float *A, float *C, unsigned int size) { + matrix_log(A, C, size); } /** @@ -953,12 +1454,22 @@ __global__ void matrix_log(double *A, double *C, unsigned int size) { * @param C the pre-allocated output matrix (of length = size) * @param siz the length of the input and output matrices */ -extern "C" -__global__ void matrix_floor(double *A, double *C, unsigned int size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < size){ - C[index] = floor(A[index]); - } +template <typename T> +__device__ void matrix_floor(T *A, T *C, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = floor(A[index]); + } +} + +extern "C" __global__ void matrix_floor_d(double *A, double *C, + unsigned int size) { + matrix_floor(A, C, size); +} + +extern "C" __global__ void matrix_floor_f(float *A, float *C, + unsigned int size) { + matrix_floor(A, C, size); } /** @@ -967,12 +1478,21 @@ __global__ void matrix_floor(double *A, double *C, unsigned int size) { * @param C the pre-allocated output matrix (of length = size) * @param siz the length of the input and output matrices */ -extern "C" -__global__ void matrix_ceil(double *A, double *C, unsigned int size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < size){ - C[index] = ceil(A[index]); - } +template <typename T> +__device__ void matrix_ceil(T *A, T *C, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = ceil(A[index]); + } +} + +extern "C" __global__ void matrix_ceil_d(double *A, double *C, + unsigned int size) { + matrix_ceil(A, C, size); +} + +extern "C" __global__ void matrix_ceil_f(float *A, float *C, unsigned int size) { + matrix_ceil(A, C, size); } /** @@ -981,12 +1501,21 @@ __global__ void matrix_ceil(double *A, double *C, unsigned int size) { * @param C the pre-allocated output matrix (of length = size) * @param siz the length of the input and output matrices */ -extern "C" -__global__ void matrix_sin(double *A, double *C, unsigned int size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < size){ - C[index] = sin(A[index]); - } +template <typename T> +__device__ void matrix_sin(T *A, T *C, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = sin(A[index]); + } +} + +extern "C" __global__ void matrix_sin_d(double *A, double *C, + unsigned int size) { + matrix_sin(A, C, size); +} + +extern "C" __global__ void matrix_sin_f(float *A, float *C, unsigned int size) { + matrix_sin(A, C, size); } /** @@ -995,12 +1524,21 @@ __global__ void matrix_sin(double *A, double *C, unsigned int size) { * @param C the pre-allocated output matrix (of length = size) * @param siz the length of the input and output matrices */ -extern "C" -__global__ void matrix_sinh(double *A, double *C, unsigned int size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < size){ - C[index] = sinh(A[index]); - } +template <typename T> +__device__ void matrix_sinh(T *A, T *C, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = sinh(A[index]); + } +} + +extern "C" __global__ void matrix_sinh_d(double *A, double *C, + unsigned int size) { + matrix_sinh(A, C, size); +} + +extern "C" __global__ void matrix_sinh_f(float *A, float *C, unsigned int size) { + matrix_sinh(A, C, size); } /** @@ -1009,12 +1547,21 @@ __global__ void matrix_sinh(double *A, double *C, unsigned int size) { * @param C the pre-allocated output matrix (of length = size) * @param siz the length of the input and output matrices */ -extern "C" -__global__ void matrix_cos(double *A, double *C, unsigned int size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < size){ - C[index] = cos(A[index]); - } +template <typename T> +__device__ void matrix_cos(T *A, T *C, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = cos(A[index]); + } +} + +extern "C" __global__ void matrix_cos_d(double *A, double *C, + unsigned int size) { + matrix_cos(A, C, size); +} + +extern "C" __global__ void matrix_cos_f(float *A, float *C, unsigned int size) { + matrix_cos(A, C, size); } /** @@ -1023,12 +1570,21 @@ __global__ void matrix_cos(double *A, double *C, unsigned int size) { * @param C the pre-allocated output matrix (of length = size) * @param siz the length of the input and output matrices */ -extern "C" -__global__ void matrix_cosh(double *A, double *C, unsigned int size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < size){ - C[index] = cosh(A[index]); - } +template <typename T> +__device__ void matrix_cosh(T *A, T *C, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = cosh(A[index]); + } +} + +extern "C" __global__ void matrix_cosh_d(double *A, double *C, + unsigned int size) { + matrix_cosh(A, C, size); +} + +extern "C" __global__ void matrix_cosh_f(float *A, float *C, unsigned int size) { + matrix_cosh(A, C, size); } /** @@ -1037,12 +1593,21 @@ __global__ void matrix_cosh(double *A, double *C, unsigned int size) { * @param C the pre-allocated output matrix (of length = size) * @param siz the length of the input and output matrices */ -extern "C" -__global__ void matrix_tan(double *A, double *C, unsigned int size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < size){ - C[index] = tan(A[index]); - } +template <typename T> +__device__ void matrix_tan(T *A, T *C, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = tan(A[index]); + } +} + +extern "C" __global__ void matrix_tan_d(double *A, double *C, + unsigned int size) { + matrix_tan(A, C, size); +} + +extern "C" __global__ void matrix_tan_f(float *A, float *C, unsigned int size) { + matrix_tan(A, C, size); } /** @@ -1051,12 +1616,21 @@ __global__ void matrix_tan(double *A, double *C, unsigned int size) { * @param C the pre-allocated output matrix (of length = size) * @param siz the length of the input and output matrices */ -extern "C" -__global__ void matrix_tanh(double *A, double *C, unsigned int size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < size){ - C[index] = tanh(A[index]); - } +template <typename T> +__device__ void matrix_tanh(T *A, T *C, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = tanh(A[index]); + } +} + +extern "C" __global__ void matrix_tanh_d(double *A, double *C, + unsigned int size) { + matrix_tanh(A, C, size); +} + +extern "C" __global__ void matrix_tanh_f(float *A, float *C, unsigned int size) { + matrix_tanh(A, C, size); } /** @@ -1065,12 +1639,21 @@ __global__ void matrix_tanh(double *A, double *C, unsigned int size) { * @param C the pre-allocated output matrix (of length = size) * @param siz the length of the input and output matrices */ -extern "C" -__global__ void matrix_asin(double *A, double *C, unsigned int size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < size){ - C[index] = asin(A[index]); - } +template <typename T> +__device__ void matrix_asin(T *A, T *C, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = asin(A[index]); + } +} + +extern "C" __global__ void matrix_asin_d(double *A, double *C, + unsigned int size) { + matrix_asin(A, C, size); +} + +extern "C" __global__ void matrix_asin_f(float *A, float *C, unsigned int size) { + matrix_asin(A, C, size); } /** @@ -1079,12 +1662,21 @@ __global__ void matrix_asin(double *A, double *C, unsigned int size) { * @param C the pre-allocated output matrix (of length = size) * @param siz the length of the input and output matrices */ -extern "C" -__global__ void matrix_acos(double *A, double *C, unsigned int size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < size){ - C[index] = acos(A[index]); - } +template <typename T> +__device__ void matrix_acos(T *A, T *C, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = acos(A[index]); + } +} + +extern "C" __global__ void matrix_acos_d(double *A, double *C, + unsigned int size) { + matrix_acos(A, C, size); +} + +extern "C" __global__ void matrix_acos_f(float *A, float *C, unsigned int size) { + matrix_acos(A, C, size); } /** @@ -1093,12 +1685,21 @@ __global__ void matrix_acos(double *A, double *C, unsigned int size) { * @param C the pre-allocated output matrix (of length = size) * @param siz the length of the input and output matrices */ -extern "C" -__global__ void matrix_atan(double *A, double *C, unsigned int size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < size){ - C[index] = atan(A[index]); - } +template <typename T> +__device__ void matrix_atan(T *A, T *C, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + C[index] = atan(A[index]); + } +} + +extern "C" __global__ void matrix_atan_d(double *A, double *C, + unsigned int size) { + matrix_atan(A, C, size); +} + +extern "C" __global__ void matrix_atan_f(float *A, float *C, unsigned int size) { + matrix_atan(A, C, size); } /** @@ -1108,14 +1709,23 @@ __global__ void matrix_atan(double *A, double *C, unsigned int size) { * @param C the pre-allocated output matrix (of length = size) * @param siz the length of the input and output matrices */ -extern "C" -__global__ void matrix_sign(double *A, double *C, unsigned int size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < size){ - if (A[index] == 0.0) { - C[index] = 0.0; - } else { - C[index] = copysign(1.0, A[index]); - } +template <typename T> +__device__ void matrix_sign(T *A, T *C, unsigned int size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < size) { + if (A[index] == 0.0) { + C[index] = 0.0; + } else { + C[index] = copysign(1.0, A[index]); } + } +} + +extern "C" __global__ void matrix_sign_d(double *A, double *C, + unsigned int size) { + matrix_sign(A, C, size); } + +extern "C" __global__ void matrix_sign_f(float *A, float *C, unsigned int size) { + matrix_sign(A, C, size); +} \ No newline at end of file