Repository: systemml Updated Branches: refs/heads/master 428f3aa21 -> de0513415
[SYSTEMML-1847] bug fixes for gpu from ml algos - Fixed errors in -gpu force arguments - Fix to GPU solve - converts sparse matrices to dense - Bug fix in GPUContext::clearTemporaryMemory - Fix for removing recorded GPUObjects - Estimate memory for each parfor body and set degree of parallelism - Setting cuda pointers to null after freeing - Fix after rebase with master for SOLVE on GPU Closes #626 Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/de051341 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/de051341 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/de051341 Branch: refs/heads/master Commit: de0513415e8fb6e9b9f289bc261612091bd4e664 Parents: 428f3aa Author: Nakul Jindal <[email protected]> Authored: Mon Aug 28 13:38:49 2017 -0700 Committer: Nakul Jindal <[email protected]> Committed: Mon Aug 28 13:38:49 2017 -0700 ---------------------------------------------------------------------- scripts/perftest/python/utils_misc.py | 12 +- .../java/org/apache/sysml/hops/BinaryOp.java | 5 +- .../controlprogram/ParForProgramBlock.java | 3 - .../parfor/opt/OptimizerRuleBased.java | 59 ++++++- .../instructions/gpu/context/CSRPointer.java | 3 + .../instructions/gpu/context/GPUContext.java | 36 +++-- .../instructions/gpu/context/GPUObject.java | 8 + .../runtime/matrix/data/LibMatrixCUDA.java | 162 +++++++++---------- 8 files changed, 176 insertions(+), 112 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/de051341/scripts/perftest/python/utils_misc.py ---------------------------------------------------------------------- diff --git a/scripts/perftest/python/utils_misc.py b/scripts/perftest/python/utils_misc.py index f9904c5..da9dbcb 100755 --- a/scripts/perftest/python/utils_misc.py +++ b/scripts/perftest/python/utils_misc.py @@ -211,20 +211,18 @@ def exec_dml_and_parse_time(exec_type, dml_file_name, args, backend_args_dict, s """ algorithm = dml_file_name + '.dml' - - sup_args = ''.join(['{} {}'.format(k, v) for k, v in systemml_args_dict.items()]) - + sup_args = ' '.join(['{} {}'.format(k, v) for k, v in systemml_args_dict.items()]) if exec_type == 'singlenode': exec_script = join(os.environ.get('SYSTEMML_HOME'), 'bin', 'systemml-standalone.py') - singlenode_pre_args = ''.join([' {} {} '.format(k, v) for k, v in backend_args_dict.items()]) - args = ''.join(['{} {}'.format(k, v) for k, v in args.items()]) + singlenode_pre_args = ' '.join(['{} {}'.format(k, v) for k, v in backend_args_dict.items()]) + args = ' '.join(['{} {}'.format(k, v) for k, v in args.items()]) cmd = [exec_script, singlenode_pre_args, '-f', algorithm, args, sup_args] cmd_string = ' '.join(cmd) if exec_type == 'hybrid_spark': exec_script = join(os.environ.get('SYSTEMML_HOME'), 'bin', 'systemml-spark-submit.py') - spark_pre_args = ''.join([' {} {} '.format(k, v) for k, v in backend_args_dict.items()]) - args = ''.join(['{} {}'.format(k, v) for k, v in args.items()]) + spark_pre_args = ' '.join([' {} {} '.format(k, v) for k, v in backend_args_dict.items()]) + args = ' '.join(['{} {}'.format(k, v) for k, v in args.items()]) cmd = [exec_script, spark_pre_args, '-f', algorithm, args, sup_args] cmd_string = ' '.join(cmd) http://git-wip-us.apache.org/repos/asf/systemml/blob/de051341/src/main/java/org/apache/sysml/hops/BinaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/BinaryOp.java b/src/main/java/org/apache/sysml/hops/BinaryOp.java index ad9f0ad..cd1f715 100644 --- a/src/main/java/org/apache/sysml/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysml/hops/BinaryOp.java @@ -1058,7 +1058,10 @@ public class BinaryOp extends Hop //ensure cp exec type for single-node operations if ( op == OpOp2.SOLVE ) { - _etype = ExecType.CP; + if (isGPUEnabled()) + _etype = ExecType.GPU; + else + _etype = ExecType.CP; } return _etype; http://git-wip-us.apache.org/repos/asf/systemml/blob/de051341/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java index 1968c26..3a9bf51 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/ParForProgramBlock.java @@ -619,9 +619,6 @@ public class ParForProgramBlock extends ForProgramBlock switch( _execMode ) { case LOCAL: //create parworkers as local threads - if (DMLScript.USE_ACCELERATOR) { - setDegreeOfParallelism(ec.getNumGPUContexts()); - } executeLocalParFor(ec, iterVar, from, to, incr); break; http://git-wip-us.apache.org/repos/asf/systemml/blob/de051341/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java index 018364f..154109a 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java @@ -32,6 +32,7 @@ import java.util.Set; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; +import org.apache.sysml.api.DMLScript; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.conf.DMLConfig; import org.apache.sysml.hops.AggBinaryOp; @@ -48,6 +49,7 @@ import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.LeftIndexingOp; import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.hops.MemoTable; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.ReorgOp; @@ -98,6 +100,8 @@ import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyze import org.apache.sysml.runtime.instructions.Instruction; import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction; +import org.apache.sysml.runtime.instructions.gpu.context.GPUContext; +import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool; import org.apache.sysml.runtime.instructions.spark.data.RDDObject; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.MatrixFormatMetaData; @@ -1175,7 +1179,42 @@ public class OptimizerRuleBased extends Optimizer _numEvaluatedPlans++; LOG.debug(getOptMode()+" OPT: rewrite 'set export replication factor' - result="+apply+((apply)?" ("+replication+")":"") ); } - + + /** + * Calculates the maximum memory needed in a CP only Parfor + * based on the {@link Hop#computeMemEstimate(MemoTable)} } function + * called recursively for the "children" of the parfor {@link OptNode}. + * + * @param n the parfor {@link OptNode} + * @return the maximum memory needed for any operation inside a parfor in CP execution mode + * @throws DMLRuntimeException if error + */ + protected double getMaxCPOnlyBudget(OptNode n) throws DMLRuntimeException { + ExecType et = n.getExecType(); + double ret = 0; + + if (n.isLeaf() && et != getRemoteExecType()) { + Hop h = OptTreeConverter.getAbstractPlanMapping().getMappedHop(n.getID()); + if (h.getForcedExecType() != LopProperties.ExecType.MR //e.g., -exec=hadoop + && h.getForcedExecType() != LopProperties.ExecType.SPARK) { + double mem = _cost.getLeafNodeEstimate(TestMeasure.MEMORY_USAGE, n, LopProperties.ExecType.CP); + if (mem >= OptimizerUtils.DEFAULT_SIZE) { + // memory estimate for worst case scenario. + // optimistically ignoring this + } else { + ret = Math.max(ret, mem); + } + } + } + + if (!n.isLeaf()) { + for (OptNode c : n.getChilds()) { + ret = Math.max(ret, getMaxCPOnlyBudget(c)); + } + } + return ret; + } + /////// //REWRITE set degree of parallelism /// @@ -1204,6 +1243,24 @@ public class OptimizerRuleBased extends Optimizer //constrain max parfor parallelism by problem size int parforK = (int)((_N<kMax)? _N : kMax); + + + // if gpu mode is enabled, the amount of parallelism is set to + // the smaller of the number of iterations and the number of GPUs + // otherwise it default to the number of CPU cores and the + // operations are run in CP mode + if (DMLScript.USE_ACCELERATOR) { + long perGPUBudget = GPUContextPool.initialGPUMemBudget(); + double maxMemUsage = getMaxCPOnlyBudget(n); + if (maxMemUsage < perGPUBudget){ + parforK = GPUContextPool.getDeviceCount(); + parforK = Math.min(parforK, (int)_N); + LOG.debug("Setting degree of parallelism + [" + parforK + "] for GPU; per GPU budget :[" + + perGPUBudget + "], parfor budget :[" + maxMemUsage + "], max parallelism per GPU : [" + + parforK + "]"); + } + } + //set parfor degree of parallelism pfpb.setDegreeOfParallelism(parforK); http://git-wip-us.apache.org/repos/asf/systemml/blob/de051341/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java index a5bc299..a4147a3 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/CSRPointer.java @@ -528,6 +528,9 @@ public class CSRPointer { cudaFreeHelper(val, eager); cudaFreeHelper(rowPtr, eager); cudaFreeHelper(colInd, eager); + val = null; + rowPtr = null; + colInd = null; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/de051341/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java index 84d181b..c6b82c4 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java @@ -528,7 +528,7 @@ public class GPUContext { * Records that a block is not used anymore */ public void removeRecordedUsage(GPUObject o) { - allocatedGPUObjects.remove(o); + allocatedGPUObjects.removeIf(a -> a.equals(o)); } /** @@ -735,21 +735,27 @@ public class GPUContext { // To record the cuda block sizes needed by allocatedGPUObjects, others are cleared up. HashMap<Pointer, Long> tmpCudaBlockSizeMap = new HashMap<>(); for (GPUObject o : allocatedGPUObjects) { - if (o.isSparse()) { - CSRPointer p = o.getSparseMatrixCudaPointer(); - if (p.rowPtr != null && cudaBlockSizeMap.containsKey(p.rowPtr)) { - tmpCudaBlockSizeMap.put(p.rowPtr, cudaBlockSizeMap.get(p.rowPtr)); - } - if (p.colInd != null && cudaBlockSizeMap.containsKey(p.colInd)) { - tmpCudaBlockSizeMap.put(p.colInd, cudaBlockSizeMap.get(p.colInd)); - } - if (p.val != null && cudaBlockSizeMap.containsKey(p.val)) { - tmpCudaBlockSizeMap.put(p.val, cudaBlockSizeMap.get(p.val)); - } + if (o.isDirty()) { + if (o.isSparse()) { + CSRPointer p = o.getSparseMatrixCudaPointer(); + if (p == null) + throw new RuntimeException("CSRPointer is null in clearTemporaryMemory"); + if (p.rowPtr != null && cudaBlockSizeMap.containsKey(p.rowPtr)) { + tmpCudaBlockSizeMap.put(p.rowPtr, cudaBlockSizeMap.get(p.rowPtr)); + } + if (p.colInd != null && cudaBlockSizeMap.containsKey(p.colInd)) { + tmpCudaBlockSizeMap.put(p.colInd, cudaBlockSizeMap.get(p.colInd)); + } + if (p.val != null && cudaBlockSizeMap.containsKey(p.val)) { + tmpCudaBlockSizeMap.put(p.val, cudaBlockSizeMap.get(p.val)); + } - } else { - Pointer p = o.getJcudaDenseMatrixPtr(); - tmpCudaBlockSizeMap.put(p, cudaBlockSizeMap.get(p)); + } else { + Pointer p = o.getJcudaDenseMatrixPtr(); + if (p == null) + throw new RuntimeException("Pointer is null in clearTemporaryMemory"); + tmpCudaBlockSizeMap.put(p, cudaBlockSizeMap.get(p)); + } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/de051341/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java index c3e23f3..1bed42a 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java @@ -301,6 +301,9 @@ public class GPUObject { * @throws DMLRuntimeException ? */ public void setSparseMatrixCudaPointer(CSRPointer sparseMatrixPtr) throws DMLRuntimeException { + if (this.jcudaSparseMatrixPtr != null) { + throw new DMLRuntimeException("jcudaSparseMatrixPtr was already allocated for " + this + ", this will cause a memory leak on the GPU"); + } this.jcudaSparseMatrixPtr = sparseMatrixPtr; this.isSparse = true; if (getJcudaDenseMatrixPtr() != null) { @@ -317,6 +320,9 @@ public class GPUObject { * @throws DMLRuntimeException ? */ public void setDenseMatrixCudaPointer(Pointer densePtr) throws DMLRuntimeException { + if (this.jcudaDenseMatrixPtr != null) { + throw new DMLRuntimeException("jcudaDenseMatrixPtr was already allocated for " + this + ", this will cause a memory leak on the GPU"); + } this.jcudaDenseMatrixPtr = densePtr; this.isSparse = false; if (getJcudaSparseMatrixPtr() != null) { @@ -373,6 +379,7 @@ public class GPUObject { Pointer tmp = transpose(getGPUContext(), getJcudaDenseMatrixPtr(), m, n, lda, ldc); cudaFreeHelper(getJcudaDenseMatrixPtr()); + jcudaDenseMatrixPtr = null; setDenseMatrixCudaPointer(tmp); } @@ -394,6 +401,7 @@ public class GPUObject { Pointer tmp = transpose(getGPUContext(), getJcudaDenseMatrixPtr(), m, n, lda, ldc); cudaFreeHelper(getJcudaDenseMatrixPtr()); + jcudaDenseMatrixPtr = null; setDenseMatrixCudaPointer(tmp); } http://git-wip-us.apache.org/repos/asf/systemml/blob/de051341/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java index 92a5546..4be5c2d 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java @@ -3550,108 +3550,100 @@ public class LibMatrixCUDA { long t0 = -1; - if (!isInSparseFormat(gCtx, in1) && !isInSparseFormat(gCtx, in2)) { // Both dense - GPUObject Aobj = in1.getGPUObject(gCtx); - GPUObject bobj = in2.getGPUObject(gCtx); - int m = toInt(in1.getNumRows()); - int n = toInt(in1.getNumColumns()); - if (in2.getNumRows() != m) - throw new DMLRuntimeException("GPU : Incorrect input for solve(), rows in A should be the same as rows in B"); - if (in2.getNumColumns() != 1) - throw new DMLRuntimeException("GPU : Incorrect input for solve(), columns in B should be 1"); - - - // Copy over matrices and - // convert dense matrices to row major - // Operation in cuSolver and cuBlas are for column major dense matrices - // and are destructive to the original input - if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); - GPUObject ATobj = (GPUObject) Aobj.clone(); - if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_OBJECT_CLONE, System.nanoTime() - t0); - if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); - ATobj.denseRowMajorToColumnMajor(); - if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_ROW_TO_COLUMN_MAJOR, System.nanoTime() - t0); - Pointer A = ATobj.getJcudaDenseMatrixPtr(); - if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); - GPUObject bTobj = (GPUObject) bobj.clone(); - if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_OBJECT_CLONE, System.nanoTime() - t0); - if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); - bTobj.denseRowMajorToColumnMajor(); - if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_ROW_TO_COLUMN_MAJOR, System.nanoTime() - t0); + GPUObject Aobj = in1.getGPUObject(gCtx); + if (isInSparseFormat(gCtx, in1)) + Aobj.sparseToDense(instName); - Pointer b = bTobj.getJcudaDenseMatrixPtr(); + GPUObject bobj = in2.getGPUObject(gCtx); + if (isInSparseFormat(gCtx, in2)) + bobj.sparseToDense(instName); - // The following set of operations is done following the example in the cusolver documentation - // http://docs.nvidia.com/cuda/cusolver/#ormqr-example1 + int m = (int) in1.getNumRows(); + int n = (int) in1.getNumColumns(); + if ((int) in2.getNumRows() != m) + throw new DMLRuntimeException("GPU : Incorrect input for solve(), rows in A should be the same as rows in B"); + if ((int) in2.getNumColumns() != 1) + throw new DMLRuntimeException("GPU : Incorrect input for solve(), columns in B should be 1"); - // step 3: query working space of geqrf and ormqr - if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); - int[] lwork = {0}; - JCusolverDn.cusolverDnDgeqrf_bufferSize(gCtx.getCusolverDnHandle(), m, n, A, m, lwork); - if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_QR_BUFFER, System.nanoTime() - t0); + // Copy over matrices and + // convert dense matrices to row major + // Operation in cuSolver and cuBlas are for column major dense matrices + // and are destructive to the original input + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); + GPUObject ATobj = (GPUObject) Aobj.clone(); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_OBJECT_CLONE, System.nanoTime() - t0); + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); + ATobj.denseRowMajorToColumnMajor(); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_ROW_TO_COLUMN_MAJOR, System.nanoTime() - t0); + Pointer A = ATobj.getJcudaDenseMatrixPtr(); - // step 4: compute QR factorization - Pointer work = gCtx.allocate(instName, lwork[0] * Sizeof.DOUBLE); - Pointer tau = gCtx.allocate(instName, m * Sizeof.DOUBLE); - Pointer devInfo = gCtx.allocate(Sizeof.INT); - if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); - JCusolverDn.cusolverDnDgeqrf(gCtx.getCusolverDnHandle(), m, n, A, m, tau, work, lwork[0], devInfo); - if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_QR, System.nanoTime() - t0); + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); + GPUObject bTobj = (GPUObject) bobj.clone(); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_OBJECT_CLONE, System.nanoTime() - t0); + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); + bTobj.denseRowMajorToColumnMajor(); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_ROW_TO_COLUMN_MAJOR, System.nanoTime() - t0); - int[] qrError = {-1}; - cudaMemcpy(Pointer.to(qrError), devInfo, Sizeof.INT, cudaMemcpyDeviceToHost); - if (qrError[0] != 0) { - throw new DMLRuntimeException("GPU : Error in call to geqrf (QR factorization) as part of solve, argument " + qrError[0] + " was wrong"); - } + Pointer b = bTobj.getJcudaDenseMatrixPtr(); - // step 5: compute Q^T*B - if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); - JCusolverDn.cusolverDnDormqr(gCtx.getCusolverDnHandle(), cublasSideMode.CUBLAS_SIDE_LEFT, cublasOperation.CUBLAS_OP_T, m, 1, n, A, m, tau, b, m, work, lwork[0], devInfo); - if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_ORMQR, System.nanoTime() - t0); - cudaMemcpy(Pointer.to(qrError), devInfo, Sizeof.INT, cudaMemcpyDeviceToHost); - if (qrError[0] != 0) { - throw new DMLRuntimeException("GPU : Error in call to ormqr (to compuete Q^T*B after QR factorization) as part of solve, argument " + qrError[0] + " was wrong"); - } + // The following set of operations is done following the example in the cusolver documentation + // http://docs.nvidia.com/cuda/cusolver/#ormqr-example1 - // step 6: compute x = R \ Q^T*B - if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); - JCublas2.cublasDtrsm(gCtx.getCublasHandle(), - cublasSideMode.CUBLAS_SIDE_LEFT, cublasFillMode.CUBLAS_FILL_MODE_UPPER, cublasOperation.CUBLAS_OP_N, cublasDiagType.CUBLAS_DIAG_NON_UNIT, - n, 1, pointerTo(1.0), A, m, b, m); - if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_TRSM, System.nanoTime() - t0); - - if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); - bTobj.denseColumnMajorToRowMajor(); - if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_COLUMN_TO_ROW_MAJOR, System.nanoTime() - t0); + // step 3: query working space of geqrf and ormqr + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); + int[] lwork = {0}; + JCusolverDn.cusolverDnDgeqrf_bufferSize(gCtx.getCusolverDnHandle(), m, n, A, m, lwork); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_QR_BUFFER, System.nanoTime() - t0); + + // step 4: compute QR factorization + Pointer work = gCtx.allocate(instName, lwork[0] * Sizeof.DOUBLE); + Pointer tau = gCtx.allocate(instName, m * Sizeof.DOUBLE); + Pointer devInfo = gCtx.allocate(Sizeof.INT); + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); + JCusolverDn.cusolverDnDgeqrf(gCtx.getCusolverDnHandle(), m, n, A, m, tau, work, lwork[0], devInfo); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_QR, System.nanoTime() - t0); - // TODO : Find a way to assign bTobj directly to the output and set the correct flags so as to not crash - // There is an avoidable copy happening here - MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, in1.getNumColumns(), 1); - cudaMemcpy(out.getGPUObject(gCtx).getJcudaDenseMatrixPtr(), bTobj.getJcudaDenseMatrixPtr(), n * 1 * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice); + int[] qrError = {-1}; + cudaMemcpy(Pointer.to(qrError), devInfo, Sizeof.INT, cudaMemcpyDeviceToHost); + if (qrError[0] != 0) { + throw new DMLRuntimeException("GPU : Error in call to geqrf (QR factorization) as part of solve, argument " + qrError[0] + " was wrong"); + } - gCtx.cudaFreeHelper(instName, work); - gCtx.cudaFreeHelper(instName, tau); - ATobj.clearData(); - bTobj.clearData(); + // step 5: compute Q^T*B + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); + JCusolverDn.cusolverDnDormqr(gCtx.getCusolverDnHandle(), cublasSideMode.CUBLAS_SIDE_LEFT, cublasOperation.CUBLAS_OP_T, m, 1, n, A, m, tau, b, m, work, lwork[0], devInfo); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_ORMQR, System.nanoTime() - t0); + cudaMemcpy(Pointer.to(qrError), devInfo, Sizeof.INT, cudaMemcpyDeviceToHost); + if (qrError[0] != 0) { + throw new DMLRuntimeException("GPU : Error in call to ormqr (to compuete Q^T*B after QR factorization) as part of solve, argument " + qrError[0] + " was wrong"); + } - //debugPrintMatrix(b, n, 1); + // step 6: compute x = R \ Q^T*B + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); + JCublas2.cublasDtrsm(gCtx.getCublasHandle(), + cublasSideMode.CUBLAS_SIDE_LEFT, cublasFillMode.CUBLAS_FILL_MODE_UPPER, cublasOperation.CUBLAS_OP_N, cublasDiagType.CUBLAS_DIAG_NON_UNIT, + n, 1, pointerTo(1.0), A, m, b, m); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_TRSM, System.nanoTime() - t0); + if (GPUStatistics.DISPLAY_STATISTICS) t0 = System.nanoTime(); + bTobj.denseColumnMajorToRowMajor(); + if (GPUStatistics.DISPLAY_STATISTICS) GPUStatistics.maintainCPMiscTimes(instName, GPUInstruction.MISC_TIMER_COLUMN_TO_ROW_MAJOR, System.nanoTime() - t0); - } else if (isInSparseFormat(gCtx, in1) && isInSparseFormat(gCtx, in2)) { // Both sparse - throw new DMLRuntimeException("GPU : solve on sparse inputs not supported"); - } else if (!isInSparseFormat(gCtx, in1) && isInSparseFormat(gCtx, in2)) { // A is dense, b is sparse - // Pointer A = getDensePointer(gCtx, in1, instName); - // Pointer B = getDensePointer(gCtx, in2, instName); - throw new DMLRuntimeException("GPU : solve on sparse inputs not supported"); - } else if (isInSparseFormat(gCtx, in1) && !isInSparseFormat(gCtx, in2)) { // A is sparse, b is dense - throw new DMLRuntimeException("GPU : solve on sparse inputs not supported"); - } + // TODO : Find a way to assign bTobj directly to the output and set the correct flags so as to not crash + // There is an avoidable copy happening here + MatrixObject out = getDenseMatrixOutputForGPUInstruction(ec, instName, outputName, in1.getNumColumns(), 1); + cudaMemcpy(out.getGPUObject(gCtx).getJcudaDenseMatrixPtr(), bTobj.getJcudaDenseMatrixPtr(), n * 1 * Sizeof.DOUBLE, cudaMemcpyDeviceToDevice); + gCtx.cudaFreeHelper(instName, work); + gCtx.cudaFreeHelper(instName, tau); + ATobj.clearData(); + bTobj.clearData(); - } + //debugPrintMatrix(b, n, 1); + } //********************************************************************/ //***************** END OF Builtin Functions ************************/
