Repository: incubator-systemml Updated Branches: refs/heads/master 76f3ca5d3 -> 2c5c3b14e
[HOTFIX] Bug fix for solve, removed warnings and added instrumentation Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/2c5c3b14 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/2c5c3b14 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/2c5c3b14 Branch: refs/heads/master Commit: 2c5c3b14e1906cda70ae1581b19a5e908b3ab329 Parents: 76f3ca5 Author: Nakul Jindal <[email protected]> Authored: Thu May 4 16:26:47 2017 -0700 Committer: Nakul Jindal <[email protected]> Committed: Thu May 4 16:26:47 2017 -0700 ---------------------------------------------------------------------- .../instructions/GPUInstructionParser.java | 4 +- .../gpu/BuiltinBinaryGPUInstruction.java | 2 + .../instructions/gpu/GPUInstruction.java | 28 ++++--- .../gpu/MatrixMatrixBuiltinGPUInstruction.java | 1 + .../instructions/gpu/context/GPUContext.java | 2 + .../instructions/gpu/context/GPUObject.java | 3 +- .../runtime/matrix/data/LibMatrixCUDA.java | 77 +++++++++++++++----- 7 files changed, 86 insertions(+), 31 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2c5c3b14/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java index ef0412c..4a45521 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java @@ -35,9 +35,9 @@ import org.apache.sysml.runtime.instructions.gpu.AggregateUnaryGPUInstruction; public class GPUInstructionParser extends InstructionParser { - public static final HashMap<String, GPUINSTRUCTION_TYPE> String2GPUInstructionType; + static final HashMap<String, GPUINSTRUCTION_TYPE> String2GPUInstructionType; static { - String2GPUInstructionType = new HashMap<String, GPUINSTRUCTION_TYPE>(); + String2GPUInstructionType = new HashMap<>(); // Neural Network Operators String2GPUInstructionType.put( "relu_backward", GPUINSTRUCTION_TYPE.Convolution); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2c5c3b14/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java index 372f883..24e9e79 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java @@ -30,7 +30,9 @@ import org.apache.sysml.runtime.matrix.operators.Operator; public abstract class BuiltinBinaryGPUInstruction extends GPUInstruction { + @SuppressWarnings("unused") private int _arity; + CPOperand output; CPOperand input1, input2; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2c5c3b14/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java index 9eef072..f4c523b 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUInstruction.java @@ -35,16 +35,20 @@ public abstract class GPUInstruction extends Instruction public enum GPUINSTRUCTION_TYPE { AggregateUnary, AggregateBinary, Convolution, MMTSJ, Reorg, ArithmeticBinary, BuiltinUnary, BuiltinBinary, Builtin }; // Memory/conversions - public final static String MISC_TIMER_HOST_TO_DEVICE = "H2D"; // time spent in bringing data to gpu (from host) - public final static String MISC_TIMER_DEVICE_TO_HOST = "D2H"; // time spent in bringing data from gpu (to host) - public final static String MISC_TIMER_DEVICE_TO_DEVICE = "D2D"; // time spent in copying data from one region on the device to another - public final static String MISC_TIMER_SPARSE_TO_DENSE = "s2d"; // time spent in converting data from sparse to dense - public final static String MISC_TIMER_DENSE_TO_SPARSE = "d2s"; // time spent in converting data from dense to sparse - public final static String MISC_TIMER_CUDA_FREE = "f"; // time spent in calling cudaFree - public final static String MISC_TIMER_ALLOCATE = "a"; // time spent to allocate memory on gpu - public final static String MISC_TIMER_ALLOCATE_DENSE_OUTPUT = "ao"; // time spent to allocate dense output (recorded differently than MISC_TIMER_ALLOCATE) - public final static String MISC_TIMER_SET_ZERO = "az"; // time spent to allocate - public final static String MISC_TIMER_REUSE = "r"; // time spent in reusing already allocated memory on GPU (mainly for the count) + public final static String MISC_TIMER_HOST_TO_DEVICE = "H2D"; // time spent in bringing data to gpu (from host) + public final static String MISC_TIMER_DEVICE_TO_HOST = "D2H"; // time spent in bringing data from gpu (to host) + public final static String MISC_TIMER_DEVICE_TO_DEVICE = "D2D"; // time spent in copying data from one region on the device to another + public final static String MISC_TIMER_SPARSE_TO_DENSE = "s2d"; // time spent in converting data from sparse to dense + public final static String MISC_TIMER_DENSE_TO_SPARSE = "d2s"; // time spent in converting data from dense to sparse + public final static String MISC_TIMER_ROW_TO_COLUMN_MAJOR = "r2c"; // time spent in converting data from row major to column major + public final static String MISC_TIMER_COLUMN_TO_ROW_MAJOR = "c2r"; // time spent in converting data from column major to row major + public final static String MISC_TIMER_OBJECT_CLONE = "clone";// time spent in cloning (deep copying) a GPUObject instance + + public final static String MISC_TIMER_CUDA_FREE = "f"; // time spent in calling cudaFree + public final static String MISC_TIMER_ALLOCATE = "a"; // time spent to allocate memory on gpu + public final static String MISC_TIMER_ALLOCATE_DENSE_OUTPUT = "ao"; // time spent to allocate dense output (recorded differently than MISC_TIMER_ALLOCATE) + public final static String MISC_TIMER_SET_ZERO = "az"; // time spent to allocate + public final static String MISC_TIMER_REUSE = "r"; // time spent in reusing already allocated memory on GPU (mainly for the count) // Matmult instructions public final static String MISC_TIMER_SPARSE_ALLOCATE_LIB = "Msao"; // time spend in allocating for sparse matrix output @@ -58,6 +62,10 @@ public abstract class GPUInstruction extends Instruction // Other BLAS instructions public final static String MISC_TIMER_DAXPY_LIB = "daxpy"; // time spent in daxpy + public final static String MISC_TIMER_QR_BUFFER = "qr_buffer"; // time spent in calculating buffer needed to perform QR + public final static String MISC_TIMER_QR = "qr"; // time spent in doing QR + public final static String MISC_TIMER_ORMQR = "ormqr"; // time spent in ormqr + public final static String MISC_TIMER_TRSM = "trsm"; // time spent in cublas Dtrsm // Transpose public final static String MISC_TIMER_SPARSE_DGEAM_LIB = "sdgeaml"; // time spent in sparse transpose (and other ops of type a*op(A) + b*op(B)) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2c5c3b14/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixBuiltinGPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixBuiltinGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixBuiltinGPUInstruction.java index f492b6e..8936735 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixBuiltinGPUInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixMatrixBuiltinGPUInstruction.java @@ -45,6 +45,7 @@ public class MatrixMatrixBuiltinGPUInstruction extends BuiltinBinaryGPUInstructi MatrixObject mat2 = getMatrixInputForGPUInstruction(ec, input2.getName()); if(opcode.equals("solve")) { + ec.setMetaData(output.getName(), mat1.getNumColumns(), 1); LibMatrixCUDA.solve(ec, ec.getGPUContext(), getExtendedOpcode(), mat1, mat2, output.getName()); } else { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2c5c3b14/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java index d71f725..673601f 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java @@ -307,6 +307,8 @@ public class GPUContext { freeList = new LinkedList<Pointer>(); freeCUDASpaceMap.put(size, freeList); } + if (freeList.contains(toFree)) + throw new RuntimeException("GPU : Internal state corrupted, double free"); freeList.add(toFree); } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2c5c3b14/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java index 1d2285d..d735e38 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java @@ -26,7 +26,6 @@ import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE; import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW; import static jcuda.jcusparse.JCusparse.cusparseDdense2csr; import static jcuda.jcusparse.JCusparse.cusparseDnnz; -import static jcuda.runtime.JCuda.cudaMalloc; import static jcuda.runtime.JCuda.cudaMemcpy; import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToDevice; import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToHost; @@ -343,7 +342,7 @@ public class GPUObject { /** * Convenience method. Converts Column Major Dense Matrix to Row Major Dense Matrix - * @throws DMLRuntimeException + * @throws DMLRuntimeException if error */ public void denseColumnMajorToRowMajor() throws DMLRuntimeException { LOG.trace("GPU : dense Ptr row-major -> col-major on " + this + ", GPUContext=" + getGPUContext()); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2c5c3b14/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java index 23304b5..a99571a 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java @@ -329,6 +329,7 @@ public class LibMatrixCUDA { * @return a sparse matrix pointer * @throws DMLRuntimeException if error occurs */ + @SuppressWarnings("unused") private static CSRPointer getSparsePointer(GPUContext gCtx, MatrixObject input, String instName) throws DMLRuntimeException { if(!isInSparseFormat(gCtx, input)) { input.getGPUObject(gCtx).denseToSparse(); @@ -2754,6 +2755,25 @@ public class LibMatrixCUDA { Pointer betaPtr = pointerTo(beta); int transa = isLeftTransposed ? CUBLAS_OP_T : CUBLAS_OP_N; int transb = isRightTransposed ? CUBLAS_OP_T : CUBLAS_OP_N; + + int lda = (int) in1.getNumColumns(); + int ldb = (int) in2.getNumColumns(); + int m = (int) in1.getNumColumns(); + int n = (int) in2.getNumRows(); + if (isLeftTransposed && isRightTransposed) { + m = (int) in1.getNumRows(); + n = (int) in2.getNumColumns(); + } + else if (isLeftTransposed) { + m = (int) in1.getNumRows(); + } else if (isRightTransposed) { + n = (int) in2.getNumColumns(); + } + int ldc = m; + + + + /** int m = (int) in1.getNumRows(); int n = (int) in1.getNumColumns(); if(!isLeftTransposed && isRightTransposed) { @@ -2763,6 +2783,7 @@ public class LibMatrixCUDA { int lda = isLeftTransposed ? n : m; int ldb = isRightTransposed ? n : m; int ldc = m; + **/ MatrixObject out = ec.getMatrixObject(outputName); boolean isSparse1 = isInSparseFormat(gCtx, in1); @@ -2963,8 +2984,10 @@ public class LibMatrixCUDA { throw new DMLRuntimeException("GPU : Invalid internal state, the GPUContext set with the ExecutionContext is not the same used to run this LibMatrixCUDA function"); // x = solve(A, b) + LOG.trace("GPU : solve" + ", GPUContext=" + gCtx); + + long t0 = -1; - // Both Sparse if (!isInSparseFormat(gCtx, in1) && !isInSparseFormat(gCtx, in2)) { // Both dense GPUObject Aobj = in1.getGPUObject(gCtx); GPUObject bobj = in2.getGPUObject(gCtx); @@ -2980,55 +3003,75 @@ public class LibMatrixCUDA { // convert dense matrices to row major // Operation in cuSolver and cuBlas are for column major dense matrices // and are destructive to the original input + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); GPUObject ATobj = (GPUObject) Aobj.clone(); - ATobj.denseRowMajorToColumnMajor(); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_OBJECT_CLONE, System.nanoTime() - t0); + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); + ATobj.denseRowMajorToColumnMajor(); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_ROW_TO_COLUMN_MAJOR, System.nanoTime() - t0); Pointer A = ATobj.getJcudaDenseMatrixPtr(); + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); GPUObject bTobj = (GPUObject) bobj.clone(); - bTobj.denseRowMajorToColumnMajor(); - Pointer b = bTobj.getJcudaDenseMatrixPtr(); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_OBJECT_CLONE, System.nanoTime() - t0); + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); + bTobj.denseRowMajorToColumnMajor(); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_ROW_TO_COLUMN_MAJOR, System.nanoTime() - t0); + + Pointer b = bTobj.getJcudaDenseMatrixPtr(); // The following set of operations is done following the example in the cusolver documentation // http://docs.nvidia.com/cuda/cusolver/#ormqr-example1 // step 3: query working space of geqrf and ormqr - int[] lwork = {0}; + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); + int[] lwork = {0}; JCusolverDn.cusolverDnDgeqrf_bufferSize(gCtx.getCusolverDnHandle(), m, n, A, m, lwork); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_QR_BUFFER, System.nanoTime() - t0); + - // step 4: compute QR factorization - Pointer work = gCtx.allocate(lwork[0] * Sizeof.DOUBLE); - Pointer tau = gCtx.allocate(Math.max(m, m) * Sizeof.DOUBLE); + // step 4: compute QR factorization + Pointer work = gCtx.allocate(instName, lwork[0] * Sizeof.DOUBLE); + Pointer tau = gCtx.allocate(instName, Math.max(m, m) * Sizeof.DOUBLE); Pointer devInfo = gCtx.allocate(Sizeof.INT); - JCusolverDn.cusolverDnDgeqrf(gCtx.getCusolverDnHandle(), m, n, A, m, tau, work, lwork[0], devInfo); + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); + JCusolverDn.cusolverDnDgeqrf(gCtx.getCusolverDnHandle(), m, n, A, m, tau, work, lwork[0], devInfo); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_QR, System.nanoTime() - t0); - int[] qrError = {-1}; + + int[] qrError = {-1}; cudaMemcpy(Pointer.to(qrError), devInfo, Sizeof.INT, cudaMemcpyDeviceToHost); if (qrError[0] != 0) { throw new DMLRuntimeException("GPU : Error in call to geqrf (QR factorization) as part of solve, argument " + qrError[0] + " was wrong"); } // step 5: compute Q^T*B - JCusolverDn.cusolverDnDormqr(gCtx.getCusolverDnHandle(), cublasSideMode.CUBLAS_SIDE_LEFT, cublasOperation.CUBLAS_OP_T, m, 1, n, A, m, tau, b, m, work, lwork[0], devInfo); - cudaMemcpy(Pointer.to(qrError), devInfo, Sizeof.INT, cudaMemcpyDeviceToHost); + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); + JCusolverDn.cusolverDnDormqr(gCtx.getCusolverDnHandle(), cublasSideMode.CUBLAS_SIDE_LEFT, cublasOperation.CUBLAS_OP_T, m, 1, n, A, m, tau, b, m, work, lwork[0], devInfo); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_ORMQR, System.nanoTime() - t0); + cudaMemcpy(Pointer.to(qrError), devInfo, Sizeof.INT, cudaMemcpyDeviceToHost); if (qrError[0] != 0) { throw new DMLRuntimeException("GPU : Error in call to ormqr (to compuete Q^T*B after QR factorization) as part of solve, argument " + qrError[0] + " was wrong"); } // step 6: compute x = R \ Q^T*B - JCublas2.cublasDtrsm(gCtx.getCublasHandle(), + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); + JCublas2.cublasDtrsm(gCtx.getCublasHandle(), cublasSideMode.CUBLAS_SIDE_LEFT, cublasFillMode.CUBLAS_FILL_MODE_UPPER, cublasOperation.CUBLAS_OP_N, cublasDiagType.CUBLAS_DIAG_NON_UNIT, n, 1, pointerTo(1.0), A, m, b, m); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_TRSM, System.nanoTime() - t0); - bTobj.denseColumnMajorToRowMajor(); + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); + bTobj.denseColumnMajorToRowMajor(); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_COLUMN_TO_ROW_MAJOR, System.nanoTime() - t0); // TODO : Find a way to assign bTobj directly to the output and set the correct flags so as to not crash // There is an avoidable copy happening here MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, instName, outputName); cudaMemcpy(out.getGPUObject(gCtx).getJcudaDenseMatrixPtr(), bTobj.getJcudaDenseMatrixPtr(), n * 1 * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice); - gCtx.cudaFreeHelper(work); - gCtx.cudaFreeHelper(tau); - gCtx.cudaFreeHelper(tau); + gCtx.cudaFreeHelper(instName, work); + gCtx.cudaFreeHelper(instName, tau); ATobj.clearData(); bTobj.clearData();
