Repository: incubator-systemml
Updated Branches:
  refs/heads/master 02ff20045 -> c9b6f02e7


[SYSTEMML-445] Fixed incorrect sparsity handling in GPU backend

- Fixed incorrect sparsity handling
- Fix memory usage counters
- Added #evictions in Statistics for better profiling
- Added logic to automatically select the convolution algorithm for conv2d 
(Disabled it for now).
- Avoid unnecessary hop creation
- Fix NullPointer exception dure to recent refactoring and also avoid redundant 
Lop creation

After this fix (and also reducing `SPARSITY_TURN_POINT to 0.0001`), the Lenet 
script runs successfully on GPU backend without singlenode flag.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/c9b6f02e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/c9b6f02e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/c9b6f02e

Branch: refs/heads/master
Commit: c9b6f02e733d7def53fa7b5bc281d094acc8a76d
Parents: 02ff200
Author: Niketan Pansare <[email protected]>
Authored: Wed Jun 22 18:03:38 2016 -0700
Committer: Niketan Pansare <[email protected]>
Committed: Wed Jun 22 18:03:38 2016 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/hops/ConvolutionOp.java    | 107 +++++++++++++++++--
 .../controlprogram/context/GPUObject.java       |   3 +
 .../controlprogram/context/JCudaContext.java    |  12 ++-
 .../controlprogram/context/JCudaObject.java     |  92 +++++++++-------
 .../runtime/matrix/data/LibMatrixCUDA.java      |  59 ++++++----
 .../sysml/runtime/util/ConvolutionUtils.java    |  52 ++++++---
 .../java/org/apache/sysml/utils/Explain.java    |   3 +-
 .../java/org/apache/sysml/utils/Statistics.java |  18 +++-
 8 files changed, 260 insertions(+), 86 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c9b6f02e/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java 
b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
index 07a45b6..28b1f36 100644
--- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
+++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
@@ -94,10 +94,12 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                
                Lop ret = ConvolutionUtils.constructConvolutionLops(this, et);
                if(ret != null) {
+                       setLops(ret);
                        return ret;
                }
                ret = 
ConvolutionUtils.constructConvolutionBackwardDataLops(this, et);
                if(ret != null) {
+                       setLops(ret);
                        return ret;
                }
                
@@ -155,6 +157,49 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                this.op = op;
        }
        
+       public static Lop constructFusedConvolutionLops(ExecType et, 
+                       ArrayList<Hop> inputs, 
+                       ConvOp op, ConvolutionOp primaryOp,
+                       long rlen, long clen) throws HopsException, 
LopsException {
+               int expectedNumInputs = 13;
+               if(op == ConvOp.MAX_POOLING_BACKWARD 
+                               || op == ConvOp.DIRECT_CONV2D 
+                               || op == ConvOp.DIRECT_CONV2D_BACKWARD_FILTER
+                               || op == ConvOp.DIRECT_CONV2D_BACKWARD_DATA) {
+                       expectedNumInputs = 14;
+               }
+               
+               if(inputs.size() != expectedNumInputs) {
+                       throw new HopsException("Incorrect number of inputs for 
" + op.name());
+               }
+               
+               Lop in = inputs.get(0).constructLops();
+               ConvolutionTransform transform1 = new ConvolutionTransform( in, 
+                               HopsConv2Lops.get(op), primaryOp.getDataType(), 
primaryOp.getValueType(), et, 1);
+               
+               // setOutputDimensions(transform1);
+               transform1.getOutputParameters().setDimensions(
+                               rlen, clen, primaryOp.getRowsInBlock(), 
primaryOp.getColsInBlock(), -1, primaryOp.getUpdateType());
+               
+               // setLineNumbers(transform1);
+               transform1.setAllPositions(primaryOp.getBeginLine(), 
primaryOp.getBeginColumn(), primaryOp.getEndLine(), primaryOp.getEndColumn());
+               
+               in.addOutput(transform1);
+               
+               // stride1, stride2, padding1, padding2  
+               // input_shape1, input_shape2, input_shape3, input_shape4, 
+               // filter_shape1, filter_shape2, filter_shape3, filter_shape4
+               for( int i=1; i < inputs.size(); i++ )
+               {
+                       Lop ltmp = inputs.get(i).constructLops();
+                       transform1.addInput(ltmp);
+                       //if(i == 1 && expectedNumInputs == 14)
+                               ltmp.addOutput(transform1);
+               }
+               transform1.setLevel(); //force order of added lops
+               return transform1;
+       }
+       
        public Lop constructConvolutionLops(ExecType et, ArrayList<Hop> inputs) 
throws HopsException, LopsException {
                int expectedNumInputs = 13;
                if(op == ConvOp.MAX_POOLING_BACKWARD 
@@ -257,13 +302,63 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                                break;
                        }
                        case IM2COL:
-                       case COL2IM: 
-                       case MAX_POOLING: 
+                       {
+                               ret = new long[3];
+                               ret[0] = getExtractedVal(params.C, params.R, 
params.S);
+                               ret[1] = getExtractedVal(params.N, params.P, 
params.Q);
+                               ret[2] = -1;
+                               break;
+                       }
+                       case COL2IM:
+                       {
+                               ret = new long[3];
+                               ret[0] = params.N;
+                               ret[1] = getExtractedVal(params.C, params.H, 
params.W);
+                               ret[2] = -1;
+                               break;
+                       }
+                       case MAX_POOLING:
+                       {
+                               ret = new long[3];
+                               ret[0] = params.N;
+                               ret[1] = getExtractedVal(params.C, params.P, 
params.Q);
+                               ret[2] = -1;
+                               break;
+                       }
                        case MAX_POOLING_BACKWARD:
-                       case DIRECT_CONV2D: 
-                       case DIRECT_CONV2D_BACKWARD_FILTER: 
+                       {
+                               ret = new long[3];
+                               ret[0] = params.N;
+                               ret[1] = getExtractedVal(params.C, params.H, 
params.W);
+                               ret[2] = -1;
+                               break;
+                       }
+                       case DIRECT_CONV2D:
+                       {
+                               ret = new long[3];
+                               ret[0] = params.N;
+                               ret[1] = getExtractedVal(params.K, params.P, 
params.Q);
+                               ret[2] = -1;
+                               break;
+                       }
+                       case DIRECT_CONV2D_BACKWARD_FILTER:
+                       {
+                               ret = new long[3];
+                               ret[0] = params.K;
+                               ret[1] = getExtractedVal(params.C, params.R, 
params.S);
+                               ret[2] = -1;
+                               break;
+                       }
                        case DIRECT_CONV2D_BACKWARD_DATA:
+                       {
+                               ret = new long[3];
+                               ret[0] = params.N;
+                               ret[1] = getExtractedVal(params.C, params.H, 
params.W);
+                               ret[2] = -1;
                                break;
+                       }
+                       default:
+                               throw new RuntimeException("Unsupported op:" + 
op.name());
                }
                
                return ret;
@@ -357,7 +452,7 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                return val1*val2;
        }
        
-       long getExtractedVal(long val1, long val2, long val3) {
+       public static long getExtractedVal(long val1, long val2, long val3) {
                if(val1 == -1 || val2 == -1 || val3 == -1) {
                        return -1;
                }
@@ -447,7 +542,7 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                }
        }
        
-       private long extractValue(Hop hop)  {
+       public static long extractValue(Hop hop)  {
                if(hop instanceof LiteralOp)
                        return (long) 
HopRewriteUtils.getDoubleValueSafe((LiteralOp)hop);
                return -1;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c9b6f02e/src/main/java/org/apache/sysml/runtime/controlprogram/context/GPUObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/GPUObject.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/GPUObject.java
index c175c9e..8037b8a 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/GPUObject.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/GPUObject.java
@@ -25,6 +25,7 @@ import java.util.concurrent.atomic.AtomicInteger;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.CacheException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.utils.Statistics;
 
 //FIXME merge JCudaObject into GPUObject to avoid unnecessary complexity
 //FIXME move to gpu instruction package
@@ -74,6 +75,8 @@ public abstract class GPUObject
                        throw new DMLRuntimeException("There is not enough 
memory on device for this matrix!");
                }
                
+               Statistics.cudaEvictionCount.addAndGet(1);
+               
                synchronized(evictionLock) {
                        Collections.sort(GPUContext.allocatedPointers, new 
Comparator<GPUObject>() {
        

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c9b6f02e/src/main/java/org/apache/sysml/runtime/controlprogram/context/JCudaContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/JCudaContext.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/JCudaContext.java
index 6c8f244..bfb823e 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/JCudaContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/JCudaContext.java
@@ -18,6 +18,8 @@
  */
 package org.apache.sysml.runtime.controlprogram.context;
 
+import java.util.concurrent.atomic.AtomicLong;
+
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.runtime.DMLRuntimeException;
@@ -53,7 +55,7 @@ public class JCudaContext extends GPUContext {
        public static boolean DEBUG = false;
        
        public static long totalNumBytes = 0;
-       public static long availableNumBytesWithoutUtilFactor = 0;
+       public static AtomicLong availableNumBytesWithoutUtilFactor = new 
AtomicLong(0);
        // Fraction of available memory to use. The available memory is 
computer when the JCudaContext is created
        // to handle the tradeoff on calling cudaMemGetInfo too often. 
        public static double GPU_MEMORY_UTILIZATION_FACTOR = 0.9; 
@@ -80,13 +82,13 @@ public class JCudaContext extends GPUContext {
                long total [] = { 0 };
                if(cudaMemGetInfo(free, total) == cudaSuccess) {
                        totalNumBytes = total[0];
-                       availableNumBytesWithoutUtilFactor = free[0];
+                       availableNumBytesWithoutUtilFactor.set(free[0]);
                }
                else {
                        throw new RuntimeException("ERROR: Unable to get memory 
information of the GPU.");
                }
                }
-               return (long) 
(availableNumBytesWithoutUtilFactor*GPU_MEMORY_UTILIZATION_FACTOR);
+               return (long) 
(availableNumBytesWithoutUtilFactor.get()*GPU_MEMORY_UTILIZATION_FACTOR);
        }
        
        
@@ -119,13 +121,13 @@ public class JCudaContext extends GPUContext {
         long total [] = { 0 };
         if(cudaMemGetInfo(free, total) == cudaSuccess) {
                totalNumBytes = total[0];
-               availableNumBytesWithoutUtilFactor = free[0];
+               availableNumBytesWithoutUtilFactor.set(free[0]);
         }
         else {
                throw new RuntimeException("ERROR: Unable to get memory 
information of the GPU.");
         }
         LOG.info("Total GPU memory: " + (totalNumBytes*(1e-6)) + " MB");
-        LOG.info("Available GPU memory: " + 
(availableNumBytesWithoutUtilFactor*(1e-6)) + " MB");
+        LOG.info("Available GPU memory: " + 
(availableNumBytesWithoutUtilFactor.get()*(1e-6)) + " MB");
        }
 
        @Override

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c9b6f02e/src/main/java/org/apache/sysml/runtime/controlprogram/context/JCudaObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/JCudaObject.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/JCudaObject.java
index 796c1e9..5d37909 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/JCudaObject.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/JCudaObject.java
@@ -36,6 +36,7 @@ import org.apache.sysml.utils.Statistics;
 public class JCudaObject extends GPUObject {
        
        public Pointer jcudaPointer = null;
+       public long numBytes;
 
        JCudaObject(MatrixObject mat2) {
                super(mat2);
@@ -44,24 +45,22 @@ public class JCudaObject extends GPUObject {
        private void prepare(boolean isInput, int numElemsToAllocate) throws 
DMLRuntimeException {
                if(jcudaPointer != null) {
                        // Already allocated on GPU and expected to be in sync
-                       // checkDimensions();
                }
                else {
-                       long GPUSize;
-                       if(numElemsToAllocate != -1)
-                               GPUSize = (Sizeof.DOUBLE) * (long) 
(numElemsToAllocate);
-                       else
-                               GPUSize = getSizeOnDevice();
-                       // Ensure enough memory while allocating the matrix
-                       if(GPUSize > getAvailableMemory()) {
-                               evict(GPUSize);
+                       if(isInput) {
+                               if(numElemsToAllocate != -1)
+                                       throw new DMLRuntimeException("Expected 
numElemsToAllocate to be -1 as it is inferred from the input");
+                               // Copy performs allocation
+                               copyFromHostToDevice();
                        }
-                       allocateMemoryOnDevice(numElemsToAllocate);
-                       synchronized(evictionLock) {
-                               GPUContext.allocatedPointers.add(this);
+                       else {
+                               // Don't copy just allocate
+                               
ensureFreeSpaceForDenseBlock(numElemsToAllocate);
+                               allocateMemoryOnDevice(numElemsToAllocate);
+                               synchronized(evictionLock) {
+                                       GPUContext.allocatedPointers.add(this);
+                               }
                        }
-                       if(isInput)
-                               copyFromHostToDevice();
                }
                numLocks.addAndGet(1);
        }
@@ -73,16 +72,6 @@ public class JCudaObject extends GPUObject {
                        throw new DMLRuntimeException("Expected device data to 
be allocated");
        }
        
-//     private void checkDimensions() throws DMLRuntimeException {
-//             if(LibMatrixCUDA.isInSparseFormat(mat))
-//                     throw new DMLRuntimeException("Sparse format not 
implemented");
-//             else {
-//                     if(mat.getNumRows()*mat.getNumColumns() != numElems) {
-//                             throw new DMLRuntimeException("The jcudaPointer 
and MatrixBlock is not in synched");
-//                     }
-//             }
-//     }
-       
        @Override
        public void acquireDenseDeviceModify(int numElemsToAllocate) throws 
DMLRuntimeException {
                prepare(false, numElemsToAllocate); 
@@ -135,15 +124,20 @@ public class JCudaObject extends GPUObject {
                                throw new DMLRuntimeException("Sparse format 
not implemented");
                        else if(numElemToAllocate == -1) {
                                // Called for dense input
-                               cudaMalloc(jcudaPointer,  
mat.getNumRows()*mat.getNumColumns()*Sizeof.DOUBLE);
+                               numBytes = 
mat.getNumRows()*mat.getNumColumns()*Sizeof.DOUBLE;
+                               cudaMalloc(jcudaPointer, numBytes);
+                               
JCudaContext.availableNumBytesWithoutUtilFactor.addAndGet(-numBytes);
                        }
                        else {
                                // Called for dense output
-                               cudaMalloc(jcudaPointer,  
numElemToAllocate*Sizeof.DOUBLE);
+                               numBytes = numElemToAllocate*Sizeof.DOUBLE;
+                               cudaMalloc(jcudaPointer,  numBytes);
+                               
JCudaContext.availableNumBytesWithoutUtilFactor.addAndGet(-numBytes);
                        }
                        
                        
Statistics.cudaAllocTime.addAndGet(System.nanoTime()-start);
                        Statistics.cudaAllocCount.addAndGet(1);
+
                }
                isAllocated = true;
        }
@@ -153,37 +147,53 @@ public class JCudaObject extends GPUObject {
                if(jcudaPointer != null) {
                        long start = System.nanoTime();
                        cudaFree(jcudaPointer);
+                       
JCudaContext.availableNumBytesWithoutUtilFactor.addAndGet(numBytes);
                        
Statistics.cudaDeAllocTime.addAndGet(System.nanoTime()-start);
                        Statistics.cudaDeAllocCount.addAndGet(1);
+                       
                }
                jcudaPointer = null;
                isAllocated = false;
                numLocks.set(0);
        }
        
+       void ensureFreeSpaceForDenseBlock(int numElem) throws 
DMLRuntimeException {
+               long GPUSize = (Sizeof.DOUBLE) * numElem;
+               if(GPUSize >= getAvailableMemory()) {
+                       evict(GPUSize);
+               }
+       }
+       
        @Override
        void copyFromHostToDevice() 
                throws DMLRuntimeException 
        {
-               if( jcudaPointer == null )
-                       throw new DMLRuntimeException("Cannot copy from host to 
device without allocating");
-               if( LibMatrixCUDA.isInSparseFormat(mat) )
-                       throw new DMLRuntimeException("Sparse format not 
implemented");
-               
                printCaller();
                long start = System.nanoTime();
                
                MatrixBlock tmp = mat.acquireRead();
-               double[] data = tmp.getDenseBlock();
-               
-               if( data == null && tmp.getSparseBlock() != null )
-                       throw new DMLRuntimeException("Incorrect sparsity 
calculation");
-               else if( data==null && tmp.getNonZeros() != 0 )
-                       throw new DMLRuntimeException("MatrixBlock is not 
allocated");
-               else if( tmp.getNonZeros() == 0 )
-                       data = new double[tmp.getNumRows()*tmp.getNumColumns()];
-               
-               cudaMemcpy(jcudaPointer, Pointer.to(data), 
mat.getNumRows()*mat.getNumColumns() * Sizeof.DOUBLE, cudaMemcpyHostToDevice);
+               if(tmp.isInSparseFormat()) {
+                       throw new DMLRuntimeException("Sparse matrix is not 
implemented");
+                       // tmp.sparseToDense();
+               }
+               else {
+                       double[] data = tmp.getDenseBlock();
+                       
+                       if( data == null && tmp.getSparseBlock() != null )
+                               throw new DMLRuntimeException("Incorrect 
sparsity calculation");
+                       else if( data==null && tmp.getNonZeros() != 0 )
+                               throw new DMLRuntimeException("MatrixBlock is 
not allocated");
+                       else if( tmp.getNonZeros() == 0 )
+                               data = new 
double[tmp.getNumRows()*tmp.getNumColumns()];
+                       
+                       // Copy dense block
+                       ensureFreeSpaceForDenseBlock(data.length);
+                       allocateMemoryOnDevice(data.length);
+                       synchronized(evictionLock) {
+                               GPUContext.allocatedPointers.add(this);
+                       }
+                       cudaMemcpy(jcudaPointer, Pointer.to(data), 
mat.getNumRows()*mat.getNumColumns() * Sizeof.DOUBLE, cudaMemcpyHostToDevice);
+               }
                
                mat.release();
                

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c9b6f02e/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index ad5110f..45f68dd 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -31,12 +31,13 @@ import static 
jcuda.jcudnn.JCudnn.cudnnDestroyTensorDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnGetConvolutionBackwardDataWorkspaceSize;
 import static 
jcuda.jcudnn.JCudnn.cudnnGetConvolutionBackwardFilterWorkspaceSize;
 import static jcuda.jcudnn.JCudnn.cudnnGetConvolutionForwardWorkspaceSize;
-import static jcuda.jcudnn.JCudnn.cudnnSetConvolutionNdDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnSetFilterNdDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnSetConvolution2dDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnSetFilter4dDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnSetTensor4dDescriptor;
 import static jcuda.jcudnn.cudnnConvolutionMode.CUDNN_CROSS_CORRELATION;
 import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
 import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
+import jcuda.jcudnn.cudnnConvolutionFwdPreference;
 import static jcuda.runtime.JCuda.cudaFree;
 import jcuda.Pointer;
 import jcuda.jcublas.JCublas;
@@ -56,6 +57,8 @@ public class LibMatrixCUDA {
        public static cudnnHandle cudnnHandle;
        public static cublasHandle cublasHandle;
        
+       private static int CONVOLUTION_PREFERENCE = 
cudnnConvolutionFwdPreference.CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
+       
        public static void conv2d(MatrixObject image, MatrixObject filter, 
MatrixObject outputBlock, int N, int C, int H, int W,
                        int K, int R, int S, int pad_h, int pad_w, int 
stride_h, int stride_w, int P, int Q)
                        throws DMLRuntimeException {
@@ -85,16 +88,38 @@ public class LibMatrixCUDA {
                        int strides [] = { stride_h, stride_w };
                        convDesc = allocateConvolutionDescriptor(padding, 
strides);
                        
-                       // TODO: Select the best algorithm depending on the 
data and supported CUDA
-                       int algo = 
jcuda.jcudnn.cudnnConvolutionFwdAlgo.CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
+                       // Select the best algorithm depending on the data and 
supported CUDA
                        
-                       long sizeInBytesArray[] = { 0 };
-            workSpace = new Pointer();
-            cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle, 
-                    srcTensorDesc, filterDesc, convDesc, dstTensorDesc, 
-                    algo, sizeInBytesArray);
-            
-                       alpha = pointerTo(1.0); // TODO
+                       int algo = -1; 
+                       workSpace = new Pointer();
+                       
+                       if(CONVOLUTION_PREFERENCE == 
cudnnConvolutionFwdPreference.CUDNN_CONVOLUTION_FWD_NO_WORKSPACE) {
+                               algo = 
jcuda.jcudnn.cudnnConvolutionFwdAlgo.CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
+                       }
+                       else if(CONVOLUTION_PREFERENCE == 
cudnnConvolutionFwdPreference.CUDNN_CONVOLUTION_FWD_PREFER_FASTEST) {
+                               int [] algos = {
+                               
jcuda.jcudnn.cudnnConvolutionFwdAlgo.CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
+                               
jcuda.jcudnn.cudnnConvolutionFwdAlgo.CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
+                               
jcuda.jcudnn.cudnnConvolutionFwdAlgo.CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
+                   };
+                               // TODO: Look into FFt, Winograd, etc
+                               // Also ensure that GPU has enough memory to 
allocate memory
+                               long sizeInBytesArray[] = { 0 };
+                   algo = 
jcuda.jcudnn.JCudnn.cudnnGetConvolutionForwardAlgorithm(cudnnHandle, 
srcTensorDesc, filterDesc, convDesc, dstTensorDesc,
+                               CONVOLUTION_PREFERENCE, sizeInBytesArray[0], 
algos);
+                   cudnnGetConvolutionForwardWorkspaceSize(cudnnHandle, 
srcTensorDesc, filterDesc, convDesc, dstTensorDesc, algo, sizeInBytesArray);
+                   if(sizeInBytesArray[0] != 0)
+                       jcuda.runtime.JCuda.cudaMalloc(workSpace, 
sizeInBytesArray[0]);
+                   sizeInBytes = sizeInBytesArray[0];
+                       }
+                       else if(CONVOLUTION_PREFERENCE == 
cudnnConvolutionFwdPreference.CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT) {
+                               throw new 
DMLRuntimeException("CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT is not 
implemented");
+                       }
+                       else {
+                               throw new DMLRuntimeException("Unsupported 
preference criteria for convolution");
+                       }
+                       
+                       alpha = pointerTo(1.0);
                        beta = pointerTo(0.0f);
                        int status = cudnnConvolutionForward(cudnnHandle, 
alpha, 
                                        srcTensorDesc, imagePointer, 
@@ -123,14 +148,12 @@ public class LibMatrixCUDA {
                        if(workSpace != null && sizeInBytes != 0)
                                cudaFree(workSpace);
                }
-       }
+       }       
        
        private static cudnnConvolutionDescriptor 
allocateConvolutionDescriptor(int padding [], int strides []) {
                cudnnConvolutionDescriptor convDesc = new 
cudnnConvolutionDescriptor();
                cudnnCreateConvolutionDescriptor(convDesc);
-               int upscale[] = { 1, 1 };
-               cudnnSetConvolutionNdDescriptor(convDesc, 2, padding, strides, 
upscale, 
-                               CUDNN_CROSS_CORRELATION, CUDNN_DATA_DOUBLE);
+               cudnnSetConvolution2dDescriptor(convDesc, padding[0], 
padding[1], strides[0], strides[1], 1, 1, CUDNN_CROSS_CORRELATION);             
  
                return convDesc;
        }
        
@@ -148,12 +171,10 @@ public class LibMatrixCUDA {
        private static cudnnFilterDescriptor allocateFilterDescriptor(int K, 
int C, int R, int S) {
                cudnnFilterDescriptor filterDesc = new cudnnFilterDescriptor();
                cudnnCreateFilterDescriptor(filterDesc);
-               int filterDim[] = { K, C, R, S };
-               cudnnSetFilterNdDescriptor(filterDesc, CUDNN_DATA_DOUBLE, 4, 
filterDim);
+               cudnnSetFilter4dDescriptor(filterDesc, CUDNN_DATA_DOUBLE, K, C, 
R, S);
                return filterDesc;
        }
-
-
+       
        public static void conv2d_backward_filter(MatrixObject image, 
MatrixObject dout,
                        MatrixObject outputBlock, int N, int C, int H, int W, 
int K, int R,
                        int S, int pad_h, int pad_w, int stride_h, int 
stride_w, int P,

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c9b6f02e/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java 
b/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java
index b2da5e2..c6ca53a 100644
--- a/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java
+++ b/src/main/java/org/apache/sysml/runtime/util/ConvolutionUtils.java
@@ -94,9 +94,15 @@ public class ConvolutionUtils {
                                        for(int i = 1; i < 
x_col.getInput().size(); i++) {
                                                
inputs.add(x_col.getInput().get(i));
                                        }
-                                       ConvolutionOp fusedHop = new 
ConvolutionOp("tmp_directconv2dBackwardFilter" + image.getName(), 
image.getDataType(), image.getValueType(), 
ConvOp.DIRECT_CONV2D_BACKWARD_FILTER, inputs);
-                                       setPositions(currentHop, fusedHop);
-                                       return 
fusedHop.constructConvolutionLops(et, inputs);
+                                       
+                                       // K, C * R * S
+                                       long K = 
ConvolutionOp.extractValue(inputs.get(10));
+                                       long C = 
ConvolutionOp.extractValue(inputs.get(7));
+                                       long R = 
ConvolutionOp.extractValue(inputs.get(12));
+                                       long S = 
ConvolutionOp.extractValue(inputs.get(13));
+                                       long rlen = K;
+                                       long clen = 
ConvolutionOp.getExtractedVal(C, R, S);
+                                       return 
ConvolutionOp.constructFusedConvolutionLops(et, inputs, 
ConvOp.DIRECT_CONV2D_BACKWARD_FILTER, (ConvolutionOp) x_col, rlen, clen);
                                }
                        }
                }
@@ -122,9 +128,26 @@ public class ConvolutionUtils {
                                        for(int i = 1; i < 
x_col.getInput().size(); i++) {
                                                
inputs.add(x_col.getInput().get(i));
                                        }
-                                       ConvolutionOp fusedHop = new 
ConvolutionOp("tmp_directconv2d" + image.getName(), image.getDataType(), 
image.getValueType(), ConvOp.DIRECT_CONV2D, inputs);
-                                       setPositions(currentHop, fusedHop);
-                                       return 
fusedHop.constructConvolutionLops(et, inputs);
+                                       
+                                       // N, K * P * Q
+                                       long N = 
ConvolutionOp.extractValue(inputs.get(6));
+                                       long H = 
ConvolutionOp.extractValue(inputs.get(8));
+                                       long W = 
ConvolutionOp.extractValue(inputs.get(9));
+                                       long K = 
ConvolutionOp.extractValue(inputs.get(10));
+                                       long R = 
ConvolutionOp.extractValue(inputs.get(12));
+                                       long S = 
ConvolutionOp.extractValue(inputs.get(13));
+                                       long stride_h = 
ConvolutionOp.extractValue(inputs.get(2));
+                                       long stride_w = 
ConvolutionOp.extractValue(inputs.get(3));
+                                       long pad_h = 
ConvolutionOp.extractValue(inputs.get(4));
+                                       long pad_w = 
ConvolutionOp.extractValue(inputs.get(5));
+                                       long P = -1; long Q = -1;
+                                       if(H > 0 && R > 0 && stride_h > 0 && 
pad_h > 0)
+                                               P = ConvolutionUtils.getP(H, R, 
stride_h, pad_h);
+                                       if(W > 0 && S > 0 && stride_w > 0 && 
pad_w > 0)
+                                               Q = ConvolutionUtils.getQ(W, S, 
stride_w, pad_w);
+                                       long rlen = N;
+                                       long clen = 
ConvolutionOp.getExtractedVal(K, P, Q);
+                                       return 
ConvolutionOp.constructFusedConvolutionLops(et, inputs, ConvOp.DIRECT_CONV2D, 
(ConvolutionOp) x_col, rlen, clen);
                                }
                        }
                }
@@ -152,9 +175,17 @@ public class ConvolutionUtils {
                                                for(int i = 1; i < 
rotate180.getInput().size(); i++) {
                                                        
inputs.add(rotate180.getInput().get(i));
                                                }
-                                               ConvolutionOp fusedHop = new 
ConvolutionOp("tmp_directconv2dBackwardData" + filter.getName(), 
filter.getDataType(), filter.getValueType(), 
ConvOp.DIRECT_CONV2D_BACKWARD_DATA, inputs);
-                                               setPositions(currentHop, 
fusedHop);
-                                               return 
fusedHop.constructConvolutionLops(et, inputs);
+                                               
+                                               // N, C * H * W
+                                               long N = 
ConvolutionOp.extractValue(inputs.get(6));
+                                               long C = 
ConvolutionOp.extractValue(inputs.get(7));
+                                               long H = 
ConvolutionOp.extractValue(inputs.get(8));
+                                               long W = 
ConvolutionOp.extractValue(inputs.get(9));
+                                               long rlen = N;
+                                               long clen = 
ConvolutionOp.getExtractedVal(C, H, W);
+                                               return 
ConvolutionOp.constructFusedConvolutionLops(et, inputs, 
ConvOp.DIRECT_CONV2D_BACKWARD_DATA, (ConvolutionOp) rotate180, rlen, clen);
+                                               
+                                               
                                        }
                                }
                        }
@@ -163,8 +194,5 @@ public class ConvolutionUtils {
                return null;
        }
        
-       private static void setPositions(Hop currentHop, Hop fusedHop) {
-               fusedHop.setAllPositions(currentHop.getBeginLine(), 
currentHop.getBeginColumn(), currentHop.getEndLine(), 
currentHop.getEndColumn());
-       }
        
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c9b6f02e/src/main/java/org/apache/sysml/utils/Explain.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Explain.java 
b/src/main/java/org/apache/sysml/utils/Explain.java
index 3a24786..8ee2822 100644
--- a/src/main/java/org/apache/sysml/utils/Explain.java
+++ b/src/main/java/org/apache/sysml/utils/Explain.java
@@ -63,6 +63,7 @@ import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyze
 import org.apache.sysml.runtime.instructions.Instruction;
 import org.apache.sysml.runtime.instructions.MRJobInstruction;
 import org.apache.sysml.runtime.instructions.cp.CPInstruction;
+import org.apache.sysml.runtime.instructions.gpu.GPUInstruction;
 import org.apache.sysml.runtime.instructions.spark.CSVReblockSPInstruction;
 import org.apache.sysml.runtime.instructions.spark.ReblockSPInstruction;
 import org.apache.sysml.runtime.instructions.spark.SPInstruction;
@@ -934,7 +935,7 @@ public class Explain
                String tmp = null;
                if( inst instanceof MRJobInstruction )
                        tmp = explainMRJobInstruction((MRJobInstruction)inst, 
level+1);
-               else if ( inst instanceof SPInstruction || inst instanceof 
CPInstruction)
+               else if ( inst instanceof SPInstruction || inst instanceof 
CPInstruction || inst instanceof GPUInstruction)
                        tmp = inst.toString();
                
                if( REPLACE_SPECIAL_CHARACTERS ){

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/c9b6f02e/src/main/java/org/apache/sysml/utils/Statistics.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java 
b/src/main/java/org/apache/sysml/utils/Statistics.java
index 1ecc62d..3af233e 100644
--- a/src/main/java/org/apache/sysml/utils/Statistics.java
+++ b/src/main/java/org/apache/sysml/utils/Statistics.java
@@ -114,6 +114,7 @@ public class Statistics
        public static AtomicLong cudaDeAllocCount = new AtomicLong(0);
        public static AtomicLong cudaToDevCount = new AtomicLong(0);
        public static AtomicLong cudaFromDevCount = new AtomicLong(0);
+       public static AtomicLong cudaEvictionCount = new AtomicLong(0);
        
        public static void incrementAllocationTime(long allocationTime, boolean 
isSparse) {
                if(isSparse)
@@ -378,6 +379,18 @@ public class Statistics
                
                denseBlockAllocationTime.set(0);
                sparseBlockAllocationTime.set(0);
+               
+               cudaInitTime = 0;
+               cudaLibrariesInitTime = 0;
+               cudaAllocTime.set(0);
+               cudaDeAllocTime.set(0);
+               cudaToDevTime.set(0);
+               cudaFromDevTime.set(0);
+               cudaAllocCount.set(0);
+               cudaDeAllocCount.set(0);
+               cudaToDevCount.set(0);
+               cudaFromDevCount.set(0);
+               cudaEvictionCount.set(0);
        }
        
        /**
@@ -632,11 +645,12 @@ public class Statistics
                                        + String.format("%.3f", 
cudaDeAllocTime.get()*1e-9) + "/"
                                        + String.format("%.3f", 
cudaToDevTime.get()*1e-9) + "/"
                                        + String.format("%.3f", 
cudaFromDevTime.get()*1e-9)  + " sec.\n");
-                       sb.append("GPU mem tx count 
(alloc/dealloc/toDev/fromDev):\t" 
+                       sb.append("GPU mem tx count 
(alloc/dealloc/toDev/fromDev/evict):\t" 
                                        + cudaAllocCount.get() + "/"
                                        + cudaDeAllocCount.get() + "/"
                                        + cudaToDevCount.get() + "/"
-                                       + cudaFromDevCount.get()  + ".\n");
+                                       + cudaFromDevCount.get() + "/"
+                                       + cudaEvictionCount.get() + ".\n");
                }
                
                //show extended caching/compilation statistics

Reply via email to