Repository: systemml
Updated Branches:
  refs/heads/master 15ecb723e -> 99b1c2e25


[SYSTEMML-2109] Codegen support for maxpool/avgpool DNN operations

This patch adds code generation support for maxpool and avgpool DNN
operations to the codegen row-template. This way often entire joins of
conv/maxpool/relu can be executed as fused operators without
parallelization barriers per operator.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/99b1c2e2
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/99b1c2e2
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/99b1c2e2

Branch: refs/heads/master
Commit: 99b1c2e252f935bbce5f53574d3e245221da3e68
Parents: 15ecb72
Author: Matthias Boehm <[email protected]>
Authored: Tue Jul 24 18:32:50 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Tue Jul 24 18:57:00 2018 -0700

----------------------------------------------------------------------
 src/main/java/org/apache/sysml/hops/DnnOp.java  |  6 ++
 .../apache/sysml/hops/codegen/cplan/CNode.java  | 24 ++++++++
 .../sysml/hops/codegen/cplan/CNodeNary.java     | 58 ++++++++++++++++++--
 .../sysml/hops/codegen/cplan/CNodeUnary.java    | 22 +-------
 .../hops/codegen/template/TemplateRow.java      | 14 ++++-
 .../runtime/codegen/LibSpoofPrimitives.java     | 40 ++++++++++++++
 .../matrix/data/LibMatrixDNNPooling.java        | 16 +++---
 .../functions/codegen/RowAggTmplTest.java       |  6 +-
 .../scripts/functions/codegen/rowAggPattern44.R |  1 +
 .../functions/codegen/rowAggPattern44.dml       |  1 +
 10 files changed, 154 insertions(+), 34 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/hops/DnnOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java 
b/src/main/java/org/apache/sysml/hops/DnnOp.java
index 8dbbeda..3b48371 100644
--- a/src/main/java/org/apache/sysml/hops/DnnOp.java
+++ b/src/main/java/org/apache/sysml/hops/DnnOp.java
@@ -217,6 +217,12 @@ public class DnnOp extends MultiThreadedHop
                        isEqualAndKnown(param1.H, param2.H) && 
isEqualAndKnown(param1.W, param2.W);
        }
        
+       public boolean isStride1Pad0() {
+               DnnParameters tmp = parseInput();
+               return tmp.stride_h == 1 && tmp.stride_w == 1
+                       && tmp.pad_h == 0 && tmp.pad_w == 0;
+       }
+       
        private static boolean isEqualAndKnown(int val1, int val2) {
                return val1 >= 0 && val2 >= 0 && val1 == val2;
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java
index b0efb42..6dab878 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops.codegen.cplan;
 
 import java.util.ArrayList;
 
+import org.apache.sysml.hops.codegen.template.TemplateUtils;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
 import org.apache.sysml.runtime.util.UtilFunctions;
@@ -204,4 +205,27 @@ public abstract class CNode
                        && _dataType == cthat._dataType
                        && _literal == cthat._literal;
        }
+       
+       protected String replaceUnaryPlaceholders(String tmp, String varj, 
boolean vectIn) {
+               //replace sparse and dense inputs
+               tmp = tmp.replace("%IN1v%", varj+"vals");
+               tmp = tmp.replace("%IN1i%", varj+"ix");
+               tmp = tmp.replace("%IN1%", 
+                       (vectIn && TemplateUtils.isMatrix(_inputs.get(0))) ? 
varj + ".values(rix)" :
+                       (vectIn && TemplateUtils.isRowVector(_inputs.get(0)) ? 
varj + ".values(0)" : varj));
+               
+               //replace start position of main input
+               String spos = (_inputs.get(0) instanceof CNodeData 
+                       && _inputs.get(0).getDataType().isMatrix()) ? 
!varj.startsWith("b") ? 
+                       varj+"i" : TemplateUtils.isMatrix(_inputs.get(0)) ? 
varj + ".pos(rix)" : "0" : "0";
+               
+               tmp = tmp.replace("%POS1%", spos);
+               tmp = tmp.replace("%POS2%", spos);
+               
+               //replace length
+               if( _inputs.get(0).getDataType().isMatrix() )
+                       tmp = tmp.replace("%LEN%", 
_inputs.get(0).getVectorLength());
+               
+               return tmp;
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
index 7f19194..e720601 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
@@ -20,15 +20,21 @@
 package org.apache.sysml.hops.codegen.cplan;
 
 import java.util.ArrayList;
+import java.util.List;
 
+import org.apache.commons.lang3.StringUtils;
 import org.apache.sysml.hops.codegen.template.TemplateUtils;
 import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.util.DnnUtils;
 import org.apache.sysml.runtime.util.UtilFunctions;
 
 public class CNodeNary extends CNode
 {
        public enum NaryType {
-               VECT_CBIND;
+               VECT_CBIND,
+               VECT_MAX_POOL,
+               VECT_AVG_POOL;
+               
                public static boolean contains(String value) {
                        for( NaryType bt : values() )
                                if( bt.name().equals(value) )
@@ -56,12 +62,19 @@ public class CNodeNary extends CNode
                                                off += input._cols;
                                        }
                                        return sb.toString();
+                               case VECT_MAX_POOL:
+                               case VECT_AVG_POOL:
+                                       String vectName = (this==VECT_MAX_POOL) 
? "Maxpool" : "Avgpool";
+                                       String paramStr = 
getPoolingParameterString(inputs);
+                                       return sparseGen ?
+                                               "    double[] %TMP% = 
LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, len, 
"+paramStr+");\n" : 
+                                               "    double[] %TMP% = 
LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %POS1%, %LEN%, 
"+paramStr+");\n";
                                default:
                                        throw new RuntimeException("Invalid 
nary type: "+this.toString());
                        }
                }
                public boolean isVectorPrimitive() {
-                       return this == VECT_CBIND;
+                       return this == VECT_CBIND || this == VECT_MAX_POOL || 
this == VECT_AVG_POOL;
                }
        }
        
@@ -90,10 +103,17 @@ public class CNodeNary extends CNode
                        sb.append(in.codegen(sparse));
                
                //generate nary operation (use sparse template, if data input)
+               boolean lsparse = sparse && (_inputs.get(0) instanceof CNodeData
+                       && _inputs.get(0).getVarname().startsWith("a")
+                       && !_inputs.get(0).isLiteral());
                String var = createVarname();
-               String tmp = _type.getTemplate(sparse, _cols, _inputs);
+               String tmp = _type.getTemplate(lsparse, _cols, _inputs);
                tmp = tmp.replace("%TMP%", var);
                
+               //replace sparse and dense inputs
+               String varj = _inputs.get(0).getVarname();
+               tmp = replaceUnaryPlaceholders(tmp, varj, false);
+               
                sb.append(tmp);
                
                //mark as generated
@@ -105,7 +125,9 @@ public class CNodeNary extends CNode
        @Override
        public String toString() {
                switch(_type) {
-                       case VECT_CBIND: return "n(cbind)";
+                       case VECT_CBIND:    return "n(cbind)";
+                       case VECT_MAX_POOL: return "n(maxpool)";
+                       case VECT_AVG_POOL: return "n(avgpool)";
                        default:
                                return "m("+_type.name().toLowerCase()+")";
                }
@@ -121,6 +143,19 @@ public class CNodeNary extends CNode
                                        _cols += in._cols;
                                _dataType = DataType.MATRIX;
                                break;
+                       case VECT_MAX_POOL:
+                       case VECT_AVG_POOL: //only stride 1, pad 0
+                               int C = 
Integer.parseInt(_inputs.get(6).getVarname());
+                               int H = 
Integer.parseInt(_inputs.get(7).getVarname());
+                               int W = 
Integer.parseInt(_inputs.get(8).getVarname());
+                               int R = 
Integer.parseInt(_inputs.get(11).getVarname());
+                               int S = 
Integer.parseInt(_inputs.get(12).getVarname());
+                               long P = DnnUtils.getP(H, R, 1, 0);
+                               long Q = DnnUtils.getQ(W, S, 1, 0);
+                               _rows = _inputs.get(0)._rows;
+                               _cols =  C * P * Q;
+                               _dataType = DataType.MATRIX;
+                               break;
                }
        }
        
@@ -142,4 +177,19 @@ public class CNodeNary extends CNode
                return super.equals(that)
                        && _type == that._type;
        }
+       
+       private static String getPoolingParameterString(List<CNode> inputs) {
+               //extract and derive individual parameters
+               int C = Integer.parseInt(inputs.get(6).getVarname());
+               int H = Integer.parseInt(inputs.get(7).getVarname());
+               int W = Integer.parseInt(inputs.get(8).getVarname());
+               int R = Integer.parseInt(inputs.get(11).getVarname());
+               int S = Integer.parseInt(inputs.get(12).getVarname());
+               int P = (int) DnnUtils.getP(H, R, 1, 0);
+               int Q = (int) DnnUtils.getQ(W, S, 1, 0);
+               
+               //construct parameter string
+               return "rix, " + StringUtils.join(
+                       new int[]{C, P, Q, R, S, H, W}, ',');
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
index ba41fad..21f7fe7 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
@@ -23,7 +23,6 @@ import java.util.Arrays;
 
 import org.apache.commons.lang.ArrayUtils;
 import org.apache.commons.lang.StringUtils;
-import org.apache.sysml.hops.codegen.template.TemplateUtils;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.runtime.util.UtilFunctions;
 
@@ -214,27 +213,10 @@ public class CNodeUnary extends CNode
                String tmp = _type.getTemplate(lsparse);
                tmp = tmp.replace("%TMP%", var);
                
-               String varj = _inputs.get(0).getVarname();
-               
                //replace sparse and dense inputs
+               String varj = _inputs.get(0).getVarname();
                boolean vectIn = varj.startsWith("b") && 
!_type.isScalarLookup();
-               tmp = tmp.replace("%IN1v%", varj+"vals");
-               tmp = tmp.replace("%IN1i%", varj+"ix");
-               tmp = tmp.replace("%IN1%", 
-                       (vectIn && TemplateUtils.isMatrix(_inputs.get(0))) ? 
varj + ".values(rix)" :
-                       (vectIn && TemplateUtils.isRowVector(_inputs.get(0)) ? 
varj + ".values(0)" : varj));
-               
-               //replace start position of main input
-               String spos = (_inputs.get(0) instanceof CNodeData 
-                       && _inputs.get(0).getDataType().isMatrix()) ? 
!varj.startsWith("b") ? 
-                       varj+"i" : TemplateUtils.isMatrix(_inputs.get(0)) ? 
varj + ".pos(rix)" : "0" : "0";
-               
-               tmp = tmp.replace("%POS1%", spos);
-               tmp = tmp.replace("%POS2%", spos);
-               
-               //replace length
-               if( _inputs.get(0).getDataType().isMatrix() )
-                       tmp = tmp.replace("%LEN%", 
_inputs.get(0).getVectorLength());
+               tmp = replaceUnaryPlaceholders(tmp, varj, vectIn);
                
                sb.append(tmp);
                

http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index 9df67d0..9885909 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -115,7 +115,9 @@ public class TemplateRow extends TemplateBase
                                && 
HopRewriteUtils.isColumnRangeIndexing((IndexingOp)hop))
                        || (HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD, 
OpOpDnn.BIASMULT)
                                && hop.getInput().get(0).dimsKnown() && 
hop.getInput().get(1).dimsKnown()
-                               && hop.getInput().get(0).getDim2()>1);
+                               && hop.getInput().get(0).getDim2()>1)
+                       || (HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, 
OpOpDnn.AVG_POOL)
+                               && hop.getInput().get(0).dimsKnown() && 
((DnnOp)hop).isStride1Pad0());
        }
 
        @Override
@@ -140,6 +142,8 @@ public class TemplateRow extends TemplateBase
                        || (HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD, 
OpOpDnn.BIASMULT)
                                && hop.getInput().get(0).dimsKnown() && 
hop.getInput().get(1).dimsKnown()
                                && hop.getInput().get(0).getDim2()>1)
+                       || (HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, 
OpOpDnn.AVG_POOL)
+                                       && hop.getInput().get(0).dimsKnown() && 
((DnnOp)hop).isStride1Pad0())
                        || isPartOfValidCumAggChain(hop) //cum* with transpose
                        || isPartOfValidTransposeMMChain(hop)); //t(f(X))%*%X
        }
@@ -156,6 +160,8 @@ public class TemplateRow extends TemplateBase
                        || (HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD, 
OpOpDnn.BIASMULT)
                                && hop.getInput().get(0).dimsKnown() && 
hop.getInput().get(1).dimsKnown()
                                && hop.getInput().get(0).getDim2()>1 )
+                       || (HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, 
OpOpDnn.AVG_POOL)
+                                       && hop.getInput().get(0).dimsKnown() && 
((DnnOp)hop).isStride1Pad0())
                        || (HopRewriteUtils.isDataGenOpWithLiteralInputs(input, 
DataGenMethod.SEQ)
                                && 
HopRewriteUtils.hasOnlyUnaryBinaryParents(input, false))
                        || (hop instanceof AggBinaryOp
@@ -476,6 +482,12 @@ public class TemplateRow extends TemplateBase
                        out = new CNodeBinary(cdata1, cdata2,
                                
BinType.valueOf("VECT_"+((DnnOp)hop).getOp().name()));
                }
+               else if( HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, 
OpOpDnn.AVG_POOL) ) {
+                       CNode[] in = hop.getInput().stream().map(h ->
+                               tmp.get(h.getHopID())).toArray(CNode[]::new);
+                       out = new CNodeNary(in, CNodeNary.NaryType
+                               .valueOf("VECT_"+((DnnOp)hop).getOp().name()));
+               }
                else if( hop instanceof NaryOp ) {
                        CNode[] inputs = new CNode[hop.getInput().size()];
                        for( int i=0; i<hop.getInput().size(); i++ ) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java 
b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
index fc0c1d2..c1460ce 100644
--- a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
+++ b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
@@ -26,7 +26,9 @@ import org.apache.sysml.runtime.functionobjects.BitwAnd;
 import org.apache.sysml.runtime.functionobjects.IntegerDivide;
 import org.apache.sysml.runtime.functionobjects.Modulus;
 import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
+import org.apache.sysml.runtime.matrix.data.LibMatrixDNNPooling;
 import org.apache.sysml.runtime.matrix.data.LibMatrixMult;
+import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType;
 
 /**
  * This library contains all vector primitives that are used in 
@@ -2052,7 +2054,45 @@ public class LibSpoofPrimitives
                LibMatrixDNN.multBias(c, b, 1, b.length, len/b.length);
                return c;
        }
+       
+       //maxpool
+       
+       public static double[] vectMaxpoolWrite(double[] a, int ai, int len, 
int rix, int C, int P, int Q, int R, int S, int H, int W) {
+               double[] c = allocVector(C*P*Q, true);
+               LibMatrixDNNPooling.poolingDenseStride1Pad0(PoolingType.MAX,
+                       -Double.MAX_VALUE, 1, a, c, rix, rix+1, ai, 0, C, P, Q, 
R, S, H, W);
+               return c;
+       } 
+       
+       public static double[] vectMaxpoolWrite(double[] avals, int[] aix, int 
ai, int alen, int len, int rix, int C, int P, int Q, int R, int S, int H, int 
W) {
+               double[] a = allocVector(len, true);
+               double[] c = allocVector(C*P*Q, true);
+               for(int k=ai; k<ai+alen; k++)
+                       a[aix[k]] = avals[k];
+               LibMatrixDNNPooling.poolingDenseStride1Pad0(PoolingType.MAX,
+                       -Double.MAX_VALUE, 1, a, c, rix, rix+1, 0, 0, C, P, Q, 
R, S, H, W);
+               return c;
+       }
+       
+       //avgpool
 
+       public static double[] vectAvgpoolWrite(double[] a, int ai, int len, 
int rix, int C, int P, int Q, int R, int S, int H, int W) {
+               double[] c = allocVector(C*P*Q, true);
+               LibMatrixDNNPooling.poolingDenseStride1Pad0(PoolingType.AVG,
+                       0, 1/(R*S), a, c, rix, rix+1, ai, 0, C, P, Q, R, S, H, 
W);
+               return c;
+       } 
+       
+       public static double[] vectAvgpoolWrite(double[] avals, int[] aix, int 
ai, int alen, int len, int rix, int C, int P, int Q, int R, int S, int H, int 
W) {
+               double[] a = allocVector(len, true);
+               double[] c = allocVector(C*P*Q, true);
+               for(int k=ai; k<ai+alen; k++)
+                       a[aix[k]] = avals[k];
+               LibMatrixDNNPooling.poolingDenseStride1Pad0(PoolingType.AVG,
+                       0, 1/(R*S), a, c, rix, rix+1, 0, 0, C, P, Q, R, S, H, 
W);
+               return c;
+       }
+       
        //complex builtin functions that are not directly generated
        //(included here in order to reduce the number of imports)
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPooling.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPooling.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPooling.java
index 4ff8e5e..7fb33a4 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPooling.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNPooling.java
@@ -97,8 +97,8 @@ public class LibMatrixDNNPooling {
                return ret;
        }
        
-       public static void poolingDenseStride1Pad0(PoolingType pType, double 
minVal, double pFact,
-                       double[] in, double[] out, int rl, int ru, int C, int 
P, int Q, int R, int S, int H, int W) {
+       public static void poolingDenseStride1Pad0(PoolingType pType, double 
minVal, double pFact, double[] in,
+                       double[] out, int rl, int ru, int ii, int oi, int C, 
int P, int Q, int R, int S, int H, int W) {
                boolean max = (pType == PoolingType.MAX);
                int CHW = C * H * W;
                
@@ -106,9 +106,9 @@ public class LibMatrixDNNPooling {
                        //quick-path w/o materialized index arrays and 
                        //simplified inner loops for P = 1, Q = 1, W = 1
                        int lenh = Math.min(R,H);
-                       for(int i = rl, oix=rl*C; i < ru; i++, oix+=C)
-                               for (int c = 0, off=i*CHW; c < C; c++, off+=H) {
-                                       out[oix+c] = max ? max(minVal, in, off, 
lenh) :
+                       for(int i = rl; i < ru; i++, oi+=C)
+                               for (int c = 0, off=ii+(i-rl)*CHW; c < C; c++, 
off+=H) {
+                                       out[oi+c] = max ? max(minVal, in, off, 
lenh) :
                                                avg(minVal, in, off, lenh, 
pFact);
                                }
                }
@@ -117,7 +117,7 @@ public class LibMatrixDNNPooling {
                        Arrays.fill(out, rl*CPQ, ru*CPQ, minVal);
                        //quick-path w/o materialized index arrays 
                        for(int i = rl; i < ru; i++)
-                               for (int c = 0, off=i*CHW, oix=i*CPQ; c < C; 
c++, off+=HW)
+                               for (int c = 0, off=ii+(i-rl)*CHW, oix=oi; c < 
C; c++, off+=HW)
                                        for (int p = 0; p < P; p++, oix+=Q)
                                                for (int h = p; h < 
Math.min(p+R,H); h++)
                                                        for (int q = 0, 
off2=off+h*W; q < Q; q++) {
@@ -139,7 +139,7 @@ public class LibMatrixDNNPooling {
                        _rl = rl; _ru = ru;
                        _params = params;
                        _poolingType = poolingType;
-                       _poolingMultiplier = Math.pow(params.R*params.S, -1);
+                       _poolingMultiplier = 1/(params.R*params.S);
                }
                
                @Override
@@ -157,7 +157,7 @@ public class LibMatrixDNNPooling {
                        
                        if( _params.isStride1Pad0() ) {
                                poolingDenseStride1Pad0(_poolingType, 
minValForMaxPoolOperations,
-                                       _poolingMultiplier, in, out, _rl, _ru, 
C, P, Q, R, S, H, W);
+                                       _poolingMultiplier, in, out, _rl, _ru, 
_rl*CHW, _rl*CPQ, C, P, Q, R, S, H, W);
                        }
                        else { //general case
                                //thread-local initialization of output block 

http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
index 04891d0..48555ae 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
@@ -80,7 +80,7 @@ public class RowAggTmplTest extends AutomatedTestBase
        private static final String TEST_NAME41 = TEST_NAME+"41"; 
//X*rowSums(X/seq(1,N)+t(seq(M,1)))
        private static final String TEST_NAME42 = TEST_NAME+"42"; 
//X/rowSums(min(X, Y, Z))
        private static final String TEST_NAME43 = TEST_NAME+"43"; 
//bias_add(X,B) + bias_mult(X,B)
-       private static final String TEST_NAME44 = TEST_NAME+"44"; //maxpool(X - 
mean(X));
+       private static final String TEST_NAME44 = TEST_NAME+"44"; //maxpool(X - 
mean(X)) + 7;
        
        private static final String TEST_DIR = "functions/codegen/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RowAggTmplTest.class.getSimpleName() + "/";
@@ -817,6 +817,10 @@ public class RowAggTmplTest extends AutomatedTestBase
                        if( testname.equals(TEST_NAME42) )
                                
Assert.assertTrue(!heavyHittersContainsSubString("min","nmin") 
                                        && 
!heavyHittersContainsSubString("spoof", 2));
+                       if( testname.equals(TEST_NAME44) )
+                               
Assert.assertTrue(!heavyHittersContainsSubString("maxpooling") 
+                                       && 
!heavyHittersContainsSubString("spoof", 2));
+                       
                }
                finally {
                        rtplatform = platformOld;

http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/test/scripts/functions/codegen/rowAggPattern44.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern44.R 
b/src/test/scripts/functions/codegen/rowAggPattern44.R
index 99ba0b0..7269df3 100644
--- a/src/test/scripts/functions/codegen/rowAggPattern44.R
+++ b/src/test/scripts/functions/codegen/rowAggPattern44.R
@@ -95,5 +95,6 @@ max_pool <- function(X, N, C, Hin, Win, Hf, Wf,
 }
 
 R = max_pool(X, numImg, numChannels, imgSize*imgSize, 1, poolSize1, poolSize2, 
stride, stride)
+R = R + 7;
 
 writeMM(as(R,"CsparseMatrix"), paste(args[2], "S", sep=""))

http://git-wip-us.apache.org/repos/asf/systemml/blob/99b1c2e2/src/test/scripts/functions/codegen/rowAggPattern44.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern44.dml 
b/src/test/scripts/functions/codegen/rowAggPattern44.dml
index f236451..f5e7b6c 100644
--- a/src/test/scripts/functions/codegen/rowAggPattern44.dml
+++ b/src/test/scripts/functions/codegen/rowAggPattern44.dml
@@ -31,5 +31,6 @@ while(FALSE){}
 
 X = X - rowMeans(X);
 R = max_pool(X, stride=[stride, stride], padding=[pad, pad], 
input_shape=[numImg, numChannels, imgSize*imgSize, 1], pool_size=[poolSize1, 
poolSize2]);
+R = R + 7;
 
 write(R, $1, format="text");

Reply via email to