This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit efc78f18c3ecf2aa9cfd916aefc07909a0db0e9c
Author: baunsgaard <[email protected]>
AuthorDate: Fri Mar 26 18:33:31 2021 +0100

    [SYSTEMDS-2914] maxpooling_backward sparse Update
    
    This commit update the maxpooling sparse output, to use
    append again, since the outputs were almost sorted,
    this means that in practice small arrays are allocated and sorted.
    to then be appended to the sparse row outputs.
    The sorting is very limited to small arrays of 1-14 elements,
    but this value can grow depending on how many kernels can be applied
    on the input horizontally.
    
    Closes #1213
---
 .github/workflows/functionsTests.yml               |   3 +-
 .../runtime/instructions/cp/DnnCPInstruction.java  |  12 -
 .../sysds/runtime/matrix/data/LibMatrixDNN.java    |   3 +
 .../runtime/matrix/data/LibMatrixDNNPooling.java   | 315 ++++++++++++++++-----
 .../applications/nn/NNMaxPool2dComponentTest.java  |   2 +-
 .../applications/nn/component/max_pool2d.dml       |   4 +-
 6 files changed, 259 insertions(+), 80 deletions(-)

diff --git a/.github/workflows/functionsTests.yml 
b/.github/workflows/functionsTests.yml
index 5e7466c..70d2af1 100644
--- a/.github/workflows/functionsTests.yml
+++ b/.github/workflows/functionsTests.yml
@@ -45,7 +45,8 @@ jobs:
           "**.functions.codegenalg.partone.**",
           "**.functions.builtin.**",
           
"**.functions.frame.**,**.functions.indexing.**,**.functions.io.**,**.functions.jmlc.**,**.functions.lineage.**",
-          
"**.functions.dnn.**,**.functions.misc.**,**.functions.mlcontext.**,**.functions.paramserv.**",
+          "**.functions.dnn.**,**.functions.paramserv.**",
+          "**.functions.misc.**,**.functions.mlcontext.**",
           "**.functions.nary.**,**.functions.quaternary.**",
           
"**.functions.parfor.**,**.functions.pipelines.**,**.functions.privacy.**,**.functions.unary.scalar.**,**.functions.updateinplace.**,**.functions.vect.**",
           
"**.functions.reorg.**,**.functions.rewrite.**,**.functions.ternary.**,**.functions.transform.**",
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/DnnCPInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/DnnCPInstruction.java
index f29b85e..a486672 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/DnnCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/DnnCPInstruction.java
@@ -548,12 +548,6 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                        }
                        else {
                                outputBlock = new MatrixBlock(K, C*R*S, 
false).allocateBlock();
-                               if(params.enableNative ){
-                                       if(matBlock.isInSparseFormat())
-                                               matBlock.sparseToDense();
-                                       if(dout.isInSparseFormat())
-                                               dout.sparseToDense();
-                               }
                                if(params.enableNative && 
!matBlock.isInSparseFormat() && !dout.isInSparseFormat())
                                        
LibMatrixNative.conv2dBackwardFilter(matBlock, dout, outputBlock, params);
                                else
@@ -568,12 +562,6 @@ public class DnnCPInstruction extends UnaryCPInstruction {
                        }
                        else {
                                outputBlock = new MatrixBlock(N, C * H * W, 
false).allocateBlock();
-                               if(params.enableNative ){
-                                       if(matBlock.isInSparseFormat())
-                                               matBlock.sparseToDense();
-                                       if(dout.isInSparseFormat())
-                                               dout.sparseToDense();
-                               }
                                if(params.enableNative && 
!isFilterSparse(matBlock) && !dout.isInSparseFormat())
                                        
LibMatrixNative.conv2dBackwardData(matBlock, dout, outputBlock, params);
                                else
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java
index d1bd2d3..598fef5 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java
@@ -178,6 +178,9 @@ public class LibMatrixDNN {
                        fillIndexesArray(params); 
                }
                else {
+                       if(!params.input2.isInSparseFormat())
+                               params.input1.sparseToDense();
+
                        if( !(params.input1.isInSparseFormat() && 
!params.input2.isInSparseFormat()) )
                                fillIndexesArray(params); //not needed for 
sparse-dense  
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNNPooling.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNNPooling.java
index 2196d1e..84170ac 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNNPooling.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNNPooling.java
@@ -37,11 +37,6 @@ import 
org.apache.sysds.runtime.matrix.data.LibMatrixDNNHelper.CellIndex3;
 public class LibMatrixDNNPooling {
        
        protected static final Log LOG =  
LogFactory.getLog(LibMatrixDNNPooling.class.getName());
-
-       // *********************************** low-level runtime operator 
selection ***********************************************
-       // *********************************** based on runtime properties 
(sparsity, native, etc) ********************************
-       // These methods help reduce branch miss predictions and 
instruction-cache misses.
-       // Also, they simplify the design of LibMatrixDNN and help in 
code-maintenance.
        
        /**
         * Factory method that returns list of callable tasks for performing 
pooling operation
@@ -78,6 +73,7 @@ public class LibMatrixDNNPooling {
                int k = 
OptimizerUtils.getConstrainedNumThreads(params.numThreads);
                int taskSize = (int)(Math.ceil((double)params.N / k / 2));
                if(poolType == PoolingType.MAX) {
+                       
                        boolean sparse1 = params.input1.isInSparseFormat();
                        boolean sparse2 = params.input2.isInSparseFormat();
                        for(int i = 0; i*taskSize < params.N; i++) {
@@ -357,22 +353,31 @@ public class LibMatrixDNNPooling {
                public Long call() throws Exception {
                        if(output.isInSparseFormat()){
                                SparseBlock out = output.getSparseBlock();
+                               final int[] i = new int[Q];
+                               final double[] v = new double[Q];  
                                for(int n = _rl; n < _ru; n++){
                                        // each row correspond to a single 
batch element.
                                        // here we allocate the sparse row.
                                        out.allocate(n, P*Q*C);
-                                       SparseRow elm = out.get(n);
+                                       final SparseRow elm = out.get(n);
+                                       final int nCHW = n*CHW;
+
+                                       // tmp arrays for sorting.
                                        for(int c = 0; c < C; c++){
                                                // each channel processed.
-                                               final int inputOffset = n*CHW + 
c*HW;
+                                               final int inputOffset = nCHW + 
c*HW;
                                                final int outputOffset = n*CPQ 
+ c*PQ;
                                                for(int p = 0; p < P; p++){
+                                                       int pointer = 0;
                                                        for(int q = 0; q < Q; 
q++){
-                                                               int maxIndex =  
getMaxIndex(p, q, inputOffset, inputArray, _params, performReluBackward);
+                                                               int maxIndex = 
getMaxIndex(p, q, inputOffset, inputArray, _params, performReluBackward);
                                                                if(maxIndex != 
-1){
-                                                                       
add(elm, maxIndex - n*CHW, doutArray[outputOffset +  p * Q + q] );
+                                                                       
i[pointer] = maxIndex - nCHW;
+                                                                       
v[pointer] = doutArray[outputOffset +  p * Q + q];
+                                                                       
pointer++;
                                                                }
                                                        }
+                                                       add(elm,i,v,pointer);
                                                }
                                        }
                                }
@@ -409,7 +414,7 @@ public class LibMatrixDNNPooling {
                MatrixBlock output; 
                boolean performReluBackward;
                double [] inputArray;  MatrixBlock dout;
-               int CHW; int P; int Q; int HW; int C;
+               final int CHW; final int P; final int Q; final int HW; final 
int C;
                public PoolingBackwardDenseSparse(int rl, int ru, DnnParameters 
params, boolean performReluBackward) {
                        _rl = rl; _ru = ru;
                        _params = params;
@@ -429,31 +434,50 @@ public class LibMatrixDNNPooling {
                @Override
                public Long call() throws Exception {
 
-                       CellIndex3 ix = new CellIndex3();
                        SparseBlock sblock = dout.sparseBlock;
                        if(output.isInSparseFormat()){
                                SparseBlock out = output.getSparseBlock();
+                               final int[] i = new int[Q];
+                               final double[] v = new double[Q];  
                                for(int n = _rl; n < _ru; n++){
                                        // each row correspond to a single 
batch element.
                                        // here we allocate the sparse row.
                                        if( sblock.isEmpty(n) ) continue;
+                                       
                                        out.allocate(n, P*Q*C);
-                                       SparseRow elm = out.get(n);
-                                       int apos = sblock.pos(n);
-                                       int alen = sblock.size(n);
-                                       int[] aix = sblock.indexes(n);
-                                       double[] avals = sblock.values(n);
+                                       final SparseRow elm = out.get(n);
+                                       
+                                       final int apos = sblock.pos(n);
+                                       final int alen = sblock.size(n);
+                                       final int[] aix = sblock.indexes(n);
+                                       final double[] avals = sblock.values(n);
+
+                                       int oldP = 0;
+                                       int pointer = 0;
+                                       final int nCHW = n*CHW;
+
                                        for(int j = apos; j < apos+alen; j++) {
-                                               ix = 
LibMatrixDNNHelper.computeTensorIndexes(aix[j], P, Q, ix);
-                                               final int inputOffset = n*CHW + 
ix.ix1*HW;
-                                               int maxIndex = 
getMaxIndex(ix.ix2, ix.ix3,
-                                                       inputOffset, 
inputArray, _params, performReluBackward);
-                                               if(maxIndex != -1)
-                                                       add(elm, maxIndex - 
n*CHW, avals[j]);
+                                               final int tmp = aix[j] / Q;
+                                               final int inputOffset = nCHW + 
(tmp / P) * HW;
+                                               final int p = tmp % P;
+                                               final int q = aix[j] % Q;
+                                               if(p != oldP){
+                                                       add(elm, i, v, pointer);
+                                                       oldP = p;
+                                                       pointer = 0;
+                                               }
+                                               int maxIndex = getMaxIndex(p, 
q, inputOffset, inputArray, _params, performReluBackward);
+                                               if(maxIndex != -1){
+                                                       i[pointer] = maxIndex - 
nCHW;
+                                                       v[pointer] = avals[j];
+                                                       pointer++;
+                                               }
                                        }
+                                       add(elm, i, v, pointer);
                                }
                        }
                        else {
+                               CellIndex3 ix = new CellIndex3();
                                double[] out = output.getDenseBlockValues();
                                for(int n = _rl; n < _ru; n++)  {
                                        if( sblock.isEmpty(n) ) continue;
@@ -475,7 +499,7 @@ public class LibMatrixDNNPooling {
                        return P*Q*C*(long)(_ru - _rl);
                }
        }
-       
+
        /**
         * Performs the avgpooling backward operation for sparse error (dout)
         */
@@ -533,6 +557,10 @@ public class LibMatrixDNNPooling {
        
        /**
         * Performs the maxpooling backward operation for sparse input and 
dense error (dout)
+        * 
+        * Currently this is NOT IN USE since the sparse left part is forced 
dense.
+        * This is because this method is inefficient compared to our dense 
version.
+        * 
         */
        private static class PoolingBackwardSparseDense implements 
Callable<Long> 
        {
@@ -572,27 +600,26 @@ public class LibMatrixDNNPooling {
                        //allocate auxiliary data structures
                        double[] maxVal = new double[PQ];
                        int[] maxIx = new int[PQ];
-                       
                        for(int n = _rl; n < _ru; n++)  {
                                for (int c = 0; c < C; c++) {
                                        //step 1: perform maxpooling w/ index 
maintenance in a 
                                        //single, sequential pass over the 
sparse input matrix
-                                       maxpoolingForward(maxVal, maxIx, n, c,
+                                       boolean empty = 
maxpoolingForward(maxVal, maxIx, n, c,
                                                padh, padw, strideh, stridew, 
C, P, Q, R, S, HW, W);
-                                       
-                                       //step 2: perform maxpooling backward
-                                       if(output.isInSparseFormat())
-                                               maxpoolingBackwardSparse(maxIx, 
c*HW, n, c, C, Q, PQ, CPQ);
-                                       else
-                                               maxpoolingBackwardDense(maxIx, 
n*CHW + c*HW, n, c, C, Q, PQ, CPQ);
-                                       
+                                       if(!empty){
+                                               //step 2: perform maxpooling 
backward
+                                               if(output.isInSparseFormat())
+                                                       
maxpoolingBackwardSparse(maxIx, c*HW, n, c, C, Q, P, CPQ);
+                                               else
+                                                       
maxpoolingBackwardDense(maxIx, n*CHW + c*HW, n, c, C, Q, PQ, CPQ);
+                                       }
                                }
                        }
                        //thread-local nnz maintenance
                        return P*Q*C*(long)(_ru - _rl);
                }
                
-               protected void maxpoolingForward(double[] maxVal, int[] maxIx, 
int n, int c, int padh, int padw, int strideh, int stridew, int C, int P, int 
Q, int R, int S, int HW, int W) {
+               protected boolean maxpoolingForward(double[] maxVal, int[] 
maxIx, int n, int c, int padh, int padw, int strideh, int stridew, int C, int 
P, int Q, int R, int S, int HW, int W) {
                        SparseBlock sblock = _params.input1.getSparseBlock();
                        if( !sblock.isEmpty(n) ) {
                                Arrays.fill(maxVal, -Double.MAX_VALUE);
@@ -619,17 +646,10 @@ public class LibMatrixDNNPooling {
                                }
                                //handle skipped zero values at end of row
                                update0(lastix+1, (c+1)*HW, maxVal, maxIx, 
padh, padw, strideh, stridew, P, Q, R, S, HW, W);
+                               return false;
                        }
                        else {
-                               //handle empty row
-                               Arrays.fill(maxVal, 0);
-                               for(int p = 0, ix=0; p < P; p++) {
-                                       int h = Math.max(-padh+p*strideh, 0);
-                                       for(int q = 0; q < Q; q++, ix++) {
-                                               int w = 
Math.max(-padw+q*stridew, 0);
-                                               maxIx[ix] = h * W + w;
-                                       }
-                               }
+                               return true;
                        }
                }
                
@@ -641,14 +661,19 @@ public class LibMatrixDNNPooling {
                                out[ outOffset + maxIx[pq] ] += dout[ 
doutOffset + pq ];
                }
 
-               protected void maxpoolingBackwardSparse(int[] maxIx, int 
offset, int n, int c, int C, int Q, int PQ, int CPQ) {
+               protected void maxpoolingBackwardSparse(int[] maxIx, int 
offset, int n, int c, int C, int Q, int P, int CPQ) {
                        double[] dout = doutput.getDenseBlockValues();
                        SparseBlock out = output.getSparseBlock();
-                       out.allocate(n, PQ);
+                       out.allocate(n, P * Q);
                        SparseRow row = out.get(n);
-                       final int doutOffset = n*CPQ + c*PQ;
-                       for( int pq = 0; pq < PQ; pq++ )
-                               row.add(maxIx[pq] + offset ,dout[ doutOffset + 
pq ]);
+                       final int doutOffset = n*CPQ + c*P * Q;
+                       int pq = 0;
+                       for( int p = 0; p < P; p++ ){
+                               for(int q = 0; q < Q; q++){
+                                       row.add(maxIx[pq] + offset ,dout[ 
doutOffset + pq ]);
+                                       pq++;
+                               }
+                       }
                }
                
                private static void update0(int lix, int uix, double[] maxVal, 
int[] maxIx, int padh, int padw, int strideh, int stridew, int P, int Q, int R, 
int S, int HW, int W) {
@@ -680,6 +705,10 @@ public class LibMatrixDNNPooling {
        
        /**
         * Performs the maxpooling backward operation for sparse input and 
sparse error (dout)
+        * 
+        * Currently this is NOT IN USE since the sparse left part is forced 
dense.
+        * This is because this method is inefficient compared to our dense 
version.
+        * 
         */
        private static class PoolingBackwardSparseSparse extends 
PoolingBackwardSparseDense
        {
@@ -713,10 +742,11 @@ public class LibMatrixDNNPooling {
                }
 
                @Override
-               protected void maxpoolingBackwardSparse(int[] maxIx, int 
offset, int n, int c, int C, int Q, int PQ, int CPQ) {
+               protected void maxpoolingBackwardSparse(int[] maxIx, int 
offset, int n, int c, int C, int Q, int P, int CPQ) {
                        SparseBlock sblock = doutput.getSparseBlock();
                        if( sblock.isEmpty(n) )
                                return;
+                       final int PQ = P*Q;
                        SparseBlock out = output.getSparseBlock();
                        out.allocate(n, PQ);
                        SparseRow row = out.get(n);
@@ -769,44 +799,199 @@ public class LibMatrixDNNPooling {
                int end_index_w = params.end_indexes_w[q];
                
                int maxIndex = -1; 
-               double maxVal = -Double.MAX_VALUE;
+               double maxVal = performReluBackward ? 0 : 
Double.NEGATIVE_INFINITY;
                
                // Note: We do not treat pad as zero and hence we don't do:  
                // maxVal = 0 
                // if start_index_h < 0 || start_index_w < 0 || end_index_h >= 
params.H || end_index_w >= params.W
                
                // Find maxIndex
-               double currDoutVal = -1;
                for (int h = start_index_h; h < end_index_h; h++) {
                        for (int w = start_index_w; w < end_index_w; w++) {
                                final int idx = inputOffset +  h*params.W + w;
-                               currDoutVal = inputArray[idx];
-                               currDoutVal = performReluBackward && 
currDoutVal < 0 ? 0 : currDoutVal;
+                               final double currDoutVal = inputArray[idx];
                                if(maxVal < currDoutVal) {
                                        maxIndex = idx;
                                        maxVal = currDoutVal;
                                }
                        }
                }
-               return maxIndex;
+               return maxVal == 0 && performReluBackward ? -1 : maxIndex;
+       }
+
+       /**
+        * Add all elements in the arrays to the sparse row. It is guaranteed 
that all i is larger than all indexes already contained in row.
+        * 
+        * @param row the row to append to
+        * @param i the indexes to append
+        * @param v the values to append
+        */
+       private static void add(SparseRow row, int[] i, double[] v, int size){
+               // sort based on the i array.
+               sort(i,v, size);
+               for(int x = 0; x < size; x++){
+                       row.append(i[x], v[x]);
+               }
        }
 
 
+
        /**
-        * Add to sparse row assuming that most of the time we would append to 
the end of the sparse row.
+        * Use sorting networks for small arrays.
+        * Note small arrays here is less than 32.
         * 
-        * @param row row to add to.
-        * @param index the index in the row to add to
-        * @param v the value to add.
+        * The basic idea is to use Network sorting, that is the theoretical
+        * fewest compare and swap operations possible for a specific size 
array.
+        * 
+        * @param i indexes to sort by
+        * @param v the values to sort along side
         */
-       private static void add(SparseRow row, int index, double v){
-               final int size = row.size();
+       private static void sort(int[] i , double[] v, int size){
+               if(size > 32)
+                       LOG.warn("Not a optimal size for small array sort " + 
size);
+               switch (size) {
+                       case 1: break;
+                       case 2: comp(i,v,0,1); break;
+                       case 3: sort3(i,v); break;
+                       case 4: sort4(i,v); break;
+                       case 5: sort5(i,v); break;
+                       case 6: sort6(i,v); break;
+                       case 7: sort7(i,v); break;
+                       default:
+                               // Most cases are handled by the sorting of 
smaller arrays, 
+                               // but just in case we have a insertion sort 
here. 
+                               // Since the array is already semi sorted, it 
is okay. But not ideal once 
+                               // we see larger arrays.
+                               // Larger arrays only occur if the input data 
allow many kernels in the horizontal
+                               // dimension.
+                               insertSort(i,v, size);
+                               break;
+               }
+       }
+
+       private static void sort3(int[] i, double[] v){
+               // 3 moves
+               comp(i,v,0,2);
+               comp(i,v,0,1);
+               comp(i,v,1,2);
+       }
+
+       private static void sort4(int[] i, double[] v){
+               // 5 moves
+               // block 1
+               comp(i,v,0,2);
+               comp(i,v,1,3);
+               // block 2
+               comp(i,v,0,1);
+               comp(i,v,2,3);
+               // block 3
+               comp(i,v,1,2);
+       }
+
+       private static void sort5(int[] i, double[] v){
+               // 9 moves
+               // block 1
+               comp(i,v,0,1);
+               comp(i,v,2,3);
+               // block 2
+               comp(i,v,1,3);
+               comp(i,v,2,4);
+               // block 3
+               comp(i,v,1,4);
+               comp(i,v,0,2);
+               // block 4
+               comp(i,v,1,2);
+               comp(i,v,3,4);
+               // block 5
+               comp(i,v,2,3);
+       }
+
+       private static void sort6(int[] i, double[] v){
+               // 12 moves
+               // block 1
+               comp(i,v,0,1);
+               comp(i,v,2,3);
+               comp(i,v,4,5);
+               // block 2
+               comp(i,v,1,3);
+               // block 3
+               comp(i,v,0,4);
+               // block 4
+               comp(i,v,1,3);
+               // block 5
+               comp(i,v,1,5);
+               // block 6
+               comp(i,v,2,4);
+               // block 7
+               comp(i,v,1,2);
+               comp(i,v,3,5);
+               // block 8
+               comp(i,v,3,4);
+               // block 9
+               comp(i,v,2,3);
+       }
+
+       private static void sort7(int[] i, double[] v){
+               // 16 moves.
+               // block 1
+               comp(i,v,0,1);
+               comp(i,v,2,3);
+               comp(i,v,4,5);
+               // block 2
+               comp(i,v,0,6);
+               // block 3
+               comp(i,v,2,4);
+               // block 4
+               comp(i,v,0,2);
+               // block 5
+               comp(i,v,1,3);
+               comp(i,v,5,6);
+               // block 6
+               comp(i,v,1,4);
+               // block 7
+               comp(i,v,2,5);
+               // block 8
+               comp(i,v,1,2);
+               comp(i,v,4,5);
+               // block 9
+               comp(i,v,2,4);
+               // block 10
+               comp(i,v,3,6);
+               // block 11
+               comp(i,v,3,5);
+               // block 12
+               comp(i,v,3,4);
+       }
+
+       private static void insertSort(int[] i, double[] v, int size){
+               int p, k, j;
+               double t;
+               for(p  = 1; p < size; p++){
+                       k = i[p];
+                       t = v[p];
+                       j = p -1;
+                       while(j >= 0 && i[j] > k){
+                               i[j+1] = i[j];
+                               v[j+1] = v[j];
+                               j = j-1;
+                       }
+                       i[j+1] = k;
+                       v[j+1] = t;
+               }
+       }
+
+       private static void comp(int[] i , double[] v, int f, int t){
+               if(i[f] > i[t])
+                       swap(i,v,f,t);
+       }
 
-               if(size <= 1)
-                       row.add(index, v);
-               else if( row.indexes()[size-1] < index)
-                       row.append(index, v);
-               else
-                       row.add(index, v);
+       private static void swap(int[] i , double[] v, int f, int t){
+               int tmpI = i[f];
+               double tmpV = v[f];
+               i[f] = i[t];
+               v[f] = v[t];
+               i[t] = tmpI;
+               v[t] = tmpV; 
        }
 }
+
diff --git 
a/src/test/java/org/apache/sysds/test/applications/nn/NNMaxPool2dComponentTest.java
 
b/src/test/java/org/apache/sysds/test/applications/nn/NNMaxPool2dComponentTest.java
index dfdacb8..0be02b6 100644
--- 
a/src/test/java/org/apache/sysds/test/applications/nn/NNMaxPool2dComponentTest.java
+++ 
b/src/test/java/org/apache/sysds/test/applications/nn/NNMaxPool2dComponentTest.java
@@ -48,7 +48,7 @@ public class NNMaxPool2dComponentTest extends BaseTest {
        @Parameterized.Parameter(1)
        public int w;
 
-       final static String[] argNames =  new String[] {"$h", "$w"};
+       final static String[] argNames = new String[] {"$h", "$w"};
 
        @Test
        public void max_pool2d_padh_padw() {
diff --git a/src/test/scripts/applications/nn/component/max_pool2d.dml 
b/src/test/scripts/applications/nn/component/max_pool2d.dml
index c13bb91..0ec075a 100644
--- a/src/test/scripts/applications/nn/component/max_pool2d.dml
+++ b/src/test/scripts/applications/nn/component/max_pool2d.dml
@@ -76,7 +76,9 @@ max_pool2d_pad = function(Integer h, Integer w) {
                                           Hf, Wf, stride, stride, padh, padw)
   dX_builtin = max_pool2d_builtin::backward(dout, Hout_builtin, Wout_builtin, 
X, C, Hin, Win,
                                             Hf, Wf, stride, stride, padh, padw)
-
+  print(toString(dX))
+  print(toString(dX_simple))
+  print(toString(dX_builtin))
   # Equivalency check
   dX = matrix(dX, rows=1, cols=N*C*Hin*Win)
   dX_simple = matrix(dX_simple, rows=1, cols=N*C*Hin*Win)

Reply via email to