Repository: systemml Updated Branches: refs/heads/master 47ce14fc6 -> c1db484d6
[SYSTEMML-1933] Generalized codegen cbind handling in row-wise ops This patch generalizes the compilation of cbind operations in codegen row templates. So far, we only supported cbind with a vector of zeros, and cbind closed the row template. We now support cbind operations (with vectors of arbitrary constants) in the middle of row templates, which also allows for multiple cbind operations in a single fused operator. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c1db484d Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c1db484d Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c1db484d Branch: refs/heads/master Commit: c1db484d6119459f7ef6a566ff2663cca286f7ab Parents: 47ce14f Author: Matthias Boehm <[email protected]> Authored: Sun Sep 24 18:29:17 2017 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sun Sep 24 18:29:17 2017 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/DataGenOp.java | 4 +++ .../sysml/hops/codegen/SpoofCompiler.java | 2 ++ .../apache/sysml/hops/codegen/SpoofFusedOp.java | 21 ++++++++----- .../sysml/hops/codegen/cplan/CNodeBinary.java | 28 ++++++++++++++--- .../sysml/hops/codegen/cplan/CNodeRow.java | 30 +++++++++++------- .../sysml/hops/codegen/cplan/CNodeUnary.java | 6 +--- .../hops/codegen/template/TemplateRow.java | 15 ++++++--- .../hops/codegen/template/TemplateUtils.java | 6 ++-- .../sysml/hops/rewrite/HopRewriteUtils.java | 10 ++++++ .../runtime/codegen/LibSpoofPrimitives.java | 24 +++++++++++++++ .../sysml/runtime/codegen/SpoofRowwise.java | 20 +++++++----- .../instructions/spark/SpoofSPInstruction.java | 5 +-- .../functions/codegen/RowAggTmplTest.java | 20 +++++++++++- .../scripts/functions/codegen/rowAggPattern31.R | 32 ++++++++++++++++++++ .../functions/codegen/rowAggPattern31.dml | 27 +++++++++++++++++ 15 files changed, 202 insertions(+), 48 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/DataGenOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/DataGenOp.java b/src/main/java/org/apache/sysml/hops/DataGenOp.java index 89a5814..eb04ed3 100644 --- a/src/main/java/org/apache/sysml/hops/DataGenOp.java +++ b/src/main/java/org/apache/sysml/hops/DataGenOp.java @@ -434,6 +434,10 @@ public class DataGenOp extends Hop implements MultiThreadedHop return ret; } + public Hop getConstantValue() { + return getInput().get(_paramIndexMap.get(DataExpression.RAND_MIN)); + } + public void setIncrementValue(double incr) { _incr = incr; http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index b98342c..a374cf1 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -645,6 +645,8 @@ public class SpoofCompiler HopRewriteUtils.setOutputParametersForScalar(hnew); hnew = HopRewriteUtils.createUnary(hnew, OpOp1.CAST_AS_MATRIX); } + else if( tmpCNode instanceof CNodeRow && ((CNodeRow)tmpCNode).getRowType()==RowType.NO_AGG_CONST ) + ((SpoofFusedOp)hnew).setConstDim2(((CNodeRow)tmpCNode).getConstDim2()); if( !(tmpCNode instanceof CNodeMultiAgg) ) HopRewriteUtils.rewireAllParentChildReferences(hop, hnew); http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java index 247a142..81b226d 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java @@ -38,8 +38,8 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop { public enum SpoofOutputDimsType { INPUT_DIMS, + INPUT_DIMS_CONST2, ROW_DIMS, - ROW_DIMS2, COLUMN_DIMS_ROWS, COLUMN_DIMS_COLS, SCALAR, @@ -52,6 +52,7 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop private Class<?> _class = null; private boolean _distSupported = false; private int _numThreads = -1; + private long _constDim2 = -1; private SpoofOutputDimsType _dimsType; public SpoofFusedOp ( ) { @@ -82,6 +83,10 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop public boolean allowsAllExecTypes() { return _distSupported; } + + public void setConstDim2(long constDim2) { + _constDim2 = constDim2; + } @Override protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) { @@ -152,9 +157,6 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop case ROW_DIMS: ret = new long[]{mc.getRows(), 1, -1}; break; - case ROW_DIMS2: - ret = new long[]{mc.getRows(), 2, -1}; - break; case COLUMN_DIMS_ROWS: ret = new long[]{mc.getCols(), 1, -1}; break; @@ -164,6 +166,9 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop case INPUT_DIMS: ret = new long[]{mc.getRows(), mc.getCols(), -1}; break; + case INPUT_DIMS_CONST2: + ret = new long[]{mc.getRows(), _constDim2, -1}; + break; case SCALAR: ret = new long[]{0, 0, -1}; break; @@ -206,10 +211,6 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop setDim1(getInput().get(0).getDim1()); setDim2(1); break; - case ROW_DIMS2: - setDim1(getInput().get(0).getDim1()); - setDim2(2); - break; case COLUMN_DIMS_ROWS: setDim1(getInput().get(0).getDim2()); setDim2(1); @@ -222,6 +223,10 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop setDim1(getInput().get(0).getDim1()); setDim2(getInput().get(0).getDim2()); break; + case INPUT_DIMS_CONST2: + setDim1(getInput().get(0).getDim1()); + setDim2(_constDim2); + break; case SCALAR: setDim1(0); setDim2(0); http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java index 926dd4d..bff044d 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java @@ -40,6 +40,7 @@ public class CNodeBinary extends CNode VECT_POW_SCALAR, VECT_MIN_SCALAR, VECT_MAX_SCALAR, VECT_EQUAL_SCALAR, VECT_NOTEQUAL_SCALAR, VECT_LESS_SCALAR, VECT_LESSEQUAL_SCALAR, VECT_GREATER_SCALAR, VECT_GREATEREQUAL_SCALAR, + VECT_CBIND, //vector-vector operations VECT_MULT, VECT_DIV, VECT_MINUS, VECT_PLUS, VECT_MIN, VECT_MAX, VECT_EQUAL, VECT_NOTEQUAL, VECT_LESS, VECT_LESSEQUAL, VECT_GREATER, VECT_GREATEREQUAL, @@ -67,7 +68,7 @@ public class CNodeBinary extends CNode return ssComm || vsComm || vvComm; } - public String getTemplate(boolean sparse, boolean scalarVector) { + public String getTemplate(boolean sparse, boolean scalarVector, boolean scalarInput) { switch (this) { case DOT_PRODUCT: return sparse ? " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" : @@ -125,6 +126,14 @@ public class CNodeBinary extends CNode " double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %LEN%);\n"; } + case VECT_CBIND: + if( scalarInput ) + return " double[] %TMP% = LibSpoofPrimitives.vectCBindWrite(%IN1%, %IN2%);\n"; + else + return sparse ? + " double[] %TMP% = LibSpoofPrimitives.vectCBindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" : + " double[] %TMP% = LibSpoofPrimitives.vectCBindWrite(%IN1%, %IN2%, %POS1%, %LEN%);\n"; + //vector-vector operations case VECT_MULT: case VECT_DIV: @@ -202,7 +211,8 @@ public class CNodeBinary extends CNode || this == VECT_MIN_SCALAR || this == VECT_MAX_SCALAR || this == VECT_EQUAL_SCALAR || this == VECT_NOTEQUAL_SCALAR || this == VECT_LESS_SCALAR || this == VECT_LESSEQUAL_SCALAR - || this == VECT_GREATER_SCALAR || this == VECT_GREATEREQUAL_SCALAR; + || this == VECT_GREATER_SCALAR || this == VECT_GREATEREQUAL_SCALAR + || this == VECT_CBIND; } public boolean isVectorVectorPrimitive() { return this == VECT_DIV || this == VECT_MULT @@ -262,10 +272,11 @@ public class CNodeBinary extends CNode boolean lsparse = sparse && (_inputs.get(0) instanceof CNodeData && !_inputs.get(0).getVarname().startsWith("b") && !_inputs.get(0).isLiteral()); + boolean scalarInput = _inputs.get(0).getDataType().isScalar(); boolean scalarVector = (_inputs.get(0).getDataType().isScalar() && _inputs.get(1).getDataType().isMatrix()); String var = createVarname(); - String tmp = _type.getTemplate(lsparse, scalarVector); + String tmp = _type.getTemplate(lsparse, scalarVector, scalarInput); tmp = tmp.replace("%TMP%", var); //replace input references and start indexes @@ -354,6 +365,7 @@ public class CNodeBinary extends CNode case VECT_LESSEQUAL: return "b(v2lte)"; case VECT_GREATEREQUAL: return "b(v2gte)"; case VECT_GREATER: return "b(v2gt)"; + case VECT_CBIND: return "b(cbind)"; case MULT: return "b(*)"; case DIV: return "b(/)"; case PLUS: return "b(+)"; @@ -399,6 +411,12 @@ public class CNodeBinary extends CNode _dataType = DataType.MATRIX; break; + case VECT_CBIND: + _rows = _inputs.get(0)._rows; + _cols = _inputs.get(0)._cols+1; + _dataType = DataType.MATRIX; + break; + case VECT_OUTERMULT_ADD: _rows = _inputs.get(0)._cols; _cols = _inputs.get(1)._cols; @@ -465,9 +483,9 @@ public class CNodeBinary extends CNode case MIN: case MAX: case AND: - case OR: + case OR: case LOG: - case LOG_NZ: + case LOG_NZ: case POW: _rows = 0; _cols = 0; http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java index b74b79d..07822d9 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java @@ -23,7 +23,6 @@ import java.util.ArrayList; import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType; import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType; -import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; import org.apache.sysml.hops.codegen.template.TemplateUtils; import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType; import org.apache.sysml.runtime.util.UtilFunctions; @@ -40,7 +39,7 @@ public class CNodeRow extends CNodeTpl + "\n" + "public final class %TMP% extends SpoofRowwise { \n" + " public %TMP%() {\n" - + " super(RowType.%TYPE%, %CBIND0%, %TB1%, %VECT_MEM%);\n" + + " super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);\n" + " }\n" + " protected void genexec(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int len, int rowIndex) { \n" + "%BODY_dense%" @@ -59,6 +58,7 @@ public class CNodeRow extends CNodeTpl } private RowType _type = null; //access pattern + private long _constDim2 = -1; //constant number of output columns private int _numVectors = -1; //number of intermediate vectors public void setRowType(RowType type) { @@ -79,6 +79,14 @@ public class CNodeRow extends CNodeTpl return _numVectors; } + public void setConstDim2(long dim2) { + _constDim2 = dim2; + } + + public long getConstDim2() { + return _constDim2; + } + @Override public void renameInputs() { rRenameDataNode(_output, _inputs.get(0), "a"); // input matrix @@ -109,8 +117,7 @@ public class CNodeRow extends CNodeTpl //replace colvector information and number of vector intermediates tmp = tmp.replace("%TYPE%", _type.name()); - tmp = tmp.replace("%CBIND0%", String.valueOf( - TemplateUtils.isUnary(_output, UnaryType.CBIND0))); + tmp = tmp.replace("%CONST_DIM2%", String.valueOf(_constDim2)); tmp = tmp.replace("%TB1%", String.valueOf( TemplateUtils.containsBinary(_output, BinType.VECT_MATRIXMULT))); tmp = tmp.replace("%VECT_MEM%", String.valueOf(_numVectors)); @@ -122,6 +129,7 @@ public class CNodeRow extends CNodeTpl switch( _type ) { case NO_AGG: case NO_AGG_B1: + case NO_AGG_CONST: return TEMPLATE_NOAGG_OUT.replace("%IN%", varName) .replace("%LEN%", _output.getVarname()+".length"); case FULL_AGG: @@ -142,13 +150,13 @@ public class CNodeRow extends CNodeTpl @Override public SpoofOutputDimsType getOutputDimType() { switch( _type ) { - case NO_AGG: return SpoofOutputDimsType.INPUT_DIMS; - case NO_AGG_B1: return SpoofOutputDimsType.ROW_RANK_DIMS; - case FULL_AGG: return SpoofOutputDimsType.SCALAR; - case ROW_AGG: return TemplateUtils.isUnary(_output, UnaryType.CBIND0) ? - SpoofOutputDimsType.ROW_DIMS2 : SpoofOutputDimsType.ROW_DIMS; - case COL_AGG: return SpoofOutputDimsType.COLUMN_DIMS_COLS; //row vector - case COL_AGG_T: return SpoofOutputDimsType.COLUMN_DIMS_ROWS; //column vector + case NO_AGG: return SpoofOutputDimsType.INPUT_DIMS; + case NO_AGG_B1: return SpoofOutputDimsType.ROW_RANK_DIMS; + case NO_AGG_CONST: return SpoofOutputDimsType.INPUT_DIMS_CONST2; + case FULL_AGG: return SpoofOutputDimsType.SCALAR; + case ROW_AGG: return SpoofOutputDimsType.ROW_DIMS; + case COL_AGG: return SpoofOutputDimsType.COLUMN_DIMS_COLS; //row vector + case COL_AGG_T: return SpoofOutputDimsType.COLUMN_DIMS_ROWS; //column vector case COL_AGG_B1: return SpoofOutputDimsType.COLUMN_RANK_DIMS; case COL_AGG_B1_T: return SpoofOutputDimsType.COLUMN_RANK_DIMS_T; default: http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java index 4bfb74b..343efb5 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java @@ -28,7 +28,7 @@ import org.apache.sysml.runtime.util.UtilFunctions; public class CNodeUnary extends CNode { public enum UnaryType { - LOOKUP_R, LOOKUP_C, LOOKUP_RC, LOOKUP0, CBIND0, //codegen specific + LOOKUP_R, LOOKUP_C, LOOKUP_RC, LOOKUP0, //codegen specific ROW_SUMS, ROW_MINS, ROW_MAXS, //codegen specific VECT_EXP, VECT_POW2, VECT_MULT2, VECT_SQRT, VECT_LOG, VECT_ABS, VECT_ROUND, VECT_CEIL, VECT_FLOOR, VECT_SIGN, @@ -94,8 +94,6 @@ public class CNodeUnary extends CNode return " double %TMP% = getValue(%IN1%, n, rowIndex, colIndex);\n"; case LOOKUP0: return " double %TMP% = %IN1%[0];\n" ; - case CBIND0: - return " double %TMP% = %IN1%; rowIndex *= 2;\n" ; case POW2: return " double %TMP% = %IN1% * %IN1%;\n" ; case MULT2: @@ -266,7 +264,6 @@ public class CNodeUnary extends CNode case LOOKUP_C: return "u(ixc)"; case LOOKUP_RC: return "u(ixrc)"; case LOOKUP0: return "u(ix0)"; - case CBIND0: return "u(cbind0)"; case POW2: return "^2"; default: return "u("+_type.name().toLowerCase()+")"; } @@ -310,7 +307,6 @@ public class CNodeUnary extends CNode case LOOKUP_C: case LOOKUP_RC: case LOOKUP0: - case CBIND0: case POW2: case MULT2: case ABS: http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index de94969..d9209be 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -50,6 +50,7 @@ import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.Hop.OpOp1; import org.apache.sysml.hops.Hop.OpOp2; import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType; import org.apache.sysml.runtime.matrix.data.LibMatrixMult; import org.apache.sysml.runtime.matrix.data.Pair; @@ -76,6 +77,8 @@ public class TemplateRow extends TemplateBase public boolean open(Hop hop) { return (hop instanceof BinaryOp && hop.dimsKnown() && isValidBinaryOperation(hop) && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1) + || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix() + && HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1))) || (hop instanceof AggBinaryOp && hop.dimsKnown() && hop.getDim2()==1 //MV && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1) || (hop instanceof AggBinaryOp && hop.dimsKnown() && LibMatrixMult.isSkinnyRightHandSide( @@ -98,8 +101,7 @@ public class TemplateRow extends TemplateBase return !isClosed() && ( (hop instanceof BinaryOp && isValidBinaryOperation(hop) ) || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().indexOf(input)==0 - && input.getDim2()==1 && hop.getInput().get(1).getDim2()==1 - && HopRewriteUtils.isEmpty(hop.getInput().get(1))) + && HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1))) || ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp) && TemplateCell.isValidOperation(hop)) || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol @@ -130,8 +132,7 @@ public class TemplateRow extends TemplateBase public CloseType close(Hop hop) { //close on column or full aggregate (e.g., colSums, t(X)%*%y) if( (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.Row) - || (hop instanceof AggBinaryOp && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) - || HopRewriteUtils.isBinary(hop, OpOp2.CBIND) ) + || (hop instanceof AggBinaryOp && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)))) return CloseType.CLOSED_VALID; else return CloseType.OPEN; @@ -192,6 +193,8 @@ public class TemplateRow extends TemplateBase CNodeRow tpl = new CNodeRow(inputs, output); tpl.setRowType(TemplateUtils.getRowType(hop, inHops2.get("X"), inHops2.get("B1"))); + if( tpl.getRowType()==RowType.NO_AGG_CONST ) + tpl.setConstDim2(hop.getDim2()); tpl.setNumVectorIntermediates(TemplateUtils .determineMinVectorIntermediates(output)); tpl.getOutput().resetVisitStatus(); @@ -323,7 +326,9 @@ public class TemplateRow extends TemplateBase { //special case for cbind with zeros CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); - out = new CNodeUnary(cdata1, UnaryType.CBIND0); + CNode cdata2 = TemplateUtils.createCNodeData( + HopRewriteUtils.getDataGenOpConstantValue(hop.getInput().get(1)), true); + out = new CNodeBinary(cdata1, cdata2, BinType.VECT_CBIND); inHops.remove(hop.getInput().get(1)); //rm 0-matrix } else if(hop instanceof BinaryOp) http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index 1924914..95383e6 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -32,7 +32,6 @@ import org.apache.sysml.hops.ParameterizedBuiltinOp; import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; -import org.apache.sysml.hops.Hop.OpOp2; import org.apache.sysml.hops.IndexingOp; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.codegen.cplan.CNode; @@ -190,8 +189,7 @@ public class TemplateUtils || (output instanceof IndexingOp && HopRewriteUtils.isColumnRangeIndexing((IndexingOp)output))) && !(output instanceof AggBinaryOp && HopRewriteUtils.isTransposeOfItself(output.getInput().get(0),X)) ) return RowType.NO_AGG_B1; - else if( output.getDim1()==X.getDim1() && (output.getDim2()==1 - || HopRewriteUtils.isBinary(output, OpOp2.CBIND)) + else if( output.getDim1()==X.getDim1() && (output.getDim2()==1) && !(output instanceof AggBinaryOp && HopRewriteUtils .isTransposeOfItself(output.getInput().get(0),X))) return RowType.ROW_AGG; @@ -206,6 +204,8 @@ public class TemplateUtils return RowType.COL_AGG_B1_T; else if( B1 != null && output.getDim1()==B1.getDim2() && output.getDim2()==X.getDim2()) return RowType.COL_AGG_B1; + else if( X.getDim1() == output.getDim1() && X.getDim2() != output.getDim2() ) + return RowType.NO_AGG_CONST; else throw new RuntimeException("Unknown row type for hop "+output.getHopID()+"."); } http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 7093b0e..af9d593 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -474,12 +474,22 @@ public class HopRewriteUtils && ArrayUtils.contains(ops, ((DataGenOp)hop).getOp())); } + public static boolean isDataGenOpWithConstantValue(Hop hop) { + return hop instanceof DataGenOp + && ((DataGenOp)hop).getOp()==DataGenMethod.RAND + && ((DataGenOp)hop).hasConstantValue(); + } + public static boolean isDataGenOpWithConstantValue(Hop hop, double value) { return hop instanceof DataGenOp && ((DataGenOp)hop).getOp()==DataGenMethod.RAND && ((DataGenOp)hop).hasConstantValue(value); } + public static Hop getDataGenOpConstantValue(Hop hop) { + return ((DataGenOp) hop).getConstantValue(); + } + public static ReorgOp createTranspose(Hop input) { return createReorg(input, ReOrgOp.TRANSPOSE); } http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java index 1a17793..6b4aad7 100644 --- a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java +++ b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java @@ -193,6 +193,30 @@ public class LibSpoofPrimitives for( int i=0; i<aix.length; i++ ) c[aix[i]] = a[i]; } + + // cbind handling + + public static double[] vectCBindWrite(double a, double b) { + double[] c = allocVector(2, false); + c[0] = a; + c[1] = b; + return c; + } + + public static double[] vectCBindWrite(double[] a, double b, int aix, int len) { + double[] c = allocVector(len+1, false); + System.arraycopy(a, aix, c, 0, len); + c[len] = b; + return c; + } + + public static double[] vectCBindWrite(double[] a, double b, int[] aix, int ai, int alen, int len) { + double[] c = allocVector(len+1, false); + for( int j = ai; j < ai+alen; j++ ) + c[aix[j]] = a[j]; + c[len] = b; + return c; + } // custom vector sums, mins, maxs http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java index 659059e..2464b15 100644 --- a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java +++ b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java @@ -49,6 +49,7 @@ public abstract class SpoofRowwise extends SpoofOperator public enum RowType { NO_AGG, //no aggregation NO_AGG_B1, //no aggregation w/ matrix mult B1 + NO_AGG_CONST, //no aggregation w/ expansion/contraction FULL_AGG, //full row/col aggregation ROW_AGG, //row aggregation (e.g., rowSums() or X %*% v) COL_AGG, //col aggregation (e.g., colSums() or t(y) %*% X) @@ -69,13 +70,13 @@ public abstract class SpoofRowwise extends SpoofOperator } protected final RowType _type; - protected final boolean _cbind0; + protected final long _constDim2; protected final boolean _tB1; protected final int _reqVectMem; - public SpoofRowwise(RowType type, boolean cbind0, boolean tB1, int reqVectMem) { + public SpoofRowwise(RowType type, long constDim2, boolean tB1, int reqVectMem) { _type = type; - _cbind0 = cbind0; + _constDim2 = constDim2; _tB1 = tB1; _reqVectMem = reqVectMem; } @@ -84,8 +85,8 @@ public abstract class SpoofRowwise extends SpoofOperator return _type; } - public boolean isCBind0() { - return _cbind0; + public long getConstDim2() { + return _constDim2; } public int getNumIntermediates() { @@ -124,7 +125,8 @@ public abstract class SpoofRowwise extends SpoofOperator //result allocation and preparations final int m = inputs.get(0).getNumRows(); final int n = inputs.get(0).getNumColumns(); - final int n2 = _type.isRowTypeB1() || hasMatrixSideInput(inputs) ? + final int n2 = (_type==RowType.NO_AGG_CONST) ? (int)_constDim2 : + _type.isRowTypeB1() || hasMatrixSideInput(inputs) ? getMinColsMatrixSideInputs(inputs) : -1; if( !aggIncr || !out.isAllocated() ) allocateOutputMatrix(m, n, n2, out); @@ -179,7 +181,8 @@ public abstract class SpoofRowwise extends SpoofOperator //result allocation and preparations final int m = inputs.get(0).getNumRows(); final int n = inputs.get(0).getNumColumns(); - final int n2 = _type.isRowTypeB1() || hasMatrixSideInput(inputs) ? + final int n2 = (_type==RowType.NO_AGG_CONST) ? (int)_constDim2 : + _type.isRowTypeB1() || hasMatrixSideInput(inputs) ? getMinColsMatrixSideInputs(inputs) : -1; allocateOutputMatrix(m, n, n2, out); final boolean flipOut = _type.isRowTypeB1ColumnAgg() @@ -258,8 +261,9 @@ public abstract class SpoofRowwise extends SpoofOperator switch( _type ) { case NO_AGG: out.reset(m, n, false); break; case NO_AGG_B1: out.reset(m, n2, false); break; + case NO_AGG_CONST: out.reset(m, (int)_constDim2, false); break; case FULL_AGG: out.reset(1, 1, false); break; - case ROW_AGG: out.reset(m, 1+(_cbind0?1:0), false); break; + case ROW_AGG: out.reset(m, 1, false); break; case COL_AGG: out.reset(1, n, false); break; case COL_AGG_T: out.reset(n, 1, false); break; case COL_AGG_B1: out.reset(n2, n, false); break; http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java index 2d609aa..a628dc0 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java @@ -354,7 +354,7 @@ public class SpoofSPInstruction extends SPInstruction { if( type == RowType.NO_AGG ) mcOut.set(mcIn); else if( type == RowType.ROW_AGG ) - mcOut.set(mcIn.getRows(), ((SpoofRowwise)op).isCBind0()? 2:1, + mcOut.set(mcIn.getRows(), 1, mcIn.getRowsPerBlock(), mcIn.getColsPerBlock()); else if( type == RowType.COL_AGG ) mcOut.set(1, mcIn.getCols(), mcIn.getRowsPerBlock(), mcIn.getColsPerBlock()); @@ -454,7 +454,8 @@ public class SpoofSPInstruction extends SPInstruction { } //setup local memory for reuse - int clen2 = (int) (_op.getRowType().isRowTypeB1() ? _inputs.get(0).getNumCols() : -1); + int clen2 = (int) ((_op.getRowType()==RowType.NO_AGG_CONST) ? _op.getConstDim2() : + _op.getRowType().isRowTypeB1() ? _inputs.get(0).getNumCols() : -1); LibSpoofPrimitives.setupThreadLocalMemory(_op.getNumIntermediates(), _clen, clen2); ArrayList<Tuple2<MatrixIndexes,MatrixBlock>> ret = new ArrayList<Tuple2<MatrixIndexes,MatrixBlock>>(); http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java index 3ecfd6b..d4f87b3 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java @@ -67,6 +67,7 @@ public class RowAggTmplTest extends AutomatedTestBase private static final String TEST_NAME28 = TEST_NAME+"28"; //Kmeans, final eval private static final String TEST_NAME29 = TEST_NAME+"29"; //sum(rowMins(X)) private static final String TEST_NAME30 = TEST_NAME+"30"; //Mlogreg inner core, multi-class + private static final String TEST_NAME31 = TEST_NAME+"31"; //MLogreg - matrix-vector cbind 0s generalized private static final String TEST_DIR = "functions/codegen/"; private static final String TEST_CLASS_DIR = TEST_DIR + RowAggTmplTest.class.getSimpleName() + "/"; @@ -78,7 +79,7 @@ public class RowAggTmplTest extends AutomatedTestBase @Override public void setUp() { TestUtils.clearAssertionInformation(); - for(int i=1; i<=30; i++) + for(int i=1; i<=31; i++) addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) ); } @@ -532,6 +533,21 @@ public class RowAggTmplTest extends AutomatedTestBase testCodegenIntegration( TEST_NAME30, false, ExecType.SPARK ); } + @Test + public void testCodegenRowAggRewrite31CP() { + testCodegenIntegration( TEST_NAME31, true, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg31CP() { + testCodegenIntegration( TEST_NAME31, false, ExecType.CP ); + } + + @Test + public void testCodegenRowAgg31SP() { + testCodegenIntegration( TEST_NAME31, false, ExecType.SPARK ); + } + private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; @@ -581,6 +597,8 @@ public class RowAggTmplTest extends AutomatedTestBase if( testname.equals(TEST_NAME30) ) Assert.assertTrue(!heavyHittersContainsSubString("spoofRA", 2) && !heavyHittersContainsSubString(RightIndex.OPCODE)); + if( testname.equals(TEST_NAME31) ) + Assert.assertTrue(!heavyHittersContainsSubString("spoofRA", 2)); } finally { rtplatform = platformOld; http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/test/scripts/functions/codegen/rowAggPattern31.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern31.R b/src/test/scripts/functions/codegen/rowAggPattern31.R new file mode 100644 index 0000000..036a3e2 --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern31.R @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") + +X = matrix(seq(1,1500), 150, 10, byrow=TRUE); +v = seq(1, ncol(X)); +R = cbind((X %*% v), matrix (7, nrow(X), 1)) +R = R - rowMaxs(R) %*% matrix(1, 1, ncol(R)); + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/test/scripts/functions/codegen/rowAggPattern31.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern31.dml b/src/test/scripts/functions/codegen/rowAggPattern31.dml new file mode 100644 index 0000000..8bdefc4 --- /dev/null +++ b/src/test/scripts/functions/codegen/rowAggPattern31.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(seq(1,1500), 150, 10); +v = seq(1, ncol(X)); +R = cbind((X %*% v), matrix (7, nrow(X), 1)) +R = R - rowMaxs (R) + +write(R, $1)
