Repository: systemml
Updated Branches:
  refs/heads/master 3acd94186 -> e624d149f


[SYSTEMML-540] Improve the performance of CP Convolution operators

- Support sparse bias_multiply.
- Allow JIT to optimize loops in maxpooling operations.
- Perform examSparsity to optimize the future sparse-enabled operations.
- Added script generation logic in Caffe2DML.

Closes #661.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e624d149
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e624d149
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e624d149

Branch: refs/heads/master
Commit: e624d149f7826fb0cd98bb3f32ada423acfaac66
Parents: 3acd941
Author: Niketan Pansare <[email protected]>
Authored: Wed Sep 13 11:08:20 2017 -0800
Committer: Niketan Pansare <[email protected]>
Committed: Wed Sep 13 12:08:20 2017 -0700

----------------------------------------------------------------------
 .../nn/test/compare_backends/test_conv2d.dml    |   3 +-
 scripts/nn/test/compare_backends/test_conv2d.sh |  22 +-
 .../compare_backends/test_conv2d_bwd_data.sh    |  22 +-
 .../compare_backends/test_conv2d_bwd_filter.sh  |  22 +-
 .../nn/test/compare_backends/test_maxpool.sh    |  22 +-
 .../org/apache/sysml/hops/ConvolutionOp.java    |  13 +-
 .../cp/ConvolutionCPInstruction.java            |   4 +-
 .../sysml/runtime/matrix/data/LibMatrixDNN.java |  68 +++--
 .../matrix/data/LibMatrixDNNConv2dHelper.java   |   6 +-
 .../runtime/matrix/data/LibMatrixDNNHelper.java |  27 +-
 .../data/LibMatrixDNNPoolingBackwardHelper.java |   5 +-
 .../matrix/data/LibMatrixDNNPoolingHelper.java  |  93 ++++---
 .../data/LibMatrixDNNRotate180Helper.java       |   3 +
 .../org/apache/sysml/api/dl/Caffe2DML.scala     | 258 ++++++++++---------
 .../org/apache/sysml/api/dl/DMLGenerator.scala  |  11 +-
 .../scala/org/apache/sysml/api/dl/Utils.scala   |   5 +
 src/test/scripts/functions/tensor/PoolTest.R    |   2 +-
 src/test/scripts/functions/tensor/PoolTest.dml  |   2 +-
 18 files changed, 304 insertions(+), 284 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/scripts/nn/test/compare_backends/test_conv2d.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/compare_backends/test_conv2d.dml 
b/scripts/nn/test/compare_backends/test_conv2d.dml
index b56a0ae..ea3bea2 100644
--- a/scripts/nn/test/compare_backends/test_conv2d.dml
+++ b/scripts/nn/test/compare_backends/test_conv2d.dml
@@ -19,7 +19,8 @@
 #
 #-------------------------------------------------------------
 
+fmt = ifdef($fmt, 'csv')
 X = read("input.mtx")
 w = read("filter.mtx")
 out = conv2d(X, w, input_shape=[$N,$C,$H,$W], filter_shape=[$F, $C, $Hf, $Wf], 
stride=[$stride,$stride], padding=[$pad,$pad])
-write(out, $out, format="csv")
+write(out, $out, format=fmt)

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/scripts/nn/test/compare_backends/test_conv2d.sh
----------------------------------------------------------------------
diff --git a/scripts/nn/test/compare_backends/test_conv2d.sh 
b/scripts/nn/test/compare_backends/test_conv2d.sh
index 4c578a6..205339c 100644
--- a/scripts/nn/test/compare_backends/test_conv2d.sh
+++ b/scripts/nn/test/compare_backends/test_conv2d.sh
@@ -20,27 +20,7 @@
 #
 #-------------------------------------------------------------
 
-jars='.'
-os_suffix='linux-x86_64'
-version='0.8.0'
-
-# Downloads the jcuda jars
-for lib in jcuda jcublas jcufft jcusparse jcusolver jcurand jnvgraph jcudnn
-do
-        file=$lib'-'$version'.jar'
-        if [ ! -f $file ]; then
-                
url='https://search.maven.org/remotecontent?filepath=org/jcuda/'$lib'/'$version'/'$file
-                wget -O $file $url
-        fi
-        jars=$jars','$file
-
-        file=$lib'-natives-'$version'-'$os_suffix'.jar'
-        if [ ! -f $file ]; then
-                
url='https://search.maven.org/remotecontent?filepath=org/jcuda/'$lib'-natives/'$version'/'$file
-                wget -O $file $url
-        fi
-        jars=$jars','$file
-done
+jars='systemml-*-extra.jar'
 
 # N = Number of images, C = number of channels, H = height, W = width
 # F = number of filters, Hf = filter height, Wf = filter width

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/scripts/nn/test/compare_backends/test_conv2d_bwd_data.sh
----------------------------------------------------------------------
diff --git a/scripts/nn/test/compare_backends/test_conv2d_bwd_data.sh 
b/scripts/nn/test/compare_backends/test_conv2d_bwd_data.sh
index da7b4f3..716560f 100644
--- a/scripts/nn/test/compare_backends/test_conv2d_bwd_data.sh
+++ b/scripts/nn/test/compare_backends/test_conv2d_bwd_data.sh
@@ -20,27 +20,7 @@
 #
 #-------------------------------------------------------------
 
-jars='.'
-os_suffix='linux-x86_64'
-version='0.8.0'
-
-# Downloads the jcuda jars
-for lib in jcuda jcublas jcufft jcusparse jcusolver jcurand jnvgraph jcudnn
-do
-        file=$lib'-'$version'.jar'
-        if [ ! -f $file ]; then
-                
url='https://search.maven.org/remotecontent?filepath=org/jcuda/'$lib'/'$version'/'$file
-                wget -O $file $url
-        fi
-        jars=$jars','$file
-
-        file=$lib'-natives-'$version'-'$os_suffix'.jar'
-        if [ ! -f $file ]; then
-                
url='https://search.maven.org/remotecontent?filepath=org/jcuda/'$lib'-natives/'$version'/'$file
-                wget -O $file $url
-        fi
-        jars=$jars','$file
-done
+jars='systemml-*-extra.jar'
 
 # N = Number of images, C = number of channels, H = height, W = width
 # F = number of filters, Hf = filter height, Wf = filter width

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/scripts/nn/test/compare_backends/test_conv2d_bwd_filter.sh
----------------------------------------------------------------------
diff --git a/scripts/nn/test/compare_backends/test_conv2d_bwd_filter.sh 
b/scripts/nn/test/compare_backends/test_conv2d_bwd_filter.sh
index f19fc73..99d3011 100644
--- a/scripts/nn/test/compare_backends/test_conv2d_bwd_filter.sh
+++ b/scripts/nn/test/compare_backends/test_conv2d_bwd_filter.sh
@@ -20,27 +20,7 @@
 #
 #-------------------------------------------------------------
 
-jars='.'
-os_suffix='linux-x86_64'
-version='0.8.0'
-
-# Downloads the jcuda jars
-for lib in jcuda jcublas jcufft jcusparse jcusolver jcurand jnvgraph jcudnn
-do
-        file=$lib'-'$version'.jar'
-        if [ ! -f $file ]; then
-                
url='https://search.maven.org/remotecontent?filepath=org/jcuda/'$lib'/'$version'/'$file
-                wget -O $file $url
-        fi
-        jars=$jars','$file
-
-        file=$lib'-natives-'$version'-'$os_suffix'.jar'
-        if [ ! -f $file ]; then
-                
url='https://search.maven.org/remotecontent?filepath=org/jcuda/'$lib'-natives/'$version'/'$file
-                wget -O $file $url
-        fi
-        jars=$jars','$file
-done
+jars='systemml-*-extra.jar'
 
 # N = Number of images, C = number of channels, H = height, W = width
 # F = number of filters, Hf = filter height, Wf = filter width

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/scripts/nn/test/compare_backends/test_maxpool.sh
----------------------------------------------------------------------
diff --git a/scripts/nn/test/compare_backends/test_maxpool.sh 
b/scripts/nn/test/compare_backends/test_maxpool.sh
index e8575e3..9d7da4a 100644
--- a/scripts/nn/test/compare_backends/test_maxpool.sh
+++ b/scripts/nn/test/compare_backends/test_maxpool.sh
@@ -20,27 +20,7 @@
 #
 #-------------------------------------------------------------
 
-jars='.'
-os_suffix='linux-x86_64'
-version='0.8.0'
-
-# Downloads the jcuda jars
-for lib in jcuda jcublas jcufft jcusparse jcusolver jcurand jnvgraph jcudnn
-do
-        file=$lib'-'$version'.jar'
-        if [ ! -f $file ]; then
-                
url='https://search.maven.org/remotecontent?filepath=org/jcuda/'$lib'/'$version'/'$file
-                wget -O $file $url
-        fi
-        jars=$jars','$file
-
-        file=$lib'-natives-'$version'-'$os_suffix'.jar'
-        if [ ! -f $file ]; then
-                
url='https://search.maven.org/remotecontent?filepath=org/jcuda/'$lib'-natives/'$version'/'$file
-                wget -O $file $url
-        fi
-        jars=$jars','$file
-done
+jars='systemml-*-extra.jar'
 
 # N = Number of images, C = number of channels, H = height, W = width
 N=5

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java 
b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
index 59ac29e..0ad9182 100644
--- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
+++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
@@ -219,8 +219,17 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
        @Override
        protected double computeOutputMemEstimate( long dim1, long dim2, long 
nnz )
        {               
-               double sparsity = 1.0;
-               return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 
sparsity);
+               if(getOp() == ConvOp.BIAS_MULTIPLY) {
+                       // in non-gpu mode, the worst case size of bias 
multiply operation is same as that of input.
+                       if(DMLScript.USE_ACCELERATOR) 
+                               return 
OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0);
+                       else
+                               return 
OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 
getInput().get(0).getSparsity());
+               }
+               else {
+                       double sparsity = 1.0;
+                       return OptimizerUtils.estimateSizeExactSparsity(dim1, 
dim2, sparsity);
+               }
        }
        
        // ---------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
index e91029e..df72f24 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
@@ -288,8 +288,8 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                }
                else {
                        // As we always fill the output first with bias
-                       outputBlock = new MatrixBlock(input.getNumRows(), 
input.getNumColumns(), false);
-                       outputBlock.allocateDenseBlock();
+                       outputBlock = new MatrixBlock(input.getNumRows(), 
input.getNumColumns(), input.isInSparseFormat());
+                       outputBlock.allocateDenseOrSparseBlock();
                        LibMatrixDNN.biasMultiply(input, bias, outputBlock, 
_numThreads);
                }
                

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index 30b8b64..40192de 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -167,7 +167,8 @@ public class LibMatrixDNN {
                execute(LibMatrixDNNHelper.getConv2dWorkers(params), params);
                
                //post-processing: maintain nnz
-               outputBlock.recomputeNonZeros();
+               outputBlock.recomputeNonZeros(); 
+               outputBlock.examSparsity();
        }
        
        /**
@@ -188,7 +189,8 @@ public class LibMatrixDNN {
                
execute(LibMatrixDNNHelper.getConv2dBackwardDataWorkers(params), params);
                
                //post-processing: maintain nnz
-               outputBlock.recomputeNonZeros();
+               outputBlock.recomputeNonZeros(); 
+               outputBlock.examSparsity();
        }
        
        /**
@@ -209,7 +211,8 @@ public class LibMatrixDNN {
                
execute(LibMatrixDNNHelper.getConv2dBackwardFilterWorkers(params), params);
                
                //post-processing: maintain nnz
-               outputBlock.recomputeNonZeros();
+               outputBlock.recomputeNonZeros(); 
+               outputBlock.examSparsity();
        }
        
        
@@ -338,7 +341,8 @@ public class LibMatrixDNN {
                execute(LibMatrixDNNHelper.getMaxPoolingBackwardWorkers(params, 
performReluBackward), params);
                
                //post-processing: maintain nnz 
-               outputBlock.recomputeNonZeros();
+               outputBlock.recomputeNonZeros(); 
+               outputBlock.examSparsity();
        }
        
        /**
@@ -391,7 +395,8 @@ public class LibMatrixDNN {
                execute(LibMatrixDNNHelper.getReluBackwardWorkers(params), 
params);
                
                // post-processing: maintain nnz
-               outputBlock.recomputeNonZeros();
+               outputBlock.recomputeNonZeros(); 
+               outputBlock.examSparsity();
        }
        
        /**
@@ -429,15 +434,17 @@ public class LibMatrixDNN {
                        double [] biasArr = bias.getDenseBlock();
                        for(int n = 0; n < N; n++) {
                                for(int k = 0; k < K; k++) {
+                                       double biasVal = biasArr[k];
                                        for(int pq = 0; pq < PQ; pq++, index++) 
{
-                                               outputArray[index] += 
biasArr[k];
+                                               outputArray[index] += biasVal;
                                        }
                                }
                        }
                }
                
                //post-processing: maintain nnz
-               outputBlock.recomputeNonZeros();
+               outputBlock.recomputeNonZeros(); 
+               outputBlock.examSparsity();
        }
        
        
@@ -469,22 +476,52 @@ public class LibMatrixDNN {
                
                if(!input.isEmptyBlock() && !bias.isEmptyBlock()) {
                        // Handles both dense and sparse inputs and copies it 
to dense output
-                       outputBlock.copy(input); 
-                       double [] outputArray = outputBlock.getDenseBlock();
-                       int index = 0;
+                       outputBlock.copy(input);
                        if(bias.isInSparseFormat())
                                bias.sparseToDense(); // Since bias is 
extremely small array
                        double [] biasArr = bias.getDenseBlock();
-                       for(int n = 0; n < N; n++) {
+                       if(!input.isInSparseFormat()) {
+                               double [] outputArray = 
outputBlock.getDenseBlock();
+                               int index = 0;
+                               for(int n = 0; n < N; n++) {
+                                       for(int k = 0; k < K; k++) {
+                                               double biasVal = biasArr[k];
+                                               for(int pq = 0; pq < PQ; pq++, 
index++) {
+                                                       outputArray[index] *= 
biasVal;
+                                               }
+                                       }
+                               }
+                       }
+                       else {
+                               // First delete those elements which will 
become zero 
                                for(int k = 0; k < K; k++) {
-                                       for(int pq = 0; pq < PQ; pq++, index++) 
{
-                                               outputArray[index] *= 
biasArr[k];
+                                       if(biasArr[k] == 0) {
+                                               for(int n = 0; n < N; n++) {
+                                                       
outputBlock.sparseBlock.deleteIndexRange(n, k*PQ, (k+1)*PQ);
+                                               }
+                                       }
+                               }
+                               // Then perform bias_multiply for non-zero bias 
entries
+                               for(int n = 0; n < N; n++) {
+                                       if( !outputBlock.sparseBlock.isEmpty(n) 
) {
+                                               int apos = 
outputBlock.sparseBlock.pos(n);
+                                               int alen = 
outputBlock.sparseBlock.size(n);
+                                               int[] aix = 
outputBlock.sparseBlock.indexes(n);
+                                               double[] avals = 
outputBlock.sparseBlock.values(n);
+                                               
+                                               for(int j=apos; j<apos+alen; 
j++) {
+                                                       // Since aix[j] => KPQ
+                                                       int k = aix[j] % PQ;
+                                                       if(biasArr[k] != 0)
+                                                               avals[j] *= 
biasArr[k];
+                                               }
                                        }
                                }
                        }
                        
                        //post-processing: maintain nnz
-                       params.output.recomputeNonZeros();
+                       params.output.recomputeNonZeros(); 
+                       params.output.examSparsity();
                }
                else {
                        params.output.setNonZeros(0);
@@ -504,7 +541,8 @@ public class LibMatrixDNN {
                execute(LibMatrixDNNHelper.getMaxPoolingWorkers(params), 
params);
                
                // post-processing: maintain nnz
-               outputBlock.recomputeNonZeros();
+               outputBlock.recomputeNonZeros(); 
+               outputBlock.examSparsity();
        }
        
        /**

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
index 4c3a3c3..66b2ed1 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNConv2dHelper.java
@@ -96,9 +96,10 @@ public class LibMatrixDNNConv2dHelper {
                                                        int alen = 
src.sparseBlock.size(k);
                                                        int[] aix = 
src.sparseBlock.indexes(k);
                                                        double[] avals = 
src.sparseBlock.values(k);
+                                                       int desPosK = destPos + 
k*PQ;
                                                        for(int j = apos; j < 
apos+alen; j++) {
                                                                int pqIndex = 
aix[j];
-                                                               dest[destPos + 
k*PQ + pqIndex ] += avals[j];
+                                                               dest[desPosK + 
pqIndex ] += avals[j];
                                                        }
                                                }
                                        }
@@ -174,9 +175,10 @@ public class LibMatrixDNNConv2dHelper {
                                                        int alen = 
src.sparseBlock.size(k);
                                                        int[] aix = 
src.sparseBlock.indexes(k);
                                                        double[] avals = 
src.sparseBlock.values(k);
+                                                       int desPosK = destPos + 
k*PQ;
                                                        for(int j = apos; j < 
apos+alen; j++) {
                                                                int pqIndex = 
aix[j];
-                                                               dest[destPos + 
k*PQ + pqIndex ] = avals[j];
+                                                               dest[desPosK + 
pqIndex ] = avals[j];
                                                        }
                                                }
                                        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
index ab96a8e..0550a98 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNHelper.java
@@ -265,9 +265,12 @@ public class LibMatrixDNNHelper {
                        double [] outputArr = mb.getDenseBlock();
                        if(filter != null) {
                                for(int k = 0; k < _params.K; k++) {
+                                       int outOffset = k*RS;
+                                       int filterOffset = k*CRS + c*RS;
                                        for(int rs = 0; rs < RS; rs++) {
-                                               outputArr[k*RS + rs] = 
filter[k*CRS + c*RS + rs];
-                                               nnz += outputArr[k*RS + rs] != 
0 ? 1 : 0;
+                                               int outIndex = outOffset + rs;
+                                               outputArr[outIndex] = 
filter[filterOffset + rs];
+                                               nnz += outputArr[outIndex] != 0 
? 1 : 0;
                                        }
                                }
                        }
@@ -473,12 +476,16 @@ public class LibMatrixDNNHelper {
                }
                else {
                        if(!input.isEmptyBlock()) {
+                               int outOffset = 
outputN*params.C*params.H*params.W;
+                               int HW = params.H*params.W;
                                int [] tensorIndexes = new int[3];
                                for(int i = 0; i < input.getNumRows(); i++) {
                                        if( !input.sparseBlock.isEmpty(i) ) {
                                                computeTensorIndexes(i, 
tensorIndexes, params.P, params.Q);
                                                int p = tensorIndexes[1];
                                                int q = tensorIndexes[2];
+                                               int tmpP = p*params.stride_h - 
params.pad_h;
+                                               int tmpQ = q*params.stride_w - 
params.pad_w;
                                                if(tensorIndexes[0] != 0) 
                                                        throw new 
DMLRuntimeException("Incorrect tensor indexes: " + tensorIndexes[0] + " != 0 <" 
+ p + " " + q + " " + tensorIndexes[0] + params.P + " " + params.Q + ">");
                                                
@@ -491,10 +498,10 @@ public class LibMatrixDNNHelper {
                                                        int c = 
tensorIndexes[0];
                                                        int r = 
tensorIndexes[1];
                                                        int s = 
tensorIndexes[2];
-                                                       int h = 
p*params.stride_h + r - params.pad_h;
-                                                       int w = 
q*params.stride_w + s - params.pad_w;
+                                                       int h = tmpP + r;
+                                                       int w = tmpQ + s;
                                                        if(h >= 0 && h < 
params.H && w >= 0 && w < params.W) {
-                                                               int outIndex = 
outputN*params.C*params.H*params.W + c*params.H*params.W + h*params.W + w;
+                                                               int outIndex = 
outOffset + c*HW + h*params.W + w;
                                                                
outputArray[outIndex] += avals[j];
                                                        }
                                                }
@@ -508,6 +515,10 @@ public class LibMatrixDNNHelper {
        // Or converts input: NPQ X CRS matrix and writes to N X CHW 
        private static void doCol2IMDenseInput(int inputN, int outputN, double 
[] inputArray, double [] outputArray, ConvolutionParameters params) throws 
DMLRuntimeException {
                final int outputNOffset = outputN*params.C*params.H*params.W;
+               final int HW = params.H*params.W;
+               final int inputNPQ = inputN*params.P*params.Q;
+               final int CRS = params.C*params.R*params.S;
+               final int RS = params.R*params.S;
                for (int p = 0; p < params.P; p++) {
                        // h = p*params.stride_h + r - params.pad_h
                        //   = r + hOffset
@@ -522,10 +533,10 @@ public class LibMatrixDNNHelper {
                                final int wOffset = q*params.stride_w - 
params.pad_w;
                                final int sStart = Math.max(0, - wOffset);
                                final int sEnd = Math.min(params.S, params.W - 
wOffset);
-                               final int tempOffset = 
(inputN*params.P*params.Q + p*params.Q + q)*params.C*params.R*params.S;
+                               final int tempOffset = (inputNPQ + p*params.Q + 
q)*CRS;
                                for (int c = 0; c < params.C; c++) {
-                                       final int outOffset = outputNOffset + 
c*params.H*params.W;
-                                       final int inputOffset = tempOffset + 
c*params.R*params.S;
+                                       final int outOffset = outputNOffset + 
c*HW;
+                                       final int inputOffset = tempOffset + 
c*RS;
                                        for (int r = rStart; r < rEnd; r++) {
                                                for (int s = sStart; s < sEnd; 
s++) {
                                                        int inputIndex = 
inputOffset + r*params.S + s;

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
index b400105..5b04e59 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingBackwardHelper.java
@@ -146,11 +146,12 @@ public class LibMatrixDNNPoolingBackwardHelper {
                public Long call() throws Exception {
                        for(int n = _rl; n < _ru; n++)  {
                                for (int c = 0; c < C; c++) {
+                                       final int doutOffset = n*CPQ + c*PQ;
+                                       final int inputOffset = n*CHW + c*HW;
                                        for (int p = 0; p < P; p++) {
                                                for (int q = 0; q < Q; q++) {
-                                                       double inVal = 
doutArray[n*CPQ + c*PQ +  p * Q + q];
+                                                       double inVal = 
doutArray[doutOffset +  p * Q + q];
                                                        if(inVal != 0) {
-                                                               final int 
inputOffset = n*CHW + c*HW;
                                                                int maxIndex = 
LibMatrixDNNHelper.getMaxIndexSparse(p, q, inputOffset, n, c, _params.input1, 
_params, performReluBackward);
                                                                if(maxIndex != 
-1)
                                                                        
outputArray[maxIndex] += inVal;

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java
index c6aaee2..19c3f71 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java
@@ -55,11 +55,14 @@ public class LibMatrixDNNPoolingHelper {
                                        final int inOffset1 = inOffset + c*HW;
                                        for (int p = 0; p < P; p++) {
                                                for (int q = 0; q < Q; q++, 
out_index++) {
+                                                       double tmp = 
outputArray[out_index];
                                                        for (int h = 
_params.start_indexes_h[p]; h < _params.end_indexes_h[p]; h++) {
-                                                               for (int w = 
_params.start_indexes_w[q]; w < _params.end_indexes_w[q]; w++) {
-                                                                       
outputArray[out_index] = Math.max(outputArray[out_index], inputArray[inOffset1 
+  h*W + w]);
+                                                               int inputIndex 
= inOffset1 +  h*W + _params.start_indexes_w[q];
+                                                               for (int w = 
_params.start_indexes_w[q]; w < _params.end_indexes_w[q]; w++, inputIndex++) {
+                                                                       tmp = 
Math.max(tmp, inputArray[inputIndex]);
                                                                }
                                                        }
+                                                       outputArray[out_index] 
= tmp;
                                                }
                                        }
                                }
@@ -75,67 +78,63 @@ public class LibMatrixDNNPoolingHelper {
        {
                public int _rl; public int _ru; 
                private final ConvolutionParameters _params;
-               int HW;
+               final int HW;
                double [] outputArray;
-               int C; int P; int Q; int W;
+               final int C; final int P; final int Q; final int W; final int 
H; final int CPQ; final int PQ;
                public SparseMaxPooling(int rl, int ru, ConvolutionParameters 
params) {
                        _rl = rl; _ru = ru;
                        _params = params;
                        outputArray = params.output.getDenseBlock();
-                       C = params.C; P = params.P; Q = params.Q; W = params.W;
+                       C = params.C; P = params.P; Q = params.Q; H = params.H; 
+                       W = params.W;
                        HW = _params.H*_params.W;
-               }
-               
-               boolean isNthRowEmpty = false;
-               int apos; int alen; int[] aix; double[] avals;
-               private void getNthSparseRow(int n) {
-                       if( !_params.input1.sparseBlock.isEmpty(n) ) {
-                               apos = _params.input1.sparseBlock.pos(n);
-                               alen = _params.input1.sparseBlock.size(n);
-                               aix = _params.input1.sparseBlock.indexes(n);
-                               avals = _params.input1.sparseBlock.values(n);
-                               isNthRowEmpty = false;
-                       }
-                       else
-                               isNthRowEmpty = true;
-               }
-               int fromIndex = -1; // as per C
-               int toIndex = -1; // as per C
-               private int setSearchIndex(int from, int searchVal) {
-                       for(int j = from; j < apos+alen; j++) {
-                               if(aix[j] > searchVal)
-                                       return Math.max(from, j-1);
-                       }
-                       return apos+alen;
-               }
-               private double getValue(int col) {
-                       if( !isNthRowEmpty ) {
-                               int index = Arrays.binarySearch(aix, fromIndex, 
toIndex, col);
-                               return index > 0 ? avals[index] : 0;
-                       }
-                       return 0;
+                       CPQ = C*P*Q;
+                       PQ = P*Q;
                }
                
                @Override
                public Long call() throws Exception {
-                       final int CPQ = C*P*Q;
                        for(int n = _rl; n < _ru; n++)  {
-                               getNthSparseRow(n);
-                               int out_index = n*CPQ;
-                               for (int c = 0; c < C; c++) {
-                                       // This allows for binary search in 
getValue to be more efficient
-                                       fromIndex = setSearchIndex(apos, c*HW);
-                                       toIndex = Math.min(apos+alen, 
setSearchIndex(fromIndex, (c+1)*HW));
-                                       for (int p = 0; p < P; p++) {
-                                               for (int q = 0; q < Q; q++, 
out_index++) {
-                                                       for (int h = 
_params.start_indexes_h[p]; h < _params.end_indexes_h[p]; h++) {
-                                                               for (int w = 
_params.start_indexes_w[q]; w < _params.end_indexes_w[q]; w++) {
-                                                                       
outputArray[out_index] = Math.max(outputArray[out_index], getValue(c*HW +  h*W 
+ w));
+                               if( !_params.input1.sparseBlock.isEmpty(n) ) {
+                                       final int apos = 
_params.input1.sparseBlock.pos(n);
+                                       final int alen = 
_params.input1.sparseBlock.size(n);
+                                       final int [] aix = 
_params.input1.sparseBlock.indexes(n);
+                                       final double [] avals = 
_params.input1.sparseBlock.values(n);
+                                       int chw = 0; int index = apos;
+                                       for (int c = 0; c < C; c++) {
+                                               final int outOffset = n*CPQ + 
c*PQ;
+                                               for(int h = 0; h < H; h++) {
+                                                       for(int w = 0; w < W; 
w++, chw++) {
+                                                               // Take into 
account zero values as well
+                                                               double nchwVal 
= 0;
+                                                               if(aix[index] 
== chw) {
+                                                                       nchwVal 
= avals[index++];
+                                                                       // 
Ensure that we satisfy the condition index < apos+alen
+                                                                       
if(index >= apos+alen) index--;
+                                                               }
+                                                               // Perform 
maxpooling without binary search :)
+                                                               // Tradeoff as 
compared to dense maxpooling: 
+                                                               // In dense 
maxpooling, iteration space CPQHW where H and W iterations are restricted by 
_params.start_indexes_h[p] 
+                                                               // and are 
eligible for JIT optimizations.
+                                                               // In sparse 
maxpooling, iteration space CHWPQ without HW restrictions.
+                                                               for (int p = 0; 
p < P; p++) {
+                                                                       if(h >= 
_params.start_indexes_h[p] && h < _params.end_indexes_h[p]) {
+                                                                               
final int outOffsetWithp = outOffset + p*Q;
+                                                                               
for (int q = 0; q < Q; q++) {
+                                                                               
        if(w >= _params.start_indexes_w[q] && w < _params.end_indexes_w[q]) {
+                                                                               
                outputArray[outOffsetWithp + q] = 
Math.max(outputArray[outOffsetWithp + q], nchwVal);
+                                                                               
        }
+                                                                               
}
+                                                                       }
                                                                }
                                                        }
                                                }
                                        }
                                }
+                               else {
+                                       // Empty input image
+                                       Arrays.fill(outputArray, n*CPQ, 
(n+1)*CPQ, 0);
+                               }
                        }
                        return 0L;
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java
index c003756..6bc7caf 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNRotate180Helper.java
@@ -66,6 +66,9 @@ public class LibMatrixDNNRotate180Helper {
        
        /**
         * Performing rotate180 when input is sparse (general case)
+        * 
+        * Why are we allocating the output of rotate180 in dense format ? 
+        * Because the number of rows of output (i.e. NPQ) is much larger than 
number of columns (i.e. K) 
         */
        static class SparseRotate180Worker implements Rotate180Worker {
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala 
b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
index 000fe32..a62fae2 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
@@ -131,6 +131,29 @@ object Caffe2DML  {
     val envFlagNesterovUDF = System.getenv("USE_NESTEROV_UDF")
     envFlagNesterovUDF != null && envFlagNesterovUDF.toBoolean
   }
+  
+  def main(args: Array[String]): Unit = {
+       // Arguments: [train_script | predict_script] $OUTPUT_DML_FILE 
$SOLVER_FILE $INPUT_CHANNELS $INPUT_HEIGHT $INPUT_WIDTH $NUM_ITER
+       if(args.length < 6) throwUsageError
+       val outputDMLFile = args(1)
+       val solverFile = args(2)
+       val inputChannels = args(3)
+       val inputHeight = args(4)
+       val inputWidth = args(5)
+       val caffeObj = new Caffe2DML(new SparkContext(), solverFile, 
inputChannels, inputHeight, inputWidth)
+       if(args(0).equals("train_script")) {
+               
Utils.writeToFile(caffeObj.getTrainingScript(true)._1.getScriptString, 
outputDMLFile)
+       }
+       else if(args(0).equals("predict_script")) {
+               Utils.writeToFile(new 
Caffe2DMLModel(caffeObj).getPredictionScript(true)._1.getScriptString, 
outputDMLFile)
+       }
+       else {
+               throwUsageError
+       }
+  }
+  def throwUsageError():Unit = {
+       throw new RuntimeException("Incorrect usage: train_script 
OUTPUT_DML_FILE SOLVER_FILE INPUT_CHANNELS INPUT_HEIGHT INPUT_WIDTH"); 
+  }
 }
 
 class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter, 
@@ -147,7 +170,10 @@ class Caffe2DML(val sc: SparkContext, val 
solverParam:Caffe.SolverParameter,
   def this(sc: SparkContext, solver1:Caffe.SolverParameter, 
numChannels:String, height:String, width:String) {
     this(sc, solver1, Utils.parseSolver(solver1), new 
CaffeNetwork(solver1.getNet, caffe.Caffe.Phase.TRAIN, numChannels, height, 
width), 
         new LearningRatePolicy(solver1), numChannels, height, width)
-  } 
+  }
+  def this(sc: SparkContext, solverPath:String, numChannels:String, 
height:String, width:String) {
+    this(sc, Utils.readCaffeSolver(solverPath), numChannels, height, width)
+  }
   val uid:String = "caffe_classifier_" + (new Random).nextLong
   override def copy(extra: org.apache.spark.ml.param.ParamMap): 
Estimator[Caffe2DMLModel] = {
     val that = new Caffe2DML(sc, solverParam, solver, net, lrPolicy, 
numChannels, height, width)
@@ -232,105 +258,30 @@ class Caffe2DML(val sc: SparkContext, val 
solverParam:Caffe.SolverParameter,
          val shouldValidate = solverParam.getTestInterval > 0 && 
solverParam.getTestIterCount > 0 && solverParam.getTestIter(0) > 0
          trainTestSplit(if(shouldValidate) solverParam.getTestIter(0) else 0)
          
-         // Set iteration-related variables such as max_epochs, 
num_iters_per_epoch, lr, etc.
-         setIterationVariables
+         // Set iteration-related variables such as num_iters_per_epoch, lr, 
etc.
+         ceilDivide(tabDMLScript, "num_iters_per_epoch", Caffe2DML.numImages, 
Caffe2DML.batchSize)
+         assign(tabDMLScript, "lr", solverParam.getBaseLr.toString)
+         assign(tabDMLScript, "max_iter", ifdef("$max_iter", 
solverParam.getMaxIter.toString))
+         assign(tabDMLScript, "e", "0")
+         
          val lossLayers = getLossLayers(net)
          // 
----------------------------------------------------------------------------
          // Main logic
-         forBlock("e", "1", "max_epochs") {
-           getTrainAlgo.toLowerCase match {
-             case "minibatch" => 
-               forBlock("i", "1", "num_iters_per_epoch") {
-                 getTrainingBatch(tabDMLScript)
-                 tabDMLScript.append("iter = iter + 1\n")
-                 // -------------------------------------------------------
-                 // Perform forward, backward and update on minibatch
-                 forward; backward; update
-                 // -------------------------------------------------------
-                 displayLoss(lossLayers(0), shouldValidate)
-            performSnapshot
-               }
-             case "batch" => {
-          tabDMLScript.append("iter = iter + 1\n")
-          // -------------------------------------------------------
-          // Perform forward, backward and update on entire dataset
-          forward; backward; update
-          // -------------------------------------------------------
-          displayLoss(lossLayers(0), shouldValidate)
-          performSnapshot
-             }
-             case "allreduce_parallel_batches" => {
-               // This setting uses the batch size provided by the user
-          if(!inputs.containsKey("$parallel_batches")) {
-            throw new RuntimeException("The parameter parallel_batches is 
required for allreduce_parallel_batches")
-          }
-          // The user specifies the number of parallel_batches
-          // This ensures that the user of generated script remembers to 
provide the commandline parameter $parallel_batches
-          assign(tabDMLScript, "parallel_batches", "$parallel_batches") 
-          assign(tabDMLScript, "group_batch_size", "parallel_batches*" + 
Caffe2DML.batchSize)
-          assign(tabDMLScript, "groups", "as.integer(ceil(" + 
Caffe2DML.numImages + "/group_batch_size))")
-          // Grab groups of mini-batches
-          forBlock("g", "1", "groups") {
-            tabDMLScript.append("iter = iter + 1\n")
-            // Get next group of mini-batches
-            assign(tabDMLScript, "group_beg", "((g-1) * group_batch_size) %% " 
+ Caffe2DML.numImages + " + 1")
-            assign(tabDMLScript, "group_end", "min(" + Caffe2DML.numImages + 
", group_beg + group_batch_size - 1)")
-            assign(tabDMLScript, "X_group_batch", Caffe2DML.X + 
"[group_beg:group_end,]")
-            assign(tabDMLScript, "y_group_batch", Caffe2DML.y + 
"[group_beg:group_end,]")
-            initializeGradients("parallel_batches")
-            assign(tabDMLScript, "X_group_batch_size", nrow("X_group_batch"))
-            parForBlock("j", "1", "parallel_batches") {
-              // Get a mini-batch in this group
-              assign(tabDMLScript, "beg", "((j-1) * " + Caffe2DML.batchSize + 
") %% nrow(X_group_batch) + 1")
-              assign(tabDMLScript, "end", "min(nrow(X_group_batch), beg + " + 
Caffe2DML.batchSize + " - 1)")
-              assign(tabDMLScript, "Xb", "X_group_batch[beg:end,]")
-              assign(tabDMLScript, "yb", "y_group_batch[beg:end,]")
-              forward; backward
-              flattenGradients
-            }
-            aggregateAggGradients    
-                 update
-                 // -------------------------------------------------------
-                 assign(tabDMLScript, "Xb", "X_group_batch")
-            assign(tabDMLScript, "yb", "y_group_batch")
-            displayLoss(lossLayers(0), shouldValidate)
-            performSnapshot
-          }
-             }
-             case "allreduce" => {
-               // This is distributed synchronous gradient descent
-               forBlock("i", "1", "num_iters_per_epoch") {
-                 tabDMLScript.append("iter = iter + 1\n")
-                 // -------------------------------------------------------
-            // Perform forward, backward and update on minibatch in parallel
-                 assign(tabDMLScript, "beg", "((i-1) * " + Caffe2DML.batchSize 
+ ") %% " + Caffe2DML.numImages + " + 1")
-                 assign(tabDMLScript, "end", " min(beg +  " + 
Caffe2DML.batchSize + " - 1, " + Caffe2DML.numImages + ")")
-                 assign(tabDMLScript, "X_group_batch", Caffe2DML.X + 
"[beg:end,]")
-            assign(tabDMLScript, "y_group_batch", Caffe2DML.y + "[beg:end,]")
-            assign(tabDMLScript, "X_group_batch_size", nrow("X_group_batch"))
-                 tabDMLScript.append("local_batch_size = 
nrow(y_group_batch)\n")
-                 val localBatchSize = "local_batch_size"
-                 initializeGradients(localBatchSize)
-                 parForBlock("j", "1", localBatchSize) {
-                   assign(tabDMLScript, "Xb", "X_group_batch[j,]")
-                   assign(tabDMLScript, "yb", "y_group_batch[j,]")
-                   forward; backward
-              flattenGradients
-                 }
-                 aggregateAggGradients    
-                 update
-                 // -------------------------------------------------------
-                 assign(tabDMLScript, "Xb", "X_group_batch")
-            assign(tabDMLScript, "yb", "y_group_batch")
-            displayLoss(lossLayers(0), shouldValidate)
-            performSnapshot
-               }
-             }
-             case _ => throw new DMLRuntimeException("Unsupported train algo:" 
+ getTrainAlgo)
-           }
-           // After every epoch, update the learning rate
-           tabDMLScript.append("# Learning rate\n")
-           lrPolicy.updateLearningRate(tabDMLScript)
+         forBlock("iter", "1", "max_iter") {
+               performTrainingIter(lossLayers, shouldValidate)
+               if(getTrainAlgo.toLowerCase.equals("batch")) {
+                       assign(tabDMLScript, "e", "iter")
+                       tabDMLScript.append("# Learning rate\n")
+                       lrPolicy.updateLearningRate(tabDMLScript)
+               }
+               else {
+                       ifBlock("iter %% num_iters_per_epoch == 0") {
+                               // After every epoch, update the learning rate
+                               assign(tabDMLScript, "e", "e + 1")
+                               tabDMLScript.append("# Learning rate\n")
+                               lrPolicy.updateLearningRate(tabDMLScript)
+                       }
+               }
          }
          // 
----------------------------------------------------------------------------
          
@@ -350,6 +301,90 @@ class Caffe2DML(val sc: SparkContext, val 
solverParam:Caffe.SolverParameter,
        }
        // 
================================================================================================
   
+  private def performTrainingIter(lossLayers:List[IsLossLayer], 
shouldValidate:Boolean):Unit = {
+       getTrainAlgo.toLowerCase match {
+      case "minibatch" => 
+          getTrainingBatch(tabDMLScript)
+          // -------------------------------------------------------
+          // Perform forward, backward and update on minibatch
+          forward; backward; update
+          // -------------------------------------------------------
+          displayLoss(lossLayers(0), shouldValidate)
+          performSnapshot
+      case "batch" => {
+             // -------------------------------------------------------
+             // Perform forward, backward and update on entire dataset
+             forward; backward; update
+             // -------------------------------------------------------
+             displayLoss(lossLayers(0), shouldValidate)
+             performSnapshot
+      }
+      case "allreduce_parallel_batches" => {
+         // This setting uses the batch size provided by the user
+             if(!inputs.containsKey("$parallel_batches")) {
+               throw new RuntimeException("The parameter parallel_batches is 
required for allreduce_parallel_batches")
+             }
+             // The user specifies the number of parallel_batches
+             // This ensures that the user of generated script remembers to 
provide the commandline parameter $parallel_batches
+             assign(tabDMLScript, "parallel_batches", "$parallel_batches") 
+             assign(tabDMLScript, "group_batch_size", "parallel_batches*" + 
Caffe2DML.batchSize)
+             assign(tabDMLScript, "groups", "as.integer(ceil(" + 
Caffe2DML.numImages + "/group_batch_size))")
+             // Grab groups of mini-batches
+             forBlock("g", "1", "groups") {
+               // Get next group of mini-batches
+               assign(tabDMLScript, "group_beg", "((g-1) * group_batch_size) 
%% " + Caffe2DML.numImages + " + 1")
+               assign(tabDMLScript, "group_end", "min(" + Caffe2DML.numImages 
+ ", group_beg + group_batch_size - 1)")
+               assign(tabDMLScript, "X_group_batch", Caffe2DML.X + 
"[group_beg:group_end,]")
+               assign(tabDMLScript, "y_group_batch", Caffe2DML.y + 
"[group_beg:group_end,]")
+               initializeGradients("parallel_batches")
+               assign(tabDMLScript, "X_group_batch_size", 
nrow("X_group_batch"))
+               parForBlock("j", "1", "parallel_batches") {
+                 // Get a mini-batch in this group
+                 assign(tabDMLScript, "beg", "((j-1) * " + Caffe2DML.batchSize 
+ ") %% nrow(X_group_batch) + 1")
+                 assign(tabDMLScript, "end", "min(nrow(X_group_batch), beg + " 
+ Caffe2DML.batchSize + " - 1)")
+                 assign(tabDMLScript, "Xb", "X_group_batch[beg:end,]")
+                 assign(tabDMLScript, "yb", "y_group_batch[beg:end,]")
+                 forward; backward
+                 flattenGradients
+               }
+               aggregateAggGradients    
+               update
+               // -------------------------------------------------------
+               assign(tabDMLScript, "Xb", "X_group_batch")
+               assign(tabDMLScript, "yb", "y_group_batch")
+               displayLoss(lossLayers(0), shouldValidate)
+               performSnapshot
+             }
+      }
+      case "allreduce" => {
+         // This is distributed synchronous gradient descent
+         // -------------------------------------------------------
+         // Perform forward, backward and update on minibatch in parallel
+         assign(tabDMLScript, "beg", "((iter-1) * " + Caffe2DML.batchSize + ") 
%% " + Caffe2DML.numImages + " + 1")
+         assign(tabDMLScript, "end", " min(beg +  " + Caffe2DML.batchSize + " 
- 1, " + Caffe2DML.numImages + ")")
+         assign(tabDMLScript, "X_group_batch", Caffe2DML.X + "[beg:end,]")
+         assign(tabDMLScript, "y_group_batch", Caffe2DML.y + "[beg:end,]")
+         assign(tabDMLScript, "X_group_batch_size", nrow("X_group_batch"))
+          tabDMLScript.append("local_batch_size = nrow(y_group_batch)\n")
+          val localBatchSize = "local_batch_size"
+          initializeGradients(localBatchSize)
+          parForBlock("j", "1", localBatchSize) {
+            assign(tabDMLScript, "Xb", "X_group_batch[j,]")
+            assign(tabDMLScript, "yb", "y_group_batch[j,]")
+            forward; backward
+          flattenGradients
+          }
+          aggregateAggGradients    
+          update
+          // -------------------------------------------------------
+          assign(tabDMLScript, "Xb", "X_group_batch")
+          assign(tabDMLScript, "yb", "y_group_batch")
+          displayLoss(lossLayers(0), shouldValidate)
+          performSnapshot
+      }
+      case _ => throw new DMLRuntimeException("Unsupported train algo:" + 
getTrainAlgo)
+    }
+  }
   // 
-------------------------------------------------------------------------------------------
   // Helper functions to generate DML
   // Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, Caffe2DML.yVal and 
Caffe2DML.numImages
@@ -499,10 +534,12 @@ class Caffe2DML(val sc: SparkContext, val 
solverParam:Caffe.SolverParameter,
   
   private def performSnapshot():Unit = {
     if(solverParam.getSnapshot > 0) {
-      ifBlock("iter %% snapshot == 0") {
+      ifBlock("iter %% " + solverParam.getSnapshot + " == 0") {
         tabDMLScript.append("snapshot_dir= \"" + solverParam.getSnapshotPrefix 
+ "\" + \"/iter_\" + iter + \"/\"\n")
-        net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l 
=> tabDMLScript.append(write(l.weight, "snapshot_dir + \"" + l.param.getName + 
"_weight.mtx\"", "binary")))
-                 net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != 
null).map(l => tabDMLScript.append(write(l.bias, "snapshot_dir + \"" + 
l.param.getName + "_bias.mtx\"", "binary")))
+        net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l 
=> tabDMLScript.append(
+               "write(" + l.weight + ", snapshot_dir + \"" + l.param.getName + 
"_weight.mtx\", format=\"binary\")\n"))
+               net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != 
null).map(l => tabDMLScript.append(
+                       "write(" + l.bias + ", snapshot_dir + \"" + 
l.param.getName + "_bias.mtx\", format=\"binary\")\n"))
       }
        }
   }
@@ -547,19 +584,6 @@ class Caffe2DML(val sc: SparkContext, val 
solverParam:Caffe.SolverParameter,
           matrix(colSums(l.dBias + "_agg"), nrow(l.bias), ncol(l.bias)))
     })
   }
-  // Set iteration-related variables such as max_epochs, num_iters_per_epoch, 
lr, etc.
-  def setIterationVariables():Unit = {
-    getTrainAlgo.toLowerCase match {
-           case "batch" => 
-             assign(tabDMLScript, "max_epochs", 
solverParam.getMaxIter.toString)
-           case _ => {
-             ceilDivide(tabDMLScript, "num_iters_per_epoch", 
Caffe2DML.numImages, Caffe2DML.batchSize)
-             ceilDivide(tabDMLScript, "max_epochs", 
solverParam.getMaxIter.toString, "num_iters_per_epoch")
-           }
-         }
-         assign(tabDMLScript, "iter", "0")
-         assign(tabDMLScript, "lr", solverParam.getBaseLr.toString)
-  }
   // 
-------------------------------------------------------------------------------------------
 }
 
@@ -617,7 +641,7 @@ class Caffe2DMLModel(val numClasses:String, val sc: 
SparkContext, val solver:Caf
          estimator.getTestAlgo.toLowerCase match {
       case "minibatch" => {
         ceilDivide(tabDMLScript(), "num_iters", Caffe2DML.numImages, 
Caffe2DML.batchSize)
-        forBlock("i", "1", "num_iters") {
+        forBlock("iter", "1", "num_iters") {
           getTestBatch(tabDMLScript)
           net.getLayers.map(layer => 
net.getCaffeLayer(layer).forward(tabDMLScript, true))
           assign(tabDMLScript, "Prob[beg:end,]", lossLayers(0).out)
@@ -656,10 +680,10 @@ class Caffe2DMLModel(val numClasses:String, val sc: 
SparkContext, val solver:Caf
       case "allreduce" => {
         // This setting doesnot use the batch size for scoring and allows the 
parfor optimizer to select the best plan
         // by minimizing the memory requirement (i.e. batch size = 1)
-        parForBlock("i", "1", Caffe2DML.numImages) {
-          assign(tabDMLScript, "Xb", "X_full[i,]")
+        parForBlock("iter", "1", Caffe2DML.numImages) {
+          assign(tabDMLScript, "Xb", "X_full[iter,]")
           net.getLayers.map(layer => 
net.getCaffeLayer(layer).forward(tabDMLScript, true))
-          assign(tabDMLScript, "Prob[i,]", lossLayers(0).out)
+          assign(tabDMLScript, "Prob[iter,]", lossLayers(0).out)
         }
       }
       case _ => throw new DMLRuntimeException("Unsupported test algo:" + 
estimator.getTestAlgo)

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala 
b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
index 456b032..6b06c26 100644
--- a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
@@ -182,10 +182,10 @@ trait NextBatchGenerator extends TabbedDMLGenerator {
          dmlScript.append("\n")
        }
        def getTestBatch(tabDMLScript:StringBuilder):Unit = {
-    assignBatch(tabDMLScript, "Xb", Caffe2DML.X, null, null, "", 
Caffe2DML.numImages, "i")
+    assignBatch(tabDMLScript, "Xb", Caffe2DML.X, null, null, "", 
Caffe2DML.numImages, "iter")
   } 
   def getTrainingBatch(tabDMLScript:StringBuilder):Unit = {
-    assignBatch(tabDMLScript, "Xb", Caffe2DML.X, "yb", Caffe2DML.y, "", 
Caffe2DML.numImages, "i")
+    assignBatch(tabDMLScript, "Xb", Caffe2DML.X, "yb", Caffe2DML.y, "", 
Caffe2DML.numImages, "iter")
   }
        def getTrainingBatch(tabDMLScript:StringBuilder, X:String, y:String, 
numImages:String):Unit = {
          assignBatch(tabDMLScript, "Xb", X, "yb", y, "", numImages, "i")
@@ -298,6 +298,13 @@ trait DMLGenerator extends SourceDMLGenerator with 
NextBatchGenerator with Visua
          numTabs -= 1
          tabDMLScript.append("}\n")
        }
+       def whileBlock(cond:String)(op: => Unit) {
+         tabDMLScript.append("while(" + cond + ") {\n")
+         numTabs += 1
+         op
+         numTabs -= 1
+         tabDMLScript.append("}\n")
+       }
        def forBlock(iterVarName:String, startVal:String, endVal:String)(op: => 
Unit) {
          tabDMLScript.append("for(" + iterVarName + " in " + startVal + ":" + 
endVal + ") {\n")
          numTabs += 1

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/main/scala/org/apache/sysml/api/dl/Utils.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Utils.scala 
b/src/main/scala/org/apache/sysml/api/dl/Utils.scala
index 0c00d3c..2684261 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Utils.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Utils.scala
@@ -277,6 +277,11 @@ object Utils {
        
        // --------------------------------------------------------------
        // File IO utility functions
+       def writeToFile(content:String, filePath:String): Unit = {
+               val pw = new java.io.PrintWriter(new File(filePath))
+               pw.write(content)
+               pw.close
+       }
        def getInputStreamReader(filePath:String ):InputStreamReader = {
                //read solver script from file
                if(filePath == null)

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/test/scripts/functions/tensor/PoolTest.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/PoolTest.R 
b/src/test/scripts/functions/tensor/PoolTest.R
index a34e0b0..aef0384 100644
--- a/src/test/scripts/functions/tensor/PoolTest.R
+++ b/src/test/scripts/functions/tensor/PoolTest.R
@@ -32,7 +32,7 @@ pad=as.integer(args[7])
 # Assumption: NCHW image format
 x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), numImg, 
numChannels*imgSize*imgSize, byrow=TRUE)
 if(as.logical(args[9])) {
-       zero_mask = (x - mean(x)) > 0 
+       zero_mask = (x - 1.5*mean(x)) > 0 
        x = x * zero_mask
 } else {
        x = x - mean(x)

http://git-wip-us.apache.org/repos/asf/systemml/blob/e624d149/src/test/scripts/functions/tensor/PoolTest.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/PoolTest.dml 
b/src/test/scripts/functions/tensor/PoolTest.dml
index cc8132f..5246a2d 100644
--- a/src/test/scripts/functions/tensor/PoolTest.dml
+++ b/src/test/scripts/functions/tensor/PoolTest.dml
@@ -30,7 +30,7 @@ poolMode=$8
 # Assumption: NCHW image format
 x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), rows=numImg, 
cols=numChannels*imgSize*imgSize)
 if($10) {
-       zero_mask = (x - mean(x)) > 0 
+       zero_mask = (x - 1.5*mean(x)) > 0 
        x = x * zero_mask
 }
 else {

Reply via email to