[SYSTEMML-1372] Extended code generator (celltmpl w/ matmult as root) In order to handle the case where previous rewrites modified sum(X^2) - which is fusable to any cellwise operations that created X - to t(X)%*%X, we now also support dot products as roots of cellwise templates (cellwise multiply and final aggregation). Furthermore, this patch adds tests for this template and fixes various issues where we unnecessarily compiled fused operators for binary and ternary operations.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/31cb2531 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/31cb2531 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/31cb2531 Branch: refs/heads/master Commit: 31cb253135ce550ab1d48aee464c94aa3ec41eef Parents: 96505b1 Author: Matthias Boehm <mboe...@gmail.com> Authored: Thu Mar 16 21:24:50 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Fri Mar 17 12:52:39 2017 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/UnaryOp.java | 5 +- .../sysml/hops/codegen/SpoofCompiler.java | 26 ++++++--- .../sysml/hops/codegen/cplan/CNodeCell.java | 15 ++++- .../sysml/hops/codegen/template/CellTpl.java | 36 ++++++++++-- .../hops/codegen/template/TemplateUtils.java | 21 ++++++- .../sysml/hops/rewrite/HopRewriteUtils.java | 6 +- .../sysml/hops/rewrite/ProgramRewriter.java | 14 +++-- .../org/apache/sysml/parser/DMLTranslator.java | 60 ++++---------------- .../functions/codegen/AlgorithmLinregCG.java | 4 +- .../functions/codegen/CellwiseTmplTest.java | 34 ++++++++--- .../org/apache/sysml/test/utils/TestUtils.java | 2 +- .../scripts/functions/codegen/cellwisetmpl6.R | 11 ++-- .../scripts/functions/codegen/cellwisetmpl6.dml | 32 +---------- .../scripts/functions/codegen/cellwisetmpl7.R | 33 +++++++++++ .../scripts/functions/codegen/cellwisetmpl7.dml | 30 ++++++++++ 15 files changed, 206 insertions(+), 123 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/31cb2531/src/main/java/org/apache/sysml/hops/UnaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/UnaryOp.java b/src/main/java/org/apache/sysml/hops/UnaryOp.java index ac306b3..c75d0e0 100644 --- a/src/main/java/org/apache/sysml/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysml/hops/UnaryOp.java @@ -60,14 +60,11 @@ public class UnaryOp extends Hop implements MultiThreadedHop //default constructor for clone } - public UnaryOp(String l, DataType dt, ValueType vt, OpOp1 o, Hop inp) - throws HopsException - { + public UnaryOp(String l, DataType dt, ValueType vt, OpOp1 o, Hop inp) { super(l, dt, vt); getInput().add(0, inp); inp.getParent().add(this); - _op = o; //compute unknown dims and nnz http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/31cb2531/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index ec24d41..54e67b6 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -46,11 +46,13 @@ import org.apache.sysml.hops.codegen.template.CPlanMemoTable; import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry; import org.apache.sysml.hops.codegen.template.TemplateUtils; import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.Hop.OpOp1; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.hops.rewrite.ProgramRewriteStatus; import org.apache.sysml.hops.rewrite.ProgramRewriter; import org.apache.sysml.hops.rewrite.RewriteCommonSubexpressionElimination; +import org.apache.sysml.hops.rewrite.RewriteRemoveUnnecessaryCasts; import org.apache.sysml.parser.DMLProgram; import org.apache.sysml.parser.ForStatement; import org.apache.sysml.parser.ForStatementBlock; @@ -87,7 +89,9 @@ public class SpoofCompiler //for equal operators from (1) different hop dags and (2) repeated recompilation private static ConcurrentHashMap<CNode, Class<?>> planCache = new ConcurrentHashMap<CNode, Class<?>>(); - private static ProgramRewriter rewriteCSE = new ProgramRewriter(new RewriteCommonSubexpressionElimination(true)); + private static ProgramRewriter rewriteCSE = new ProgramRewriter( + new RewriteCommonSubexpressionElimination(true), + new RewriteRemoveUnnecessaryCasts()); public static void generateCode(DMLProgram dmlp) throws LanguageException, HopsException, DMLRuntimeException @@ -217,7 +221,7 @@ public class SpoofCompiler //cleanup codegen plans (remove unnecessary inputs, fix hop-cnodedata mapping, //remove empty templates with single cnodedata input, remove spurious lookups) cplans = cleanupCPlans(cplans); - + //explain before modification if( LDEBUG && !cplans.isEmpty() ) { //existing cplans LOG.info("Codegen EXPLAIN (before optimize): \n"+Explain.explainHops(roots)); @@ -265,7 +269,7 @@ public class SpoofCompiler //generate final hop dag ret = constructModifiedHopDag(roots, cplans, clas); - //run common subexpression elimination + //run common subexpression elimination and other rewrites ret = rewriteCSE.rewriteHopDAGs(ret, new ProgramRewriteStatus()); //explain after modification @@ -438,11 +442,15 @@ public class SpoofCompiler else hnew.addInput(inHops[i]); //add inputs } - hnew.setOutputBlocksizes(hop.getRowsInBlock() , hop.getColsInBlock()); - hnew.setDim1(hop.getDim1()); - hnew.setDim2(hop.getDim2()); - if(tmpCNode instanceof CNodeOuterProduct && ((CNodeOuterProduct)tmpCNode).isTransposeOutput() ) { + + //modify output parameters + HopRewriteUtils.setOutputParameters(hnew, hop.getDim1(), hop.getDim2(), + hop.getRowsInBlock(), hop.getColsInBlock(), hop.getNnz()); + if(tmpCNode instanceof CNodeOuterProduct && ((CNodeOuterProduct)tmpCNode).isTransposeOutput() ) hnew = HopRewriteUtils.createTranspose(hnew); + else if( tmpCNode instanceof CNodeCell && ((CNodeCell)tmpCNode).requiredCastDtm() ) { + HopRewriteUtils.setOutputParametersForScalar(hnew); + hnew = HopRewriteUtils.createUnary(hnew, OpOp1.CAST_AS_MATRIX); } HopRewriteUtils.rewireAllParentChildReferences(hop, hnew); @@ -496,9 +504,9 @@ public class SpoofCompiler //remove cplan w/ single op and w/o agg if( tpl instanceof CNodeCell && ((CNodeCell)tpl).getCellType()==CellType.NO_AGG - && tpl.getOutput() instanceof CNodeUnary && tpl.getOutput().getInput().get(0) instanceof CNodeData) + && TemplateUtils.hasSingleOperation(tpl) ) cplans2.remove(e.getKey()); - + //remove cplan if empty if( tpl.getOutput() instanceof CNodeData ) cplans2.remove(e.getKey()); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/31cb2531/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java index 493c2dd..8122078 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java +++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeCell.java @@ -48,6 +48,7 @@ public class CNodeCell extends CNodeTpl + "}"; private CellType _type = null; + private boolean _requiresCastdtm = false; private boolean _multipleConsumers = false; public CNodeCell(ArrayList<CNode> inputs, CNode output ) { @@ -71,6 +72,15 @@ public class CNodeCell extends CNodeTpl return _type; } + public void setRequiresCastDtm(boolean flag) { + _requiresCastdtm = flag; + _hash = 0; + } + + public boolean requiredCastDtm() { + return _requiresCastdtm; + } + @Override public String codegen(boolean sparse) { String tmp = TEMPLATE; @@ -126,8 +136,9 @@ public class CNodeCell extends CNodeTpl if( _hash == 0 ) { int h1 = super.hashCode(); int h2 = _type.hashCode(); + int h3 = Boolean.valueOf(_requiresCastdtm).hashCode(); //note: _multipleConsumers irrelevant for plan comparison - _hash = Arrays.hashCode(new int[]{h1,h2}); + _hash = Arrays.hashCode(new int[]{h1,h2,h3}); } return _hash; } @@ -140,6 +151,7 @@ public class CNodeCell extends CNodeTpl CNodeCell that = (CNodeCell)o; return super.equals(that) && _type == that._type + && _requiresCastdtm == that._requiresCastdtm && equalInputReferences( _output, that._output, _inputs, that._inputs); } @@ -149,6 +161,7 @@ public class CNodeCell extends CNodeTpl StringBuilder sb = new StringBuilder(); sb.append("SPOOF CELLWISE [type="); sb.append(_type.name()); + sb.append(", castdtm="+_requiresCastdtm); sb.append(", mc="+_multipleConsumers); sb.append("]"); return sb.toString(); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/31cb2531/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java b/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java index 6a3cd83..1aa1909 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CellTpl.java @@ -24,6 +24,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; +import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.Hop; @@ -61,8 +62,9 @@ public class CellTpl extends BaseTpl @Override public boolean fuse(Hop hop, Hop input) { return !isClosed() && (isValidOperation(hop) - || ( hop instanceof AggUnaryOp && ((AggUnaryOp) hop).getOp() == AggOp.SUM - && ((AggUnaryOp) hop).getDirection()!= Direction.Col)); + || (HopRewriteUtils.isSum(hop) && ((AggUnaryOp) hop).getDirection()!= Direction.Col) + || (HopRewriteUtils.isMatrixMultiply(hop) && hop.getDim1()==1 && hop.getDim2()==1) + && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))); } @Override @@ -74,9 +76,10 @@ public class CellTpl extends BaseTpl @Override public CloseType close(Hop hop) { //need to close cell tpl after aggregation, see fuse for exact properties - if( hop instanceof AggUnaryOp && isValidOperation(hop.getInput().get(0)) ) + if( (HopRewriteUtils.isSum(hop) && ((AggUnaryOp) hop).getDirection()!= Direction.Col) + || (HopRewriteUtils.isMatrixMultiply(hop) && hop.getDim1()==1 && hop.getDim2()==1) ) return CloseType.CLOSED_VALID; - else if( hop instanceof AggUnaryOp ) + else if( hop instanceof AggUnaryOp || hop instanceof AggBinaryOp ) return CloseType.CLOSED_INVALID; else return CloseType.OPEN; @@ -109,6 +112,7 @@ public class CellTpl extends BaseTpl CNode output = tmp.get(hop.getHopID()); CNodeCell tpl = new CNodeCell(inputs, output); tpl.setCellType(TemplateUtils.getCellType(hop)); + tpl.setRequiresCastDtm(hop instanceof AggBinaryOp); // return cplan instance return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), tpl); @@ -191,12 +195,34 @@ public class CellTpl extends BaseTpl out = new CNodeTernary(cdata1, cdata2, cdata3, TernaryType.valueOf(top.getOp().toString())); } - else if (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp() == AggOp.SUM + else if( HopRewriteUtils.isTransposeOperation(hop) ) + { + out = tmp.get(hop.getInput().get(0).getHopID()); + } + else if( hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp() == AggOp.SUM && (((AggUnaryOp) hop).getDirection() == Direction.RowCol || ((AggUnaryOp) hop).getDirection() == Direction.Row) ) { out = tmp.get(hop.getInput().get(0).getHopID()); } + else if( hop instanceof AggBinaryOp ) { + //guaranteed to be a dot product, so there are two cases: + //(1) t(X)%*%X -> sum(X^2) and t(X) %*% Y -> sum(X*Y) + if( HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop.getInput().get(1)) ) { + CNode cdata1 = tmp.get(hop.getInput().get(1).getHopID()); + out = new CNodeUnary(cdata1, UnaryType.POW2); + } + else { + CNode cdata1 = TemplateUtils.skipTranspose(tmp.get(hop.getInput().get(0).getHopID()), + hop.getInput().get(0), tmp, compileLiterals); + if( TemplateUtils.isColVector(cdata1) ) + cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); + CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); + if( TemplateUtils.isColVector(cdata2) ) + cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R); + out = new CNodeBinary(cdata1, cdata2, BinType.MULT); + } + } tmp.put(hop.getHopID(), out); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/31cb2531/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index 80216fd..862c4f6 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -31,17 +31,19 @@ import org.apache.sysml.hops.BinaryOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.TernaryOp; -import org.apache.sysml.hops.Hop.AggOp; import org.apache.sysml.hops.Hop.Direction; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.codegen.cplan.CNode; +import org.apache.sysml.hops.codegen.cplan.CNodeBinary; import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType; import org.apache.sysml.hops.codegen.cplan.CNodeData; +import org.apache.sysml.hops.codegen.cplan.CNodeTernary; import org.apache.sysml.hops.codegen.cplan.CNodeUnary; import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType; import org.apache.sysml.hops.codegen.template.BaseTpl.TemplateType; import org.apache.sysml.hops.rewrite.HopRewriteUtils; import org.apache.sysml.hops.codegen.cplan.CNodeTernary.TernaryType; +import org.apache.sysml.hops.codegen.cplan.CNodeTpl; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType; import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType; @@ -188,8 +190,8 @@ public class TemplateUtils } public static CellType getCellType(Hop hop) { - return (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp() == AggOp.SUM) ? - ((((AggUnaryOp) hop).getDirection() == Direction.RowCol) ? + return (hop instanceof AggBinaryOp) ? CellType.FULL_AGG : + HopRewriteUtils.isSum(hop) ? ((((AggUnaryOp) hop).getDirection() == Direction.RowCol) ? CellType.FULL_AGG : CellType.ROW_AGG) : CellType.NO_AGG; } @@ -248,4 +250,17 @@ public class TemplateUtils return true; return false; } + + public static boolean hasSingleOperation(CNodeTpl tpl) { + CNode output = tpl.getOutput(); + return (output instanceof CNodeUnary || output instanceof CNodeBinary + || output instanceof CNodeTernary) && hasOnlyDataNodeInputs(output); + } + + public static boolean hasOnlyDataNodeInputs(CNode node) { + boolean ret = true; + for( CNode c : node.getInput() ) + ret &= (c instanceof CNodeData); + return ret; + } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/31cb2531/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index c644773..8478fca 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -467,7 +467,6 @@ public class HopRewriteUtils } public static UnaryOp createUnary(Hop input, OpOp1 type) - throws HopsException { DataType dt = (type==OpOp1.CAST_AS_SCALAR) ? DataType.SCALAR : (type==OpOp1.CAST_AS_MATRIX) ? DataType.MATRIX : input.getDataType(); @@ -593,6 +592,7 @@ public class HopRewriteUtils } public static void setOutputParametersForScalar( Hop hop ) { + hop.setDataType(DataType.SCALAR); hop.setDim1( 0 ); hop.setDim2( 0 ); hop.setOutputBlocksizes(-1, -1 ); @@ -795,6 +795,10 @@ public class HopRewriteUtils return hop instanceof AggBinaryOp && ((AggBinaryOp)hop).isMatrixMultiply(); } + public static boolean isSum(Hop hop) { + return (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getOp()==AggOp.SUM); + } + public static boolean isNonZeroIndicator(Hop pred, Hop hop ) { if( pred instanceof BinaryOp && ((BinaryOp)pred).getOp()==OpOp2.NOTEQUAL http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/31cb2531/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java index 9a36d0a..a34ccbe 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java @@ -142,13 +142,14 @@ public class ProgramRewriter /** * Construct a program rewriter for a given rewrite which is passed from outside. * - * @param rewrite the HOP rewrite rule + * @param rewrites the HOP rewrite rules */ - public ProgramRewriter( HopRewriteRule rewrite ) + public ProgramRewriter( HopRewriteRule... rewrites ) { //initialize HOP DAG rewrite ruleSet (with fixed rewrite order) _dagRuleSet = new ArrayList<HopRewriteRule>(); - _dagRuleSet.add( rewrite ); + for( HopRewriteRule rewrite : rewrites ) + _dagRuleSet.add( rewrite ); _sbRuleSet = new ArrayList<StatementBlockRewriteRule>(); } @@ -156,15 +157,16 @@ public class ProgramRewriter /** * Construct a program rewriter for a given rewrite which is passed from outside. * - * @param rewrite the statement block rewrite rule + * @param rewrites the statement block rewrite rules */ - public ProgramRewriter( StatementBlockRewriteRule rewrite ) + public ProgramRewriter( StatementBlockRewriteRule... rewrites ) { //initialize HOP DAG rewrite ruleSet (with fixed rewrite order) _dagRuleSet = new ArrayList<HopRewriteRule>(); _sbRuleSet = new ArrayList<StatementBlockRewriteRule>(); - _sbRuleSet.add( rewrite ); + for( StatementBlockRewriteRule rewrite : rewrites ) + _sbRuleSet.add( rewrite ); } /** http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/31cb2531/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index c0eb30f..b99dd37 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -1496,16 +1496,9 @@ public class DMLTranslator { if ( target.getDim1() != -1 ) rowUpperHops = new LiteralOp(target.getOrigDim1()); - else - { - try { - //currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hops.OpOp1.NROW, expr); - rowUpperHops = new UnaryOp(target.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NROW, hops.get(target.getName())); - rowUpperHops.setAllPositions(target.getBeginLine(), target.getBeginColumn(), target.getEndLine(), target.getEndColumn()); - } catch (HopsException e) { - LOG.error(target.printErrorLocation() + "error processing row upper index for indexed expression " + target.toString()); - throw new RuntimeException(target.printErrorLocation() + "error processing row upper index for indexed expression " + target.toString()); - } + else { + rowUpperHops = new UnaryOp(target.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NROW, hops.get(target.getName())); + rowUpperHops.setAllPositions(target.getBeginLine(), target.getBeginColumn(), target.getEndLine(), target.getEndColumn()); } } if (target.getColLowerBound() != null) @@ -1520,20 +1513,9 @@ public class DMLTranslator if ( target.getDim2() != -1 ) colUpperHops = new LiteralOp(target.getOrigDim2()); else - { - try { - colUpperHops = new UnaryOp(target.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NCOL, hops.get(target.getName())); - } catch (HopsException e) { - LOG.error(target.printErrorLocation() + " error processing column upper index for indexed expression " + target.toString()); - throw new RuntimeException(target.printErrorLocation() + " error processing column upper index for indexed expression " + target.toString(), e); - } - } + colUpperHops = new UnaryOp(target.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NCOL, hops.get(target.getName())); } - //if (target == null) { - // target = createTarget(source); - //} - // process the source expression to get source Hops Hop sourceOp = processExpression(source, target, hops); @@ -1579,16 +1561,9 @@ public class DMLTranslator { if ( source.getOrigDim1() != -1 ) rowUpperHops = new LiteralOp(source.getOrigDim1()); - else - { - try { - //currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hops.OpOp1.NROW, expr); - rowUpperHops = new UnaryOp(source.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NROW, hops.get(source.getName())); - rowUpperHops.setAllPositions(source.getBeginLine(),source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); - } catch (HopsException e) { - LOG.error(source.printErrorLocation() + "error processing row upper index for indexed identifier " + source.toString()); - throw new RuntimeException(source.printErrorLocation() + "error processing row upper index for indexed identifier " + source.toString() + e); - } + else { + rowUpperHops = new UnaryOp(source.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NROW, hops.get(source.getName())); + rowUpperHops.setAllPositions(source.getBeginLine(),source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); } } if (source.getColLowerBound() != null) @@ -1603,14 +1578,7 @@ public class DMLTranslator if ( source.getOrigDim2() != -1 ) colUpperHops = new LiteralOp(source.getOrigDim2()); else - { - try { - colUpperHops = new UnaryOp(source.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NCOL, hops.get(source.getName())); - } catch (HopsException e) { - LOG.error(source.printErrorLocation() + "error processing column upper index for indexed indentifier " + source.toString(), e); - throw new RuntimeException(source.printErrorLocation() + "error processing column upper index for indexed indentifier " + source.toString(), e); - } - } + colUpperHops = new UnaryOp(source.getName(), DataType.SCALAR, ValueType.INT, Hop.OpOp1.NCOL, hops.get(source.getName())); } if (target == null) { @@ -1759,15 +1727,11 @@ public class DMLTranslator target.setValueType(ValueType.BOOLEAN); if (source.getRight() == null) { - Hop currUop = null; - try { - currUop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NOT, left); - currUop.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); - } catch (HopsException e) { - throw new ParseException(e.getMessage()); - } + Hop currUop = new UnaryOp(target.getName(), target.getDataType(), target.getValueType(), Hop.OpOp1.NOT, left); + currUop.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); return currUop; - } else { + } + else { Hop currBop = null; OpOp2 op = null; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/31cb2531/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java index e5b4a83..6e3549e 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/AlgorithmLinregCG.java @@ -43,8 +43,8 @@ public class AlgorithmLinregCG extends AutomatedTestBase private final static double eps = 1e-5; - private final static int rows = 1468; - private final static int cols = 1007; + private final static int rows = 2468; + private final static int cols = 507; private final static double sparsity1 = 0.7; //dense private final static double sparsity2 = 0.1; //sparse http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/31cb2531/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java index 4a6259a..08edd03 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java @@ -40,7 +40,8 @@ public class CellwiseTmplTest extends AutomatedTestBase private static final String TEST_NAME3 = "cellwisetmpl3"; private static final String TEST_NAME4 = "cellwisetmpl4"; private static final String TEST_NAME5 = "cellwisetmpl5"; - private static final String TEST_NAME6 = "cellwisetmpl6"; //sum + private static final String TEST_NAME6 = "cellwisetmpl6"; + private static final String TEST_NAME7 = "cellwisetmpl7"; private static final String TEST_DIR = "functions/codegen/"; private static final String TEST_CLASS_DIR = TEST_DIR + CellwiseTmplTest.class.getSimpleName() + "/"; @@ -58,6 +59,7 @@ public class CellwiseTmplTest extends AutomatedTestBase addTestConfiguration( TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "4" }) ); addTestConfiguration( TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "5" }) ); addTestConfiguration( TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] { "6" }) ); + addTestConfiguration( TEST_NAME7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] { "7" }) ); } @Test @@ -90,6 +92,11 @@ public class CellwiseTmplTest extends AutomatedTestBase public void testCodegenCellwiseRewrite6() { testCodegenIntegration( TEST_NAME6, true, ExecType.CP ); } + + @Test + public void testCodegenCellwiseRewrite7() { + testCodegenIntegration( TEST_NAME7, true, ExecType.CP ); + } @Test public void testCodegenCellwise1() { @@ -121,12 +128,22 @@ public class CellwiseTmplTest extends AutomatedTestBase public void testCodegenCellwise6() { testCodegenIntegration( TEST_NAME6, false, ExecType.CP ); } + + @Test + public void testCodegenCellwise7() { + testCodegenIntegration( TEST_NAME7, false, ExecType.CP ); + } @Test public void testCodegenCellwiseRewrite1_sp() { testCodegenIntegration( TEST_NAME1, true, ExecType.SPARK ); } + @Test + public void testCodegenCellwiseRewrite7_sp() { + testCodegenIntegration( TEST_NAME7, true, ExecType.SPARK ); + } + private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType ) { @@ -158,23 +175,24 @@ public class CellwiseTmplTest extends AutomatedTestBase runTest(true, false, null, -1); runRScript(true); - if(testname.equals(TEST_NAME6)) //tak+ - { + if(testname.equals(TEST_NAME6) || testname.equals(TEST_NAME7) ) { //compare scalars HashMap<CellIndex, Double> dmlfile = readDMLScalarFromHDFS("S"); HashMap<CellIndex, Double> rfile = readRScalarFromFS("S"); TestUtils.compareScalars((Double) dmlfile.values().toArray()[0], (Double) rfile.values().toArray()[0],0); } - else - { + else { //compare matrices HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("S"); HashMap<CellIndex, Double> rfile = readRMatrixFromFS("S"); TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); - if( !(rewrites && testname.equals(TEST_NAME2)) ) //sigmoid - Assert.assertTrue(heavyHittersContainsSubString("spoofCell") - || heavyHittersContainsSubString("sp_spoofCell")); } + + if( !(rewrites && testname.equals(TEST_NAME2)) ) //sigmoid + Assert.assertTrue(heavyHittersContainsSubString("spoofCell") + || heavyHittersContainsSubString("sp_spoofCell")); + if( testname.equals(TEST_NAME7) ) //ensure matrix mult is fused + Assert.assertTrue(!heavyHittersContainsSubString("tsmm")); } finally { OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewrites; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/31cb2531/src/test/java/org/apache/sysml/test/utils/TestUtils.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/utils/TestUtils.java b/src/test/java/org/apache/sysml/test/utils/TestUtils.java index 4cb7df0..980095a 100644 --- a/src/test/java/org/apache/sysml/test/utils/TestUtils.java +++ b/src/test/java/org/apache/sysml/test/utils/TestUtils.java @@ -423,7 +423,7 @@ public class TestUtils matrixType = 2; if ( matrixType == -1 ) - throw new RuntimeException("unknown matrix type while reading R matrix: ." + line); + throw new RuntimeException("unknown matrix type while reading R matrix: " + line); line = reader.readLine(); // header line with dimension and nnz information http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/31cb2531/src/test/scripts/functions/codegen/cellwisetmpl6.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/cellwisetmpl6.R b/src/test/scripts/functions/codegen/cellwisetmpl6.R index 669e76f..1990a09 100644 --- a/src/test/scripts/functions/codegen/cellwisetmpl6.R +++ b/src/test/scripts/functions/codegen/cellwisetmpl6.R @@ -22,12 +22,11 @@ args<-commandArgs(TRUE) options(digits=22) library("Matrix") -#X= matrix( seq(1,25), 5, 5, byrow = TRUE) + X = matrix( c(1,2,3), nrow=3, ncol=1, byrow = TRUE) -y=matrix( c(1,1,1), nrow=3, ncol=1, byrow = TRUE) -z=matrix( c(3,3,3), nrow=3, ncol=1, byrow = TRUE) -#S= X*as.matrix(X>0) -#S=7 + (1 / exp(X) ) -S=sum(X*y*z) +y = matrix( c(1,1,1), nrow=3, ncol=1, byrow = TRUE) +z = matrix( c(3,3,3), nrow=3, ncol=1, byrow = TRUE) + +S = sum(X*y*z) print(S) write(S,paste(args[2],"S",sep="")) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/31cb2531/src/test/scripts/functions/codegen/cellwisetmpl6.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/cellwisetmpl6.dml b/src/test/scripts/functions/codegen/cellwisetmpl6.dml index 7ff5124..db64459 100644 --- a/src/test/scripts/functions/codegen/cellwisetmpl6.dml +++ b/src/test/scripts/functions/codegen/cellwisetmpl6.dml @@ -19,36 +19,10 @@ # #------------------------------------------------------------- -#------------------------------------------------------------- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -#------------------------------------------------------------- - -X= matrix( "1 2 3", rows=3, cols=1) -y=matrix( "1 1 1", rows=3, cols=1) -z=matrix( "3 3 3", rows=3, cols=1) - +X = matrix( "1 2 3", rows=3, cols=1) +y = matrix( "1 1 1", rows=3, cols=1) +z = matrix( "3 3 3", rows=3, cols=1) S = sum(X*y*z) print(S) write(S,$1) -#S=10 + floor(round(abs((X+w)+z))) -#G = abs(exp(X)) - -#print(sum(G)) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/31cb2531/src/test/scripts/functions/codegen/cellwisetmpl7.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/cellwisetmpl7.R b/src/test/scripts/functions/codegen/cellwisetmpl7.R new file mode 100644 index 0000000..eb75d4c --- /dev/null +++ b/src/test/scripts/functions/codegen/cellwisetmpl7.R @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = seq(7, 1007); +Y = seq(6, 1006); + +Z = X + Y - 7 + abs(X); +R = t(Z) %*% Z + +print(R) +write(R, paste(args[2],"S",sep="")) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/31cb2531/src/test/scripts/functions/codegen/cellwisetmpl7.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/cellwisetmpl7.dml b/src/test/scripts/functions/codegen/cellwisetmpl7.dml new file mode 100644 index 0000000..6b2afd3 --- /dev/null +++ b/src/test/scripts/functions/codegen/cellwisetmpl7.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = seq(7, 1007); +Y = seq(6, 1006); + +Z = X + Y - 7 + abs(X); +R = as.scalar(t(Z) %*% Z) +if(1==1){} + +print(R) +write(R, $1)