Repository: systemml Updated Branches: refs/heads/master 4f29b3485 -> 323dd72a8
[SYSTEMML-1903] Fix robustness codegen row ops w/ unknowns This patch fixes special cases of codegen row templates with partial unknowns, which is important for robustness during initial compilation even though the unknowns led to dynamic recompilation during runtime. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/323dd72a Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/323dd72a Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/323dd72a Branch: refs/heads/master Commit: 323dd72a8ed18687aa3019387c4ab7b0598bd9d5 Parents: 4f29b34 Author: Matthias Boehm <mboe...@gmail.com> Authored: Thu Oct 19 15:07:54 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Thu Oct 19 16:06:14 2017 -0700 ---------------------------------------------------------------------- .../hops/codegen/template/TemplateRow.java | 38 ++++++++++---------- .../hops/codegen/template/TemplateUtils.java | 2 +- 2 files changed, 21 insertions(+), 19 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/323dd72a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index 0389983..e664b9f 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -250,7 +250,7 @@ public class TemplateRow extends TemplateBase else if (((AggUnaryOp)hop).getDirection() == Direction.Col && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) { //vector add without temporary copy if( cdata1 instanceof CNodeBinary && ((CNodeBinary)cdata1).getType().isVectorScalarPrimitive() ) - out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1), + out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1), ((CNodeBinary)cdata1).getType().getVectorAddPrimitive()); else out = cdata1; @@ -269,7 +269,7 @@ public class TemplateRow extends TemplateBase { //correct input under transpose cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals); - inHops.remove(hop.getInput().get(0)); + inHops.remove(hop.getInput().get(0)); inHops.add(hop.getInput().get(0).getInput().get(0)); //note: vectorMultAdd applicable to vector-scalar, and vector-vector @@ -310,7 +310,8 @@ public class TemplateRow extends TemplateBase CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); // if one input is a matrix then we need to do vector by scalar operations - if(hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1 ) + if(hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1 + || (!hop.dimsKnown() && cdata1.getDataType()==DataType.MATRIX ) ) { if( HopRewriteUtils.isUnary(hop, SUPPORTED_VECT_UNARY) ) { String opname = "VECT_"+((UnaryOp)hop).getOp().name(); @@ -320,12 +321,11 @@ public class TemplateRow extends TemplateBase } else throw new RuntimeException("Unsupported unary matrix " - + "operation: " + ((UnaryOp)hop).getOp().name()); + + "operation: " + ((UnaryOp)hop).getOp().name()); } else //general scalar case { cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0)); - String primitiveOpName = ((UnaryOp)hop).getOp().toString(); out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName)); } @@ -355,7 +355,9 @@ public class TemplateRow extends TemplateBase // if one input is a matrix then we need to do vector by scalar operations if( (hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1) - || (hop.getInput().get(1).getDim1() > 1 && hop.getInput().get(1).getDim2() > 1)) + || (hop.getInput().get(1).getDim1() > 1 && hop.getInput().get(1).getDim2() > 1) + || (!(hop.dimsKnown() && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown()) + && (cdata1.getDataType().isMatrix() || cdata2.getDataType().isMatrix()))) { if( HopRewriteUtils.isBinary(hop, SUPPORTED_VECT_BINARY) ) { if( TemplateUtils.isMatrix(cdata1) && (TemplateUtils.isMatrix(cdata2) @@ -371,14 +373,14 @@ public class TemplateRow extends TemplateBase cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R); out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname)); } - if( cdata1 instanceof CNodeData && inHops2.isEmpty() + if( cdata1 instanceof CNodeData && inHops2.isEmpty() && !(cdata1.getDataType()==DataType.SCALAR) ) { inHops2.put("X", hop.getInput().get(0)); } } else throw new RuntimeException("Unsupported binary matrix " - + "operation: " + ((BinaryOp)hop).getOp().name()); + + "operation: " + ((BinaryOp)hop).getOp().name()); } else //one input is a vector/scalar other is a scalar { @@ -389,7 +391,7 @@ public class TemplateRow extends TemplateBase || (TemplateUtils.isColVector(hop.getInput().get(0)) && cdata2 instanceof CNodeData && hop.getInput().get(1).getDataType().isMatrix())) cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R); - out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName)); + out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName)); } } else if(hop instanceof TernaryOp) @@ -405,16 +407,16 @@ public class TemplateRow extends TemplateBase //construct ternary cnode, primitive operation derived from OpOp3 out = new CNodeTernary(cdata1, cdata2, cdata3, - TernaryType.valueOf(top.getOp().toString())); + TernaryType.valueOf(top.getOp().toString())); } - else if( hop instanceof ParameterizedBuiltinOp ) + else if( hop instanceof ParameterizedBuiltinOp ) { CNode cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID()); cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0)); CNode cdata2 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID()); CNode cdata3 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID()); - TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? + TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? TernaryType.REPLACE_NAN : TernaryType.REPLACE; out = new CNodeTernary(cdata1, cdata2, cdata3, ttype); } @@ -422,7 +424,7 @@ public class TemplateRow extends TemplateBase { CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID()); out = new CNodeTernary(cdata1, - TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), + TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), (!hop.dimsKnown()||hop.getDim2()>1) ? TernaryType.LOOKUP_RVECT1 : TernaryType.LOOKUP_RC1); } @@ -456,13 +458,13 @@ public class TemplateRow extends TemplateBase @Override public int compare(Hop h1, Hop h2) { - long ncells1 = h1.isScalar() ? Long.MIN_VALUE : - (h1==_X) ? Long.MAX_VALUE : (h1==_B1) ? Long.MAX_VALUE-1 : + long ncells1 = h1.isScalar() ? Long.MIN_VALUE : + (h1==_X) ? Long.MAX_VALUE : (h1==_B1) ? Long.MAX_VALUE-1 : h1.dimsKnown() ? h1.getLength() : Long.MAX_VALUE-2; - long ncells2 = h2.isScalar() ? Long.MIN_VALUE : - (h2==_X) ? Long.MAX_VALUE : (h2==_B1) ? Long.MAX_VALUE-1 : + long ncells2 = h2.isScalar() ? Long.MIN_VALUE : + (h2==_X) ? Long.MAX_VALUE : (h2==_B1) ? Long.MAX_VALUE-1 : h2.dimsKnown() ? h2.getLength() : Long.MAX_VALUE-2; - return (ncells1 > ncells2) ? -1 : (ncells1 < ncells2) ? 1 : 0; + return (ncells1 > ncells2) ? -1 : (ncells1 < ncells2) ? 1 : 0; } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/323dd72a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index 497dae0..96e15cb 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -184,7 +184,7 @@ public class TemplateUtils public static RowType getRowType(Hop output, Hop... inputs) { Hop X = inputs[0]; Hop B1 = (inputs.length>1) ? inputs[1] : null; - if( (X!=null && HopRewriteUtils.isEqualSize(output, X)) || X==null ) + if( (X!=null && HopRewriteUtils.isEqualSize(output, X)) || X==null || !X.dimsKnown() ) return RowType.NO_AGG; else if( ((B1!=null && output.getDim1()==X.getDim1() && output.getDim2()==B1.getDim2()) || (output instanceof IndexingOp && HopRewriteUtils.isColumnRangeIndexing((IndexingOp)output)))