[SYSTEMML-2067] Codegen support for im2col/conv2d DNN operations

This patch adds codegen support in row templates for DNN conv2d
operations. Specially, we generated row-wise im2col and conv2d-mm
operations, which allows for CSE of im2col if multiple conv2d operations
are fused into the same row-wise operator. 


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/fce9d978
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/fce9d978
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/fce9d978

Branch: refs/heads/master
Commit: fce9d978d00bf5d42d6be9a475dee346e1844e00
Parents: b3b50c0
Author: Matthias Boehm <[email protected]>
Authored: Thu Aug 2 22:16:23 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Thu Aug 2 22:54:08 2018 -0700

----------------------------------------------------------------------
 src/main/java/org/apache/sysml/hops/DnnOp.java  |  26 ++---
 .../sysml/hops/codegen/cplan/CNodeNary.java     |  95 +++++++++++++---
 .../hops/codegen/template/TemplateRow.java      |  23 ++--
 .../runtime/codegen/LibSpoofPrimitives.java     |  38 ++++++-
 .../runtime/matrix/data/LibMatrixDNNIm2Col.java |   7 +-
 .../runtime/matrix/data/LibMatrixMult.java      |   7 +-
 .../functions/codegen/RowAggTmplTest.java       |  29 ++++-
 .../codegen/RowConv2DOperationsTest.java        |   5 +-
 .../scripts/functions/codegen/rowAggPattern46.R | 107 +++++++++++++++++++
 .../functions/codegen/rowAggPattern46.dml       |  43 ++++++++
 10 files changed, 326 insertions(+), 54 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/fce9d978/src/main/java/org/apache/sysml/hops/DnnOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java 
b/src/main/java/org/apache/sysml/hops/DnnOp.java
index 3b48371..4e22f59 100644
--- a/src/main/java/org/apache/sysml/hops/DnnOp.java
+++ b/src/main/java/org/apache/sysml/hops/DnnOp.java
@@ -601,15 +601,13 @@ public class DnnOp extends MultiThreadedHop
                                || op == OpOpDnn.CONV2D 
                                || op == OpOpDnn.CONV2D_BACKWARD_FILTER
                                || op == OpOpDnn.CONV2D_BACKWARD_DATA) {
-                       imageHeightHop = getInput().get(8);
-                       filterHeightHop = getInput().get(12);
                        _cachedParams.setIfUnknown(
                                        getInput().get(6),  // N
                                        getInput().get(7),  // C
-                                       imageHeightHop,     // H
+                                       getInput().get(8),  // H
                                        getInput().get(9),  // W
                                        getInput().get(10), // K
-                                       filterHeightHop,    // R
+                                       getInput().get(12), // R
                                        getInput().get(13), // S
                                        getInput().get(2),  // stride_h
                                        getInput().get(3),  // stride_w
@@ -617,19 +615,17 @@ public class DnnOp extends MultiThreadedHop
                                        getInput().get(5), _maxNumThreads);
                }
                else {
-                       imageHeightHop = getInput().get(7);
-                       filterHeightHop = getInput().get(11);
                        _cachedParams.setIfUnknown(
                                        getInput().get(5),
-                                       getInput().get(6), 
-                                       imageHeightHop, 
-                                       getInput().get(8), 
-                                       getInput().get(9), 
-                                       filterHeightHop, 
-                                       getInput().get(12), 
-                                       getInput().get(1), 
-                                       getInput().get(2), 
-                                       getInput().get(3), 
+                                       getInput().get(6),
+                                       getInput().get(7),
+                                       getInput().get(8),
+                                       getInput().get(9),
+                                       getInput().get(11),
+                                       getInput().get(12),
+                                       getInput().get(1),
+                                       getInput().get(2),
+                                       getInput().get(3),
                                        getInput().get(4), _maxNumThreads);
                }
                

http://git-wip-us.apache.org/repos/asf/systemml/blob/fce9d978/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
index e720601..28e47f4 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
@@ -33,7 +33,9 @@ public class CNodeNary extends CNode
        public enum NaryType {
                VECT_CBIND,
                VECT_MAX_POOL,
-               VECT_AVG_POOL;
+               VECT_AVG_POOL,
+               VECT_IM2COL,
+               VECT_CONV2DMM;
                
                public static boolean contains(String value) {
                        for( NaryType bt : values() )
@@ -63,18 +65,30 @@ public class CNodeNary extends CNode
                                        }
                                        return sb.toString();
                                case VECT_MAX_POOL:
-                               case VECT_AVG_POOL:
+                               case VECT_AVG_POOL: {
                                        String vectName = (this==VECT_MAX_POOL) 
? "Maxpool" : "Avgpool";
-                                       String paramStr = 
getPoolingParameterString(inputs);
+                                       String paramStr = 
getDnnParameterString(inputs, true);
                                        return sparseGen ?
                                                "    double[] %TMP% = 
LibSpoofPrimitives.vect"+vectName+"Write(%IN1v%, %IN1i%, %POS1%, alen, len, 
"+paramStr+");\n" : 
                                                "    double[] %TMP% = 
LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %POS1%, %LEN%, 
"+paramStr+");\n";
+                               }
+                               case VECT_IM2COL: {
+                                       String paramStr = 
getDnnParameterString(inputs, true);
+                                       return sparseGen ?
+                                               "    double[] %TMP% = 
LibSpoofPrimitives.vectIm2colWrite(%IN1v%, %IN1i%, %POS1%, alen, len, 
"+paramStr+");\n" : 
+                                               "    double[] %TMP% = 
LibSpoofPrimitives.vectIm2colWrite(%IN1%, %POS1%, %LEN%, "+paramStr+");\n";
+                               }
+                               case VECT_CONV2DMM: {
+                                       return "    double[] %TMP% = 
LibSpoofPrimitives.vectConv2dmmWrite(%IN2%, %IN1%, %POS2%, %POS1%, %LEN%, "
+                                               + getDnnParameterString(inputs, 
false) +");\n";
+                               }
                                default:
                                        throw new RuntimeException("Invalid 
nary type: "+this.toString());
                        }
                }
                public boolean isVectorPrimitive() {
-                       return this == VECT_CBIND || this == VECT_MAX_POOL || 
this == VECT_AVG_POOL;
+                       return this == VECT_CBIND || this == VECT_MAX_POOL || 
this == VECT_AVG_POOL
+                               || this == VECT_IM2COL || this == 
NaryType.VECT_CONV2DMM;
                }
        }
        
@@ -111,8 +125,11 @@ public class CNodeNary extends CNode
                tmp = tmp.replace("%TMP%", var);
                
                //replace sparse and dense inputs
-               String varj = _inputs.get(0).getVarname();
-               tmp = replaceUnaryPlaceholders(tmp, varj, false);
+               String varj1 = _inputs.get(0).getVarname();
+               String varj2 = _inputs.get(1).getVarname();
+               tmp = (_type == NaryType.VECT_CONV2DMM) ?
+                       replaceBinaryPlaceholders(tmp, new 
String[]{varj1,varj2}, false) :
+                       replaceUnaryPlaceholders(tmp, varj1, false);
                
                sb.append(tmp);
                
@@ -128,6 +145,8 @@ public class CNodeNary extends CNode
                        case VECT_CBIND:    return "n(cbind)";
                        case VECT_MAX_POOL: return "n(maxpool)";
                        case VECT_AVG_POOL: return "n(avgpool)";
+                       case VECT_IM2COL:   return "n(im2col)";
+                       case VECT_CONV2DMM: return "n(conv2dmm)";
                        default:
                                return "m("+_type.name().toLowerCase()+")";
                }
@@ -144,7 +163,7 @@ public class CNodeNary extends CNode
                                _dataType = DataType.MATRIX;
                                break;
                        case VECT_MAX_POOL:
-                       case VECT_AVG_POOL: //only stride 1, pad 0
+                       case VECT_AVG_POOL: { //only stride 1, pad 0
                                int C = 
Integer.parseInt(_inputs.get(6).getVarname());
                                int H = 
Integer.parseInt(_inputs.get(7).getVarname());
                                int W = 
Integer.parseInt(_inputs.get(8).getVarname());
@@ -152,10 +171,28 @@ public class CNodeNary extends CNode
                                int S = 
Integer.parseInt(_inputs.get(12).getVarname());
                                long P = DnnUtils.getP(H, R, 1, 0);
                                long Q = DnnUtils.getQ(W, S, 1, 0);
-                               _rows = _inputs.get(0)._rows;
+                               _rows = _inputs.get(0)._rows; //N
                                _cols =  C * P * Q;
                                _dataType = DataType.MATRIX;
                                break;
+                       }
+                       case VECT_IM2COL:
+                               _rows = 1;
+                               _cols = -1;
+                               _dataType = DataType.MATRIX;
+                               break;
+                       case VECT_CONV2DMM: {
+                               int H = 
Integer.parseInt(_inputs.get(8).getVarname());
+                               int W = 
Integer.parseInt(_inputs.get(9).getVarname());
+                               int K = 
Integer.parseInt(_inputs.get(10).getVarname());
+                               int R = 
Integer.parseInt(_inputs.get(12).getVarname());
+                               int S = 
Integer.parseInt(_inputs.get(13).getVarname());
+                               long P = DnnUtils.getP(H, R, 1, 0);
+                               long Q = DnnUtils.getQ(W, S, 1, 0);
+                               _rows = _inputs.get(0)._rows; //N
+                               _cols = K * P * Q;
+                               _dataType = DataType.MATRIX;
+                       }
                }
        }
        
@@ -178,18 +215,46 @@ public class CNodeNary extends CNode
                        && _type == that._type;
        }
        
-       private static String getPoolingParameterString(List<CNode> inputs) {
+       private static String getDnnParameterString(List<CNode> inputs, boolean 
unary) {
+               int off = unary ? 0 : 1;
+               
                //extract and derive individual parameters
-               int C = Integer.parseInt(inputs.get(6).getVarname());
-               int H = Integer.parseInt(inputs.get(7).getVarname());
-               int W = Integer.parseInt(inputs.get(8).getVarname());
-               int R = Integer.parseInt(inputs.get(11).getVarname());
-               int S = Integer.parseInt(inputs.get(12).getVarname());
+               int C = Integer.parseInt(inputs.get(off+6).getVarname());
+               int H = Integer.parseInt(inputs.get(off+7).getVarname());
+               int W = Integer.parseInt(inputs.get(off+8).getVarname());
+               int K = Integer.parseInt(inputs.get(off+9).getVarname());
+               int R = Integer.parseInt(inputs.get(off+11).getVarname());
+               int S = Integer.parseInt(inputs.get(off+12).getVarname());
                int P = (int) DnnUtils.getP(H, R, 1, 0);
                int Q = (int) DnnUtils.getQ(W, S, 1, 0);
                
                //construct parameter string
                return "rix, " + StringUtils.join(
-                       new int[]{C, P, Q, R, S, H, W}, ',');
+                       new int[]{C, P, Q, K, R, S, H, W}, ',');
+       }
+       
+
+       private String replaceBinaryPlaceholders(String tmp, String[] vars, 
boolean vectIn) {
+               //replace sparse and dense inputs
+               for( int j=0; j<2; j++ ) {
+                       String varj = vars[j];
+                       
+                       //replace sparse and dense inputs
+                       tmp = tmp.replace("%IN"+(j+1)+"v%", varj+"vals");
+                       tmp = tmp.replace("%IN"+(j+1)+"i%", varj+"ix");
+                       tmp = tmp.replace("%IN"+(j+1)+"%", 
+                               varj.startsWith("b") ? varj + ".values(rix)" : 
varj );
+                       
+                       //replace start position of main input
+                       tmp = tmp.replace("%POS"+(j+1)+"%", (_inputs.get(j) 
instanceof CNodeData 
+                               && _inputs.get(j).getDataType().isMatrix()) ? 
!varj.startsWith("b") ? varj+"i" : 
+                               (TemplateUtils.isMatrix(_inputs.get(j)) && 
_type!=NaryType.VECT_CONV2DMM) ? varj + ".pos(rix)" : "0" : "0");
+               }
+               
+               //replace length
+               if( _inputs.get(0).getDataType().isMatrix() )
+                       tmp = tmp.replace("%LEN%", 
_inputs.get(0).getVectorLength());
+               
+               return tmp;
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/fce9d978/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index 9885909..8f7ea1a 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -116,8 +116,9 @@ public class TemplateRow extends TemplateBase
                        || (HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD, 
OpOpDnn.BIASMULT)
                                && hop.getInput().get(0).dimsKnown() && 
hop.getInput().get(1).dimsKnown()
                                && hop.getInput().get(0).getDim2()>1)
-                       || (HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, 
OpOpDnn.AVG_POOL)
-                               && hop.getInput().get(0).dimsKnown() && 
((DnnOp)hop).isStride1Pad0());
+                       || (HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, 
OpOpDnn.AVG_POOL, OpOpDnn.CONV2D)
+                               && hop.getInput().get(0).dimsKnown() && 
((DnnOp)hop).isStride1Pad0()
+                               && hop.getInput().get(1).dimsKnown()); //for 
conv2d
        }
 
        @Override
@@ -142,8 +143,9 @@ public class TemplateRow extends TemplateBase
                        || (HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD, 
OpOpDnn.BIASMULT)
                                && hop.getInput().get(0).dimsKnown() && 
hop.getInput().get(1).dimsKnown()
                                && hop.getInput().get(0).getDim2()>1)
-                       || (HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, 
OpOpDnn.AVG_POOL)
-                                       && hop.getInput().get(0).dimsKnown() && 
((DnnOp)hop).isStride1Pad0())
+                       || (HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, 
OpOpDnn.AVG_POOL, OpOpDnn.CONV2D)
+                               && hop.getInput().get(0).dimsKnown() && 
((DnnOp)hop).isStride1Pad0()
+                               && hop.getInput().get(1).dimsKnown() && 
hop.getInput().get(1)!=input) //for conv2d
                        || isPartOfValidCumAggChain(hop) //cum* with transpose
                        || isPartOfValidTransposeMMChain(hop)); //t(f(X))%*%X
        }
@@ -160,8 +162,9 @@ public class TemplateRow extends TemplateBase
                        || (HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD, 
OpOpDnn.BIASMULT)
                                && hop.getInput().get(0).dimsKnown() && 
hop.getInput().get(1).dimsKnown()
                                && hop.getInput().get(0).getDim2()>1 )
-                       || (HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, 
OpOpDnn.AVG_POOL)
-                                       && hop.getInput().get(0).dimsKnown() && 
((DnnOp)hop).isStride1Pad0())
+                       || (HopRewriteUtils.isDnn(hop, OpOpDnn.MAX_POOL, 
OpOpDnn.AVG_POOL, OpOpDnn.CONV2D)
+                               && hop.getInput().get(0).dimsKnown() && 
((DnnOp)hop).isStride1Pad0()
+                               && hop.getInput().get(1).dimsKnown() && 
hop.getInput().get(1)!=input) //for conv2d
                        || (HopRewriteUtils.isDataGenOpWithLiteralInputs(input, 
DataGenMethod.SEQ)
                                && 
HopRewriteUtils.hasOnlyUnaryBinaryParents(input, false))
                        || (hop instanceof AggBinaryOp
@@ -488,6 +491,14 @@ public class TemplateRow extends TemplateBase
                        out = new CNodeNary(in, CNodeNary.NaryType
                                .valueOf("VECT_"+((DnnOp)hop).getOp().name()));
                }
+               else if( HopRewriteUtils.isDnn(hop, OpOpDnn.CONV2D) ) {
+                       CNode[] in1 = hop.getInput().stream().filter(h -> 
h!=hop.getInput().get(1))
+                               .map(h 
->tmp.get(h.getHopID())).toArray(CNode[]::new);
+                       CNode im2col = new CNodeNary(in1, 
CNodeNary.NaryType.VECT_IM2COL);
+                       CNode[] in2 = hop.getInput().stream().map(h -> 
(h==hop.getInput().get(0)) ?
+                               im2col : 
tmp.get(h.getHopID())).toArray(CNode[]::new);
+                       out = new CNodeNary(in2, 
CNodeNary.NaryType.VECT_CONV2DMM);
+               }
                else if( hop instanceof NaryOp ) {
                        CNode[] inputs = new CNode[hop.getInput().size()];
                        for( int i=0; i<hop.getInput().size(); i++ ) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/fce9d978/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java 
b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
index c1460ce..e0c1c57 100644
--- a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
+++ b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
@@ -25,10 +25,12 @@ import org.apache.commons.math3.util.FastMath;
 import org.apache.sysml.runtime.functionobjects.BitwAnd;
 import org.apache.sysml.runtime.functionobjects.IntegerDivide;
 import org.apache.sysml.runtime.functionobjects.Modulus;
+import org.apache.sysml.runtime.matrix.data.DenseBlockDRB;
 import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
 import org.apache.sysml.runtime.matrix.data.LibMatrixDNNPooling;
 import org.apache.sysml.runtime.matrix.data.LibMatrixMult;
 import org.apache.sysml.runtime.matrix.data.LibMatrixDNN.PoolingType;
+import org.apache.sysml.runtime.matrix.data.LibMatrixDNNIm2Col;
 
 /**
  * This library contains all vector primitives that are used in 
@@ -2057,14 +2059,14 @@ public class LibSpoofPrimitives
        
        //maxpool
        
-       public static double[] vectMaxpoolWrite(double[] a, int ai, int len, 
int rix, int C, int P, int Q, int R, int S, int H, int W) {
+       public static double[] vectMaxpoolWrite(double[] a, int ai, int len, 
int rix, int C, int P, int Q, int K, int R, int S, int H, int W) {
                double[] c = allocVector(C*P*Q, true);
                LibMatrixDNNPooling.poolingDenseStride1Pad0(PoolingType.MAX,
                        -Double.MAX_VALUE, 1, a, c, rix, rix+1, ai, 0, C, P, Q, 
R, S, H, W);
                return c;
        } 
        
-       public static double[] vectMaxpoolWrite(double[] avals, int[] aix, int 
ai, int alen, int len, int rix, int C, int P, int Q, int R, int S, int H, int 
W) {
+       public static double[] vectMaxpoolWrite(double[] avals, int[] aix, int 
ai, int alen, int len, int rix, int C, int P, int Q, int K, int R, int S, int 
H, int W) {
                double[] a = allocVector(len, true);
                double[] c = allocVector(C*P*Q, true);
                for(int k=ai; k<ai+alen; k++)
@@ -2076,14 +2078,14 @@ public class LibSpoofPrimitives
        
        //avgpool
 
-       public static double[] vectAvgpoolWrite(double[] a, int ai, int len, 
int rix, int C, int P, int Q, int R, int S, int H, int W) {
+       public static double[] vectAvgpoolWrite(double[] a, int ai, int len, 
int rix, int C, int P, int Q, int K, int R, int S, int H, int W) {
                double[] c = allocVector(C*P*Q, true);
                LibMatrixDNNPooling.poolingDenseStride1Pad0(PoolingType.AVG,
                        0, 1/(R*S), a, c, rix, rix+1, ai, 0, C, P, Q, R, S, H, 
W);
                return c;
        } 
        
-       public static double[] vectAvgpoolWrite(double[] avals, int[] aix, int 
ai, int alen, int len, int rix, int C, int P, int Q, int R, int S, int H, int 
W) {
+       public static double[] vectAvgpoolWrite(double[] avals, int[] aix, int 
ai, int alen, int len, int rix, int C, int P, int Q, int K, int R, int S, int 
H, int W) {
                double[] a = allocVector(len, true);
                double[] c = allocVector(C*P*Q, true);
                for(int k=ai; k<ai+alen; k++)
@@ -2093,6 +2095,34 @@ public class LibSpoofPrimitives
                return c;
        }
        
+       //im2col
+       
+       public static double[] vectIm2colWrite(double[] a, int ai, int len, int 
rix, int C, int P, int Q, int K, int R, int S, int H, int W) {
+               double[] c = allocVector(C*R*S * P*Q, true);
+               LibMatrixDNNIm2Col.im2colDenseStride1Pad0(a, c, ai, C, R, S, H, 
W, P, Q);
+               return c;
+       }
+       
+       public static double[] vectIm2colWrite(double[] avals, int[] aix, int 
ai, int alen, int len, int rix, int C, int P, int Q, int K, int R, int S, int 
H, int W) {
+               double[] a = allocVector(len, true);
+               double[] c = allocVector(C*R*S * P*Q, true);
+               for(int k=ai; k<ai+alen; k++)
+                       a[aix[k]] = avals[k];
+               LibMatrixDNNIm2Col.im2colDenseStride1Pad0(a, c, ai, C, R, S, H, 
W, P, Q);
+               return c;
+       }
+       
+       //conv2d matrix mult
+       
+       public static double[] vectConv2dmmWrite(double[] a, double[] b, int 
ai, int bi, int len, int rix, int C, int P, int Q, int K, int R, int S, int H, 
int W) {
+               double[] c = allocVector(K*P*Q, true);
+               int CRS = C*R*S, PQ = P*Q;
+               LibMatrixMult.matrixMultDenseDenseMM(
+                       new DenseBlockDRB(a, K, CRS), new DenseBlockDRB(b, CRS, 
PQ),
+                       new DenseBlockDRB(c, K, PQ), PQ, CRS, 0, K, 0, PQ);
+               return c;
+       } 
+       
        //complex builtin functions that are not directly generated
        //(included here in order to reduce the number of imports)
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/fce9d978/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java
index d344c93..dd2473e 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNNIm2Col.java
@@ -41,7 +41,7 @@ public class LibMatrixDNNIm2Col
                //dense and sparse operation dispatch
                if( !in.sparse && stride1Pad0 && !trans )
                        im2colDenseStride1Pad0(in.getDenseBlockValues(),
-                               out.getDenseBlockValues(), r, C, R, S, H, W, P, 
Q);
+                               out.getDenseBlockValues(), r*C*H*W, C, R, S, H, 
W, P, Q);
                else if( !in.sparse )
                        im2colDense(in.getDenseBlockValues(), 
out.getDenseBlockValues(),
                                r, C, R, S, H, W, P, Q, stride_h, stride_w, 
pad_h, pad_w, trans);
@@ -50,8 +50,7 @@ public class LibMatrixDNNIm2Col
                                stride_h, stride_w, pad_h, pad_w, trans);
        }
        
-       public static void im2colDenseStride1Pad0(double[] in, double[] out, 
int r, int C, int R, int S, int H, int W, int P, int Q) {
-               int nOffset = r * C * H * W;
+       public static void im2colDenseStride1Pad0(double[] in, double[] out, 
int ai, int C, int R, int S, int H, int W, int P, int Q) {
                int CRS = C * R * S;
                for (int c = 0; c < CRS; ++c) {
                        int wOffset = c % S;
@@ -60,7 +59,7 @@ public class LibMatrixDNNIm2Col
                        for (int h = 0; h < P; ++h) {
                                int hPadded = h + hOffset;
                                int outOffset = (c * P + h) * Q;
-                               int inputOffset = nOffset + (cInput * H + 
hPadded) * W;
+                               int inputOffset = ai + (cInput * H + hPadded) * 
W;
                                System.arraycopy(in, inputOffset + wOffset, 
out, outOffset, Q);
                                int w = Q - 1;
                                int wPadded = w + wOffset;

http://git-wip-us.apache.org/repos/asf/systemml/blob/fce9d978/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
index 8f60003..ad3e3b2 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
@@ -1046,9 +1046,10 @@ public class LibMatrixMult
                }
        }
        
-       private static void matrixMultDenseDenseMM(DenseBlock a, DenseBlock b, 
DenseBlock c, int n, int cd, int rl, int ru, int cl, int cu) {
+       //note: public for use by codegen for consistency
+       public static void matrixMultDenseDenseMM(DenseBlock a, DenseBlock b, 
DenseBlock c, int n, int cd, int rl, int ru, int cl, int cu) {
                //1) Unrolled inner loop (for better instruction-level 
parallelism)
-               //2) Blocked execution (for less cache trashing in parallel 
exec)       
+               //2) Blocked execution (for less cache trashing in parallel 
exec) 
                //3) Asymmetric block sizes (for less misses in inner loop, yet 
blocks in L1/L2)
                
                final int blocksizeI = 32; //64//256KB c block (typical L2 size 
per core), 32KB a block 
@@ -3093,7 +3094,7 @@ public class LibMatrixMult
        {
                double val = 0;
                final int bn = len%8;
-                               
+               
                //compute rest
                for( int i = 0; i < bn; i++, ai++, bi++ )
                        val += a[ ai ] * b[ bi ];

http://git-wip-us.apache.org/repos/asf/systemml/blob/fce9d978/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
index d0a417d..2360092 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
@@ -82,6 +82,7 @@ public class RowAggTmplTest extends AutomatedTestBase
        private static final String TEST_NAME43 = TEST_NAME+"43"; 
//bias_add(X,B) + bias_mult(X,B)
        private static final String TEST_NAME44 = TEST_NAME+"44"; //maxpool(X - 
mean(X)) + 7;
        private static final String TEST_NAME45 = TEST_NAME+"45"; //vector 
allocation;
+       private static final String TEST_NAME46 = TEST_NAME+"46"; //conv2d(X - 
mean(X), F1) + conv2d(X - mean(X), F2);
        
        private static final String TEST_DIR = "functions/codegen/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RowAggTmplTest.class.getSimpleName() + "/";
@@ -93,7 +94,7 @@ public class RowAggTmplTest extends AutomatedTestBase
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               for(int i=1; i<=45; i++)
+               for(int i=1; i<=46; i++)
                        addTestConfiguration( TEST_NAME+i, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) 
}) );
        }
        
@@ -771,6 +772,21 @@ public class RowAggTmplTest extends AutomatedTestBase
        public void testCodegenRowAgg45SP() {
                testCodegenIntegration( TEST_NAME45, false, ExecType.SPARK );
        }
+       
+       @Test
+       public void testCodegenRowAggRewrite46CP() {
+               testCodegenIntegration( TEST_NAME46, true, ExecType.CP );
+       }
+
+       @Test
+       public void testCodegenRowAgg46CP() {
+               testCodegenIntegration( TEST_NAME46, false, ExecType.CP );
+       }
+
+       @Test
+       public void testCodegenRowAgg46SP() {
+               testCodegenIntegration( TEST_NAME46, false, ExecType.SPARK );
+       }
 
        private void testCodegenIntegration( String testname, boolean rewrites, 
ExecType instType )
        {
@@ -793,17 +809,17 @@ public class RowAggTmplTest extends AutomatedTestBase
                        
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + testname + ".dml";
-                       programArgs = new String[]{"-explain", 
"recompile_runtime", "-stats", "-args", output("S") };
+                       programArgs = new String[]{"-explain", "-stats", 
"-args", output("S") };
                        
                        fullRScriptName = HOME + testname + ".R";
                        rCmd = getRCmd(inputDir(), expectedDir());
 
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
                        
-                       runTest(true, false, null, -1); 
-                       runRScript(true); 
+                       runTest(true, false, null, -1);
+                       runRScript(true);
                        
-                       //compare matrices 
+                       //compare matrices
                        HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS("S");
                        HashMap<CellIndex, Double> rfile  = 
readRMatrixFromFS("S");
                        TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
@@ -836,6 +852,9 @@ public class RowAggTmplTest extends AutomatedTestBase
                        if( testname.equals(TEST_NAME44) )
                                
Assert.assertTrue(!heavyHittersContainsSubString("maxpooling") 
                                        && 
!heavyHittersContainsSubString("spoof", 2));
+                       if( testname.equals(TEST_NAME46) )
+                               
Assert.assertTrue(!heavyHittersContainsSubString("conv2d") 
+                                       && 
!heavyHittersContainsSubString("spoof", 2));
                        
                }
                finally {

http://git-wip-us.apache.org/repos/asf/systemml/blob/fce9d978/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowConv2DOperationsTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowConv2DOperationsTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowConv2DOperationsTest.java
index 6910c03..46e10a3 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowConv2DOperationsTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowConv2DOperationsTest.java
@@ -29,6 +29,7 @@ import 
org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Assert;
 import org.junit.Test;
 
 public class RowConv2DOperationsTest extends AutomatedTestBase
@@ -111,8 +112,8 @@ public class RowConv2DOperationsTest extends 
AutomatedTestBase
                        HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS("B");
                        HashMap<CellIndex, Double> rfile  = 
readRMatrixFromFS("B");
                        TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
-                       
//Assert.assertTrue(heavyHittersContainsSubString("spoofRA") 
-                       //      || heavyHittersContainsSubString("sp_spoofRA"));
+                       
Assert.assertTrue(heavyHittersContainsSubString("spoofRA") 
+                               || heavyHittersContainsSubString("sp_spoofRA"));
                }
                finally {
                        rtplatform = platformOld;

http://git-wip-us.apache.org/repos/asf/systemml/blob/fce9d978/src/test/scripts/functions/codegen/rowAggPattern46.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern46.R 
b/src/test/scripts/functions/codegen/rowAggPattern46.R
new file mode 100644
index 0000000..e40ebaa
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern46.R
@@ -0,0 +1,107 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+args <- commandArgs(TRUE)
+library("Matrix")
+library("matrixStats") 
+
+pad_image <- function(img, Hin, Win, padh, padw){
+  C = nrow(img)
+  img_padded = matrix(0, C, (Hin+2*padh)*(Win+2*padw), byrow=TRUE)  # zeros
+  for (c in 1:C) {
+    img_slice = matrix(img[c,], Hin, Win, byrow=TRUE)  # depth slice C reshaped
+    img_padded_slice = matrix(0, Hin+2*padh, Win+2*padw)
+    img_padded_slice[(padh+1):(padh+Hin), (padw+1):(padw+Win)] = img_slice
+    img_padded[c,] = matrix(t(img_padded_slice), 1, (Hin+2*padh)*(Win+2*padw)) 
 # reshape
+  }
+  img_padded
+}
+
+im2col <- function(img, Hin, Win, Hf, Wf, strideh, stridew) {
+  C = nrow(img)
+  Hout = as.integer((Hin - Hf) / strideh + 1)
+  Wout = as.integer((Win - Wf) / stridew + 1)
+
+  img_cols = matrix(0, C*Hf*Wf, Hout*Wout, byrow=TRUE)  # zeros
+  for (hout in 1:Hout) {  # all output rows
+    hin = (hout-1) * strideh + 1
+    for (wout in 1:Wout) {  # all output columns
+      win = (wout-1) * stridew + 1
+      # Extract a local patch of the input image corresponding spatially to 
the filter sizes.
+      img_patch = matrix(0, C, Hf*Wf, byrow=TRUE)  # zeros
+      for (c in 1:C) {  # all channels
+        img_slice = matrix(img[c,], Hin, Win, byrow=TRUE)  # reshape
+        img_patch[c,] = matrix(t(img_slice[hin:(hin+Hf-1), win:(win+Wf-1)]), 
1, Hf*Wf)
+      }
+      img_cols[,(hout-1)*Wout + wout] = matrix(t(img_patch), C*Hf*Wf, 1)  # 
reshape
+    }
+  }
+  img_cols
+}
+               
+conv2d <- function(X, W, C, Hin, Win, Hf, Wf, strideh, stridew, padh, padw) {
+  N = nrow(X)
+  F = nrow(W)
+  Hout = as.integer((Hin + 2 * padh - Hf) / strideh + 1)
+  Wout = as.integer((Win + 2 * padw - Wf) / stridew + 1)
+  
+  # Create output volume
+  out = matrix(0, N, F*Hout*Wout, byrow=TRUE)
+
+  # Convolution - im2col implementation
+  for (n in 1:N) {  # all examples
+    Xn = matrix(X[n,], C, Hin*Win, byrow=TRUE)  # reshape
+
+    # Pad image
+    Xn_padded = pad_image(Xn, Hin, Win, padh, padw)  # shape (C, 
(Hin+2*padh)*(Win+2*padw))
+
+    # Extract local image patches into columns with im2col, of shape (C*Hf*Wf, 
Hout*Wout)
+    Xn_padded_cols = im2col(Xn_padded, Hin+2*padh, Win+2*padw, Hf, Wf, 
strideh, stridew)
+
+    # Convolve patches with filters
+    outn = W %*% Xn_padded_cols   # shape (F, Hout*Wout)
+    out[n,] = matrix(t(outn), 1, F*Hout*Wout)  # reshape
+  }
+  
+  out
+}
+
+imgSize=8
+numImg=16
+numChannels=4
+numFilters=3
+filterSize=4
+stride=1
+pad=0
+
+Hout = as.integer((imgSize + 2 * pad - filterSize) / stride + 1)
+
+X = matrix(seq(1, numImg*numChannels*imgSize*imgSize), numImg, 
numChannels*imgSize*imgSize, byrow=TRUE);
+W1 = matrix(seq(1, numFilters*numChannels*filterSize*filterSize), numFilters, 
numChannels*filterSize*filterSize, byrow=TRUE)
+W2 = matrix(seq(1, numFilters*numChannels*filterSize*filterSize)+7, 
numFilters, numChannels*filterSize*filterSize, byrow=TRUE)
+b = matrix(seq(1, numFilters), numFilters, 1, byrow=TRUE) 
+
+X = X - rowMeans(X)
+
+R1 = conv2d(X, W1, numChannels, imgSize, imgSize, filterSize, filterSize, 
stride, stride, pad, pad);
+R2 = conv2d(X, W2, numChannels, imgSize, imgSize, filterSize, filterSize, 
stride, stride, pad, pad);
+R = R1 + R2;
+
+writeMM(as(R,"CsparseMatrix"), paste(args[2], "S", sep=""))

http://git-wip-us.apache.org/repos/asf/systemml/blob/fce9d978/src/test/scripts/functions/codegen/rowAggPattern46.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern46.dml 
b/src/test/scripts/functions/codegen/rowAggPattern46.dml
new file mode 100644
index 0000000..745ff2f
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern46.dml
@@ -0,0 +1,43 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# 
+#-------------------------------------------------------------
+
+imgSize=8
+numImg=16
+numChannels=4
+numFilters=3
+filterSize=4
+stride=1
+pad=0
+
+X = matrix(seq(1, numImg*numChannels*imgSize*imgSize), rows=numImg, 
cols=numChannels*imgSize*imgSize);
+W1 = matrix(seq(1, numFilters*numChannels*filterSize*filterSize), 
rows=numFilters, cols=numChannels*filterSize*filterSize)
+W2 = matrix(seq(1, numFilters*numChannels*filterSize*filterSize)+7, 
rows=numFilters, cols=numChannels*filterSize*filterSize)
+b = matrix(seq(1, numFilters), rows=numFilters, cols=1) 
+
+while(FALSE){}
+
+X = X - rowMeans(X);
+
+R1 = conv2d(X, W1, padding=[pad, pad], stride=[stride, stride], 
input_shape=[numImg, numChannels, imgSize, imgSize], filter_shape=[numFilters, 
numChannels, filterSize, filterSize])
+R2 = conv2d(X, W2, padding=[pad, pad], stride=[stride, stride], 
input_shape=[numImg, numChannels, imgSize, imgSize], filter_shape=[numFilters, 
numChannels, filterSize, filterSize])
+R = R1 + R2;
+
+write(R, $1, format="text");

Reply via email to