Repository: systemml
Updated Branches:
  refs/heads/master 98ee9b7d8 -> 33559144c


[SYSTEMML-1961] Performance dense max pooling (init, cases, nnz)

This patch makes a number of performance improvements to the existing
maxpooling builtin functions:

1) New special cases for (a) stride=1, pad=0, and (b) stride=1, pad=0,
P=1, Q=1, W=1, without materialization of start-end index arrays and
significantly simplified loop structures.

2) Thread local output initialization and nnz maintenance to reduce the
the serial fraction and improve temporal locality for multi-threaded
operations.

On a special case scenario of a 1K x 200*2048 input (~3.3GB) and
stride=1, pad=0, P=1, Q=1, and W=1, this patch improved performance from
253ms to 131ms, which is now at peak memory bandwidth. 


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/17c5d5aa
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/17c5d5aa
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/17c5d5aa

Branch: refs/heads/master
Commit: 17c5d5aae9669049164f59fccf3820b1c822c83f
Parents: 98ee9b7
Author: Matthias Boehm <[email protected]>
Authored: Sat Oct 14 23:38:51 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Oct 14 23:38:51 2017 -0700

----------------------------------------------------------------------
 .../cp/ConvolutionCPInstruction.java            |   3 -
 .../matrix/data/ConvolutionParameters.java      | 113 +++++++++++--------
 .../sysml/runtime/matrix/data/LibMatrixDNN.java |  33 +++---
 .../matrix/data/LibMatrixDNNPoolingHelper.java  |  88 ++++++++++-----
 .../integration/functions/tensor/PoolTest.java  | 110 ++++++++----------
 src/test/scripts/functions/tensor/PoolTest.R    |   7 +-
 src/test/scripts/functions/tensor/PoolTest.dml  |   3 +
 7 files changed, 193 insertions(+), 164 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/17c5d5aa/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
index f6ff998..2c7b972 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ConvolutionCPInstruction.java
@@ -20,7 +20,6 @@
 package org.apache.sysml.runtime.instructions.cp;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -354,8 +353,6 @@ public class ConvolutionCPInstruction extends 
UnaryCPInstruction {
                        }
                        else {
                                outputBlock = new MatrixBlock(N, C*P*Q, 
false).allocateBlock();
-                               if(instOpcode.equalsIgnoreCase("maxpooling"))
-                                       
Arrays.fill(outputBlock.getDenseBlock(), -Double.MAX_VALUE);
                                LibMatrixDNN.maxpooling(matBlock, outputBlock, 
params);
                        }
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/17c5d5aa/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java
index c66b5c6..d64a261 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/ConvolutionParameters.java
@@ -29,11 +29,13 @@ import org.apache.sysml.runtime.util.ConvolutionUtils;
  * This class is container that stores parameters required for executing 
following operations:
  * conv2d, conv2d_backward_data, conv2d_backward_filter, maxpooling, 
maxpooling_backward 
  */
-public class ConvolutionParameters implements Serializable {
+public class ConvolutionParameters implements Serializable 
+{
        private static final long serialVersionUID = -212362627205772829L;
-       public int N; public int C; public int H; public int W;
-       public int K; public int R; public int S; public int stride_h; public 
int stride_w; public int pad_h; public int pad_w;
-       public int P; public int Q; public int numThreads;
+       
+       public int N, C, H, W, K, R, S, P, Q;
+       public int stride_h, stride_w, pad_h, pad_w;
+       public int numThreads;
        
        // Optional variables used by ConvolutionCPInstruction
        public boolean enableNative = false;
@@ -43,52 +45,9 @@ public class ConvolutionParameters implements Serializable {
        public MatrixBlock bias;
        public int [] start_indexes_h, end_indexes_h, start_indexes_w, 
end_indexes_w; 
        
-       private static int convertToInt(long val) throws DMLRuntimeException {
-               if( val > Integer.MAX_VALUE )
-                       throw new DMLRuntimeException("The value for 
ConvolutionParameters is too large:" + val);
-               return (int) val;
-       }
-       
-       public boolean compare(ConvolutionParameters that) {
-               if(this.N == that.N && this.C == that.C && this.H == that.H && 
this.W == that.W
-                               && this.K == that.K && this.R == that.R && 
this.S == that.S && this.stride_h == that.stride_h
-                                && this.stride_w == that.stride_w  && 
this.pad_h == that.pad_h
-                                 && this.pad_w == that.pad_w   && 
this.numThreads == that.numThreads) {
-                       return true;
-               }
-               return false;
-       }
-       
-       @Override
-       public String toString() {
-               return "(NCHW=[" + N + " " + C + " " + H + " " + W + "], 
KCRS=[" + K + " " + R + " " + S + "], stride=[" + stride_h + "," + stride_w  + 
-                               "], pad=[" + pad_h + "," + pad_w + "])";  
-       }
-       
-       public void setIfUnknown(Hop N, Hop C, Hop H, Hop W,
-                       Hop K, Hop R, Hop S, Hop stride_h, Hop stride_w, Hop 
pad_h, Hop pad_w, int numThreads) throws DMLRuntimeException {
-               if(this.N < 0) this.N = 
convertToInt(Hop.computeSizeInformation(N));
-               if(this.C < 0) this.C = 
convertToInt(Hop.computeSizeInformation(C));
-               if(this.H < 0) this.H = 
convertToInt(Hop.computeSizeInformation(H));
-               if(this.W < 0) this.W = 
convertToInt(Hop.computeSizeInformation(W));
-               if(this.K < 0) this.K = 
convertToInt(Hop.computeSizeInformation(K));
-               if(this.R < 0) this.R = 
convertToInt(Hop.computeSizeInformation(R));
-               if(this.S < 0) this.S = 
convertToInt(Hop.computeSizeInformation(S));
-               if(this.stride_h < 0) this.stride_h = 
convertToInt(Hop.computeSizeInformation(stride_h));
-               if(this.stride_w < 0) this.stride_w = 
convertToInt(Hop.computeSizeInformation(stride_w));
-               if(this.pad_h < 0) this.pad_h = 
convertToInt(Hop.computeSizeInformation(pad_h));
-               if(this.pad_w < 0) this.pad_w = 
convertToInt(Hop.computeSizeInformation(pad_w));
-               if(this.P < 0 && this.H >= 0 && this.R >= 0 && this.stride_h >= 
0 && this.pad_h >= 0) {
-                       this.P = (int) ConvolutionUtils.getP(this.H, this.R, 
this.stride_h, this.pad_h);
-               }
-               if(this.Q < 0 && this.W >= 0 && this.S >= 0 && this.stride_w >= 
0 && this.pad_w >= 0) {
-                       this.Q = (int) ConvolutionUtils.getQ(this.W, this.S, 
this.stride_w, this.pad_w);
-               }
-               this.numThreads = numThreads;
-       }
-       
        public ConvolutionParameters(long N, long C, long H, long W,
-                       long K, long R, long S, long stride_h, long stride_w, 
long pad_h, long pad_w, int numThreads) throws DMLRuntimeException {
+                       long K, long R, long S, long stride_h, long stride_w, 
+                       long pad_h, long pad_w, int numThreads) throws 
DMLRuntimeException {
                this.N = convertToInt(N);
                this.C = convertToInt(C);
                this.H = convertToInt(H);
@@ -139,7 +98,63 @@ public class ConvolutionParameters implements Serializable {
                this.numThreads = numThreads;
        }
        
+       private static int convertToInt(long val) throws DMLRuntimeException {
+               if( val > Integer.MAX_VALUE )
+                       throw new DMLRuntimeException("The value for 
ConvolutionParameters is too large:" + val);
+               return (int) val;
+       }
+       
+       public boolean compare(ConvolutionParameters that) {
+               if(this.N == that.N && this.C == that.C && this.H == that.H && 
this.W == that.W
+                               && this.K == that.K && this.R == that.R && 
this.S == that.S && this.stride_h == that.stride_h
+                                && this.stride_w == that.stride_w  && 
this.pad_h == that.pad_h
+                                 && this.pad_w == that.pad_w   && 
this.numThreads == that.numThreads) {
+                       return true;
+               }
+               return false;
+       }
+       
+       @Override
+       public String toString() {
+               return "(NCHW=[" + N + " " + C + " " + H + " " + W + "], 
KCRS=[" + K + " " + R + " " + S + "], stride=[" + stride_h + "," + stride_w  + 
+                               "], pad=[" + pad_h + "," + pad_w + "])";  
+       }
+       
+       public void setIfUnknown(Hop N, Hop C, Hop H, Hop W,
+                       Hop K, Hop R, Hop S, Hop stride_h, Hop stride_w, Hop 
pad_h, Hop pad_w, int numThreads) throws DMLRuntimeException {
+               if(this.N < 0) this.N = 
convertToInt(Hop.computeSizeInformation(N));
+               if(this.C < 0) this.C = 
convertToInt(Hop.computeSizeInformation(C));
+               if(this.H < 0) this.H = 
convertToInt(Hop.computeSizeInformation(H));
+               if(this.W < 0) this.W = 
convertToInt(Hop.computeSizeInformation(W));
+               if(this.K < 0) this.K = 
convertToInt(Hop.computeSizeInformation(K));
+               if(this.R < 0) this.R = 
convertToInt(Hop.computeSizeInformation(R));
+               if(this.S < 0) this.S = 
convertToInt(Hop.computeSizeInformation(S));
+               if(this.stride_h < 0) this.stride_h = 
convertToInt(Hop.computeSizeInformation(stride_h));
+               if(this.stride_w < 0) this.stride_w = 
convertToInt(Hop.computeSizeInformation(stride_w));
+               if(this.pad_h < 0) this.pad_h = 
convertToInt(Hop.computeSizeInformation(pad_h));
+               if(this.pad_w < 0) this.pad_w = 
convertToInt(Hop.computeSizeInformation(pad_w));
+               if(this.P < 0 && this.H >= 0 && this.R >= 0 && this.stride_h >= 
0 && this.pad_h >= 0) {
+                       this.P = (int) ConvolutionUtils.getP(this.H, this.R, 
this.stride_h, this.pad_h);
+               }
+               if(this.Q < 0 && this.W >= 0 && this.S >= 0 && this.stride_w >= 
0 && this.pad_w >= 0) {
+                       this.Q = (int) ConvolutionUtils.getQ(this.W, this.S, 
this.stride_w, this.pad_w);
+               }
+               this.numThreads = numThreads;
+       }
+       
        public boolean isOutputThreadSafe() {
                return output.isThreadSafe();
        }
+       
+       public boolean isStride1Pad0() {
+               return (stride_h==1 && stride_w==1
+                       && pad_h==0 && pad_w==0);
+       }
+       
+       public boolean isAllOnes(Integer...params) {
+               boolean ret = true;
+               for(int param : params)
+                       ret &= (param == 1);
+               return ret;
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/17c5d5aa/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
index 3f67b1a..b967780 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java
@@ -356,19 +356,15 @@ public class LibMatrixDNN {
                params.end_indexes_h = new int[params.P];
                params.start_indexes_w = new int[params.Q];
                params.end_indexes_w = new int[params.Q];
-               for (int p = 0; p < params.P; p++) {
-                       int start_index_h = p * params.stride_h - params.pad_h;
-                       int end_index_h = start_index_h + params.R;
+               for( int p=0, ix=-params.pad_h; p < params.P; p++, 
ix+=params.stride_h ) {
                        // Note: We do not treat pad as zero
-                       params.start_indexes_h[p] = Math.max(start_index_h, 0);
-                       params.end_indexes_h[p] = Math.min(end_index_h, 
params.H);
+                       params.start_indexes_h[p] = Math.max(ix, 0);
+                       params.end_indexes_h[p] = Math.min(ix+params.R, 
params.H);
                }
-               for (int q = 0; q < params.Q; q++) {
-                       int start_index_w =  q * params.stride_w - params.pad_w;
-                       int end_index_w = start_index_w + params.S;
+               for( int q=0, ix=-params.pad_w; q < params.Q; q++, 
ix+=params.stride_w) {
                        // Note: We do not treat pad as zero
-                       params.start_indexes_w[q] = Math.max(start_index_w, 0);
-                       params.end_indexes_w[q] = Math.min(end_index_w, 
params.W);
+                       params.start_indexes_w[q] = Math.max(ix, 0);
+                       params.end_indexes_w[q] = Math.min(ix+params.S, 
params.W);
                }
        }
        
@@ -528,21 +524,24 @@ public class LibMatrixDNN {
                }
        }
        
-       public static void maxpooling(MatrixBlock input, MatrixBlock 
outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
+       public static void maxpooling(MatrixBlock input, MatrixBlock output, 
ConvolutionParameters params) throws DMLRuntimeException {
                params.input1 = input;
-               params.output = outputBlock;
+               params.output = output;
                
                if(input.getNumColumns() != params.C*params.H*params.W || 
input.getNumRows() != params.N) {
-                       throw new DMLRuntimeException("Incorrect input 
dimensions in maxpooling:" + input.getNumRows() + " " + input.getNumColumns() + 
" " + params.N + " " + params.C*params.H*params.W);
+                       throw new DMLRuntimeException("Incorrect input 
dimensions in maxpooling:" + input.getNumRows() + " " 
+                               + input.getNumColumns() + " " + params.N + " " 
+ params.C*params.H*params.W);
                }
                
-               fillIndexesArray(params);
+               //materialize indexes unless basic case with stride=1 and pad=0
+               if( !params.isStride1Pad0() || input.sparse )
+                       fillIndexesArray(params);
                
-               execute(LibMatrixDNNHelper.getMaxPoolingWorkers(params), 
params);
+               long nnz = 
execute(LibMatrixDNNHelper.getMaxPoolingWorkers(params), params);
                
                // post-processing: maintain nnz
-               outputBlock.recomputeNonZeros(); 
-               outputBlock.examSparsity();
+               output.setNonZeros(nnz);
+               output.examSparsity();
        }
        
        /**

http://git-wip-us.apache.org/repos/asf/systemml/blob/17c5d5aa/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java
index 19c3f71..52cdbcd 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPoolingHelper.java
@@ -31,43 +31,58 @@ public class LibMatrixDNNPoolingHelper {
         */
        public static class DenseMaxPooling implements Callable<Long> 
        {
-               public int _rl; public int _ru; 
+               private final int _rl, _ru; 
                private final ConvolutionParameters _params;
-               double [] inputArray; double [] outputArray;
-               int C; int P; int Q; int W;
+               
                public DenseMaxPooling(int rl, int ru, ConvolutionParameters 
params) {
                        _rl = rl; _ru = ru;
                        _params = params;
-                       inputArray = params.input1.getDenseBlock();
-                       outputArray = params.output.getDenseBlock();
-                       C = params.C; P = params.P; Q = params.Q; W = params.W;
                }
                
                @Override
                public Long call() throws Exception {
+                       final int C = _params.C, P = _params.P, Q = _params.Q;
+                       final int R = _params.R, S = _params.S, H = _params.H, 
W = _params.W;
                        final int HW = _params.H*_params.W;
                        final int CHW = _params.C*_params.H*_params.W;
                        final int CPQ = C*P*Q;
-                       for(int n = _rl; n < _ru; n++)  {
-                               final int inOffset = n*CHW;
-                               int out_index = n*CPQ;
-                               for (int c = 0; c < C; c++) {
-                                       final int inOffset1 = inOffset + c*HW;
-                                       for (int p = 0; p < P; p++) {
-                                               for (int q = 0; q < Q; q++, 
out_index++) {
-                                                       double tmp = 
outputArray[out_index];
-                                                       for (int h = 
_params.start_indexes_h[p]; h < _params.end_indexes_h[p]; h++) {
-                                                               int inputIndex 
= inOffset1 +  h*W + _params.start_indexes_w[q];
-                                                               for (int w = 
_params.start_indexes_w[q]; w < _params.end_indexes_w[q]; w++, inputIndex++) {
-                                                                       tmp = 
Math.max(tmp, inputArray[inputIndex]);
-                                                               }
-                                                       }
-                                                       outputArray[out_index] 
= tmp;
-                                               }
-                                       }
-                               }
+                       double[] in = _params.input1.getDenseBlock();
+                       double[] out = _params.output.getDenseBlock();
+                       
+                       //thread-local initialization of output block 
+                       if( !(_params.isStride1Pad0() && _params.isAllOnes(P, 
Q, W)) )
+                               Arrays.fill(out, _rl*CPQ, _ru*CPQ, 
-Double.MAX_VALUE);
+                       
+                       if( _params.isStride1Pad0() && _params.isAllOnes(P, Q, 
W) ) { 
+                               //quick-path w/o materialized index arrays and 
+                               //simplified inner loops for P = 1, Q = 1, W = 1
+                               int lenh = Math.min(R,H);
+                               for(int i = _rl, oix=_rl*C; i < _ru; i++, 
oix+=C)
+                                       for (int c = 0, off=i*CHW; c < C; c++, 
off+=H)
+                                               out[oix+c] = 
max(-Double.MAX_VALUE, in, off, lenh);
+                       }
+                       else if( _params.isStride1Pad0() ) {
+                               //quick-path w/o materialized index arrays 
+                               for(int i = _rl; i < _ru; i++)
+                                       for (int c = 0, off=i*CHW, oix=i*CPQ; c 
< C; c++, off+=HW)
+                                               for (int p = 0; p < P; p++, 
oix+=Q)
+                                                       for (int h = p; h < 
Math.min(p+R,H); h++)
+                                                               for (int q = 0, 
off2=off+h*W; q < Q; q++)
+                                                                       
out[oix+q] = max(out[oix+q], in, off2+q, Math.min(S,W-q));
+                       }
+                       else { //general case
+                               int[] hl = _params.start_indexes_h, hu = 
_params.end_indexes_h;
+                               int[] wl = _params.start_indexes_w, wu = 
_params.end_indexes_w;
+                               for(int i = _rl; i < _ru; i++)
+                                       for (int c = 0, off=i*CHW, oix=i*CPQ; c 
< C; c++, off+=HW)
+                                               for (int p = 0; p < P; p++, 
oix+=Q)
+                                                       for (int h = hl[p]; h < 
hu[p]; h++)
+                                                               for (int q = 0, 
off2=off+h*W; q < Q; q++)
+                                                                       
out[oix+q] = max(out[oix+q], in, off2+wl[q], wu[q]-wl[q]);
                        }
-                       return 0L;
+                       
+                       //thread-local recomputation of non-zeros
+                       return _params.output.recomputeNonZeros(_rl, _ru-1);
                }
        }
        
@@ -76,24 +91,26 @@ public class LibMatrixDNNPoolingHelper {
         */
        public static class SparseMaxPooling implements Callable<Long> 
        {
-               public int _rl; public int _ru; 
+               private final int _rl, _ru; 
                private final ConvolutionParameters _params;
-               final int HW;
-               double [] outputArray;
-               final int C; final int P; final int Q; final int W; final int 
H; final int CPQ; final int PQ;
+               private double [] outputArray;
+               private final int C, P, Q, W, H, CPQ, PQ;
+               
                public SparseMaxPooling(int rl, int ru, ConvolutionParameters 
params) {
                        _rl = rl; _ru = ru;
                        _params = params;
                        outputArray = params.output.getDenseBlock();
                        C = params.C; P = params.P; Q = params.Q; H = params.H; 
                        W = params.W;
-                       HW = _params.H*_params.W;
                        CPQ = C*P*Q;
                        PQ = P*Q;
                }
                
                @Override
                public Long call() throws Exception {
+                       //thread-local initialization of output block 
+                       Arrays.fill(outputArray, _rl *CPQ, _ru*CPQ, 
-Double.MAX_VALUE);
+                       
                        for(int n = _rl; n < _ru; n++)  {
                                if( !_params.input1.sparseBlock.isEmpty(n) ) {
                                        final int apos = 
_params.input1.sparseBlock.pos(n);
@@ -136,7 +153,16 @@ public class LibMatrixDNNPoolingHelper {
                                        Arrays.fill(outputArray, n*CPQ, 
(n+1)*CPQ, 0);
                                }
                        }
-                       return 0L;
+                       
+                       //thread-local recomputation of non-zeros
+                       return _params.output.recomputeNonZeros(_rl, _ru-1);
                }
        }
+       
+       private static double max(final double aval, double[] b, final int bi, 
final int len) {
+               double ret = aval;
+               for( int i = bi; i < bi+len; i++ )
+                       ret = Math.max(ret, b[i]);
+               return ret;
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/17c5d5aa/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolTest.java
index e1c84c5..1acb14c 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/tensor/PoolTest.java
@@ -51,146 +51,130 @@ public class PoolTest extends AutomatedTestBase
        }
        
        @Test
-       public void testMaxPool2DDense2() 
-       {
+       public void testMaxPool2DDense2() {
                int numImg = 2; int imgSize = 6; int numChannels = 1;  int 
stride = 1; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
                runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, 
pad, poolSize1, poolSize2, "max", false);
        }
        
-       
        @Test
-       public void testMaxPool2DDense3() 
-       {
+       public void testMaxPool2DDense3() {
                int numImg = 3; int imgSize = 7; int numChannels = 2;  int 
stride = 2; int pad = 0; int poolSize1 = 3; int poolSize2 = 3;
                runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, 
pad, poolSize1, poolSize2, "max", false);
        }
        
        @Test
-       public void testMaxPool2DDense4() 
-       {
+       public void testMaxPool2DDense4() {
                int numImg = 2; int imgSize = 4; int numChannels = 2;  int 
stride = 1; int pad = 0; int poolSize1 = 3; int poolSize2 = 3;
                runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, 
pad, poolSize1, poolSize2, "max", false);
        }
        
        @Test
-       public void testMaxPool2DSparse1() 
-       {
+       public void testMaxPool2DDense5() {
+               int numImg = 2; int imgSize = 8; int numChannels = 4;  int 
stride = 1; int pad = 0; int poolSize1 = imgSize*imgSize; int poolSize2 = 1;
+               runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, 
pad, poolSize1, poolSize2, "max2", false);
+       }
+       
+       @Test
+       public void testMaxPool2DSparse1() {
                int numImg = 1; int imgSize = 6; int numChannels = 1;  int 
stride = 2; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
                runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, 
pad, poolSize1, poolSize2, "max", true);
        }
        
        @Test
-       public void testMaxPool2DSparse2() 
-       {
+       public void testMaxPool2DSparse2() {
                int numImg = 2; int imgSize = 6; int numChannels = 1;  int 
stride = 1; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
                runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, 
pad, poolSize1, poolSize2, "max", true);
        }
        
-       
        @Test
-       public void testMaxPool2DSparse3() 
-       {
+       public void testMaxPool2DSparse3() {
                int numImg = 3; int imgSize = 7; int numChannels = 2;  int 
stride = 2; int pad = 0; int poolSize1 = 3; int poolSize2 = 3;
                runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, 
pad, poolSize1, poolSize2, "max", true);
        }
        
        @Test
-       public void testMaxPool2DSparse4() 
-       {
+       public void testMaxPool2DSparse4() {
                int numImg = 2; int imgSize = 4; int numChannels = 2;  int 
stride = 1; int pad = 0; int poolSize1 = 3; int poolSize2 = 3;
                runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, 
pad, poolSize1, poolSize2, "max", true);
        }
        
+       @Test
+       public void testMaxPool2DSparse5() {
+               int numImg = 2; int imgSize = 32; int numChannels = 4;  int 
stride = 1; int pad = 0; int poolSize1 = imgSize*imgSize; int poolSize2 = 1;
+               runPoolTest(ExecType.CP, imgSize, numImg, numChannels, stride, 
pad, poolSize1, poolSize2, "max2", true);
+       }
+       
        // ----------------------------------------
        
        @Test
-       public void testMaxPool2DDense1SP() 
-       {
+       public void testMaxPool2DDense1SP() {
                int numImg = 1; int imgSize = 50; int numChannels = 1;  int 
stride = 2; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
                runPoolTest(ExecType.SPARK, imgSize, numImg, numChannels, 
stride, pad, poolSize1, poolSize2, "max", false);
        }
        
        @Test
-       public void testMaxPool2DDense2SP() 
-       {
+       public void testMaxPool2DDense2SP() {
                int numImg = 2; int imgSize = 6; int numChannels = 1;  int 
stride = 1; int pad = 0; int poolSize1 = 2; int poolSize2 = 2;
                runPoolTest(ExecType.SPARK, imgSize, numImg, numChannels, 
stride, pad, poolSize1, poolSize2, "max", false);
        }
        
-       
        @Test
-       public void testMaxPool2DDense3SP() 
-       {
+       public void testMaxPool2DDense3SP() {
                int numImg = 3; int imgSize = 7; int numChannels = 2;  int 
stride = 2; int pad = 0; int poolSize1 = 3; int poolSize2 = 3;
                runPoolTest(ExecType.SPARK, imgSize, numImg, numChannels, 
stride, pad, poolSize1, poolSize2, "max", false);
        }
        
        @Test
-       public void testMaxPool2DDense4SP() 
-       {
+       public void testMaxPool2DDense4SP() {
                int numImg = 2; int imgSize = 4; int numChannels = 2;  int 
stride = 1; int pad = 0; int poolSize1 = 3; int poolSize2 = 3;
                runPoolTest(ExecType.SPARK, imgSize, numImg, numChannels, 
stride, pad, poolSize1, poolSize2, "max", false);
        }
        
-       /**
-        * 
-        * @param et
-        * @param sparse
-        */
        public void runPoolTest( ExecType et, int imgSize, int numImg, int 
numChannels, int stride, 
                        int pad, int poolSize1, int poolSize2, String poolMode, 
boolean sparse) 
        {
-               RUNTIME_PLATFORM oldRTP = rtplatform;
-                       
+               RUNTIME_PLATFORM platformOld = rtplatform;
+               switch( et ){
+                       case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+                       case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+                       default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
+               }
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               if( rtplatform == RUNTIME_PLATFORM.SPARK )
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
                
                try
                {
-                       String sparseVal = (""+sparse).toUpperCase();
-           TestConfiguration config = getTestConfiguration(TEST_NAME);
-           if(et == ExecType.SPARK) {
-               rtplatform = RUNTIME_PLATFORM.SPARK;
-           }
-           else {
-               rtplatform = (et==ExecType.MR)? RUNTIME_PLATFORM.HADOOP : 
RUNTIME_PLATFORM.SINGLE_NODE;
-           }
-                       if( rtplatform == RUNTIME_PLATFORM.SPARK )
-                               DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+                       String sparseVal = String.valueOf(sparse).toUpperCase();
                        
+                       TestConfiguration config = 
getTestConfiguration(TEST_NAME);
                        loadTestConfiguration(config);
-               
-                       /* This is for running the junit test the new way, 
i.e., construct the arguments directly */
+       
                        String RI_HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = RI_HOME + TEST_NAME + ".dml";
-                       
-                       programArgs = new String[]{"-explain", "-args",  "" + 
imgSize, "" + numImg, 
-                                       "" + numChannels, "" + poolSize1, "" + 
poolSize2, 
-                                       "" + stride, "" + pad, poolMode, 
-                                       output("B"), sparseVal};
-                               
-                       boolean exceptionExpected = false;
-                       int expectedNumberOfJobs = -1;
-                       runTest(true, exceptionExpected, null, 
expectedNumberOfJobs);
+                       programArgs = new String[]{"-explain", "-args", 
String.valueOf(imgSize), 
+                               String.valueOf(numImg), 
String.valueOf(numChannels),
+                               String.valueOf(poolSize1), 
String.valueOf(poolSize2),
+                               String.valueOf(stride), String.valueOf(pad), 
poolMode,
+                               output("B"), sparseVal};
                        
                        fullRScriptName = RI_HOME + TEST_NAME + ".R";
                        rCmd = "Rscript" + " " + fullRScriptName + " " + 
imgSize + " " + numImg + 
-                                       " " + numChannels + " " + poolSize1 + 
-                                       " " + poolSize2 + " " + stride + " " + 
pad + " " + expectedDir() + " " + sparseVal; 
+                               " " + numChannels + " " + poolSize1 + " " + 
poolSize2 + " " + stride + 
+                               " " + pad + " " + expectedDir() + " " + 
sparseVal + " " + poolMode; 
                        
-                       // Run comparison R script
+                       // run scripts
+                       runTest(true, false, null, -1);
                        runRScript(true);
-                       HashMap<CellIndex, Double> bHM = readRMatrixFromFS("B");
                        
+                       //compare results
+                       HashMap<CellIndex, Double> bHM = readRMatrixFromFS("B");
                        HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS("B");
                        TestUtils.compareMatrices(dmlfile, bHM, epsilon, 
"B-DML", "NumPy");
-                       
                }
-               finally
-               {
-                       rtplatform = oldRTP;
+               finally {
+                       rtplatform = platformOld;
                        DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
                }
        }
-       
-       
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/17c5d5aa/src/test/scripts/functions/tensor/PoolTest.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/PoolTest.R 
b/src/test/scripts/functions/tensor/PoolTest.R
index aef0384..fdcba56 100644
--- a/src/test/scripts/functions/tensor/PoolTest.R
+++ b/src/test/scripts/functions/tensor/PoolTest.R
@@ -28,6 +28,7 @@ poolSize1=as.integer(args[4])
 poolSize2=as.integer(args[5])
 stride=as.integer(args[6])
 pad=as.integer(args[7])
+mode=args[10]
 
 # Assumption: NCHW image format
 x=matrix(seq(1, numImg*numChannels*imgSize*imgSize), numImg, 
numChannels*imgSize*imgSize, byrow=TRUE)
@@ -98,6 +99,10 @@ max_pool <- function(X, N, C, Hin, Win, Hf, Wf,
   out
 }
 
-output = max_pool(x, numImg, numChannels, imgSize, imgSize, poolSize1, 
poolSize2, stride, stride)
+if( mode=="max" ) {
+  output = max_pool(x, numImg, numChannels, imgSize, imgSize, poolSize1, 
poolSize2, stride, stride)
+} else {
+  output = max_pool(x, numImg, numChannels, imgSize*imgSize, 1, poolSize1, 
poolSize2, stride, stride)
+}
 
 writeMM(as(output,"CsparseMatrix"), paste(args[8], "B", sep=""))
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/17c5d5aa/src/test/scripts/functions/tensor/PoolTest.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/tensor/PoolTest.dml 
b/src/test/scripts/functions/tensor/PoolTest.dml
index 5246a2d..5f20fae 100644
--- a/src/test/scripts/functions/tensor/PoolTest.dml
+++ b/src/test/scripts/functions/tensor/PoolTest.dml
@@ -39,6 +39,9 @@ else {
 if(poolMode == "max") {
        output = max_pool(x, stride=[stride, stride], padding=[pad, pad], 
input_shape=[numImg, numChannels, imgSize, imgSize], pool_size=[poolSize1, 
poolSize2])
 }
+else if(poolMode == "max2") {
+       output = max_pool(x, stride=[stride, stride], padding=[pad, pad], 
input_shape=[numImg, numChannels, imgSize*imgSize, 1], pool_size=[poolSize1, 
poolSize2])
+}
 #else {
        #output = avg_pool(x, stride=[stride, stride], padding=[pad, pad], 
input_shape=[numImg, numChannels, imgSize, imgSize], pool_size=[poolSize1, 
poolSize2])
 #}

Reply via email to