Repository: systemml Updated Branches: refs/heads/master 0e323ec26 -> 61139e400
http://git-wip-us.apache.org/repos/asf/systemml/blob/61139e40/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java index 443adf4..ce53aea 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java @@ -336,7 +336,7 @@ public class ExecutionContext { public Pair<MatrixObject, Boolean> getSparseMatrixOutputForGPUInstruction(String varName, long numRows, long numCols, long nnz) { MatrixObject mo = allocateGPUMatrixObject(varName, numRows, numCols); mo.getMatrixCharacteristics().setNonZeros(nnz); - boolean allocated = mo.getGPUObject(getGPUContext(0)).acquireDeviceModifySparse(); + boolean allocated = mo.getGPUObject(getGPUContext(0)).acquireDeviceModifySparse(); return new Pair<>(mo, allocated); } http://git-wip-us.apache.org/repos/asf/systemml/blob/61139e40/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java index 135e0b1..b3ec497 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java @@ -25,6 +25,8 @@ import static jcuda.jcusparse.JCusparse.cusparseSetMatType; import static jcuda.jcusparse.JCusparse.cusparseSetPointerMode; import static jcuda.jcusparse.JCusparse.cusparseXcsrgeamNnz; import static jcuda.jcusparse.JCusparse.cusparseXcsrgemmNnz; +import static jcuda.jcusparse.JCusparse.cusparseXcsr2coo; + import static jcuda.jcusparse.cusparseIndexBase.CUSPARSE_INDEX_BASE_ZERO; import static jcuda.jcusparse.cusparseMatrixType.CUSPARSE_MATRIX_TYPE_GENERAL; import static jcuda.runtime.JCuda.cudaMemcpy; @@ -111,6 +113,24 @@ public class CSRPointer { colInd = new Pointer(); allocateMatDescrPointer(); } + + /** + * Note: the user is expected to free the returned pointer. + * + * @param handle cusparse handle + * @param rows number of rows of the CSR pointer + * @return integer array of nnz uncompressed row indices (with index base 0). + */ + public Pointer getCooRowPointer(cusparseHandle handle, int rows) { + if(nnz > 0) { + Pointer cooRowInd = gpuContext.allocate(null, getIntSizeOf(nnz)); + cusparseXcsr2coo(handle, rowPtr, LibMatrixCUDA.toInt(nnz), rows, cooRowInd, CUSPARSE_INDEX_BASE_ZERO); + return cooRowInd; + } + else { + throw new DMLRuntimeException("csr2coo only support when nnz > 0, but instead found " + nnz); + } + } private static long getDataTypeSizeOf(long numElems) { return numElems * ((long) LibMatrixCUDA.sizeOfDataType); http://git-wip-us.apache.org/repos/asf/systemml/blob/61139e40/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java index e2d5824..46ab3f7 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java @@ -830,7 +830,7 @@ public class LibMatrixCUDA { // Subtract mean from every element in the matrix ScalarOperator minusOp = new RightScalarOperator(Minus.getMinusFnObject(), mean); - matrixScalarOp(gCtx, instName, in, mean, rlen, clen, tmp, minusOp); + denseMatrixScalarOp(gCtx, instName, in, mean, rlen, clen, tmp, minusOp); squareMatrix(gCtx, instName, tmp, tmp2, rlen, clen); @@ -899,7 +899,7 @@ public class LibMatrixCUDA { reduceCol(gCtx, instName, "reduce_col_sum", tmp2, tmpCol, rlen, clen); ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), rlen - 1); - matrixScalarOp(gCtx, instName, tmpCol, rlen - 1, 1, clen, out, divideOp); + denseMatrixScalarOp(gCtx, instName, tmpCol, rlen - 1, 1, clen, out, divideOp); gCtx.cudaFreeHelper(instName, tmpCol, gCtx.EAGER_CUDA_FREE); gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE); @@ -922,7 +922,7 @@ public class LibMatrixCUDA { reduceRow(gCtx, instName, "reduce_row_sum", tmp2, tmpRow, rlen, clen); ScalarOperator divideOp = new RightScalarOperator(Divide.getDivideFnObject(), clen - 1); - matrixScalarOp(gCtx, instName, tmpRow, clen - 1, rlen, 1, out, divideOp); + denseMatrixScalarOp(gCtx, instName, tmpRow, clen - 1, rlen, 1, out, divideOp); gCtx.cudaFreeHelper(instName, tmpRow, gCtx.EAGER_CUDA_FREE); gCtx.cudaFreeHelper(instName, tmp, gCtx.EAGER_CUDA_FREE); @@ -940,7 +940,7 @@ public class LibMatrixCUDA { */ private static void squareMatrix(GPUContext gCtx, String instName, Pointer in, Pointer out, int rlen, int clen) { ScalarOperator power2op = new RightScalarOperator(Power.getPowerFnObject(), 2); - matrixScalarOp(gCtx, instName, in, 2, rlen, clen, out, power2op); + denseMatrixScalarOp(gCtx, instName, in, 2, rlen, clen, out, power2op); } /** @@ -1134,21 +1134,13 @@ public class LibMatrixCUDA { LOG.trace("GPU : matrixScalarRelational, scalar: " + constant + ", GPUContext=" + gCtx); } - Pointer A, C; if (isSparseAndEmpty(gCtx, in)) { setOutputToConstant(ec, gCtx, instName, op.executeScalar(0.0), outputName, in.getNumRows(), in.getNumColumns()); return; } else { - A = getDensePointer(gCtx, in, instName); - MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, in.getNumRows(), in.getNumColumns()); // Allocated the dense output matrix - C = getDensePointer(gCtx, out, instName); + matrixScalarOp(ec, gCtx, instName, in, outputName, false, op); } - - int rlenA = toInt(in.getNumRows()); - int clenA = toInt(in.getNumColumns()); - - matrixScalarOp(gCtx, instName, A, constant, rlenA, clenA, C, op); } /** @@ -1307,6 +1299,10 @@ public class LibMatrixCUDA { dgeam(ec, gCtx, instName, in1, in2, outputName, isLeftTransposed, isRightTransposed, alpha, beta); } } + + private static long getIntSizeOf(long numElems) { + return numElems * ((long) Sizeof.INT); + } /** * Utility to do matrix-scalar operation kernel @@ -1328,12 +1324,49 @@ public class LibMatrixCUDA { int rlenA = toInt(in.getNumRows()); int clenA = toInt(in.getNumColumns()); - Pointer A = getDensePointer(gCtx, in, instName); // TODO: FIXME: Implement sparse binCellSparseScalarOp kernel double scalar = op.getConstant(); - // MatrixObject out = ec.getMatrixObject(outputName); - MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, rlenA, clenA); // Allocated the dense output matrix - Pointer C = getDensePointer(gCtx, out, instName); - matrixScalarOp(gCtx, instName, A, scalar, rlenA, clenA, C, op); + + if(isInSparseFormat(gCtx, in)) { + double zeroVal = op.executeScalar(0.0); + CSRPointer sparseA = getSparsePointer(gCtx, in, instName); + long nnz = sparseA.nnz; + if(zeroVal == 0.0) { + // op(sparse input, scalar) -> sparse output + MatrixObject out = getSparseMatrixOutputForGPUInstruction(ec, rlenA, clenA, nnz, instName, outputName); + CSRPointer sparseC = getSparsePointer(gCtx, out, instName); + if(nnz > 0) { + // Since sparse safe operators, only perform matrixScalar operators on val pointer assuming it to be a + // dense matrix of size [nnz, 1]. + denseMatrixScalarOp(gCtx, instName, sparseA.val, scalar, toInt(nnz), 1, sparseC.val, op); + cudaMemcpy(sparseC.rowPtr, sparseA.rowPtr, getIntSizeOf(rlenA+1), cudaMemcpyDeviceToDevice); + cudaMemcpy(sparseC.colInd, sparseA.colInd, getIntSizeOf(nnz), cudaMemcpyDeviceToDevice); + } + } + else { + // op(sparse input, scalar) -> dense output + MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, rlenA, clenA); // Allocated the dense output matrix + Pointer C = getDensePointer(gCtx, out, instName); + setOutputToConstant(gCtx, instName, zeroVal, C, toInt(rlenA*clenA)); + if(nnz > 0) { + long t0 = ConfigurationManager.isFinegrainedStatistics() ? System.nanoTime() : 0; + Pointer cooRowPtrA = sparseA.getCooRowPointer(getCusparseHandle(gCtx), rlenA); + int isLeftScalar = (op instanceof LeftScalarOperator) ? 1 : 0; + getCudaKernels(gCtx).launchKernel("sparse_dense_matrix_scalar_op", + ExecutionConfig.getConfigForSimpleVectorOperations(toInt(nnz)), + cooRowPtrA, sparseA.colInd, sparseA.val, scalar, C, toInt(nnz), toInt(clenA), getBinaryOp(op.fn), isLeftScalar); + if (ConfigurationManager.isFinegrainedStatistics()) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_MATRIX_SCALAR_OP_KERNEL, System.nanoTime() - t0); + gCtx.cudaFreeHelper(instName, cooRowPtrA, gCtx.EAGER_CUDA_FREE); + } + } + } + else { + // op(dense input, scalar) -> dense output + Pointer A = getDensePointer(gCtx, in, instName); + // MatrixObject out = ec.getMatrixObject(outputName); + MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, rlenA, clenA); // Allocated the dense output matrix + Pointer C = getDensePointer(gCtx, out, instName); + denseMatrixScalarOp(gCtx, instName, A, scalar, rlenA, clenA, C, op); + } } /** @@ -1350,7 +1383,7 @@ public class LibMatrixCUDA { * @param c the dense output matrix * @param op operation to perform */ - private static void matrixScalarOp(GPUContext gCtx, String instName, Pointer a, double scalar, int rlenA, int clenA, Pointer c, ScalarOperator op) { + private static void denseMatrixScalarOp(GPUContext gCtx, String instName, Pointer a, double scalar, int rlenA, int clenA, Pointer c, ScalarOperator op) { if(LOG.isTraceEnabled()) { LOG.trace("GPU : matrix_scalar_op" + ", GPUContext=" + gCtx); } @@ -1550,17 +1583,27 @@ public class LibMatrixCUDA { } else { MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, numRows, numCols); // Allocated the dense output matrix Pointer A = getDensePointer(gCtx, out, instName); - int rlen = toInt(out.getNumRows()); - int clen = toInt(out.getNumColumns()); - long t0 = 0; - if (ConfigurationManager.isFinegrainedStatistics()) - t0 = System.nanoTime(); - int size = rlen * clen; - getCudaKernels(gCtx).launchKernel("fill", ExecutionConfig.getConfigForSimpleVectorOperations(size), A, constant, size); - if (ConfigurationManager.isFinegrainedStatistics()) - GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_FILL_KERNEL, System.nanoTime() - t0); + int size = toInt(out.getNumRows() * out.getNumColumns()); + setOutputToConstant(gCtx, instName, constant, A, size); } } + + /** + * Fills an an array on the GPU with a given scalar value + * @param gCtx a valid {@link GPUContext} + * @param instName name of the invoking instruction to record{@link Statistics}. + * @param constant scalar value with which to fill the matrix + * @param A pointer to the input/output array + * @param size length of A (not in bytes, but in float/double) + */ + private static void setOutputToConstant(GPUContext gCtx, String instName, double constant, Pointer A, int size) { + long t0 = 0; + if (ConfigurationManager.isFinegrainedStatistics()) + t0 = System.nanoTime(); + getCudaKernels(gCtx).launchKernel("fill", ExecutionConfig.getConfigForSimpleVectorOperations(size), A, constant, size); + if (ConfigurationManager.isFinegrainedStatistics()) + GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_FILL_KERNEL, System.nanoTime() - t0); + } /** * Performs a deep copy of input device double pointer corresponding to matrix