Repository: systemml
Updated Branches:
  refs/heads/master 47ce14fc6 -> c1db484d6


[SYSTEMML-1933] Generalized codegen cbind handling in row-wise ops

This patch generalizes the compilation of cbind operations in codegen
row templates. So far, we only supported cbind with a vector of zeros,
and cbind closed the row template. We now support cbind operations (with
vectors of arbitrary constants) in the middle of row templates, which
also allows for multiple cbind operations in a single fused operator. 


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c1db484d
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c1db484d
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c1db484d

Branch: refs/heads/master
Commit: c1db484d6119459f7ef6a566ff2663cca286f7ab
Parents: 47ce14f
Author: Matthias Boehm <[email protected]>
Authored: Sun Sep 24 18:29:17 2017 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sun Sep 24 18:29:17 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/DataGenOp.java   |  4 +++
 .../sysml/hops/codegen/SpoofCompiler.java       |  2 ++
 .../apache/sysml/hops/codegen/SpoofFusedOp.java | 21 ++++++++-----
 .../sysml/hops/codegen/cplan/CNodeBinary.java   | 28 ++++++++++++++---
 .../sysml/hops/codegen/cplan/CNodeRow.java      | 30 +++++++++++-------
 .../sysml/hops/codegen/cplan/CNodeUnary.java    |  6 +---
 .../hops/codegen/template/TemplateRow.java      | 15 ++++++---
 .../hops/codegen/template/TemplateUtils.java    |  6 ++--
 .../sysml/hops/rewrite/HopRewriteUtils.java     | 10 ++++++
 .../runtime/codegen/LibSpoofPrimitives.java     | 24 +++++++++++++++
 .../sysml/runtime/codegen/SpoofRowwise.java     | 20 +++++++-----
 .../instructions/spark/SpoofSPInstruction.java  |  5 +--
 .../functions/codegen/RowAggTmplTest.java       | 20 +++++++++++-
 .../scripts/functions/codegen/rowAggPattern31.R | 32 ++++++++++++++++++++
 .../functions/codegen/rowAggPattern31.dml       | 27 +++++++++++++++++
 15 files changed, 202 insertions(+), 48 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/DataGenOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DataGenOp.java 
b/src/main/java/org/apache/sysml/hops/DataGenOp.java
index 89a5814..eb04ed3 100644
--- a/src/main/java/org/apache/sysml/hops/DataGenOp.java
+++ b/src/main/java/org/apache/sysml/hops/DataGenOp.java
@@ -434,6 +434,10 @@ public class DataGenOp extends Hop implements 
MultiThreadedHop
                return ret;
        }
        
+       public Hop getConstantValue() {
+               return 
getInput().get(_paramIndexMap.get(DataExpression.RAND_MIN));
+       }
+       
        public void setIncrementValue(double incr)
        {
                _incr = incr;

http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index b98342c..a374cf1 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -645,6 +645,8 @@ public class SpoofCompiler
                                
HopRewriteUtils.setOutputParametersForScalar(hnew);
                                hnew = HopRewriteUtils.createUnary(hnew, 
OpOp1.CAST_AS_MATRIX);
                        }
+                       else if( tmpCNode instanceof CNodeRow && 
((CNodeRow)tmpCNode).getRowType()==RowType.NO_AGG_CONST )
+                               
((SpoofFusedOp)hnew).setConstDim2(((CNodeRow)tmpCNode).getConstDim2());
                        
                        if( !(tmpCNode instanceof CNodeMultiAgg) )
                                
HopRewriteUtils.rewireAllParentChildReferences(hop, hnew);

http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
index 247a142..81b226d 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
@@ -38,8 +38,8 @@ public class SpoofFusedOp extends Hop implements 
MultiThreadedHop
 {
        public enum SpoofOutputDimsType {
                INPUT_DIMS,
+               INPUT_DIMS_CONST2,
                ROW_DIMS,
-               ROW_DIMS2,
                COLUMN_DIMS_ROWS,
                COLUMN_DIMS_COLS,
                SCALAR,
@@ -52,6 +52,7 @@ public class SpoofFusedOp extends Hop implements 
MultiThreadedHop
        private Class<?> _class = null;
        private boolean _distSupported = false;
        private int _numThreads = -1;
+       private long _constDim2 = -1;
        private SpoofOutputDimsType _dimsType;
        
        public SpoofFusedOp ( ) {
@@ -82,6 +83,10 @@ public class SpoofFusedOp extends Hop implements 
MultiThreadedHop
        public boolean allowsAllExecTypes() {
                return _distSupported;
        }
+       
+       public void setConstDim2(long constDim2) {
+               _constDim2 = constDim2;
+       }
 
        @Override
        protected double computeOutputMemEstimate(long dim1, long dim2, long 
nnz) {
@@ -152,9 +157,6 @@ public class SpoofFusedOp extends Hop implements 
MultiThreadedHop
                                case ROW_DIMS:
                                        ret = new long[]{mc.getRows(), 1, -1};
                                        break;
-                               case ROW_DIMS2:
-                                       ret = new long[]{mc.getRows(), 2, -1};
-                                       break;
                                case COLUMN_DIMS_ROWS:
                                        ret = new long[]{mc.getCols(), 1, -1};
                                        break;
@@ -164,6 +166,9 @@ public class SpoofFusedOp extends Hop implements 
MultiThreadedHop
                                case INPUT_DIMS:
                                        ret = new long[]{mc.getRows(), 
mc.getCols(), -1};
                                        break;
+                               case INPUT_DIMS_CONST2:
+                                       ret = new long[]{mc.getRows(), 
_constDim2, -1};
+                                       break;
                                case SCALAR:
                                        ret = new long[]{0, 0, -1};
                                        break;
@@ -206,10 +211,6 @@ public class SpoofFusedOp extends Hop implements 
MultiThreadedHop
                                setDim1(getInput().get(0).getDim1());
                                setDim2(1);
                                break;
-                       case ROW_DIMS2:
-                               setDim1(getInput().get(0).getDim1());
-                               setDim2(2);
-                               break;
                        case COLUMN_DIMS_ROWS:
                                setDim1(getInput().get(0).getDim2());
                                setDim2(1);
@@ -222,6 +223,10 @@ public class SpoofFusedOp extends Hop implements 
MultiThreadedHop
                                setDim1(getInput().get(0).getDim1());
                                setDim2(getInput().get(0).getDim2());
                                break;
+                       case INPUT_DIMS_CONST2:
+                               setDim1(getInput().get(0).getDim1());
+                               setDim2(_constDim2);
+                               break;
                        case SCALAR:
                                setDim1(0);
                                setDim2(0);

http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
index 926dd4d..bff044d 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
@@ -40,6 +40,7 @@ public class CNodeBinary extends CNode
                VECT_POW_SCALAR, VECT_MIN_SCALAR, VECT_MAX_SCALAR,
                VECT_EQUAL_SCALAR, VECT_NOTEQUAL_SCALAR, VECT_LESS_SCALAR, 
                VECT_LESSEQUAL_SCALAR, VECT_GREATER_SCALAR, 
VECT_GREATEREQUAL_SCALAR,
+               VECT_CBIND,
                //vector-vector operations
                VECT_MULT, VECT_DIV, VECT_MINUS, VECT_PLUS, VECT_MIN, VECT_MAX, 
VECT_EQUAL, 
                VECT_NOTEQUAL, VECT_LESS, VECT_LESSEQUAL, VECT_GREATER, 
VECT_GREATEREQUAL,
@@ -67,7 +68,7 @@ public class CNodeBinary extends CNode
                        return ssComm || vsComm || vvComm;
                }
                
-               public String getTemplate(boolean sparse, boolean scalarVector) 
{
+               public String getTemplate(boolean sparse, boolean scalarVector, 
boolean scalarInput) {
                        switch (this) {
                                case DOT_PRODUCT:   
                                        return sparse ? "    double %TMP% = 
LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" :
@@ -125,6 +126,14 @@ public class CNodeBinary extends CNode
                                                                                
"    double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, 
%POS1%, %LEN%);\n";
                                }
                                
+                               case VECT_CBIND:
+                                       if( scalarInput )
+                                               return  "    double[] %TMP% = 
LibSpoofPrimitives.vectCBindWrite(%IN1%, %IN2%);\n";
+                                       else
+                                               return sparse ? 
+                                                               "    double[] 
%TMP% = LibSpoofPrimitives.vectCBindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, 
%LEN%);\n" : 
+                                                               "    double[] 
%TMP% = LibSpoofPrimitives.vectCBindWrite(%IN1%, %IN2%, %POS1%, %LEN%);\n";
+                               
                                //vector-vector operations
                                case VECT_MULT:
                                case VECT_DIV:
@@ -202,7 +211,8 @@ public class CNodeBinary extends CNode
                                || this == VECT_MIN_SCALAR || this == 
VECT_MAX_SCALAR
                                || this == VECT_EQUAL_SCALAR || this == 
VECT_NOTEQUAL_SCALAR
                                || this == VECT_LESS_SCALAR || this == 
VECT_LESSEQUAL_SCALAR
-                               || this == VECT_GREATER_SCALAR || this == 
VECT_GREATEREQUAL_SCALAR;
+                               || this == VECT_GREATER_SCALAR || this == 
VECT_GREATEREQUAL_SCALAR
+                               || this == VECT_CBIND;
                }
                public boolean isVectorVectorPrimitive() {
                        return this == VECT_DIV || this == VECT_MULT 
@@ -262,10 +272,11 @@ public class CNodeBinary extends CNode
                boolean lsparse = sparse && (_inputs.get(0) instanceof 
CNodeData 
                        && !_inputs.get(0).getVarname().startsWith("b")
                        && !_inputs.get(0).isLiteral());
+               boolean scalarInput = _inputs.get(0).getDataType().isScalar();
                boolean scalarVector = (_inputs.get(0).getDataType().isScalar()
                        && _inputs.get(1).getDataType().isMatrix());
                String var = createVarname();
-               String tmp = _type.getTemplate(lsparse, scalarVector);
+               String tmp = _type.getTemplate(lsparse, scalarVector, 
scalarInput);
                tmp = tmp.replace("%TMP%", var);
                
                //replace input references and start indexes
@@ -354,6 +365,7 @@ public class CNodeBinary extends CNode
                        case VECT_LESSEQUAL:           return "b(v2lte)";
                        case VECT_GREATEREQUAL:        return "b(v2gte)";
                        case VECT_GREATER:             return "b(v2gt)";
+                       case VECT_CBIND:               return "b(cbind)";
                        case MULT:                     return "b(*)";
                        case DIV:                      return "b(/)";
                        case PLUS:                     return "b(+)";
@@ -399,6 +411,12 @@ public class CNodeBinary extends CNode
                                _dataType = DataType.MATRIX;
                                break;
                        
+                       case VECT_CBIND:
+                               _rows = _inputs.get(0)._rows;
+                               _cols = _inputs.get(0)._cols+1;
+                               _dataType = DataType.MATRIX;
+                               break;
+                       
                        case VECT_OUTERMULT_ADD:
                                _rows = _inputs.get(0)._cols;
                                _cols = _inputs.get(1)._cols;
@@ -465,9 +483,9 @@ public class CNodeBinary extends CNode
                        case MIN: 
                        case MAX: 
                        case AND: 
-                       case OR:                        
+                       case OR: 
                        case LOG: 
-                       case LOG_NZ:    
+                       case LOG_NZ:
                        case POW: 
                                _rows = 0;
                                _cols = 0;

http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
index b74b79d..07822d9 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
@@ -23,7 +23,6 @@ import java.util.ArrayList;
 
 import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
 import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType;
-import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType;
 import org.apache.sysml.hops.codegen.template.TemplateUtils;
 import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType;
 import org.apache.sysml.runtime.util.UtilFunctions;
@@ -40,7 +39,7 @@ public class CNodeRow extends CNodeTpl
                        + "\n"
                        + "public final class %TMP% extends SpoofRowwise { \n"
                        + "  public %TMP%() {\n"
-                       + "    super(RowType.%TYPE%, %CBIND0%, %TB1%, 
%VECT_MEM%);\n"
+                       + "    super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, 
%VECT_MEM%);\n"
                        + "  }\n"
                        + "  protected void genexec(double[] a, int ai, 
SideInput[] b, double[] scalars, double[] c, int len, int rowIndex) { \n"
                        + "%BODY_dense%"
@@ -59,6 +58,7 @@ public class CNodeRow extends CNodeTpl
        }
        
        private RowType _type = null; //access pattern 
+       private long _constDim2 = -1; //constant number of output columns
        private int _numVectors = -1; //number of intermediate vectors
        
        public void setRowType(RowType type) {
@@ -79,6 +79,14 @@ public class CNodeRow extends CNodeTpl
                return _numVectors;
        }
        
+       public void setConstDim2(long dim2) {
+               _constDim2 = dim2;
+       }
+       
+       public long getConstDim2() {
+               return _constDim2;
+       }
+       
        @Override
        public void renameInputs() {
                rRenameDataNode(_output, _inputs.get(0), "a"); // input matrix
@@ -109,8 +117,7 @@ public class CNodeRow extends CNodeTpl
                
                //replace colvector information and number of vector 
intermediates
                tmp = tmp.replace("%TYPE%", _type.name());
-               tmp = tmp.replace("%CBIND0%", String.valueOf(
-                       TemplateUtils.isUnary(_output, UnaryType.CBIND0)));
+               tmp = tmp.replace("%CONST_DIM2%", String.valueOf(_constDim2));
                tmp = tmp.replace("%TB1%", String.valueOf(
                        TemplateUtils.containsBinary(_output, 
BinType.VECT_MATRIXMULT)));
                tmp = tmp.replace("%VECT_MEM%", String.valueOf(_numVectors));
@@ -122,6 +129,7 @@ public class CNodeRow extends CNodeTpl
                switch( _type ) {
                        case NO_AGG:
                        case NO_AGG_B1:
+                       case NO_AGG_CONST:
                                return TEMPLATE_NOAGG_OUT.replace("%IN%", 
varName)
                                        .replace("%LEN%", 
_output.getVarname()+".length");
                        case FULL_AGG:
@@ -142,13 +150,13 @@ public class CNodeRow extends CNodeTpl
        @Override
        public SpoofOutputDimsType getOutputDimType() {
                switch( _type ) {
-                       case NO_AGG:    return SpoofOutputDimsType.INPUT_DIMS;
-                       case NO_AGG_B1: return 
SpoofOutputDimsType.ROW_RANK_DIMS;
-                       case FULL_AGG:  return SpoofOutputDimsType.SCALAR;
-                       case ROW_AGG:   return TemplateUtils.isUnary(_output, 
UnaryType.CBIND0) ?
-                                               SpoofOutputDimsType.ROW_DIMS2 : 
SpoofOutputDimsType.ROW_DIMS;
-                       case COL_AGG:   return 
SpoofOutputDimsType.COLUMN_DIMS_COLS; //row vector
-                       case COL_AGG_T: return 
SpoofOutputDimsType.COLUMN_DIMS_ROWS; //column vector
+                       case NO_AGG:       return 
SpoofOutputDimsType.INPUT_DIMS;
+                       case NO_AGG_B1:    return 
SpoofOutputDimsType.ROW_RANK_DIMS;
+                       case NO_AGG_CONST: return 
SpoofOutputDimsType.INPUT_DIMS_CONST2; 
+                       case FULL_AGG:     return SpoofOutputDimsType.SCALAR;
+                       case ROW_AGG:      return SpoofOutputDimsType.ROW_DIMS;
+                       case COL_AGG:      return 
SpoofOutputDimsType.COLUMN_DIMS_COLS; //row vector
+                       case COL_AGG_T:    return 
SpoofOutputDimsType.COLUMN_DIMS_ROWS; //column vector
                        case COL_AGG_B1:   return 
SpoofOutputDimsType.COLUMN_RANK_DIMS; 
                        case COL_AGG_B1_T: return 
SpoofOutputDimsType.COLUMN_RANK_DIMS_T; 
                        default:

http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
index 4bfb74b..343efb5 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
@@ -28,7 +28,7 @@ import org.apache.sysml.runtime.util.UtilFunctions;
 public class CNodeUnary extends CNode
 {
        public enum UnaryType {
-               LOOKUP_R, LOOKUP_C, LOOKUP_RC, LOOKUP0, CBIND0, //codegen 
specific
+               LOOKUP_R, LOOKUP_C, LOOKUP_RC, LOOKUP0, //codegen specific
                ROW_SUMS, ROW_MINS, ROW_MAXS, //codegen specific
                VECT_EXP, VECT_POW2, VECT_MULT2, VECT_SQRT, VECT_LOG,
                VECT_ABS, VECT_ROUND, VECT_CEIL, VECT_FLOOR, VECT_SIGN, 
@@ -94,8 +94,6 @@ public class CNodeUnary extends CNode
                                return "    double %TMP% = getValue(%IN1%, n, 
rowIndex, colIndex);\n";  
                                case LOOKUP0:
                                        return "    double %TMP% = %IN1%[0];\n" 
;
-                               case CBIND0:
-                                       return "    double %TMP% = %IN1%; 
rowIndex *= 2;\n" ;
                                case POW2:
                                        return "    double %TMP% = %IN1% * 
%IN1%;\n" ;
                                case MULT2:
@@ -266,7 +264,6 @@ public class CNodeUnary extends CNode
                        case LOOKUP_C:  return "u(ixc)";
                        case LOOKUP_RC: return "u(ixrc)";
                        case LOOKUP0:   return "u(ix0)";
-                       case CBIND0:    return "u(cbind0)";
                        case POW2:      return "^2";
                        default:        return 
"u("+_type.name().toLowerCase()+")";
                }
@@ -310,7 +307,6 @@ public class CNodeUnary extends CNode
                        case LOOKUP_C:
                        case LOOKUP_RC:
                        case LOOKUP0:
-                       case CBIND0:
                        case POW2:
                        case MULT2:
                        case ABS:  

http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index de94969..d9209be 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -50,6 +50,7 @@ import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.OpOp1;
 import org.apache.sysml.hops.Hop.OpOp2;
 import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType;
 import org.apache.sysml.runtime.matrix.data.LibMatrixMult;
 import org.apache.sysml.runtime.matrix.data.Pair;
 
@@ -76,6 +77,8 @@ public class TemplateRow extends TemplateBase
        public boolean open(Hop hop) {
                return (hop instanceof BinaryOp && hop.dimsKnown() && 
isValidBinaryOperation(hop)
                                && hop.getInput().get(0).getDim1()>1 && 
hop.getInput().get(0).getDim2()>1)
+                       || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && 
hop.getInput().get(0).isMatrix()
+                               && 
HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1)))
                        || (hop instanceof AggBinaryOp && hop.dimsKnown() && 
hop.getDim2()==1 //MV
                                && hop.getInput().get(0).getDim1()>1 && 
hop.getInput().get(0).getDim2()>1)
                        || (hop instanceof AggBinaryOp && hop.dimsKnown() && 
LibMatrixMult.isSkinnyRightHandSide(
@@ -98,8 +101,7 @@ public class TemplateRow extends TemplateBase
                return !isClosed() && 
                        (  (hop instanceof BinaryOp && 
isValidBinaryOperation(hop) ) 
                        || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && 
hop.getInput().indexOf(input)==0
-                               && input.getDim2()==1 && 
hop.getInput().get(1).getDim2()==1
-                               && 
HopRewriteUtils.isEmpty(hop.getInput().get(1)))
+                               && 
HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1)))
                        || ((hop instanceof UnaryOp || hop instanceof 
ParameterizedBuiltinOp) 
                                        && TemplateCell.isValidOperation(hop))
                        || (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection()!=Direction.RowCol
@@ -130,8 +132,7 @@ public class TemplateRow extends TemplateBase
        public CloseType close(Hop hop) {
                //close on column or full aggregate (e.g., colSums, t(X)%*%y)
                if(    (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection()!=Direction.Row)
-                       || (hop instanceof AggBinaryOp && 
HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)))
-                       || HopRewriteUtils.isBinary(hop, OpOp2.CBIND) )
+                       || (hop instanceof AggBinaryOp && 
HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))))
                        return CloseType.CLOSED_VALID;
                else
                        return CloseType.OPEN;
@@ -192,6 +193,8 @@ public class TemplateRow extends TemplateBase
                CNodeRow tpl = new CNodeRow(inputs, output);
                tpl.setRowType(TemplateUtils.getRowType(hop, 
                        inHops2.get("X"), inHops2.get("B1")));
+               if( tpl.getRowType()==RowType.NO_AGG_CONST )
+                       tpl.setConstDim2(hop.getDim2());
                tpl.setNumVectorIntermediates(TemplateUtils
                        .determineMinVectorIntermediates(output));
                tpl.getOutput().resetVisitStatus();
@@ -323,7 +326,9 @@ public class TemplateRow extends TemplateBase
                {
                        //special case for cbind with zeros
                        CNode cdata1 = 
tmp.get(hop.getInput().get(0).getHopID());
-                       out = new CNodeUnary(cdata1, UnaryType.CBIND0);
+                       CNode cdata2 = TemplateUtils.createCNodeData(
+                               
HopRewriteUtils.getDataGenOpConstantValue(hop.getInput().get(1)), true);
+                       out = new CNodeBinary(cdata1, cdata2, 
BinType.VECT_CBIND);
                        inHops.remove(hop.getInput().get(1)); //rm 0-matrix
                }
                else if(hop instanceof BinaryOp)

http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 1924914..95383e6 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -32,7 +32,6 @@ import org.apache.sysml.hops.ParameterizedBuiltinOp;
 import org.apache.sysml.hops.TernaryOp;
 import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.Direction;
-import org.apache.sysml.hops.Hop.OpOp2;
 import org.apache.sysml.hops.IndexingOp;
 import org.apache.sysml.hops.UnaryOp;
 import org.apache.sysml.hops.codegen.cplan.CNode;
@@ -190,8 +189,7 @@ public class TemplateUtils
                        || (output instanceof IndexingOp && 
HopRewriteUtils.isColumnRangeIndexing((IndexingOp)output)))
                        && !(output instanceof AggBinaryOp && 
HopRewriteUtils.isTransposeOfItself(output.getInput().get(0),X)) )
                        return RowType.NO_AGG_B1;
-               else if( output.getDim1()==X.getDim1() && (output.getDim2()==1 
-                               || HopRewriteUtils.isBinary(output, 
OpOp2.CBIND))
+               else if( output.getDim1()==X.getDim1() && (output.getDim2()==1)
                        && !(output instanceof AggBinaryOp && HopRewriteUtils
                                
.isTransposeOfItself(output.getInput().get(0),X)))
                        return RowType.ROW_AGG;
@@ -206,6 +204,8 @@ public class TemplateUtils
                        return RowType.COL_AGG_B1_T;
                else if( B1 != null && output.getDim1()==B1.getDim2() && 
output.getDim2()==X.getDim2())
                        return RowType.COL_AGG_B1;
+               else if( X.getDim1() == output.getDim1() && X.getDim2() != 
output.getDim2() )
+                       return RowType.NO_AGG_CONST;
                else
                        throw new RuntimeException("Unknown row type for hop 
"+output.getHopID()+".");
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 7093b0e..af9d593 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -474,12 +474,22 @@ public class HopRewriteUtils
                        && ArrayUtils.contains(ops, ((DataGenOp)hop).getOp()));
        }
        
+       public static boolean isDataGenOpWithConstantValue(Hop hop) {
+               return hop instanceof DataGenOp
+                       && ((DataGenOp)hop).getOp()==DataGenMethod.RAND
+                       && ((DataGenOp)hop).hasConstantValue();
+       }
+       
        public static boolean isDataGenOpWithConstantValue(Hop hop, double 
value) {
                return hop instanceof DataGenOp
                        && ((DataGenOp)hop).getOp()==DataGenMethod.RAND
                        && ((DataGenOp)hop).hasConstantValue(value);
        }
        
+       public static Hop getDataGenOpConstantValue(Hop hop) {
+               return ((DataGenOp) hop).getConstantValue();
+       } 
+       
        public static ReorgOp createTranspose(Hop input) {
                return createReorg(input, ReOrgOp.TRANSPOSE);
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java 
b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
index 1a17793..6b4aad7 100644
--- a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
+++ b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
@@ -193,6 +193,30 @@ public class LibSpoofPrimitives
                for( int i=0; i<aix.length; i++ )
                        c[aix[i]] = a[i];
        }
+       
+       // cbind handling
+       
+       public static double[] vectCBindWrite(double a, double b) {
+               double[] c = allocVector(2, false);
+               c[0] = a;
+               c[1] = b;
+               return c;
+       }
+       
+       public static double[] vectCBindWrite(double[] a, double b, int aix, 
int len) {
+               double[] c = allocVector(len+1, false);
+               System.arraycopy(a, aix, c, 0, len);
+               c[len] = b;
+               return c;
+       }
+       
+       public static double[] vectCBindWrite(double[] a, double b, int[] aix, 
int ai, int alen, int len) {
+               double[] c = allocVector(len+1, false);
+               for( int j = ai; j < ai+alen; j++ )
+                       c[aix[j]] = a[j];
+               c[len] = b;
+               return c;
+       }
 
        // custom vector sums, mins, maxs
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java 
b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
index 659059e..2464b15 100644
--- a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
+++ b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
@@ -49,6 +49,7 @@ public abstract class SpoofRowwise extends SpoofOperator
        public enum RowType {
                NO_AGG,    //no aggregation
                NO_AGG_B1, //no aggregation w/ matrix mult B1
+               NO_AGG_CONST, //no aggregation w/ expansion/contraction
                FULL_AGG,  //full row/col aggregation
                ROW_AGG,   //row aggregation (e.g., rowSums() or X %*% v)
                COL_AGG,   //col aggregation (e.g., colSums() or t(y) %*% X)
@@ -69,13 +70,13 @@ public abstract class SpoofRowwise extends SpoofOperator
        }
        
        protected final RowType _type;
-       protected final boolean _cbind0;
+       protected final long _constDim2;
        protected final boolean _tB1;
        protected final int _reqVectMem;
        
-       public SpoofRowwise(RowType type, boolean cbind0, boolean tB1, int 
reqVectMem) {
+       public SpoofRowwise(RowType type, long constDim2, boolean tB1, int 
reqVectMem) {
                _type = type;
-               _cbind0 = cbind0;
+               _constDim2 = constDim2;
                _tB1 = tB1;
                _reqVectMem = reqVectMem;
        }
@@ -84,8 +85,8 @@ public abstract class SpoofRowwise extends SpoofOperator
                return _type;
        }
        
-       public boolean isCBind0() {
-               return _cbind0;
+       public long getConstDim2() {
+               return _constDim2;
        }
        
        public int getNumIntermediates() {
@@ -124,7 +125,8 @@ public abstract class SpoofRowwise extends SpoofOperator
                //result allocation and preparations
                final int m = inputs.get(0).getNumRows();
                final int n = inputs.get(0).getNumColumns();
-               final int n2 = _type.isRowTypeB1() || 
hasMatrixSideInput(inputs) ?
+               final int n2 = (_type==RowType.NO_AGG_CONST) ? (int)_constDim2 
: 
+                       _type.isRowTypeB1() || hasMatrixSideInput(inputs) ?
                        getMinColsMatrixSideInputs(inputs) : -1;
                if( !aggIncr || !out.isAllocated() )
                        allocateOutputMatrix(m, n, n2, out);
@@ -179,7 +181,8 @@ public abstract class SpoofRowwise extends SpoofOperator
                //result allocation and preparations
                final int m = inputs.get(0).getNumRows();
                final int n = inputs.get(0).getNumColumns();
-               final int n2 = _type.isRowTypeB1() || 
hasMatrixSideInput(inputs) ?
+               final int n2 = (_type==RowType.NO_AGG_CONST) ? (int)_constDim2 
: 
+                       _type.isRowTypeB1() || hasMatrixSideInput(inputs) ?
                        getMinColsMatrixSideInputs(inputs) : -1;
                allocateOutputMatrix(m, n, n2, out);
                final boolean flipOut = _type.isRowTypeB1ColumnAgg()
@@ -258,8 +261,9 @@ public abstract class SpoofRowwise extends SpoofOperator
                switch( _type ) {
                        case NO_AGG:       out.reset(m, n, false); break;
                        case NO_AGG_B1:    out.reset(m, n2, false); break;
+                       case NO_AGG_CONST: out.reset(m, (int)_constDim2, 
false); break;
                        case FULL_AGG:     out.reset(1, 1, false); break;
-                       case ROW_AGG:      out.reset(m, 1+(_cbind0?1:0), 
false); break;
+                       case ROW_AGG:      out.reset(m, 1, false); break;
                        case COL_AGG:      out.reset(1, n, false); break;
                        case COL_AGG_T:    out.reset(n, 1, false); break;
                        case COL_AGG_B1:   out.reset(n2, n, false); break;

http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
index 2d609aa..a628dc0 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
@@ -354,7 +354,7 @@ public class SpoofSPInstruction extends SPInstruction {
                        if( type == RowType.NO_AGG )
                                mcOut.set(mcIn);
                        else if( type == RowType.ROW_AGG )
-                               mcOut.set(mcIn.getRows(), 
((SpoofRowwise)op).isCBind0()? 2:1, 
+                               mcOut.set(mcIn.getRows(), 1, 
                                        mcIn.getRowsPerBlock(), 
mcIn.getColsPerBlock());
                        else if( type == RowType.COL_AGG )
                                mcOut.set(1, mcIn.getCols(), 
mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
@@ -454,7 +454,8 @@ public class SpoofSPInstruction extends SPInstruction {
                        }
                        
                        //setup local memory for reuse
-                       int clen2 = (int) (_op.getRowType().isRowTypeB1() ? 
_inputs.get(0).getNumCols() : -1);
+                       int clen2 = (int) 
((_op.getRowType()==RowType.NO_AGG_CONST) ? _op.getConstDim2() :
+                               _op.getRowType().isRowTypeB1() ? 
_inputs.get(0).getNumCols() : -1);
                        
LibSpoofPrimitives.setupThreadLocalMemory(_op.getNumIntermediates(), _clen, 
clen2);
                        
                        ArrayList<Tuple2<MatrixIndexes,MatrixBlock>> ret = new 
ArrayList<Tuple2<MatrixIndexes,MatrixBlock>>();

http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
index 3ecfd6b..d4f87b3 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
@@ -67,6 +67,7 @@ public class RowAggTmplTest extends AutomatedTestBase
        private static final String TEST_NAME28 = TEST_NAME+"28"; //Kmeans, 
final eval
        private static final String TEST_NAME29 = TEST_NAME+"29"; 
//sum(rowMins(X))
        private static final String TEST_NAME30 = TEST_NAME+"30"; //Mlogreg 
inner core, multi-class
+       private static final String TEST_NAME31 = TEST_NAME+"31"; //MLogreg - 
matrix-vector cbind 0s generalized
        
        private static final String TEST_DIR = "functions/codegen/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RowAggTmplTest.class.getSimpleName() + "/";
@@ -78,7 +79,7 @@ public class RowAggTmplTest extends AutomatedTestBase
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               for(int i=1; i<=30; i++)
+               for(int i=1; i<=31; i++)
                        addTestConfiguration( TEST_NAME+i, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) 
}) );
        }
        
@@ -532,6 +533,21 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME30, false, ExecType.SPARK );
        }
        
+       @Test
+       public void testCodegenRowAggRewrite31CP() {
+               testCodegenIntegration( TEST_NAME31, true, ExecType.CP );
+       }
+       
+       @Test
+       public void testCodegenRowAgg31CP() {
+               testCodegenIntegration( TEST_NAME31, false, ExecType.CP );
+       }
+       
+       @Test
+       public void testCodegenRowAgg31SP() {
+               testCodegenIntegration( TEST_NAME31, false, ExecType.SPARK );
+       }
+       
        private void testCodegenIntegration( String testname, boolean rewrites, 
ExecType instType )
        {       
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
@@ -581,6 +597,8 @@ public class RowAggTmplTest extends AutomatedTestBase
                        if( testname.equals(TEST_NAME30) )
                                
Assert.assertTrue(!heavyHittersContainsSubString("spoofRA", 2)
                                        && 
!heavyHittersContainsSubString(RightIndex.OPCODE));
+                       if( testname.equals(TEST_NAME31) )
+                               
Assert.assertTrue(!heavyHittersContainsSubString("spoofRA", 2));
                }
                finally {
                        rtplatform = platformOld;

http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/test/scripts/functions/codegen/rowAggPattern31.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern31.R 
b/src/test/scripts/functions/codegen/rowAggPattern31.R
new file mode 100644
index 0000000..036a3e2
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern31.R
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = matrix(seq(1,1500), 150, 10, byrow=TRUE);
+v = seq(1, ncol(X));
+R = cbind((X %*% v), matrix (7, nrow(X), 1))
+R = R - rowMaxs(R) %*% matrix(1, 1, ncol(R));
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/test/scripts/functions/codegen/rowAggPattern31.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern31.dml 
b/src/test/scripts/functions/codegen/rowAggPattern31.dml
new file mode 100644
index 0000000..8bdefc4
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern31.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(1,1500), 150, 10);
+v = seq(1, ncol(X));
+R = cbind((X %*% v), matrix (7, nrow(X), 1))
+R = R - rowMaxs (R)
+
+write(R, $1)

Reply via email to