[SYSTEMML-701][SYSTEMML-702] Sparse matrix multiplication for GPU

Closes #196.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/58a95460
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/58a95460
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/58a95460

Branch: refs/heads/master
Commit: 58a954609c840f7101d0cbd700ed160d63ef1b6b
Parents: 7c12992
Author: Nakul Jindal <[email protected]>
Authored: Mon Aug 22 16:46:52 2016 -0700
Committer: Deron Eriksson <[email protected]>
Committed: Mon Aug 22 16:46:52 2016 -0700

----------------------------------------------------------------------
 .../controlprogram/caching/MatrixObject.java    |  10 +
 .../context/ExecutionContext.java               |  42 +-
 .../gpu/AggregateBinaryGPUInstruction.java      |  33 +-
 .../gpu/ConvolutionGPUInstruction.java          |  10 +-
 .../instructions/gpu/MMTSJGPUInstruction.java   |   2 +-
 .../instructions/gpu/context/GPUContext.java    |  23 +-
 .../instructions/gpu/context/GPUObject.java     |  48 +-
 .../instructions/gpu/context/JCudaContext.java  |  14 +
 .../instructions/gpu/context/JCudaObject.java   | 568 +++++++++++++++++--
 .../runtime/matrix/data/LibMatrixCUDA.java      | 510 +++++++++++++++--
 .../sysml/runtime/matrix/data/MatrixBlock.java  |  17 +-
 .../runtime/matrix/data/SparseBlockCOO.java     |  27 +
 .../runtime/matrix/data/SparseBlockCSR.java     |  82 ++-
 .../runtime/matrix/data/SparseBlockMCSR.java    |   8 +
 .../java/org/apache/sysml/utils/Statistics.java |   8 +-
 15 files changed, 1214 insertions(+), 188 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/58a95460/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java
index b18b9ee..fc5df81 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java
@@ -33,6 +33,8 @@ import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import 
org.apache.sysml.runtime.controlprogram.ParForProgramBlock.PDataPartitionFormat;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUObject;
 import org.apache.sysml.runtime.instructions.spark.data.RDDObject;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.MatrixDimensionsMetaData;
@@ -660,4 +662,12 @@ public class MatrixObject extends 
CacheableData<MatrixBlock>
                long newnnz = SparkExecutionContext.writeRDDtoHDFS(rdd, fname, 
oinfo);  
                ((MatrixDimensionsMetaData) 
_metaData).getMatrixCharacteristics().setNonZeros(newnnz);
        }
+       
+       /**
+        * Allocates the {@link GPUObject} which will be used
+        * to track the pointer on the GPU
+        */
+       public void allocateGPUObject(){
+               setGPUObject(GPUContext.createGPUObject(this));
+       }
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/58a95460/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
index 70a5b4f..1ad1525 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
@@ -46,6 +46,7 @@ import org.apache.sysml.runtime.instructions.cp.IntObject;
 import org.apache.sysml.runtime.instructions.cp.ScalarObject;
 import org.apache.sysml.runtime.instructions.cp.StringObject;
 import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUObject;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.MatrixDimensionsMetaData;
 import org.apache.sysml.runtime.matrix.MatrixFormatMetaData;
@@ -258,21 +259,46 @@ public class ExecutionContext
                                
((MatrixFormatMetaData)oldMetaData).getInputInfo()));
        }
        
-       public MatrixObject getMatrixOutputForGPUInstruction(String varName, 
boolean isSparse) 
+       public MatrixObject getDenseMatrixOutputForGPUInstruction(String 
varName) 
                throws DMLRuntimeException 
        {       
-               if(isSparse) {
-                       throw new DMLRuntimeException("Sparse matrix block is 
not supported for GPU instruction");
-               }
+               MatrixObject mo = allocateGPUMatrixObject(varName);
+               mo.getGPUObject().acquireDeviceModifyDense();
+               mo.getMatrixCharacteristics().setNonZeros(-1);
+               return mo;
+       }
+
+       /**
+        * Allocates a sparse matrix in CSR format on the GPU.
+        * Assumes that mat.getNumRows() returns a valid number
+        * @param varName
+        * @param nnz   number of non zeroes
+        * @return
+        * @throws DMLRuntimeException
+        */
+       public MatrixObject getSparseMatrixOutputForGPUInstruction(String 
varName, long nnz)
+               throws DMLRuntimeException
+       {
+               MatrixObject mo = allocateGPUMatrixObject(varName);
+               mo.getMatrixCharacteristics().setNonZeros(nnz);
+               mo.getGPUObject().acquireDeviceModifySparse();
+               return mo;
+       }
+       
+       /**
+        * Allocates the {@link GPUObject} for a given LOPS Variable (eg. 
_mVar3)
+        * @param varName
+        * @return
+        * @throws DMLRuntimeException
+        */
+       public MatrixObject allocateGPUMatrixObject(String varName) throws 
DMLRuntimeException {
                MatrixObject mo = getMatrixObject(varName);
                if( mo.getGPUObject() == null ) {
-                       mo.setGPUObject(GPUContext.createGPUObject(mo));
+                       mo.allocateGPUObject();
                }
-               
mo.getGPUObject().acquireDenseDeviceModify((int)(mo.getNumRows()*mo.getNumColumns()));
-               mo.getMatrixCharacteristics().setNonZeros(-1);
                return mo;
        }
-       
+
        public MatrixObject getMatrixInputForGPUInstruction(String varName) 
                        throws DMLRuntimeException 
        {       

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/58a95460/src/main/java/org/apache/sysml/runtime/instructions/gpu/AggregateBinaryGPUInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/AggregateBinaryGPUInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/AggregateBinaryGPUInstruction.java
index 9c413d0..5ed9b51 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/AggregateBinaryGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/AggregateBinaryGPUInstruction.java
@@ -84,33 +84,8 @@ public class AggregateBinaryGPUInstruction extends 
GPUInstruction
        
        @Override
        public void processInstruction(ExecutionContext ec) 
-               throws DMLRuntimeException
-       {       
-               // --------------------------------------
-               // This code will be removed when the JIRA SYSTEMML-702 is 
complete
-               // FIXME this code does not adhere to compiler memory budgets
-               if(     isSparse(ec, _input1.getName()) || isSparse(ec, 
_input2.getName())) {
-                       //get inputs
-                       MatrixBlock matBlock1 = 
ec.getMatrixInput(_input1.getName());
-               MatrixBlock matBlock2 = ec.getMatrixInput(_input2.getName());
-               
-               if(_isLeftTransposed) 
-                       matBlock1 = transpose(matBlock1);
-               if(_isRightTransposed) 
-                       matBlock2 = transpose(matBlock2);
-                       
-               //compute matrix multiplication
-               AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr;
-                       MatrixBlock soresBlock = (MatrixBlock) 
(matBlock1.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), 
ab_op));
-                               
-                       //release inputs/outputs
-                       ec.releaseMatrixInput(_input1.getName());
-                       ec.releaseMatrixInput(_input2.getName());
-                       ec.setMatrixOutput(_output.getName(), soresBlock);
-                       return;
-               }
-               // --------------------------------------
-               
+               throws DMLRuntimeException 
+       {
                Statistics.incrementNoOfExecutedGPUInst();
                
                AggregateBinaryOperator op = (AggregateBinaryOperator) _optr;
@@ -127,8 +102,8 @@ public class AggregateBinaryGPUInstruction extends 
GPUInstruction
         int clen = (int) (_isRightTransposed ? m2.getNumRows() : 
m2.getNumColumns());
         
         ec.setMetaData(_output.getName(), rlen, clen);
-        MatrixObject out = 
ec.getMatrixOutputForGPUInstruction(_output.getName(), false);
-        LibMatrixCUDA.matmult(m1, m2, out, _isLeftTransposed, 
_isRightTransposed);
+        //MatrixObject out = 
ec.getMatrixOutputForGPUInstruction(_output.getName(), false);
+        MatrixObject out = LibMatrixCUDA.matmult(ec, m1, m2, 
_output.getName(), _isLeftTransposed, _isRightTransposed);
         
                //release inputs/outputs
                ec.releaseMatrixInputForGPUInstruction(_input1.getName());

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/58a95460/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
index e489f1c..4626f71 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ConvolutionGPUInstruction.java
@@ -181,7 +181,7 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction
                                throw new DMLRuntimeException("Incorrect 
dimensions for filter in conv2d");
                        
                        ec.setMetaData(_output.getName(), N, K * P * Q);
-                       MatrixObject out = 
ec.getMatrixOutputForGPUInstruction(_output.getName(), false);
+                       MatrixObject out = 
ec.getDenseMatrixOutputForGPUInstruction(_output.getName());
                        LibMatrixCUDA.conv2d(image, filter, out, N, C, H, W,
                                        K, R, S, pad_h, pad_w, stride_h, 
stride_w, P, Q);
                }
@@ -197,7 +197,7 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction
                                                dout.getNumRows() + " != " +  N 
+ " || " + dout.getNumColumns() + " != " + K*P*Q);
                        
                        ec.setMetaData(_output.getName(), K, C * R * S);
-                       MatrixObject out = 
ec.getMatrixOutputForGPUInstruction(_output.getName(), false);
+                       MatrixObject out = 
ec.getDenseMatrixOutputForGPUInstruction(_output.getName());
                        LibMatrixCUDA.conv2d_backward_filter(image, dout, out, 
N, C, H, W,
                                        K, R, S, pad_h, pad_w, stride_h, 
stride_w, P, Q);
                        // TODO: For now always copy the device data to host
@@ -215,7 +215,7 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction
                                                dout.getNumRows() + " != " +  N 
+ " || " + dout.getNumColumns() + " != " + K*P*Q);
                        
                        ec.setMetaData(_output.getName(), N, C * H * W);
-                       MatrixObject out = 
ec.getMatrixOutputForGPUInstruction(_output.getName(), false);
+                       MatrixObject out = 
ec.getDenseMatrixOutputForGPUInstruction(_output.getName());
                        LibMatrixCUDA.conv2d_backward_data(filter, dout, out, 
N, C, H, W,
                                        K, R, S, pad_h, pad_w, stride_h, 
stride_w, P, Q);
                }
@@ -228,7 +228,7 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction
                                                image.getNumRows() + " != " +  
N + " || " + image.getNumColumns() + " != " + C*H*W);
                        
                        ec.setMetaData(_output.getName(), N, C * P * Q);
-                       MatrixObject out = 
ec.getMatrixOutputForGPUInstruction(_output.getName(), false);
+                       MatrixObject out = 
ec.getDenseMatrixOutputForGPUInstruction(_output.getName());
                        LibMatrixCUDA.maxpooling(image, out, N, C, H, W,
                                        K, R, S, pad_h, pad_w, stride_h, 
stride_w, P, Q);
                }
@@ -244,7 +244,7 @@ public class ConvolutionGPUInstruction extends 
GPUInstruction
                                                image.getNumRows() + " != " +  
N + " || " + image.getNumColumns() + " != " + K*P*Q);
                        
                        ec.setMetaData(_output.getName(), N, C * H * W);
-                       MatrixObject out = 
ec.getMatrixOutputForGPUInstruction(_output.getName(), false);
+                       MatrixObject out = 
ec.getDenseMatrixOutputForGPUInstruction(_output.getName());
                        LibMatrixCUDA.maxpooling_backward(image, dout, out, N, 
C, H, W,
                                        K, R, S, pad_h, pad_w, stride_h, 
stride_w, P, Q);
                }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/58a95460/src/main/java/org/apache/sysml/runtime/instructions/gpu/MMTSJGPUInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MMTSJGPUInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MMTSJGPUInstruction.java
index 4709085..9ecb93c 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MMTSJGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MMTSJGPUInstruction.java
@@ -105,7 +105,7 @@ public class MMTSJGPUInstruction extends GPUInstruction
 
                 //execute operations 
                 ec.setMetaData(_output.getName(), rlen, clen);
-                MatrixObject out = 
ec.getMatrixOutputForGPUInstruction(_output.getName(), false);
+                MatrixObject out = 
ec.getDenseMatrixOutputForGPUInstruction(_output.getName());
                 LibMatrixCUDA.matmultTSMM(mat, out, isLeftTransposed);
                 
                 ec.releaseMatrixInputForGPUInstruction(_input.getName());

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/58a95460/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
index 0127154..480bf35 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUContext.java
@@ -36,24 +36,21 @@ public abstract class GPUContext {
        
        public abstract long getAvailableMemory();
        
-       // Creation / Destruction of GPUContext and related handles
+       /**
+        * Creation / Destruction of GPUContext and related handles
+        * @return
+        */
        public static GPUContext createGPUContext() {
                if(currContext == null && DMLScript.USE_ACCELERATOR) {
-                       // TODO: Handle this thread and resolve concurrency 
related bugs if any
-                       new Thread(new Runnable() {
-                               @Override
-                               public void run() {
-                                       // Lazy GPU context creation
-                                       synchronized(isGPUContextCreated) {
-                                               currContext = new 
JCudaContext();
-                                               
OptimizerUtils.GPU_MEMORY_BUDGET = 
((JCudaContext)currContext).getAvailableMemory();
-                                               isGPUContextCreated = true;
-                                       }
-                               }
-                       }).start();
+                       synchronized(isGPUContextCreated) {
+                               currContext = new JCudaContext();
+                               OptimizerUtils.GPU_MEMORY_BUDGET = 
((JCudaContext)currContext).getAvailableMemory();
+                               isGPUContextCreated = true;
+                       }
                }
                return currContext;
        }
+       
        public static GPUObject createGPUObject(MatrixObject mo) {
                if(DMLScript.USE_ACCELERATOR) {
                        synchronized(isGPUContextCreated) {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/58a95460/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
index 33f7099..2a1350f 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java
@@ -26,6 +26,7 @@ import java.util.concurrent.atomic.AtomicLong;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.CacheException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.utils.Statistics;
 
 //FIXME merge JCudaObject into GPUObject to avoid unnecessary complexity
@@ -56,26 +57,61 @@ public abstract class GPUObject
        }
        
        public abstract void acquireDeviceRead() throws DMLRuntimeException;
-       public abstract void acquireDenseDeviceModify(int numElemsToAllocate) 
throws DMLRuntimeException;
+       /**
+        * To signal intent that a matrix block will be written to on the GPU
+        * @throws DMLRuntimeException
+        */
+       public abstract void acquireDeviceModifyDense() throws 
DMLRuntimeException;
+       /**
+        * To signal intent that a sparse matrix block will be written to on 
the GPU
+        * @throws DMLRuntimeException
+        */
+       public abstract void acquireDeviceModifySparse() throws 
DMLRuntimeException;
+       
+       /**
+        * If memory on GPU has been allocated from elsewhere, this method 
+        * updates the internal bookkeeping
+        * @param numBytes
+        */
+       public abstract void setDeviceModify(long numBytes);
+       
        public abstract void acquireHostRead() throws CacheException;
        public abstract void acquireHostModify() throws CacheException;
        public abstract void releaseInput() throws CacheException;
        public abstract void releaseOutput() throws CacheException;
        
        // package-level visibility as these methods are guarded by underlying 
GPUContext
+       /**
+        * Allocates memory on the GPU
+        * @param numElemToAllocate             number of elements in dense 
matrix, -1 for unknown or sparse matrix
+        * @throws DMLRuntimeException  
+        */
        abstract void allocateMemoryOnDevice(int numElemToAllocate) throws 
DMLRuntimeException;
        abstract void deallocateMemoryOnDevice() throws DMLRuntimeException;
        abstract long getSizeOnDevice() throws DMLRuntimeException;
+       
        abstract void copyFromHostToDevice() throws DMLRuntimeException;
-       abstract void copyFromDeviceToHost() throws DMLRuntimeException; // 
Called by export()
        
+       /**
+        * Copies a matrix block (dense or sparse) from GPU Memory to Host 
memory.
+        * A {@link MatrixBlock} instance is allocated, data from the GPU is 
copied in,
+        * the current one in Host memory is deallocated by calling {@link 
MatrixObject#acquireModify(MatrixBlock)}
+        * and overwritten with the newly allocated instance.
+        * TODO : re-examine this to avoid spurious allocations of memory for 
optimizations
+        * @throws DMLRuntimeException
+        */
+       abstract void copyFromDeviceToHost() throws DMLRuntimeException; // 
Called by export()
        
        /**
-        * It finds matrix toBeRemoved such that toBeRemoved.GPUSize is the 
smallest one whose size is greater than the eviction size
+        * Cycles through the sorted list of allocated {@link GPUObject} 
instances. Sorting is based on
+        * number of (read) locks that have been obtained on it (reverse 
order). It repeatedly frees up 
+        * blocks on which there are zero locks until the required size has 
been freed up.  
         * // TODO: update it with hybrid policy
-        * @return toBeRemoved
+        * @param GPUSize                               Desired size to be 
freed up on the GPU
+        * @throws DMLRuntimeException  If no blocks to free up or if not 
enough blocks with zero locks on them.         
+        * @return 
         */
-       protected void evict(final long GPUSize) throws DMLRuntimeException {
+       protected static void evict(final long GPUSize) throws 
DMLRuntimeException {
         if(GPUContext.allocatedPointers.size() == 0) {
                 throw new DMLRuntimeException("There is not enough memory on 
device for this matrix!");
         }
@@ -153,7 +189,7 @@ public abstract class GPUObject
        
        static Boolean evictionLock = new Boolean(true);
        
-       protected long getAvailableMemory() {
+       protected static long getAvailableMemory() {
                return GPUContext.currContext.getAvailableMemory();
        }
        

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/58a95460/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaContext.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaContext.java
index 708badc..2ef00d4 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaContext.java
@@ -29,13 +29,20 @@ import org.apache.sysml.utils.Statistics;
 import jcuda.driver.JCudaDriver;
 import jcuda.jcublas.JCublas2;
 import jcuda.jcublas.cublasHandle;
+import jcuda.jcublas.cublasPointerMode;
 import jcuda.jcudnn.JCudnn;
 import jcuda.runtime.JCuda;
 import jcuda.jcudnn.cudnnHandle;
+import jcuda.jcusparse.JCusparse;
+import jcuda.jcusparse.cusparseHandle;
+
 import static jcuda.jcudnn.JCudnn.cudnnCreate;
 import static jcuda.jcublas.JCublas2.cublasCreate;
+import static jcuda.jcublas.JCublas2.cublasSetPointerMode;
 import static jcuda.jcublas.JCublas2.cublasDestroy;
 import static jcuda.jcudnn.JCudnn.cudnnDestroy;
+import static jcuda.jcusparse.JCusparse.cusparseDestroy;
+import static jcuda.jcusparse.JCusparse.cusparseCreate;
 import static jcuda.driver.JCudaDriver.cuInit;
 import static jcuda.driver.JCudaDriver.cuDeviceGetCount;
 import static jcuda.runtime.JCuda.cudaMemGetInfo;
@@ -66,6 +73,7 @@ public class JCudaContext extends GPUContext {
                JCuda.setExceptionsEnabled(true);
                JCudnn.setExceptionsEnabled(true);
                JCublas2.setExceptionsEnabled(true);
+               JCusparse.setExceptionsEnabled(true);
                JCudaDriver.setExceptionsEnabled(true);
                cuInit(0); // Initialize the driver
                // Obtain the number of devices
@@ -115,6 +123,11 @@ public class JCudaContext extends GPUContext {
                cudnnCreate(LibMatrixCUDA.cudnnHandle);
                LibMatrixCUDA.cublasHandle = new cublasHandle();
                cublasCreate(LibMatrixCUDA.cublasHandle);
+               // For cublas v2, cublasSetPointerMode tells Cublas whether to 
expect scalar arguments on device or on host
+               // This applies to arguments like "alpha" in Dgemm, and "y" in 
Ddot.
+               // cublasSetPointerMode(LibMatrixCUDA.cublasHandle, 
cublasPointerMode.CUBLAS_POINTER_MODE_DEVICE); 
+               LibMatrixCUDA.cusparseHandle = new cusparseHandle();
+               cusparseCreate(LibMatrixCUDA.cusparseHandle);
                Statistics.cudaLibrariesInitTime = System.nanoTime() - start;
                
                long free [] = { 0 };
@@ -136,6 +149,7 @@ public class JCudaContext extends GPUContext {
                        synchronized(isGPUContextCreated) {
                                cudnnDestroy(LibMatrixCUDA.cudnnHandle);
                                cublasDestroy(LibMatrixCUDA.cublasHandle);
+                               cusparseDestroy(LibMatrixCUDA.cusparseHandle);
                                currContext = null;
                                isGPUContextCreated = false;
                        }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/58a95460/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaObject.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaObject.java
index 7f0b26b..88ebf46 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaObject.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/JCudaObject.java
@@ -18,45 +18,337 @@
  */
 package org.apache.sysml.runtime.instructions.gpu.context;
 
+import static jcuda.jcusparse.JCusparse.cusparseCreateMatDescr;
+import static jcuda.jcusparse.JCusparse.cusparseDcsr2dense;
+import static jcuda.jcusparse.JCusparse.cusparseDdense2csr;
+import static jcuda.jcusparse.JCusparse.cusparseDnnz;
+import static jcuda.jcusparse.JCusparse.cusparseSetMatIndexBase;
+import static jcuda.jcusparse.JCusparse.cusparseSetMatType;
+import static jcuda.jcusparse.JCusparse.cusparseSetPointerMode;
+import static jcuda.jcusparse.JCusparse.cusparseXcsrgemmNnz;
+import static jcuda.jcusparse.cusparseIndexBase.CUSPARSE_INDEX_BASE_ZERO;
+import static jcuda.jcusparse.cusparseMatrixType.CUSPARSE_MATRIX_TYPE_GENERAL;
 import static jcuda.runtime.JCuda.cudaFree;
 import static jcuda.runtime.JCuda.cudaMalloc;
 import static jcuda.runtime.JCuda.cudaMemcpy;
-import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyHostToDevice;
 import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToHost;
-import jcuda.Pointer;
-import jcuda.Sizeof;
+import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyHostToDevice;
 
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.CacheException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.SparseBlock;
+import org.apache.sysml.runtime.matrix.data.SparseBlockCOO;
+import org.apache.sysml.runtime.matrix.data.SparseBlockCSR;
+import org.apache.sysml.runtime.matrix.data.SparseBlockMCSR;
 import org.apache.sysml.utils.Statistics;
 
+import jcuda.Pointer;
+import jcuda.Sizeof;
+import jcuda.jcublas.JCublas2;
+import jcuda.jcublas.cublasHandle;
+import jcuda.jcublas.cublasOperation;
+import jcuda.jcusparse.JCusparse;
+import jcuda.jcusparse.cusparseDirection;
+import jcuda.jcusparse.cusparseHandle;
+import jcuda.jcusparse.cusparseMatDescr;
+import jcuda.jcusparse.cusparsePointerMode;
+
 public class JCudaObject extends GPUObject {
        
-       public Pointer jcudaPointer = null;
+       /**
+        * Compressed Sparse Row (CSR) format for CUDA
+        * Generalized matrix multiply is implemented for CSR format in the 
cuSparse library
+        */
+       public static class CSRPointer {
+               
+               public static cusparseMatDescr matrixDescriptor;
+               
+               /**
+                * @return Singleton default matrix descriptor object 
+                *                      (set with CUSPARSE_MATRIX_TYPE_GENERAL, 
CUSPARSE_INDEX_BASE_ZERO)
+                */
+               public static cusparseMatDescr 
getDefaultCuSparseMatrixDescriptor(){
+                       if (matrixDescriptor == null){
+                               // Code from JCuda Samples - 
http://www.jcuda.org/samples/JCusparseSample.java
+                               matrixDescriptor = new cusparseMatDescr();
+                               cusparseCreateMatDescr(matrixDescriptor);
+                               cusparseSetMatType(matrixDescriptor, 
CUSPARSE_MATRIX_TYPE_GENERAL);
+                               cusparseSetMatIndexBase(matrixDescriptor, 
CUSPARSE_INDEX_BASE_ZERO);
+                       }
+                       return matrixDescriptor;
+               }
+               
+               private static final double ULTRA_SPARSITY_TURN_POINT = 0.0004;
+
+               /**
+                * Default constructor to help with Factory method {@link 
#allocateCSRMatrix(long, long, long)}
+                */
+               private CSRPointer() {
+                       val = new Pointer();
+                       rowPtr = new Pointer();
+                       colInd = new Pointer();
+                       allocateMatDescrPointer();
+               }
+               
+               public long nnz;                /** Number of non zeroes        
                                                                        */
+               public Pointer val;             /** double array of non zero 
values                                                     */
+               public Pointer rowPtr;  /** integer array of start of all rows 
and end of last row + 1  */
+               public Pointer colInd;  /** integer array of nnz values' column 
indices                                 */
+               public cusparseMatDescr descr;  /** descriptor of matrix, only 
CUSPARSE_MATRIX_TYPE_GENERAL supported   */
+               
+               /** 
+                * Check for ultra sparsity
+                * @param rows
+                * @param cols
+                * @return
+                */
+               public boolean isUltraSparse(int rows, int cols) {
+                       double sp = ((double)nnz/rows/cols);
+                       return sp<ULTRA_SPARSITY_TURN_POINT;
+               }
+               
+               /**
+                * Initializes {@link #descr} to CUSPARSE_MATRIX_TYPE_GENERAL,
+                * the default that works for DGEMM.
+                */
+               private void allocateMatDescrPointer() {                        
+                       this.descr = getDefaultCuSparseMatrixDescriptor();
+               }
+               
+               /**
+                * Estimate the size of a CSR matrix in GPU memory
+                * Size of pointers is not needed and is not added in
+                * @param nnz2  number of non zeroes
+                * @param rows  number of rows 
+                * @return
+                */
+               public static long estimateSize(long nnz2, long rows) {
+                       long sizeofValArray = (Sizeof.DOUBLE) * nnz2;
+                       long sizeofRowPtrArray  = (Sizeof.INT) * (rows + 1);
+                       long sizeofColIndArray = (Sizeof.INT) * nnz2;
+                       long sizeofDescr = (Sizeof.INT) * 4;
+                       // From the CUSPARSE documentation, the 
cusparseMatDescr in native code is represented as: 
+                       // typedef struct {
+                       //      cusparseMatrixType_t MatrixType;
+                       //      cusparseFillMode_t FillMode;
+                       //      cusparseDiagType_t DiagType;
+                       //      cusparseIndexBase_t IndexBase;
+                       // } cusparseMatDescr_t;
+                       long tot = sizeofValArray + sizeofRowPtrArray + 
sizeofColIndArray + sizeofDescr;
+                       return tot;
+               }
+               
+               /** 
+                * Factory method to allocate an empty CSR Sparse matrix on the 
GPU
+                * @param nnz2  number of non-zeroes
+                * @param rows  number of rows
+                * @return a {@link CSRPointer} instance that encapsulates the 
CSR matrix on GPU
+                * @throws DMLRuntimeException 
+                */
+               public static CSRPointer allocateEmpty(long nnz2, long rows) 
throws DMLRuntimeException {
+                       CSRPointer r = new CSRPointer();
+                       r.nnz = nnz2;
+                       ensureFreeSpace(Sizeof.DOUBLE * nnz2 + Sizeof.INT * 
(rows + 1) + Sizeof.INT * nnz2);
+                       long t0 = System.nanoTime();
+                       cudaMalloc(r.val, Sizeof.DOUBLE * nnz2);
+                       cudaMalloc(r.rowPtr, Sizeof.INT * (rows + 1));
+                       cudaMalloc(r.colInd, Sizeof.INT * nnz2);
+                       
Statistics.cudaAllocTime.addAndGet(System.nanoTime()-t0);
+                       Statistics.cudaAllocCount.addAndGet(3);
+                       return r;
+               }
+               
+               /**
+                * Static method to copy a CSR sparse matrix from Host to Device
+                * @param dest  [input] destination location (on GPU)
+                * @param rows  number of rows
+                * @param nnz   number of non-zeroes
+                * @param rowPtr        integer array of row pointers
+                * @param colInd        integer array of column indices
+                * @param values        double array of non zero values
+                * @return a {@link CSRPointer} instance that encapsulates the 
CSR matrix on GPU
+                */
+               public static void copyToDevice(CSRPointer dest, int rows, long 
nnz, int[] rowPtr, int[] colInd, double[] values) {
+                       CSRPointer r = dest;
+                       long t0 = System.nanoTime();
+                       r.nnz = nnz;
+                       cudaMemcpy(r.rowPtr, Pointer.to(rowPtr), (rows + 1) * 
Sizeof.INT, cudaMemcpyHostToDevice);
+                       cudaMemcpy(r.colInd, Pointer.to(colInd), nnz * 
Sizeof.INT, cudaMemcpyHostToDevice);
+                       cudaMemcpy(r.val, Pointer.to(values), nnz * 
Sizeof.DOUBLE, cudaMemcpyHostToDevice);
+                       
Statistics.cudaToDevTime.addAndGet(System.nanoTime()-t0);
+                       Statistics.cudaToDevCount.addAndGet(3);
+               }
+               
+               /**
+                * Static method to copy a CSR sparse matrix from Device to host
+                * @param src   [input] source location (on GPU)
+                * @param rows  [input] number of rows
+                * @param nnz   [input] number of non-zeroes
+                * @param rowPtr        [output] pre-allocated integer array of 
row pointers of size (rows+1)
+                * @param colInd        [output] pre-allocated integer array of 
column indices of size nnz
+                * @param values        [output] pre-allocated double array of 
values of size nnz
+                */
+               public static void copyToHost(CSRPointer src, int rows, long 
nnz, int[] rowPtr, int[] colInd, double[] values){
+                       CSRPointer r = src;
+                       long t0 = System.nanoTime();
+                       cudaMemcpy(Pointer.to(rowPtr), r.rowPtr, (rows + 1) * 
Sizeof.INT, cudaMemcpyDeviceToHost);
+                       cudaMemcpy(Pointer.to(colInd), r.colInd, nnz * 
Sizeof.INT, cudaMemcpyDeviceToHost);
+                       cudaMemcpy(Pointer.to(values), r.val, nnz * 
Sizeof.DOUBLE, cudaMemcpyDeviceToHost);
+                       
Statistics.cudaFromDevTime.addAndGet(System.nanoTime()-t0);
+                       Statistics.cudaFromDevCount.addAndGet(3);
+               }
+               
+               /**
+                * Estimates the number of non-zero elements from the result of 
a sparse matrix multiplication C = A * B
+                * and returns the {@link CSRPointer} to C with the appropriate 
GPU memory.
+                * @param handle        a valid {@link cusparseHandle}
+                * @param A                     Sparse Matrix A on GPU
+                * @param transA        'T' if A is to be transposed, 'N' 
otherwise
+                * @param B                     Sparse Matrix B on GPU
+                * @param transB        'T' if B is to be transposed, 'N' 
otherwise
+                * @param m                     Rows in A
+                * @param n                     Columns in B
+                * @param k                     Columns in A / Rows in B
+                * @return
+                * @throws DMLRuntimeException
+                */
+               public static CSRPointer 
allocateForMatrixMultiply(cusparseHandle handle, CSRPointer A, int transA, 
CSRPointer B, int transB, int m, int n, int k) 
+                               throws DMLRuntimeException{
+                       // Following the code example at 
http://docs.nvidia.com/cuda/cusparse/#cusparse-lt-t-gt-csrgemm and at
+                       // 
https://github.com/jcuda/jcuda-matrix-utils/blob/master/JCudaMatrixUtils/src/test/java/org/jcuda/matrix/samples/JCusparseSampleDgemm.java
+                       
+                       CSRPointer C = new CSRPointer();
+                       cusparseSetPointerMode(handle, 
cusparsePointerMode.CUSPARSE_POINTER_MODE_HOST);
+                       
+                       JCudaObject.ensureFreeSpace(Sizeof.INT * (m+1));
+                       long t0 = System.nanoTime();
+                       cudaMalloc(C.rowPtr, Sizeof.INT * (m+1));
+                       
Statistics.cudaAllocTime.addAndGet(System.nanoTime()-t0);
+                       Statistics.cudaAllocCount.addAndGet(1);
+                       int[] CnnzArray = { -1 };
+                       if (A.nnz >= Integer.MAX_VALUE || B.nnz >= 
Integer.MAX_VALUE) { 
+                               throw new DMLRuntimeException("Number of non 
zeroes is larger than supported by cuSparse"); 
+                       }
+                       cusparseXcsrgemmNnz(handle, transA, transB, m, n, k, 
+                                       A.descr, (int)A.nnz, A.rowPtr, 
A.colInd, 
+                                       B.descr, (int)B.nnz, B.rowPtr, 
B.colInd, 
+                                       C.descr, C.rowPtr, 
Pointer.to(CnnzArray));
+                       if (CnnzArray[0] != -1){
+                               C.nnz = CnnzArray[0];
+                       }
+                       else {
+                       int baseArray[] = { 0 };
+                       cudaMemcpy(Pointer.to(CnnzArray), 
C.rowPtr.withByteOffset(m * Sizeof.INT), 1 * Sizeof.INT, 
cudaMemcpyDeviceToHost);
+                   cudaMemcpy(Pointer.to(baseArray), C.rowPtr,                 
                                           1 * Sizeof.INT, 
cudaMemcpyDeviceToHost);
+                   C.nnz = CnnzArray[0] - baseArray[0];
+                       }
+                       JCudaObject.ensureFreeSpace(Sizeof.DOUBLE * C.nnz);
+                       long t1 = System.nanoTime();
+                       cudaMalloc(C.val, Sizeof.DOUBLE * C.nnz);
+                       
Statistics.cudaAllocTime.addAndGet(System.nanoTime()-t1);
+                       Statistics.cudaAllocCount.addAndGet(1);
+                       
+                       JCudaObject.ensureFreeSpace(Sizeof.INT * C.nnz);
+                       long t2 = System.nanoTime();
+                       cudaMalloc(C.colInd, Sizeof.INT * C.nnz);
+                       
Statistics.cudaAllocTime.addAndGet(System.nanoTime()-t2);
+                       Statistics.cudaAllocCount.addAndGet(1);
+                       
+                       return C;
+               }
+               
+               /**
+                * Copies this CSR matrix on the GPU to a dense row-major matrix
+                * on the GPU. This is a temporary matrix for operations such 
as 
+                * cusparseDcsrmv.
+                * Since the allocated matrix is temporary, bookkeeping is not 
updated.
+                * The called is responsible for calling "free" on the returned 
Pointer object
+                * @param cusparseHandle        a valid {@link cusparseHandle}
+                * @param cublasHandle          a valid {@link cublasHandle}
+                * @param rows          number of rows in this CSR matrix
+                * @param cols          number of columns in this CSR matrix
+                * @return                      A {@link Pointer} to the 
allocated dense matrix (in column-major format)
+                * @throws DMLRuntimeException
+                */
+               public Pointer toDenseMatrix(cusparseHandle cusparseHandle, 
cublasHandle cublasHandle, int rows, int cols) throws DMLRuntimeException {
+                       long size = rows * cols * Sizeof.DOUBLE;
+                       Pointer A = JCudaObject.allocate(size);
+                       cusparseDcsr2dense(cusparseHandle, rows, cols, descr, 
val, rowPtr, colInd, A, rows);
+                       // int[] alpha = { 1 };
+                       // int[] beta = { 1 };
+                       // Pointer C = JCudaObject.allocate(size);
+                       // Transpose the matrix to get a dense matrix
+                       // JCublas2.cublasDgeam(cublasHandle, 
cublasOperation.CUBLAS_OP_T, cublasOperation.CUBLAS_OP_N, cols, rows, 
Pointer.to(alpha), A, rows, Pointer.to(beta), new Pointer(), cols, C, cols);
+                       // cudaFree(A);
+                       // return C;
+                       return A;
+               }
+               
+               /**
+                * Calls cudaFree on the allocated {@link Pointer} instances
+                */
+               public void deallocate() {
+                       cudaFree(val);
+                       cudaFree(rowPtr);
+                       cudaFree(colInd);
+               }
+       };
+       
+       public Pointer jcudaDenseMatrixPtr = null;              /** Pointer to 
dense matrix */
+       public CSRPointer jcudaSparseMatrixPtr = null;  /** Pointer to sparse 
matrix */
+
        public long numBytes;
 
        JCudaObject(MatrixObject mat2) {
                super(mat2);
        }
        
-       private void prepare(boolean isInput, int numElemsToAllocate) throws 
DMLRuntimeException {
-               if(jcudaPointer != null) {
+       /**
+        * Allocates temporary space on the device.
+        * Does not update bookkeeping.
+        * The caller is responsible for freeing up after usage.
+        * @param size
+        * @return
+        * @throws DMLRuntimeException
+        */
+       public static Pointer allocate(long size) throws DMLRuntimeException{
+               Pointer A = new Pointer();
+               ensureFreeSpace(size);
+               long t0 = System.nanoTime();
+               cudaMalloc(A, size);
+               Statistics.cudaAllocTime.getAndAdd(System.nanoTime() - t0);
+               Statistics.cudaAllocCount.getAndAdd(1);
+               return A;
+       }
+       
+       /**
+        * Allocate necessary memory on the GPU for this {@link JCudaObject} 
instance.
+        * @param isInput if the block is input, isSparse argument is ignored
+        * @param isSparse if the block is sparse
+        * @throws DMLRuntimeException
+        */
+       private void prepare(boolean isInput, boolean isSparse) throws 
DMLRuntimeException {
+               if(jcudaDenseMatrixPtr != null || jcudaSparseMatrixPtr != null) 
{
                        // Already allocated on GPU and expected to be in sync
                }
                else {
                        if(isInput) {
-                               if(numElemsToAllocate != -1)
-                                       throw new DMLRuntimeException("Expected 
numElemsToAllocate to be -1 as it is inferred from the input");
-                               // Copy performs allocation
                                copyFromHostToDevice();
                        }
                        else {
                                // Don't copy just allocate
-                               
ensureFreeSpaceForDenseBlock(numElemsToAllocate);
-                               allocateMemoryOnDevice(numElemsToAllocate);
+                               if (isSparse){
+                                       long sparseSize = 
CSRPointer.estimateSize(mat.getNnz(), mat.getNumRows());
+                                       ensureFreeSpace(sparseSize);
+                                       allocateMemoryOnDevice(-1);
+                               } else {        // Dense block, size = numRows 
* numCols
+                                       int size = (int) (mat.getNumRows() * 
mat.getNumColumns());
+                                       ensureFreeSpace(Sizeof.DOUBLE * size);
+                                       allocateMemoryOnDevice(size);
+                               }
                                synchronized(evictionLock) {
                                        GPUContext.allocatedPointers.add(this);
                                }
@@ -67,14 +359,23 @@ public class JCudaObject extends GPUObject {
        
        @Override
        public void acquireDeviceRead() throws DMLRuntimeException {
-               prepare(true, -1);
+               prepare(true, false);
                if(!isAllocated) 
                        throw new DMLRuntimeException("Expected device data to 
be allocated");
        }
        
        @Override
-       public void acquireDenseDeviceModify(int numElemsToAllocate) throws 
DMLRuntimeException {
-               prepare(false, numElemsToAllocate); 
+       public void acquireDeviceModifyDense() throws DMLRuntimeException {
+               prepare(false, false); 
+               isDeviceCopyModified = true;
+               if(!isAllocated) 
+                       throw new DMLRuntimeException("Expected device data to 
be allocated");
+       }
+       
+       @Override
+       public void acquireDeviceModifySparse() throws DMLRuntimeException {
+               isInSparseFormat = true;
+               prepare(false, true);
                isDeviceCopyModified = true;
                if(!isAllocated) 
                        throw new DMLRuntimeException("Expected device data to 
be allocated");
@@ -156,21 +457,26 @@ public class JCudaObject extends GPUObject {
 
        @Override
        void allocateMemoryOnDevice(int numElemToAllocate) throws 
DMLRuntimeException {
-               if(jcudaPointer == null) {
+               if(jcudaDenseMatrixPtr == null && jcudaSparseMatrixPtr == null) 
{
                        long start = System.nanoTime();
-                       jcudaPointer = new Pointer();
-                       if(numElemToAllocate == -1 && 
LibMatrixCUDA.isInSparseFormat(mat))
-                               throw new DMLRuntimeException("Sparse format 
not implemented");
-                       else if(numElemToAllocate == -1) {
+                       if(numElemToAllocate == -1 && 
LibMatrixCUDA.isInSparseFormat(mat)) {
+                               jcudaSparseMatrixPtr = 
CSRPointer.allocateEmpty(mat.getNnz(), mat.getNumRows()); 
+                               numBytes = 
CSRPointer.estimateSize(mat.getNnz(), mat.getNumRows());
+                               
JCudaContext.availableNumBytesWithoutUtilFactor.addAndGet(-numBytes);
+                               isInSparseFormat = true;
+                               //throw new DMLRuntimeException("Sparse format 
not implemented");
+                       } else if(numElemToAllocate == -1) {
                                // Called for dense input
+                               jcudaDenseMatrixPtr = new Pointer();
                                numBytes = 
mat.getNumRows()*mat.getNumColumns()*Sizeof.DOUBLE;
-                               cudaMalloc(jcudaPointer, numBytes);
+                               cudaMalloc(jcudaDenseMatrixPtr, numBytes);
                                
JCudaContext.availableNumBytesWithoutUtilFactor.addAndGet(-numBytes);
                        }
                        else {
                                // Called for dense output
+                               jcudaDenseMatrixPtr = new Pointer();
                                numBytes = numElemToAllocate*Sizeof.DOUBLE;
-                               cudaMalloc(jcudaPointer,  numBytes);
+                               cudaMalloc(jcudaDenseMatrixPtr,  numBytes);
                                
JCudaContext.availableNumBytesWithoutUtilFactor.addAndGet(-numBytes);
                        }
                        
@@ -182,24 +488,42 @@ public class JCudaObject extends GPUObject {
        }
        
        @Override
+       public void setDeviceModify(long numBytes) {
+               this.numLocks.addAndGet(1);
+               this.numBytes = numBytes;
+               
JCudaContext.availableNumBytesWithoutUtilFactor.addAndGet(-numBytes);
+       }
+
+       @Override
        void deallocateMemoryOnDevice() {
-               if(jcudaPointer != null) {
+               if(jcudaDenseMatrixPtr != null) {
                        long start = System.nanoTime();
-                       cudaFree(jcudaPointer);
+                       cudaFree(jcudaDenseMatrixPtr);
                        
JCudaContext.availableNumBytesWithoutUtilFactor.addAndGet(numBytes);
                        
Statistics.cudaDeAllocTime.addAndGet(System.nanoTime()-start);
                        Statistics.cudaDeAllocCount.addAndGet(1);
-                       
                }
-               jcudaPointer = null;
+               if (jcudaSparseMatrixPtr != null) {
+                       long start = System.nanoTime();
+                       jcudaSparseMatrixPtr.deallocate();
+                       
JCudaContext.availableNumBytesWithoutUtilFactor.addAndGet(numBytes);
+                       
Statistics.cudaDeAllocTime.addAndGet(System.nanoTime()-start);
+                       Statistics.cudaDeAllocCount.addAndGet(1);
+               }
+               jcudaDenseMatrixPtr = null;
+               jcudaSparseMatrixPtr = null;
                isAllocated = false;
                numLocks.set(0);
        }
        
-       void ensureFreeSpaceForDenseBlock(int numElem) throws 
DMLRuntimeException {
-               long GPUSize = (Sizeof.DOUBLE) * numElem;
-               if(GPUSize >= getAvailableMemory()) {
-                       evict(GPUSize);
+       /** 
+        * Thin wrapper over {@link #evict(long)}
+        * @param size
+        * @throws DMLRuntimeException
+        */
+       static void ensureFreeSpace(long size) throws DMLRuntimeException {
+               if(size >= getAvailableMemory()) {
+                       evict(size);
                }
        }
        
@@ -212,7 +536,44 @@ public class JCudaObject extends GPUObject {
                
                MatrixBlock tmp = mat.acquireRead();
                if(tmp.isInSparseFormat()) {
-                       throw new DMLRuntimeException("Sparse matrix is not 
implemented");
+                       
+                       int rowPtr[] = null;
+                       int colInd[] = null;
+                       double[] values = null;
+                                       
+                       SparseBlock block = tmp.getSparseBlock();
+                       // CSR is the preferred format for cuSparse GEMM
+                       // Converts MCSR and COO to CSR
+                       SparseBlockCSR csrBlock = null;
+                       if (block instanceof SparseBlockCSR){ 
+                               csrBlock = (SparseBlockCSR)block;
+                       } else if (block instanceof SparseBlockCOO) {
+                               // TODO - should we do this on the GPU using 
cusparse<t>coo2csr() ?
+                               long t0 = System.nanoTime();
+                               SparseBlockCOO cooBlock = (SparseBlockCOO)block;
+                               csrBlock = new 
SparseBlockCSR((int)mat.getNumRows(), cooBlock.rowIndexes(), 
cooBlock.indexes(), cooBlock.values());
+                               
Statistics.cudaConversionTime.addAndGet(System.nanoTime() - t0);
+                               
Statistics.cudaConversionCount.incrementAndGet();
+                       } else if (block instanceof SparseBlockMCSR) {
+                               long t0 = System.nanoTime();
+                               SparseBlockMCSR mcsrBlock = 
(SparseBlockMCSR)block;
+                               csrBlock = new 
SparseBlockCSR(mcsrBlock.getRows(), (int)mcsrBlock.size());
+                               
Statistics.cudaConversionTime.addAndGet(System.nanoTime() - t0);
+                               
Statistics.cudaConversionCount.incrementAndGet();
+                       } else {
+                               throw new DMLRuntimeException("Unsupported 
sparse matrix format for CUDA operations");
+                       }
+                       rowPtr = csrBlock.rowPointers();
+                       colInd = csrBlock.indexes();
+                       values = csrBlock.values();     
+                       ensureFreeSpace(CSRPointer.estimateSize(mat.getNnz(), 
mat.getNumRows()));
+                       allocateMemoryOnDevice(-1);
+                       synchronized(evictionLock) {
+                               GPUContext.allocatedPointers.add(this);
+                       }
+                       CSRPointer.copyToDevice(jcudaSparseMatrixPtr, 
tmp.getNumRows(), tmp.getNonZeros(), rowPtr, colInd, values);
+                       
+                       // throw new DMLRuntimeException("Sparse matrix is not 
implemented");
                        // tmp.sparseToDense();
                }
                else {
@@ -226,12 +587,12 @@ public class JCudaObject extends GPUObject {
                                data = new 
double[tmp.getNumRows()*tmp.getNumColumns()];
                        
                        // Copy dense block
-                       ensureFreeSpaceForDenseBlock(data.length);
+                       ensureFreeSpace(Sizeof.DOUBLE * data.length);
                        allocateMemoryOnDevice(data.length);
                        synchronized(evictionLock) {
                                GPUContext.allocatedPointers.add(this);
                        }
-                       cudaMemcpy(jcudaPointer, Pointer.to(data), 
mat.getNumRows()*mat.getNumColumns() * Sizeof.DOUBLE, cudaMemcpyHostToDevice);
+                       cudaMemcpy(jcudaDenseMatrixPtr, Pointer.to(data), 
mat.getNumRows()*mat.getNumColumns() * Sizeof.DOUBLE, cudaMemcpyHostToDevice);
                }
                
                mat.release();
@@ -242,28 +603,48 @@ public class JCudaObject extends GPUObject {
 
        @Override
        protected void copyFromDeviceToHost() throws DMLRuntimeException {
-               if(jcudaPointer != null) {
+               if (jcudaDenseMatrixPtr != null && jcudaSparseMatrixPtr != 
null){
+                       throw new DMLRuntimeException("Invalid state : JCuda 
dense/sparse pointer are both allocated");
+               }
+               if(jcudaDenseMatrixPtr != null) {
                        printCaller();
-                       if(LibMatrixCUDA.isInSparseFormat(mat))
-                               throw new DMLRuntimeException("Sparse format 
not implemented");
-                       else {
-                               long start = System.nanoTime();
-                               MatrixBlock tmp = new 
MatrixBlock((int)mat.getNumRows(), (int)mat.getNumColumns(), false);
-                               tmp.allocateDenseBlock();
-                               double [] data = tmp.getDenseBlock();
-                               
-                               cudaMemcpy(Pointer.to(data), jcudaPointer, 
data.length * Sizeof.DOUBLE, cudaMemcpyDeviceToHost);
-
-                               tmp.recomputeNonZeros();
-                               mat.acquireModify(tmp);
-                               mat.release();
-                               
-                               
Statistics.cudaFromDevTime.addAndGet(System.nanoTime()-start);
-                               Statistics.cudaFromDevCount.addAndGet(1);
-                       }
+                       long start = System.nanoTime();
+                       MatrixBlock tmp = new 
MatrixBlock((int)mat.getNumRows(), (int)mat.getNumColumns(), false);
+                       tmp.allocateDenseBlock();
+                       double [] data = tmp.getDenseBlock();
+                       
+                       cudaMemcpy(Pointer.to(data), jcudaDenseMatrixPtr, 
data.length * Sizeof.DOUBLE, cudaMemcpyDeviceToHost);
+                       
+                       tmp.recomputeNonZeros();
+                       mat.acquireModify(tmp);
+                       mat.release();
+                       
+                       
Statistics.cudaFromDevTime.addAndGet(System.nanoTime()-start);
+                       Statistics.cudaFromDevCount.addAndGet(1);
+               }
+               else if (jcudaSparseMatrixPtr != null){
+                       printCaller();
+                       if(!LibMatrixCUDA.isInSparseFormat(mat))
+                               throw new DMLRuntimeException("Block not in 
sparse format on host yet the device sparse matrix pointer is not null");
+                       long start = System.nanoTime();
+                       
+                       int rows = (int) mat.getNumRows();
+                       int cols = (int) mat.getNumColumns();
+                       int nnz = (int) jcudaSparseMatrixPtr.nnz;
+                       int[] rowPtr = new int[rows + 1];
+                       int[] colInd = new int[nnz];
+                       double[] values = new double[nnz];
+                       CSRPointer.copyToHost(jcudaSparseMatrixPtr, rows, nnz, 
rowPtr, colInd, values);
+                       
+                       SparseBlockCSR sparseBlock = new SparseBlockCSR(rowPtr, 
colInd, values, nnz);
+                       MatrixBlock tmp = new MatrixBlock(rows, cols, nnz, 
sparseBlock);
+                       mat.acquireModify(tmp);
+                       mat.release();
+                       
Statistics.cudaFromDevTime.addAndGet(System.nanoTime()-start);
+                       Statistics.cudaFromDevCount.addAndGet(1);
                }
                else {
-                       throw new DMLRuntimeException("Cannot copy from device 
to host as JCuda pointer is not allocated");
+                       throw new DMLRuntimeException("Cannot copy from device 
to host as JCuda dense/sparse pointer is not allocated");
                }
                isDeviceCopyModified = false;
        }
@@ -271,14 +652,15 @@ public class JCudaObject extends GPUObject {
        @Override
        protected long getSizeOnDevice() throws DMLRuntimeException {
                long GPUSize = 0;
-               int rlen = (int) mat.getNumRows();
-               int clen = (int) mat.getNumColumns();
+               long rlen = mat.getNumRows();
+               long clen = mat.getNumColumns();
+               long nnz = mat.getNnz();
 
                if(LibMatrixCUDA.isInSparseFormat(mat)) {
-                       throw new DMLRuntimeException("Sparse format not 
implemented");
+                       GPUSize = CSRPointer.estimateSize(nnz, rlen);
                }
                else {
-                       GPUSize = (Sizeof.DOUBLE) * (long) (rlen * clen);
+                       GPUSize = (Sizeof.DOUBLE) * (rlen * clen);
                }
                return GPUSize;
        }
@@ -288,6 +670,10 @@ public class JCudaObject extends GPUObject {
                return str[str.length - 1] + "." + st.getMethodName();
        }
        
+       /**
+        * Convenience debugging method.
+        * Checks {@link JCudaContext#DEBUG} flag before printing to System.out
+        */
        private void printCaller() {
                if(JCudaContext.DEBUG) {
                        StackTraceElement[] st = 
Thread.currentThread().getStackTrace();
@@ -299,4 +685,74 @@ public class JCudaObject extends GPUObject {
                }
                        
        }
-}
\ No newline at end of file
+       
+       /**
+        * Convenience method to directly examine the Sparse matrix on GPU
+        */
+       public CSRPointer getSparseMatrixCudaPointer() {
+               return jcudaSparseMatrixPtr;
+       }
+       
+       /**
+        * Convenience method to directly set the sparse matrix on GPU
+        * Needed for operations like {@link 
JCusparse#cusparseDcsrgemm(cusparseHandle, int, int, int, int, int, 
cusparseMatDescr, int, Pointer, Pointer, Pointer, cusparseMatDescr, int, 
Pointer, Pointer, Pointer, cusparseMatDescr, Pointer, Pointer, Pointer)}
+        * @param jcudaSparseMatrixPtr
+        */
+       public void setSparseMatrixCudaPointer(CSRPointer jcudaSparseMatrixPtr) 
{
+               this.jcudaSparseMatrixPtr = jcudaSparseMatrixPtr;
+               this.isAllocated = true;
+               this.isInSparseFormat = true;
+       }
+       
+       public void setDenseMatrixCudaPointer(Pointer densePtr){
+               this.jcudaDenseMatrixPtr = densePtr;
+               this.isAllocated = true;
+               this.isInSparseFormat = false;
+       }
+       
+       /**
+        * Convenience method to convert a CSR matrix to a dense matrix on the 
GPU
+        * Since the allocated matrix is temporary, bookkeeping is not updated.
+        * Caller is responsible for deallocating memory on GPU.
+        * @param rows
+        * @param cols
+        * @param densePtr      [in] dense matrix pointer on the GPU in row 
major
+        * @return
+        * @throws DMLRuntimeException
+        */
+       public static CSRPointer denseToSparse(cusparseHandle cusparseHandle, 
int rows, int cols, Pointer densePtr) throws DMLRuntimeException {              
  
+               cusparseMatDescr matDescr = 
CSRPointer.getDefaultCuSparseMatrixDescriptor();
+               Pointer nnzPerRowPtr = new Pointer();
+               Pointer nnzTotalDevHostPtr = new Pointer();
+               
+               ensureFreeSpace((rows + 1) * Sizeof.INT);
+               
+               long t1 = System.nanoTime();
+               cudaMalloc(nnzPerRowPtr, cols * Sizeof.INT);
+               cudaMalloc(nnzTotalDevHostPtr, Sizeof.INT);
+               Statistics.cudaAllocTime.addAndGet(System.nanoTime() - t1);
+               Statistics.cudaAllocCount.addAndGet(2);         
+               
+               // Output is in dense vector format, convert it to CSR
+               cusparseDnnz(cusparseHandle, 
cusparseDirection.CUSPARSE_DIRECTION_ROW, rows, cols, matDescr, densePtr, rows, 
nnzPerRowPtr, nnzTotalDevHostPtr);
+       
+               int[] nnzC = {-1};
+               
+               long t2 = System.nanoTime();
+               cudaMemcpy(Pointer.to(nnzC), nnzTotalDevHostPtr, Sizeof.INT, 
cudaMemcpyDeviceToHost);
+               Statistics.cudaFromDevTime.addAndGet(System.nanoTime() - t2);
+               Statistics.cudaFromDevCount.addAndGet(2);               
+               
+               if (nnzC[0] == -1){
+                       throw new DMLRuntimeException("cusparseDnnz did not 
calculate the correct number of nnz from the sparse-matrix vector mulitply on 
the GPU");
+               }
+               
+               CSRPointer C = CSRPointer.allocateEmpty(nnzC[0], rows);         
+               cusparseDdense2csr(cusparseHandle, rows, cols, matDescr, 
densePtr, rows, nnzPerRowPtr, C.val, C.rowPtr, C.colInd);
+               
+               cudaFree(nnzPerRowPtr);
+               cudaFree(nnzTotalDevHostPtr);
+               
+               return C;
+       }
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/58a95460/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
index 6a25b49..07cfa0c 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java
@@ -22,50 +22,77 @@ package org.apache.sysml.runtime.matrix.data;
 import static jcuda.jcudnn.JCudnn.cudnnConvolutionBackwardData;
 import static jcuda.jcudnn.JCudnn.cudnnConvolutionBackwardFilter;
 import static jcuda.jcudnn.JCudnn.cudnnConvolutionForward;
-import static jcuda.jcudnn.JCudnn.cudnnPoolingForward;
-import static jcuda.jcudnn.JCudnn.cudnnPoolingBackward;
 import static jcuda.jcudnn.JCudnn.cudnnCreateConvolutionDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnCreateFilterDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnCreateTensorDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnCreatePoolingDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnCreateTensorDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnDestroyConvolutionDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnDestroyFilterDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnDestroyTensorDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnDestroyPoolingDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnDestroyTensorDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnGetConvolutionBackwardDataWorkspaceSize;
 import static 
jcuda.jcudnn.JCudnn.cudnnGetConvolutionBackwardFilterWorkspaceSize;
 import static jcuda.jcudnn.JCudnn.cudnnGetConvolutionForwardWorkspaceSize;
+import static jcuda.jcudnn.JCudnn.cudnnPoolingBackward;
+import static jcuda.jcudnn.JCudnn.cudnnPoolingForward;
 import static jcuda.jcudnn.JCudnn.cudnnSetConvolution2dDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnSetFilter4dDescriptor;
-import static jcuda.jcudnn.JCudnn.cudnnSetTensor4dDescriptor;
 import static jcuda.jcudnn.JCudnn.cudnnSetPooling2dDescriptor;
+import static jcuda.jcudnn.JCudnn.cudnnSetTensor4dDescriptor;
 import static jcuda.jcudnn.cudnnConvolutionMode.CUDNN_CROSS_CORRELATION;
 import static jcuda.jcudnn.cudnnDataType.CUDNN_DATA_DOUBLE;
-import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
 import static jcuda.jcudnn.cudnnPoolingMode.CUDNN_POOLING_MAX;
-import jcuda.jcudnn.cudnnConvolutionFwdPreference;
+import static jcuda.jcudnn.cudnnTensorFormat.CUDNN_TENSOR_NCHW;
+import static jcuda.jcusparse.JCusparse.cusparseDcsrgemm;
+import static jcuda.jcusparse.JCusparse.cusparseDcsrmv;
+import static jcuda.jcusparse.JCusparse.cusparseDdense2csr;
+import static jcuda.jcusparse.JCusparse.cusparseDnnz;
+import static 
jcuda.jcusparse.cusparseOperation.CUSPARSE_OPERATION_NON_TRANSPOSE;
+import static jcuda.jcusparse.cusparseOperation.CUSPARSE_OPERATION_TRANSPOSE;
+import static jcuda.runtime.JCuda.cudaFree;
 import static jcuda.runtime.JCuda.cudaMalloc;
+
+import static 
jcuda.jcusparse.cusparseOperation.CUSPARSE_OPERATION_NON_TRANSPOSE;
+import static jcuda.jcusparse.cusparseOperation.CUSPARSE_OPERATION_TRANSPOSE;
 import static jcuda.runtime.JCuda.cudaFree;
+import static jcuda.runtime.JCuda.cudaMemcpy;
+import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyDeviceToHost;
+import static jcuda.runtime.cudaMemcpyKind.cudaMemcpyHostToDevice;
+
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.gpu.context.JCudaObject;
+import 
org.apache.sysml.runtime.instructions.gpu.context.JCudaObject.CSRPointer;
+import org.apache.sysml.utils.Statistics;
+
 import jcuda.Pointer;
 import jcuda.Sizeof;
-import jcuda.jcublas.JCublas;
+import jcuda.jcublas.JCublas2;
+import jcuda.jcublas.cublasFillMode;
 import jcuda.jcublas.cublasHandle;
+import jcuda.jcublas.cublasOperation;
 import jcuda.jcudnn.cudnnConvolutionDescriptor;
+import jcuda.jcudnn.cudnnConvolutionFwdPreference;
 import jcuda.jcudnn.cudnnFilterDescriptor;
 import jcuda.jcudnn.cudnnHandle;
 import jcuda.jcudnn.cudnnPoolingDescriptor;
 import jcuda.jcudnn.cudnnTensorDescriptor;
-
-import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysml.runtime.instructions.gpu.context.JCudaObject;
+import jcuda.jcusparse.cusparseHandle;
+import java.util.Arrays;
 
 //FIXME move could to respective instructions, this is not a block library
 public class LibMatrixCUDA {
        
        public static cudnnHandle cudnnHandle;
        public static cublasHandle cublasHandle;
-       
+       public static cusparseHandle cusparseHandle;
+
+    private static final Log LOG = 
LogFactory.getLog(LibMatrixCUDA.class.getName());
+
        private static int CONVOLUTION_PREFERENCE = 
cudnnConvolutionFwdPreference.CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
        
        public static void conv2d(MatrixObject image, MatrixObject filter, 
MatrixObject outputBlock, int N, int C, int H, int W,
@@ -89,9 +116,9 @@ public class LibMatrixCUDA {
                        // (Pointer) gpuCtx.prepare(image, true, true);
                        // (Pointer) gpuCtx.prepare(filter, true, true);
                        
-                       Pointer imagePointer = 
((JCudaObject)image.getGPUObject()).jcudaPointer; 
-                       Pointer filterPointer = 
((JCudaObject)filter.getGPUObject()).jcudaPointer; 
-                       Pointer dstPointer = 
((JCudaObject)outputBlock.getGPUObject()).jcudaPointer; 
+                       Pointer imagePointer = 
((JCudaObject)image.getGPUObject()).jcudaDenseMatrixPtr; 
+                       Pointer filterPointer = 
((JCudaObject)filter.getGPUObject()).jcudaDenseMatrixPtr; 
+                       Pointer dstPointer = 
((JCudaObject)outputBlock.getGPUObject()).jcudaDenseMatrixPtr; 
                        
                        int padding [] = { pad_h, pad_w }; 
                        int strides [] = { stride_h, stride_w };
@@ -221,9 +248,9 @@ public class LibMatrixCUDA {
                        dwDesc = allocateFilterDescriptor(K, C, R, S);
                        
                        // Allocate data
-                       Pointer imagePointer = 
((JCudaObject)image.getGPUObject()).jcudaPointer; 
-                       Pointer doutPointer = 
((JCudaObject)dout.getGPUObject()).jcudaPointer; 
-                       Pointer dwPointer = 
((JCudaObject)outputBlock.getGPUObject()).jcudaPointer; 
+                       Pointer imagePointer = 
((JCudaObject)image.getGPUObject()).jcudaDenseMatrixPtr; 
+                       Pointer doutPointer = 
((JCudaObject)dout.getGPUObject()).jcudaDenseMatrixPtr; 
+                       Pointer dwPointer = 
((JCudaObject)outputBlock.getGPUObject()).jcudaDenseMatrixPtr; 
                        
                        alpha = pointerTo(1.0); // TODO
                        beta = pointerTo(0.0f);
@@ -281,7 +308,7 @@ public class LibMatrixCUDA {
        
            // Since CuBLAS expects inputs in column-major format,
            // reverse the order of matrix-multiplication and take care of 
dimension mismatch.      
-           char transa = isLeftTransposed ? 'N' : 'T';
+           int transa = isLeftTransposed ? cublasOperation.CUBLAS_OP_N : 
cublasOperation.CUBLAS_OP_T;
            // Note: the dimensions are swapped
            int m = (int) (isLeftTransposed ? left.getNumColumns() : 
left.getNumRows());
            int k = (int) (isLeftTransposed ? left.getNumRows() : 
left.getNumColumns());
@@ -289,8 +316,8 @@ public class LibMatrixCUDA {
            if(m == -1)
                    throw new DMLRuntimeException("Incorrect dimensions");
        
-           double alpha = 1.0d;
-           double beta = 0.0d;
+           double[] alpha = {1.0d};
+           double[] beta = {0.0d};
        
            int lda = (int) (isLeftTransposed ? m : k);
            int ldc = m;
@@ -300,57 +327,416 @@ public class LibMatrixCUDA {
            if(!output.getGPUObject().isAllocated)
                    throw new DMLRuntimeException("Output is not allocated:" + 
output.getGPUObject().isAllocated);
        
-           Pointer A = ((JCudaObject)left.getGPUObject()).jcudaPointer;
-           Pointer C = ((JCudaObject)output.getGPUObject()).jcudaPointer;
+           Pointer A = ((JCudaObject)left.getGPUObject()).jcudaDenseMatrixPtr;
+           Pointer C = 
((JCudaObject)output.getGPUObject()).jcudaDenseMatrixPtr;
            
            //TODO: Fix it if there is a cuBLAS API to do flipping
-           JCublas.cublasDsyrk('U',transa, m, k, alpha, A, lda, beta, C, ldc);
-           JCublas.cublasDsyrk('L',transa, m, k, alpha, A, lda, beta, C, ldc);
+           
+           JCublas2.cublasDsyrk(cublasHandle, 
cublasFillMode.CUBLAS_FILL_MODE_UPPER,transa, m, k, Pointer.to(alpha), A, lda, 
Pointer.to(beta), C, ldc);
+           JCublas2.cublasDsyrk(cublasHandle, 
cublasFillMode.CUBLAS_FILL_MODE_LOWER,transa, m, k, Pointer.to(alpha), A, lda, 
Pointer.to(beta), C, ldc);
        }
        
-       public static void matmult(MatrixObject left1, MatrixObject right1, 
MatrixObject output, 
+       /**
+        * Matrix multiply on GPU
+        * Examines sparsity and shapes and routes call to appropriate method
+        * from cuBLAS or cuSparse
+        * C = op(A) x op(B)
+        * @param ec                                    Current {@link 
ExecutionContext} instance
+        * @param left1                                 Matrix A
+        * @param right1                                Matrix B
+        * @param outputName                    Name of the output matrix C (in 
code generated after LOP layer)
+        * @param isLeftTransposed1             op for A, transposed or not
+        * @param isRightTransposed1    op for B, tranposed or not
+        * @return      output of matrix multiply
+        * @throws DMLRuntimeException
+        */
+       public static MatrixObject matmult(ExecutionContext ec, MatrixObject 
left1, MatrixObject right1, String outputName,
                        boolean isLeftTransposed1, boolean isRightTransposed1) 
throws DMLRuntimeException {
-               if(isInSparseFormat(left1) || isInSparseFormat(right1)) {
-                       throw new DMLRuntimeException("Sparse GPU matrix 
multiplication is not implemented");
+               
+               if(!left1.getGPUObject().isAllocated() || 
!right1.getGPUObject().isAllocated())
+                       throw new DMLRuntimeException("One of input is not 
allocated:" + left1.getGPUObject().isAllocated() + " " + 
right1.getGPUObject().isAllocated());
+               
+               boolean bothDense = !left1.getGPUObject().isInSparseFormat() && 
!right1.getGPUObject().isInSparseFormat();
+               boolean bothSparse = left1.getGPUObject().isInSparseFormat() && 
right1.getGPUObject().isInSparseFormat();
+               
+               MatrixObject output = ec.getMatrixObject(outputName);
+
+               if (bothDense) {                // Dense C = Dense A * Dense B
+                       // For both dense, do cuBLAS
+                       ec.getDenseMatrixOutputForGPUInstruction(outputName);   
// Allocated the dense output matrix
+                       denseDenseMatmult(output, left1, right1, 
isLeftTransposed1, isRightTransposed1);
+               }
+               else if (bothSparse){   // Sparse C = Sparse A * Sparse B
+                       ec.allocateGPUMatrixObject(outputName);
+                       bothSparseMatmult(output, left1, right1, 
isLeftTransposed1, isRightTransposed1);
+               }
+               else {  // Either of A or B is sparse, Sparse C = Sparse/Dense 
A * Dense/Sparse B
+                               // Convert the dense to sparse and use the 
cusparseDcsrgemm routine
+                       ec.allocateGPUMatrixObject(outputName);
+                       eitherSparseMatmult(output, left1, right1, 
isLeftTransposed1, isRightTransposed1);
                }
                
-               // Since CuBLAS expects inputs in column-major format,
-               // reverse the order of matrix-multiplication and take care of 
dimension mismatch.
-               MatrixObject left = right1; 
-               MatrixObject right = left1;
+               return output;
+       }
+       
+       /**
+        * One of the matrices is sparse, the other dense
+        * C = op(A) x op(B)
+        * @param output                                allocated output object 
for C on host to which GPU output will be attached
+        * @param left                                  Matrix A on host
+        * @param right                                 Matrix B on host
+        * @param isLeftTransposed              op for A, tranposed or not
+        * @param isRightTransposed             op for B, transposed or not
+        * @throws DMLRuntimeException
+        */
+       protected static void eitherSparseMatmult(MatrixObject output, 
MatrixObject left, MatrixObject right,
+                       boolean isLeftTransposed, boolean isRightTransposed) 
throws DMLRuntimeException {
+               
+               int transA = isLeftTransposed ? CUSPARSE_OPERATION_TRANSPOSE : 
CUSPARSE_OPERATION_NON_TRANSPOSE;
+               int transB = isRightTransposed ? CUSPARSE_OPERATION_TRANSPOSE : 
CUSPARSE_OPERATION_NON_TRANSPOSE;
+               
+               int m = (int) (isLeftTransposed ? left.getNumColumns() : 
left.getNumRows()) ;
+               int n = (int) (isRightTransposed ? right.getNumRows() : 
right.getNumColumns());
+               int k = (int) (isLeftTransposed ? left.getNumRows() :  
left.getNumColumns());
+               int k1 = (int) (isRightTransposed ? right.getNumColumns() : 
right.getNumRows());
+               if(k != k1) 
+                       throw new DMLRuntimeException("Dimension mismatch: " + 
k + " != " + k1);
+               
+               if(m == -1 || n == -1 || k == -1)
+                       throw new DMLRuntimeException("Incorrect dimensions");
+               
+               
+               if (left.getGPUObject().isInSparseFormat()) {   
+                       // Left sparse, right dense
+                       sparseDenseMatmult(output, left, right, 
isLeftTransposed, isRightTransposed, transA, transB, m, n, k);
+               } else {
+                       // Left dense, right sparse
+                       denseSparseMatmult(output, right, left, 
isLeftTransposed, isRightTransposed, transA, transB, m, n, k);
+               }
+       }
+       
+       /**
+        * C = op(A) * op(B) where A is dense and B is sparse
+        * If B is ultrasparse, A is converted to a sparse matrix and {@link 
#sparseSparseMatmult(MatrixObject, int, int, int, int, int, CSRPointer, 
CSRPointer)} is invoked
+        * otherwise B is converted to a dense matrix and {@link 
#denseDenseMatmult(MatrixObject, int, int, int, int, boolean, boolean, Pointer, 
Pointer)} is invoked.
+        * @param output
+        * @param right
+        * @param left
+        * @param isLeftTransposed
+        * @param isRightTransposed
+        * @param transA
+        * @param transB
+        * @param m
+        * @param n
+        * @param k
+        * @throws DMLRuntimeException
+        */
+       protected static void denseSparseMatmult(MatrixObject output, 
MatrixObject right, MatrixObject left,
+                       boolean isLeftTransposed, boolean isRightTransposed, 
int transA, int transB, int m, int n, int k)
+                       throws DMLRuntimeException {
+               // right sparse, left dense
+               CSRPointer B = 
((JCudaObject)right.getGPUObject()).jcudaSparseMatrixPtr;
+               Pointer ADense = 
((JCudaObject)left.getGPUObject()).jcudaDenseMatrixPtr;
+               if (B.isUltraSparse(k, n)){
+                       LOG.debug(" GPU Dense-Sparse Matrix Multiplication 
(Converted to Sparse-Sparse)");
+                       // Convert left to CSR and do cuSparse matmul
+                       long t0 = System.nanoTime();
+                       CSRPointer A = 
JCudaObject.denseToSparse(cusparseHandle, (int)left.getNumRows(), 
(int)right.getNumColumns(), ADense);
+                       
Statistics.cudaConversionTime.addAndGet(System.nanoTime() - t0);
+                       Statistics.cudaConversionCount.addAndGet(1);
+                       sparseSparseMatmult(output, transA, transB, m, n, k, A, 
B);
+                       A.deallocate();
+               } else {
+                       LOG.debug(" GPU Dense-Sparse Matrix Multiplication 
(Converted to Dense-Dense)");
+                       // Convert right to dense and do a cuBlas matmul
+                       // BDenseTransposed is a column major matrix
+                       // Note the arguments to denseDenseMatmult to 
accommodate for this.
+                       Pointer BDenseTransposed = 
B.toDenseMatrix(cusparseHandle, cublasHandle, (int)right.getNumRows(), 
(int)right.getNumColumns());
+                       output.getGPUObject().acquireDeviceModifyDense();       
// To allocate the dense matrix
+                       Pointer C = 
((JCudaObject)output.getGPUObject()).jcudaDenseMatrixPtr;           
+                       denseDenseMatmult(C, 
+                                       (int) left.getNumRows(), (int) 
left.getNumColumns(),
+                                       (int) right.getNumColumns(), (int) 
right.getNumRows(), 
+                                       isLeftTransposed, !isRightTransposed,
+                                       ADense, BDenseTransposed);
+                       cudaFree(BDenseTransposed);
+               }
+       }
+
+       /**
+        * * C = op(A) * op(B) where A is sparse and B is dense
+        * If A is ultrasparse, B is converted to a sparse matrix and {@link 
#sparseSparseMatmult(MatrixObject, int, int, int, int, int, CSRPointer, 
CSRPointer)} is invoked
+        * otherwise A is converted to a dense matrix and {@link 
#denseDenseMatmult(MatrixObject, int, int, int, int, boolean, boolean, Pointer, 
Pointer)} is invoked.
+        * @param output
+        * @param left
+        * @param right
+        * @param isLeftTransposed
+        * @param isRightTransposed
+        * @param transA
+        * @param transB
+        * @param m
+        * @param n
+        * @param k
+        * @throws DMLRuntimeException
+        */
+       protected static void sparseDenseMatmult(MatrixObject output, 
MatrixObject left, MatrixObject right,
+                       boolean isLeftTransposed, boolean isRightTransposed, 
int transA, int transB, int m, int n, int k)
+                       throws DMLRuntimeException {
+               CSRPointer A = 
((JCudaObject)left.getGPUObject()).jcudaSparseMatrixPtr;
+               Pointer BDense = 
((JCudaObject)right.getGPUObject()).jcudaDenseMatrixPtr;
+               
+               if (n == 1){    
+                       // Sparse Matrix - Dense Vector multiply
+                       LOG.debug(" GPU Sparse Matrix - Dense Vector Mutliply");
+                       sparseMatrixDenseVectorMult(output, A, BDense, transA, 
(int)left.getNumRows(), (int)left.getNumColumns());
+                       
+               } else {
+                       // Sparse Matrix Dense Matrix multiply
+                       if (A.isUltraSparse(m, k)){     
+                               LOG.debug(" GPU Sparse-Dense Matrix 
Multiplication (Converted to Sparse-Sparse)");
+                               // Convert right to CSR and do cuSparse matmul
+                               long t0 = System.nanoTime();
+                               CSRPointer B = 
JCudaObject.denseToSparse(cusparseHandle, (int)right.getNumRows(), 
(int)right.getNumColumns(), BDense);
+                               
Statistics.cudaConversionTime.addAndGet(System.nanoTime() - t0);
+                               Statistics.cudaConversionCount.addAndGet(1);
+                               sparseSparseMatmult(output, transA, transB, m, 
n, k, A, B);
+                               B.deallocate();
+                       } else {                                        
+                               LOG.debug(" GPU Sparse-Dense Matrix 
Multiplication (Converted to Dense-Dense)");
+                               // Convert left to dense and do a cuBlas matmul
+                               // ADenseTransposed is a column major matrix
+                               // Note the arguments to denseDenseMatmult to 
accommodate for this.
+                               Pointer ADenseTransposed = 
A.toDenseMatrix(cusparseHandle, cublasHandle, (int)left.getNumRows(), 
(int)left.getNumColumns());
+                               
output.getGPUObject().acquireDeviceModifyDense();       // To allocate the 
dense matrix
+                               Pointer C = 
((JCudaObject)output.getGPUObject()).jcudaDenseMatrixPtr;           
+                               denseDenseMatmult(C, 
+                                               (int) left.getNumColumns(), 
(int) left.getNumRows(),
+                                               (int) right.getNumRows(), (int) 
right.getNumColumns(), 
+                                               !isLeftTransposed, 
isRightTransposed,
+                                               ADenseTransposed, BDense);
+                               cudaFree(ADenseTransposed);
+                       }
+               }
+       }
+
+       /**
+        * C = op(A) x B
+        * A is a sparse matrix, B is a dense vector
+        * @param output        allocated output on the host, to which the GPU 
output C will be attached
+        * @param A                     sparse matrix A on the GPU
+        * @param B_dense       dense matrix/vector B on the GPU
+        * @param transA        op for A, tranposed or not
+        * @param m                     number of rows in A (not op(A))
+        * @param k                     number of cols in A or number of rows 
in B (not op(A) or op(B))
+        * @throws DMLRuntimeException
+        */
+       protected static void sparseMatrixDenseVectorMult(MatrixObject output, 
CSRPointer A, Pointer B_dense, int transA,
+                       int m, int k) throws DMLRuntimeException {
+               long size = m * Sizeof.DOUBLE;
+               if (transA == CUSPARSE_OPERATION_TRANSPOSE){
+                       size = k * Sizeof.DOUBLE;
+               }
+               Pointer C_dense = JCudaObject.allocate((int)size);
+               double[] alpha = { 1 };
+               double[] beta = { 0 };
+               cusparseDcsrmv(cusparseHandle, transA, m, k, (int)A.nnz, 
Pointer.to(alpha), A.descr, A.val, A.rowPtr, A.colInd, B_dense, 
Pointer.to(beta), C_dense);
+               
+               
((JCudaObject)(output.getGPUObject())).setDenseMatrixCudaPointer(C_dense);
+               output.getGPUObject().setDeviceModify(size);
+       }
+
+       /**
+        * Sparse C = Sparse op(A) * Sparse op(B)
+        * Reroutes call to sparse matrix-vector mult if needed
+        * @param output
+        * @param left
+        * @param right
+        * @param isLeftTransposed
+        * @param isRightTransposed
+        * @throws DMLRuntimeException
+        */
+       protected static void bothSparseMatmult(MatrixObject output, 
MatrixObject left, MatrixObject right,
+                       boolean isLeftTransposed, boolean isRightTransposed) 
throws DMLRuntimeException {
+               
+               int transA = isLeftTransposed ? CUSPARSE_OPERATION_TRANSPOSE : 
CUSPARSE_OPERATION_NON_TRANSPOSE;
+               int transB = isRightTransposed ? CUSPARSE_OPERATION_TRANSPOSE : 
CUSPARSE_OPERATION_NON_TRANSPOSE;
+               
+               int m = (int) (isLeftTransposed ? left.getNumColumns() : 
left.getNumRows()) ;
+               int n = (int) (isRightTransposed ? right.getNumRows() : 
right.getNumColumns());
+               int k = (int) (isLeftTransposed ? left.getNumRows() :  
left.getNumColumns());
+               int k1 = (int) (isRightTransposed ? right.getNumColumns() : 
right.getNumRows());
+               if(k != k1) 
+                       throw new DMLRuntimeException("Dimension mismatch: " + 
k + " != " + k1);
+               
+               if(m == -1 || n == -1 || k == -1)
+                       throw new DMLRuntimeException("Incorrect dimensions");
+                       
+               CSRPointer A = 
((JCudaObject)left.getGPUObject()).jcudaSparseMatrixPtr;
+               CSRPointer B = 
((JCudaObject)right.getGPUObject()).jcudaSparseMatrixPtr;
+               
+               // TODO if (m == 1) {   // Vector-matrix multiplication
+               
+               if (!isRightTransposed && right.getNumColumns() == 1){  // 
Matrix-Vector multiplication
+                       sparseMatrixVectorMult(output, transA, 
(int)left.getNumRows(), (int)left.getNumColumns(), (int)right.getNumRows(), A, 
B);
+               } else {                                                        
                                        // Matrix-Matrix multiplication
+                       sparseSparseMatmult(output, transA, transB, m, n, k, A, 
B);
+               }
+       }
+
+       /**
+        * Does a sparse matrix-vector multiply.
+        * C = op(A) x B, A is a sparse matrix, B is a sparse vector with 
numCols = 1.
+        * @param output        allocated output object C to which the GPU 
output matrix will be attached
+        * @param transA        if A is to be transposed or not (the op in 
op(A))
+        * @param m                     number of rows in A (not op(A))
+        * @param n                     number of cols in A (not op(A))
+        * @param k                     number of rows in B, (cols in B is 
assumed to be 1)             
+        * @param A                     left sparse matrix on GPU
+        * @param B                     right sparse vector on GPU
+        * @throws DMLRuntimeException
+        */
+       protected static void sparseMatrixVectorMult(MatrixObject output, int 
transA, int m, int n, int k,
+                       CSRPointer A, CSRPointer B) throws DMLRuntimeException {
+               LOG.debug(" GPU Sparse Matrix Sparse Vector Multiply (Converted 
to Sparse Matrix Dense Vector Multiply)");
+               Pointer BDenseVector = B.toDenseMatrix(cusparseHandle, 
cublasHandle, k, 1);
+               sparseMatrixDenseVectorMult(output, A, BDenseVector, transA, m, 
k);
+       }
+
+       /**
+        * Does a sparse-sparse Matrix multiply
+        * C = op(A) x op(B), A, B are sparse matrices
+        * @param output        allocated output object on host to which the 
GPU output matrix will be attached
+        * @param transA        op for A - to be transposed or not
+        * @param transB        op for B
+        * @param m                     number of rows in op(A)
+        * @param n                     number of cols in op(B)
+        * @param k                     number of cols in op(A) or rows in op(B)
+        * @param A                     left sparse matrix on GPU
+        * @param B                     right sparse matrix on GPU
+        * @throws DMLRuntimeException
+        */
+       protected static void sparseSparseMatmult(MatrixObject output, int 
transA, int transB, int m, int n, int k,
+                       CSRPointer A, CSRPointer B) throws DMLRuntimeException {
+               LOG.debug(" GPU Sparse-Sparse Matrix Multiply ");
+
+               CSRPointer C = 
CSRPointer.allocateForMatrixMultiply(cusparseHandle, A, transA, B, transB, m, 
n, k);
+               
((JCudaObject)output.getGPUObject()).setSparseMatrixCudaPointer(C);
+               long sizeOfC = CSRPointer.estimateSize(C.nnz, 
output.getNumRows());
+               output.getGPUObject().setDeviceModify(sizeOfC);
+               
+               cusparseDcsrgemm(cusparseHandle, transA, transB, m, n, k,
+                               A.descr, (int)A.nnz, A.val, A.rowPtr, A.colInd,
+                               B.descr, (int)B.nnz, B.val, B.rowPtr, B.colInd,
+                               C.descr, C.val, C.rowPtr, C.colInd);
+       }
+
+       /**
+        * Dense dense matrix multiply
+        * C = op(A) * op(B), A and B are dense matrices
+        * @param output                                output object C on host 
with GPU data allocated                         
+        * @param left1                                 left matrix A on host 
(in row-major order)
+        * @param right1                                right matrix B on host 
(in row-major order)
+        * @param isLeftTransposed1     op for A, transposed or not
+        * @param isRightTransposed1    op for B, transposed or not
+        * @return
+        * @throws DMLRuntimeException
+        */
+       protected static void denseDenseMatmult(MatrixObject output, 
MatrixObject left1, MatrixObject right1,
+                       boolean isLeftTransposed1, boolean isRightTransposed1) 
throws DMLRuntimeException {
+               
+               Pointer leftPtr = 
((JCudaObject)left1.getGPUObject()).jcudaDenseMatrixPtr;
+               Pointer rightPtr = 
((JCudaObject)right1.getGPUObject()).jcudaDenseMatrixPtr;
+               
+               int leftRows = (int) left1.getNumRows();
+               int leftCols = (int) left1.getNumColumns();
+               int rightRows = (int) right1.getNumRows();
+               int rightCols = (int) right1.getNumColumns();
+               Pointer C = 
((JCudaObject)output.getGPUObject()).jcudaDenseMatrixPtr;           
+               denseDenseMatmult(C, leftRows, leftCols, rightRows, rightCols, 
isLeftTransposed1, isRightTransposed1,
+                               leftPtr, rightPtr);
+       }
+
+       /**
+        * Dense-dense matrix multiply
+        * C = op(A) * op(B), A and B are dense matrices
+        * On the host, the matrices are in row-major format; cuBLAS expects 
them in column-major format.
+        * What we have as input is t(A) and t(B), t(X) = transpose of X.
+        * We do t(B) %*% t(A) to get t(C); 
+        * If we were to calculate t(t(C), we would get the resultant matrix C, 
but this would be in column-major format.
+        * What we really want is t(C). This we already have as the result of 
t(B) %*% t(A).
+        * @param output                        output allocated on GPU in 
column major format
+        * @param leftRows1                     number of rows in A
+        * @param leftCols1                     number of cols in A
+        * @param rightRows1            number of rows in B
+        * @param rightCols1            number of cols in B
+        * @param isLeftTransposed1             op for A, transposed or not
+        * @param isRightTransposed1    op for B, transposed or not
+        * @param leftPtr                       A allocated on the GPU in 
row-major format
+        * @param rightPtr                      B allocated on the GPU in 
row-major format
+        * @throws DMLRuntimeException
+        */
+       public static void denseDenseMatmult(Pointer output, int leftRows1, int 
leftCols1, int rightRows1,
+                       int rightCols1, boolean isLeftTransposed1, boolean 
isRightTransposed1, Pointer leftPtr, Pointer rightPtr)
+                       throws DMLRuntimeException {
+               
+               Pointer A = rightPtr;
+               Pointer B = leftPtr;
+               
+               int leftRows = rightCols1;
+               int leftCols = rightRows1;
+               int rightRows = leftCols1;
+               int rightCols = leftRows1;
+               
                boolean isLeftTransposed = isRightTransposed1; 
                boolean isRightTransposed = isLeftTransposed1; 
                
-               char transa = isLeftTransposed ? 'T' : 'N';
-               char transb = isRightTransposed ? 'T' : 'N';
                // Note: the dimensions are swapped
-               int m = (int) (isLeftTransposed ? left.getNumRows() : 
left.getNumColumns()) ;
-               int n = (int) (isRightTransposed ? right.getNumColumns() : 
right.getNumRows());
-               int k = (int) (isLeftTransposed ?  left.getNumColumns() : 
left.getNumRows());
-               int k1 = (int) (isRightTransposed ?  right.getNumRows() : 
right.getNumColumns());
+               int m = (int) (isLeftTransposed ? leftCols : leftRows) ;
+               int n = (int) (isRightTransposed ? rightRows : rightCols);
+               int k = (int) (isLeftTransposed ?  leftRows : leftCols);
+               int k1 = (int) (isRightTransposed ?  rightCols : rightRows);
                if(k != k1) 
                        throw new DMLRuntimeException("Dimension mismatch: " + 
k + " != " + k1);
                
                if(m == -1 || n == -1 || k == -1)
                        throw new DMLRuntimeException("Incorrect dimensions");
                
-               double alpha = 1;
-               double beta = 0;
+               double[] one = { 1 };
+               double[] zero = { 0 };
                
-               int lda = isLeftTransposed ?  k : m;
-               int ldb = isRightTransposed ? n : k;
+               //int lda = leftRows;
+               //int ldb = leftCols;
+        int lda = isLeftTransposed ?  k : m;
+        int ldb = isRightTransposed ? n : k;
                int ldc = m;
                
-               if(!left.getGPUObject().isAllocated() || 
!right.getGPUObject().isAllocated())
-                       throw new DMLRuntimeException("One of input is not 
allocated:" + left.getGPUObject().isAllocated() + " " + 
right.getGPUObject().isAllocated());
-               if(!output.getGPUObject().isAllocated())
-                       throw new DMLRuntimeException("Output is not 
allocated:" + output.getGPUObject().isAllocated());
-               
-               Pointer A = ((JCudaObject)left.getGPUObject()).jcudaPointer;
-               Pointer B = ((JCudaObject)right.getGPUObject()).jcudaPointer;
-               Pointer C = ((JCudaObject)output.getGPUObject()).jcudaPointer;
-               
-               JCublas.cublasDgemm(transa, transb, m, n, k, alpha, A, lda, B, 
ldb, beta, C, ldc);
+               int transa = isLeftTransposed ? cublasOperation.CUBLAS_OP_T : 
cublasOperation.CUBLAS_OP_N;
+               int transb = isRightTransposed ? cublasOperation.CUBLAS_OP_T : 
cublasOperation.CUBLAS_OP_N;
+
+               Pointer C = output;
+               if (m == 1 && n == 1){ 
+                       // Vector product
+                       LOG.debug(" GPU Dense-dense Vector Product");
+                       double[] result = {0};
+                       JCublas2.cublasDdot(cublasHandle, k, A, 1, B, 1, 
Pointer.to(result));
+                       // By default in CuBlas V2, cublas pointer mode is set 
to CUBLAS_POINTER_MODE_HOST.
+                       // This means that scalar values passed are on host (as 
opposed to on device).
+                       // The result is copied from the host back to the 
device so that the rest of 
+                       // infrastructure can treat it uniformly.
+                       cudaMemcpy(C, Pointer.to(result), 1 * Sizeof.DOUBLE, 
cudaMemcpyHostToDevice);
+               } else if (m == 1) {
+                       // Vector-matrix multiply
+                       LOG.debug(" GPU Dense Vector-Matrix Multiply");
+                       transb = isRightTransposed ? 
cublasOperation.CUBLAS_OP_N : cublasOperation.CUBLAS_OP_T;
+                       JCublas2.cublasDgemv(cublasHandle, transb, rightRows, 
rightCols, Pointer.to(one), B, ldb, A, 1, Pointer.to(zero), C, 1);
+               } else if (n == 1){
+                       // Matrix-vector multiply
+                       LOG.debug(" GPU Dense Matrix-Vector Multiply");
+                       JCublas2.cublasDgemv(cublasHandle, transa, leftRows, 
leftCols, Pointer.to(one), A, lda, B, 1, Pointer.to(zero), C, 1);
+               } else {
+                       LOG.debug(" GPU Dense-Dense Matrix Multiply ");
+                       JCublas2.cublasDgemm(cublasHandle, transa, transb, m, 
n, k, Pointer.to(one), A, lda, B, ldb, Pointer.to(zero), C, ldc);
+               }
        }
 
        public static void conv2d_backward_data(MatrixObject filter, 
MatrixObject dout,
@@ -373,10 +759,10 @@ public class LibMatrixCUDA {
                        dxDesc = allocateTensorDescriptor(N, C, H, W);
                        
                        // Allocate data
-                       Pointer w = 
((JCudaObject)filter.getGPUObject()).jcudaPointer; 
-                       Pointer dy = 
((JCudaObject)dout.getGPUObject()).jcudaPointer; 
-                       Pointer dx = 
((JCudaObject)output.getGPUObject()).jcudaPointer; 
-                       
+                       Pointer w = 
((JCudaObject)filter.getGPUObject()).jcudaDenseMatrixPtr; 
+                       Pointer dy = 
((JCudaObject)dout.getGPUObject()).jcudaDenseMatrixPtr; 
+                       Pointer dx = 
((JCudaObject)output.getGPUObject()).jcudaDenseMatrixPtr; 
+
                        alpha = pointerTo(1.0); // TODO
                        beta = pointerTo(0.0f);
                        
@@ -453,8 +839,8 @@ public class LibMatrixCUDA {
                        poolingDesc = allocatePoolingDescriptor(R, S, pad_h, 
pad_w, stride_h, stride_w);
                        
                        // Allocate data
-                       Pointer x = 
((JCudaObject)image.getGPUObject()).jcudaPointer; 
-                       Pointer y = 
((JCudaObject)outputBlock.getGPUObject()).jcudaPointer; 
+                       Pointer x = 
((JCudaObject)image.getGPUObject()).jcudaDenseMatrixPtr; 
+                       Pointer y = 
((JCudaObject)outputBlock.getGPUObject()).jcudaDenseMatrixPtr; 
                        
                        alpha = pointerTo(1.0);
                        beta = pointerTo(0.0f);
@@ -527,9 +913,9 @@ public class LibMatrixCUDA {
                        cudaMalloc(y, numBytes);
                        
                        // Allocate data
-                       Pointer x = 
((JCudaObject)image.getGPUObject()).jcudaPointer; 
-                       Pointer dx = 
((JCudaObject)outputBlock.getGPUObject()).jcudaPointer;
-                       Pointer dy = 
((JCudaObject)dout.getGPUObject()).jcudaPointer;
+                       Pointer x = 
((JCudaObject)image.getGPUObject()).jcudaDenseMatrixPtr;
+                       Pointer dx = 
((JCudaObject)outputBlock.getGPUObject()).jcudaDenseMatrixPtr;
+                       Pointer dy = 
((JCudaObject)dout.getGPUObject()).jcudaDenseMatrixPtr;
                        
                        alpha = pointerTo(1.0);
                        beta = pointerTo(0.0f);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/58a95460/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
index ccd28af..9d233fc 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
@@ -95,7 +95,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
        public static final double SPARSITY_TURN_POINT = 0.4;
        //sparsity threshold for ultra-sparse matrix operations (40nnz in a 
1kx1k block)
        public static final double ULTRA_SPARSITY_TURN_POINT = 0.00004; 
-       //default sparse block type: modified compressed sparse rows 
+       //default sparse block type: modified compressed sparse rows, CSR would 
best suited to use cuSparse 
        public static final SparseBlock.Type DEFAULT_SPARSEBLOCK = 
SparseBlock.Type.MCSR;
        //default sparse block type for update in place: compressed sparse rows 
to prevent serialization
        public static final SparseBlock.Type DEFAULT_INPLACE_SPARSEBLOCK = 
SparseBlock.Type.CSR;
@@ -153,6 +153,21 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                copy(that);
        }
        
+       /**
+        * Constructs a sparse {@link MatrixBlock} with a given instance of a 
{@link SparseBlock} 
+        * @param rows  number of rows
+        * @param cols  number of columns
+        * @param nnz   number of non zeroes
+        * @param sparseBlock
+        */
+       public MatrixBlock(int rows, int cols, long nnz, SparseBlock 
sparseBlock) {
+               this.rlen = rows;
+               this.clen = cols;
+               this.nonZeros = nnz;
+               this.sparse = true;
+               this.sparseBlock = sparseBlock;
+       }
+       
        public MatrixBlock(MatrixBlock that, SparseBlock.Type stype, boolean 
deep) {
                this(that.rlen, that.clen, that.sparse);
                

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/58a95460/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java
index 19407e1..670da9b 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java
@@ -648,4 +648,31 @@ public class SparseBlockCOO extends SparseBlock
                        throw new RuntimeException("SparseBlockCOOIterator is 
unsupported!");                   
                }               
        }
+       
+       /**
+        * Get raw access to underlying array of row indices
+        * For use in GPU code
+        * @return
+        */
+       public int[] rowIndexes() {
+               return _rindexes;
+       }
+       
+       /** 
+        * Get raw access to underlying array of column indices
+        * For use in GPU code
+        * @return
+        */
+       public int[] indexes() {
+               return _cindexes;
+       }
+       
+       /**
+        * Get raw access to underlying array of values
+        * For use in GPU code
+        * @return
+        */
+       public double[] values() {
+               return _values;
+       }
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/58a95460/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCSR.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCSR.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCSR.java
index b6bcb5f..ab4e665 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCSR.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCSR.java
@@ -60,6 +60,13 @@ public class SparseBlockCSR extends SparseBlock
                _size = 0;
        }
        
+       public SparseBlockCSR(int[] rowPtr, int[] colInd, double[] values, int 
nnz){
+               _ptr = rowPtr;
+               _indexes = colInd;
+               _values = values;
+               _size = nnz;
+       }
+       
        /**
         * Copy constructor sparse block abstraction. 
         */
@@ -105,6 +112,8 @@ public class SparseBlockCSR extends SparseBlock
        
        /**
         * Copy constructor old sparse row representation. 
+        * @param rows
+        * @param nnz number of non-zeroes
         */
        public SparseBlockCSR(SparseRow[] rows, int nnz)
        {
@@ -119,16 +128,52 @@ public class SparseBlockCSR extends SparseBlock
                        int alen = rows[i].size();
                        int[] aix = rows[i].indexes();
                        double[] avals = rows[i].values();
-                       for( int j=0; j<alen; j++ ) {
-                               _indexes[pos] = aix[j];
-                               _values[pos] = avals[j];
-                               pos++;
-                       }
+                       System.arraycopy(aix, 0, _indexes, pos, alen);
+                       System.arraycopy(avals, 0, _values, pos, alen);
+                       pos += alen;
                        _ptr[i+1]=pos;  
                }
        }
        
        /**
+        * Copy constructor for COO representation
+        * @param rowInd        row indices
+        * @param colInd        column indices
+        * @param values        non zero values
+        */
+       public SparseBlockCSR(int rows, int[] rowInd, int[] colInd, double[] 
values){
+               int nnz = values.length;
+               _ptr = new int[rows+1];
+               _indexes = Arrays.copyOf(colInd, colInd.length);
+               _values = Arrays.copyOf(values, values.length);
+               _size = nnz;
+               
+               for (int i=0; i<rows; i++){
+                       _ptr[i] = -1;
+               }
+               _ptr[rows] = nnz;
+               _ptr[0]    = 0;
+               
+               // Input Example -> rowInd = [0,0,1,1,2,2,2,4,4,5]
+               //                                                       
[0,1,2,3,4,5,6,7,8,9]
+               for (int i=nnz-1; i>=1; i--){
+                       _ptr[rowInd[i]] = i;
+               }
+               // Output Example -> _ptr = [0|2|_|4|7|9|nnz]
+               // _ = -1
+               
+               // Pad out the missing values
+               // Input example -> _ptr = [0|2|_|4|7|9|nnz]
+               for (int i=1; i<rows; i++){
+                       if (_ptr[i] == -1){
+                               _ptr[i] = _ptr[i-1];
+                       }
+               }
+               // Output example -> _ptr = [0|2|2|4|7|9|nnz]
+                               
+       }
+       
+       /**
         * Get the estimated in-memory size of the sparse block in CSR 
         * with the given dimensions w/o accounting for overallocation. 
         * 
@@ -646,4 +691,31 @@ public class SparseBlockCSR extends SparseBlock
                for( int i=rl; i<rlen+1; i++ )
                        _ptr[i]-=cnt;
        }
+       
+       /**
+        * Get raw access to underlying array of row pointers
+        * For use in GPU code
+        * @return
+        */
+       public int[] rowPointers() {
+               return _ptr;
+       }
+       
+       /** 
+        * Get raw access to underlying array of column indices
+        * For use in GPU code
+        * @return
+        */
+       public int[] indexes() {
+               return _indexes;
+       }
+       
+       /**
+        * Get raw access to underlying array of values
+        * For use in GPU code
+        * @return
+        */
+       public double[] values() {
+               return _values;
+       }
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/58a95460/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockMCSR.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockMCSR.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockMCSR.java
index 1c9f13d..56ef982 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockMCSR.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockMCSR.java
@@ -319,4 +319,12 @@ public class SparseBlockMCSR extends SparseBlock
                
                return sb.toString();
        }
+       
+       /**
+        * Helper function for MCSR -> {COO, CSR}
+        * @return the underlying array of {@link SparseRow}
+        */
+       public SparseRow[] getRows() {
+               return _rows;
+       }
 }


Reply via email to