[SYSTEMML-282] UpdateInPlace parfor intermediate matrix objects Closes #27.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/f8f423c3 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/f8f423c3 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/f8f423c3 Branch: refs/heads/master Commit: f8f423c3b615e7040af4d4f34c009b74ed4b49f7 Parents: 7f8716b Author: Arvind Surve <[email protected]> Authored: Mon Feb 15 10:22:37 2016 -0800 Committer: Deron Eriksson <[email protected]> Committed: Mon Feb 15 10:22:37 2016 -0800 ---------------------------------------------------------------------- src/main/java/org/apache/sysml/hops/DataOp.java | 11 +- src/main/java/org/apache/sysml/hops/Hop.java | 14 +- .../RewriteSplitDagDataDependentOperators.java | 11 +- .../rewrite/RewriteSplitDagUnknownCSVRead.java | 5 +- src/main/java/org/apache/sysml/lops/Data.java | 4 +- .../org/apache/sysml/lops/OutputParameters.java | 16 + .../java/org/apache/sysml/lops/compile/Dag.java | 19 +- .../org/apache/sysml/parser/DMLTranslator.java | 10 +- .../controlprogram/parfor/opt/OptNode.java | 10 + .../parfor/opt/OptimizerRuleBased.java | 786 ++++++++++++++++++- .../instructions/cp/VariableCPInstruction.java | 72 +- .../java/org/apache/sysml/utils/Explain.java | 9 +- .../java/org/apache/sysml/utils/Statistics.java | 12 + .../test/integration/AutomatedTestBase.java | 1 - .../updateinplace/UpdateInPlaceTest.java | 278 +++++++ .../functions/updateinplace/updateinplace1.dml | 34 + .../functions/updateinplace/updateinplace10.dml | 36 + .../functions/updateinplace/updateinplace11.dml | 41 + .../functions/updateinplace/updateinplace12.dml | 44 ++ .../functions/updateinplace/updateinplace13.dml | 45 ++ .../functions/updateinplace/updateinplace14.dml | 45 ++ .../functions/updateinplace/updateinplace15.dml | 39 + .../functions/updateinplace/updateinplace2.dml | 36 + .../functions/updateinplace/updateinplace3.dml | 38 + .../functions/updateinplace/updateinplace4.dml | 37 + .../functions/updateinplace/updateinplace5.dml | 34 + .../functions/updateinplace/updateinplace6.dml | 40 + .../functions/updateinplace/updateinplace7.dml | 39 + .../functions/updateinplace/updateinplace8.dml | 39 + .../functions/updateinplace/updateinplace9.dml | 38 + 30 files changed, 1792 insertions(+), 51 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/main/java/org/apache/sysml/hops/DataOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/DataOp.java b/src/main/java/org/apache/sysml/hops/DataOp.java index 0008a34..925d13a 100644 --- a/src/main/java/org/apache/sysml/hops/DataOp.java +++ b/src/main/java/org/apache/sysml/hops/DataOp.java @@ -79,6 +79,12 @@ public class DataOp extends Hop setInputFormatType(FileFormatTypes.BINARY); } + public DataOp(String l, DataType dt, ValueType vt, DataOpTypes dop, + String fname, long dim1, long dim2, long nnz, boolean updateInPlace, long rowsPerBlock, long colsPerBlock) { + this(l, dt, vt, dop, fname, dim1, dim2, nnz, rowsPerBlock, colsPerBlock); + setUpdateInPlace(updateInPlace); + } + /** * READ operation for Matrix * This constructor supports expressions in parameters @@ -178,10 +184,11 @@ public class DataOp extends Hop _dataop = type; } - public void setOutputParams(long dim1, long dim2, long nnz, long rowsPerBlock, long colsPerBlock) { + public void setOutputParams(long dim1, long dim2, long nnz, boolean updateInPlace, long rowsPerBlock, long colsPerBlock) { setDim1(dim1); setDim2(dim2); setNnz(nnz); + setUpdateInPlace(updateInPlace); setRowsInBlock(rowsPerBlock); setColsInBlock(colsPerBlock); } @@ -229,7 +236,7 @@ public class DataOp extends Hop case PERSISTENTREAD: l = new Data(HopsData2Lops.get(_dataop), null, inputLops, getName(), null, getDataType(), getValueType(), false, getInputFormatType()); - l.getOutputParameters().setDimensions(getDim1(), getDim2(), _inRowsInBlock, _inColsInBlock, getNnz()); + l.getOutputParameters().setDimensions(getDim1(), getDim2(), _inRowsInBlock, _inColsInBlock, getNnz(), getUpdateInPlace()); break; case PERSISTENTWRITE: http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index 7ddc995..a468ed7 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -83,6 +83,7 @@ public abstract class Hop protected long _rows_in_block = -1; protected long _cols_in_block = -1; protected long _nnz = -1; + protected boolean _updateInPlace = false; protected ArrayList<Hop> _parent = new ArrayList<Hop>(); protected ArrayList<Hop> _input = new ArrayList<Hop>(); @@ -841,6 +842,14 @@ public abstract class Hop return _nnz; } + public void setUpdateInPlace(boolean updateInPlace){ + _updateInPlace = updateInPlace; + } + + public boolean getUpdateInPlace(){ + return _updateInPlace; + } + public abstract Lop constructLops() throws HopsException, LopsException; @@ -954,7 +963,7 @@ public abstract class Hop s.append(h.getHopID() + "; "); } - s.append("\n dims [" + _dim1 + "," + _dim2 + "] blk [" + _rows_in_block + "," + _cols_in_block + "] nnz " + _nnz); + s.append("\n dims [" + _dim1 + "," + _dim2 + "] blk [" + _rows_in_block + "," + _cols_in_block + "] nnz: " + _nnz + " UpdateInPlace: " + _updateInPlace); s.append(" MemEstimate = Out " + (_outputMemEstimate/1024/1024) + " MB, In&Out " + (_memEstimate/1024/1024) + " MB" ); LOG.debug(s.toString()); } @@ -980,7 +989,7 @@ public abstract class Hop throws HopsException { lop.getOutputParameters().setDimensions( - getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz()); + getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz(), getUpdateInPlace()); } public Lop getLops() { @@ -1820,6 +1829,7 @@ public abstract class Hop _rows_in_block = that._rows_in_block; _cols_in_block = that._cols_in_block; _nnz = that._nnz; + _updateInPlace = that._updateInPlace; //no copy of lops (regenerated) _parent = new ArrayList<Hop>(); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java index 000f45a..1785a51 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java @@ -111,6 +111,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite long rlen = c.getDim1(); long clen = c.getDim2(); long nnz = c.getNnz(); + boolean updateInPlace = c.getUpdateInPlace(); long brlen = c.getRowsInBlock(); long bclen = c.getColsInBlock(); @@ -121,7 +122,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite //create new transient read DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(), - DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, brlen, bclen); + DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, updateInPlace, brlen, bclen); tread.setVisited(VisitStatus.DONE); HopRewriteUtils.copyLineNumbers(c, tread); @@ -151,7 +152,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite //create new transient read DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(), - DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, brlen, bclen); + DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, updateInPlace, brlen, bclen); tread.setVisited(VisitStatus.DONE); HopRewriteUtils.copyLineNumbers(c, tread); @@ -172,7 +173,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite DataOp twrite = new DataOp(varname, c.getDataType(), c.getValueType(), c, DataOpTypes.TRANSIENTWRITE, null); twrite.setVisited(VisitStatus.DONE); - twrite.setOutputParams(rlen, clen, nnz, brlen, bclen); + twrite.setOutputParams(rlen, clen, nnz, updateInPlace, brlen, bclen); HopRewriteUtils.copyLineNumbers(c, twrite); sb1hops.add(twrite); } @@ -385,13 +386,13 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite Hop c = p.getValue(); DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(), DataOpTypes.TRANSIENTREAD, - null, c.getDim1(), c.getDim2(), c.getNnz(), c.getRowsInBlock(), c.getColsInBlock()); + null, c.getDim1(), c.getDim2(), c.getNnz(), c.getUpdateInPlace(), c.getRowsInBlock(), c.getColsInBlock()); tread.setVisited(VisitStatus.DONE); HopRewriteUtils.copyLineNumbers(c, tread); DataOp twrite = new DataOp(varname, c.getDataType(), c.getValueType(), c, DataOpTypes.TRANSIENTWRITE, null); twrite.setVisited(VisitStatus.DONE); - twrite.setOutputParams(c.getDim1(), c.getDim2(), c.getNnz(), c.getRowsInBlock(), c.getColsInBlock()); + twrite.setOutputParams(c.getDim1(), c.getDim2(), c.getNnz(), c.getUpdateInPlace(), c.getRowsInBlock(), c.getColsInBlock()); HopRewriteUtils.copyLineNumbers(c, twrite); //create additional cut by rewriting both hop dags http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java index 7a2db5a..78d0de7 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java @@ -72,12 +72,13 @@ public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule long rlen = reblock.getDim1(); long clen = reblock.getDim2(); long nnz = reblock.getNnz(); + boolean updateInPlace = c.getUpdateInPlace(); long brlen = reblock.getRowsInBlock(); long bclen = reblock.getColsInBlock(); //create new transient read DataOp tread = new DataOp(reblock.getName(), reblock.getDataType(), reblock.getValueType(), - DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, brlen, bclen); + DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, updateInPlace, brlen, bclen); HopRewriteUtils.copyLineNumbers(reblock, tread); //replace reblock with transient read @@ -93,7 +94,7 @@ public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule //add reblock sub dag to first statement block DataOp twrite = new DataOp(reblock.getName(), reblock.getDataType(), reblock.getValueType(), reblock, DataOpTypes.TRANSIENTWRITE, null); - twrite.setOutputParams(rlen, clen, nnz, brlen, bclen); + twrite.setOutputParams(rlen, clen, nnz, updateInPlace, brlen, bclen); HopRewriteUtils.copyLineNumbers(reblock, twrite); sb1hops.add(twrite); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/main/java/org/apache/sysml/lops/Data.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/Data.java b/src/main/java/org/apache/sysml/lops/Data.java index 6a24191..99949c3 100644 --- a/src/main/java/org/apache/sysml/lops/Data.java +++ b/src/main/java/org/apache/sysml/lops/Data.java @@ -221,7 +221,7 @@ public class Data extends Lop return getID() + ":" + "File_Name: " + this.getOutputParameters().getFile_name() + " " + "Label: " + this.getOutputParameters().getLabel() + " " + "Operation: = " + operation + " " + "Format: " + this.outParams.getFormat() + " Datatype: " + getDataType() + " Valuetype: " + getValueType() + " num_rows = " + this.getOutputParameters().getNumRows() + " num_cols = " + - this.getOutputParameters().getNumCols(); + this.getOutputParameters().getNumCols() + " UpdateInPlace: " + this.getOutputParameters().getUpdateInPlace(); } /** @@ -550,6 +550,8 @@ public class Data extends Lop sb.append( oparams.getColsInBlock() ); sb.append( OPERAND_DELIMITOR ); sb.append( oparams.getNnz() ); + sb.append( OPERAND_DELIMITOR ); + sb.append( oparams.getUpdateInPlace() ); /* Format-specific properties */ if ( oparams.getFormat() == Format.CSV ) { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/main/java/org/apache/sysml/lops/OutputParameters.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/OutputParameters.java b/src/main/java/org/apache/sysml/lops/OutputParameters.java index 854e25b..974a603 100644 --- a/src/main/java/org/apache/sysml/lops/OutputParameters.java +++ b/src/main/java/org/apache/sysml/lops/OutputParameters.java @@ -37,6 +37,7 @@ public class OutputParameters private long _num_rows = -1; private long _num_cols = -1; private long _nnz = -1; + private boolean _updateInPlace = false; private long _num_rows_in_block = -1; private long _num_cols_in_block = -1; private String _file_name = null; @@ -81,6 +82,11 @@ public class OutputParameters } } + public void setDimensions(long rows, long cols, long rows_per_block, long cols_per_block, long nnz, boolean updateInPlace) throws HopsException { + _updateInPlace = updateInPlace; + setDimensions(rows, cols, rows_per_block, cols_per_block, nnz); + } + public Format getFormat() { return matrix_format; } @@ -126,6 +132,15 @@ public class OutputParameters { _nnz = nnz; } + + public boolean getUpdateInPlace() { + return _updateInPlace; + } + + public void setUpdateInPlace(boolean updateInPlace) + { + _updateInPlace = updateInPlace; + } public long getRowsInBlock() { return _num_rows_in_block; @@ -149,6 +164,7 @@ public class OutputParameters sb.append("rows=" + getNumRows() + Lop.VALUETYPE_PREFIX); sb.append("cols=" + getNumCols() + Lop.VALUETYPE_PREFIX); sb.append("nnz=" + getNnz() + Lop.VALUETYPE_PREFIX); + sb.append("updateInPlace=" + getUpdateInPlace() + Lop.VALUETYPE_PREFIX); sb.append("rowsInBlock=" + getRowsInBlock() + Lop.VALUETYPE_PREFIX); sb.append("colsInBlock=" + getColsInBlock() + Lop.VALUETYPE_PREFIX); sb.append("isBlockedRepresentation=" + isBlocked() + Lop.VALUETYPE_PREFIX); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/main/java/org/apache/sysml/lops/compile/Dag.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/compile/Dag.java b/src/main/java/org/apache/sysml/lops/compile/Dag.java index f5693ef..6296d92 100644 --- a/src/main/java/org/apache/sysml/lops/compile/Dag.java +++ b/src/main/java/org/apache/sysml/lops/compile/Dag.java @@ -2281,7 +2281,7 @@ public class Dag<N extends Lop> // TODO: ideally, this should be done by having a member variable in Lop // which stores the outputInfo. try { - oparams.setDimensions(oparams.getNumRows(), oparams.getNumCols(), -1, -1, oparams.getNnz()); + oparams.setDimensions(oparams.getNumRows(), oparams.getNumCols(), -1, -1, oparams.getNnz(), oparams.getUpdateInPlace()); } catch(HopsException e) { throw new LopsException(node.printErrorLocation() + "error in getOutputInfo in Dag ", e); } @@ -2399,7 +2399,7 @@ public class Dag<N extends Lop> oparams.getFile_name(), true, OutputInfo.outputInfoToString(OutputInfo.CSVOutputInfo), - new MatrixCharacteristics(oparams.getNumRows(), oparams.getNumCols(), -1, -1, oparams.getNnz()), + new MatrixCharacteristics(oparams.getNumRows(), oparams.getNumCols(), -1, -1, oparams.getNnz()), oparams.getUpdateInPlace(), false, delimLop.getStringValue(), true ); @@ -2439,7 +2439,8 @@ public class Dag<N extends Lop> oparams.getFile_name(), true, OutputInfo.outputInfoToString(getOutputInfo(node, false)), - new MatrixCharacteristics(oparams.getNumRows(), oparams.getNumCols(), rpb, cpb, oparams.getNnz()) + new MatrixCharacteristics(oparams.getNumRows(), oparams.getNumCols(), rpb, cpb, oparams.getNnz()), + oparams.getUpdateInPlace() ); createvarInst.setLocation(node); @@ -2470,7 +2471,8 @@ public class Dag<N extends Lop> getFilePath() + fnOutParams.getLabel(), true, OutputInfo.outputInfoToString(getOutputInfo((N)fnOut, false)), - new MatrixCharacteristics(fnOutParams.getNumRows(), fnOutParams.getNumCols(), (int)fnOutParams.getRowsInBlock(), (int)fnOutParams.getColsInBlock(), fnOutParams.getNnz()) + new MatrixCharacteristics(fnOutParams.getNumRows(), fnOutParams.getNumCols(), (int)fnOutParams.getRowsInBlock(), (int)fnOutParams.getColsInBlock(), fnOutParams.getNnz()), + oparams.getUpdateInPlace() ); if (node._beginLine != 0) @@ -2585,7 +2587,8 @@ public class Dag<N extends Lop> tempFileName, true, OutputInfo.outputInfoToString(out.getOutInfo()), - new MatrixCharacteristics(oparams.getNumRows(), oparams.getNumCols(), rpb, cpb, oparams.getNnz()) + new MatrixCharacteristics(oparams.getNumRows(), oparams.getNumCols(), rpb, cpb, oparams.getNnz()), + oparams.getUpdateInPlace() ); createvarInst.setLocation(node); @@ -2694,7 +2697,8 @@ public class Dag<N extends Lop> tempFileName, false, OutputInfo.outputInfoToString(getOutputInfo(node, false)), - new MatrixCharacteristics(oparams.getNumRows(), oparams.getNumCols(), rpb, cpb, oparams.getNnz()) + new MatrixCharacteristics(oparams.getNumRows(), oparams.getNumCols(), rpb, cpb, oparams.getNnz()), + oparams.getUpdateInPlace() ); //NOTE: no instruction patching because final write from cp instruction @@ -2721,7 +2725,8 @@ public class Dag<N extends Lop> fnameStr, false, OutputInfo.outputInfoToString(getOutputInfo(node, false)), - new MatrixCharacteristics(oparams.getNumRows(), oparams.getNumCols(), rpb, cpb, oparams.getNnz()) + new MatrixCharacteristics(oparams.getNumRows(), oparams.getNumCols(), rpb, cpb, oparams.getNnz()), + oparams.getUpdateInPlace() ); // remove the variable CPInstruction currInstr = CPInstructionParser.parseSingleInstruction( http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index 017c246..543acf0 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -1009,7 +1009,7 @@ public class DMLTranslator ae.setInputFormatType(Expression.convertFormatType(formatName)); if (ae.getDataType() == DataType.SCALAR ) { - ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), -1, -1); + ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateInPlace(), -1, -1); } else { switch(ae.getInputFormatType()) { @@ -1017,12 +1017,12 @@ public class DMLTranslator case MM: case CSV: // write output in textcell format - ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), -1, -1); + ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateInPlace(), -1, -1); break; case BINARY: // write output in binary block format - ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), DMLTranslator.DMLBlockSize, DMLTranslator.DMLBlockSize); + ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateInPlace(), DMLTranslator.DMLBlockSize, DMLTranslator.DMLBlockSize); break; default: @@ -1075,7 +1075,7 @@ public class DMLTranslator Integer statementId = liveOutToTemp.get(target.getName()); if ((statementId != null) && (statementId.intValue() == i)) { DataOp transientwrite = new DataOp(target.getName(), target.getDataType(), target.getValueType(), ae, DataOpTypes.TRANSIENTWRITE, null); - transientwrite.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getRowsInBlock(), ae.getColsInBlock()); + transientwrite.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateInPlace(), ae.getRowsInBlock(), ae.getColsInBlock()); transientwrite.setAllPositions(target.getBeginLine(), target.getBeginColumn(), target.getEndLine(), target.getEndLine()); updatedLiveOut.addVariable(target.getName(), target); output.add(transientwrite); @@ -1107,7 +1107,7 @@ public class DMLTranslator Integer statementId = liveOutToTemp.get(target.getName()); if ((statementId != null) && (statementId.intValue() == i)) { DataOp transientwrite = new DataOp(target.getName(), target.getDataType(), target.getValueType(), ae, DataOpTypes.TRANSIENTWRITE, null); - transientwrite.setOutputParams(origDim1, origDim2, ae.getNnz(), ae.getRowsInBlock(), ae.getColsInBlock()); + transientwrite.setOutputParams(origDim1, origDim2, ae.getNnz(), ae.getUpdateInPlace(), ae.getRowsInBlock(), ae.getColsInBlock()); transientwrite.setAllPositions(target.getBeginLine(), target.getBeginColumn(), target.getEndLine(), target.getEndColumn()); updatedLiveOut.addVariable(target.getName(), target); output.add(transientwrite); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptNode.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptNode.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptNode.java index bb774a8..20c08c6 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptNode.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptNode.java @@ -173,11 +173,21 @@ public class OptNode return ret; } + public int getBeginLine() + { + return _beginLine; + } + public void setBeginLine( int line ) { _beginLine = line; } + public int getEndLine() + { + return _endLine; + } + public void setEndLine( int line ) { _endLine = line; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java index 4eee876..9754756 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java @@ -24,6 +24,9 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; import java.util.Map.Entry; import java.util.Set; @@ -36,6 +39,7 @@ import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.FunctionOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.AggBinaryOp.MMultMethod; +import org.apache.sysml.hops.Hop.DataOpTypes; import org.apache.sysml.hops.Hop.MultiThreadedHop; import org.apache.sysml.hops.Hop.ParamBuiltinOp; import org.apache.sysml.hops.Hop.ReOrgOp; @@ -54,16 +58,19 @@ import org.apache.sysml.hops.recompile.Recompiler; import org.apache.sysml.lops.LopProperties; import org.apache.sysml.lops.LopsException; import org.apache.sysml.parser.DMLProgram; +import org.apache.sysml.parser.Expression; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.FunctionStatementBlock; import org.apache.sysml.parser.LanguageException; import org.apache.sysml.parser.ParForStatement; import org.apache.sysml.parser.ParForStatementBlock; import org.apache.sysml.parser.StatementBlock; +import org.apache.sysml.parser.VariableSet; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.DMLUnsupportedOperationException; import org.apache.sysml.runtime.controlprogram.ForProgramBlock; import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; +import org.apache.sysml.runtime.controlprogram.IfProgramBlock; import org.apache.sysml.runtime.controlprogram.LocalVariableMap; import org.apache.sysml.runtime.controlprogram.ParForProgramBlock; import org.apache.sysml.runtime.controlprogram.Program; @@ -74,6 +81,7 @@ import org.apache.sysml.runtime.controlprogram.ParForProgramBlock.PExecMode; import org.apache.sysml.runtime.controlprogram.ParForProgramBlock.POptMode; import org.apache.sysml.runtime.controlprogram.ParForProgramBlock.PResultMerge; import org.apache.sysml.runtime.controlprogram.ParForProgramBlock.PTaskPartitioner; +import org.apache.sysml.runtime.controlprogram.WhileProgramBlock; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; @@ -147,6 +155,8 @@ public class OptimizerRuleBased extends Optimizer public static final boolean APPLY_REWRITE_NESTED_PARALLELISM = false; public static final String FUNCTION_UNFOLD_NAMEPREFIX = "__unfold_"; + public static final boolean APPLY_REWRITE_UPDATE_INPLACE_INTERMEDIATE = true; + public static final double PAR_K_FACTOR = OptimizationWrapper.PAR_FACTOR_INFRASTRUCTURE; public static final double PAR_K_MR_FACTOR = 1.0 * OptimizationWrapper.PAR_FACTOR_INFRASTRUCTURE; @@ -167,8 +177,11 @@ public class OptimizerRuleBased extends Optimizer protected CostEstimator _cost = null; - + protected static ThreadLocal<ArrayList<String>> listUIPRes = new ThreadLocal<ArrayList<String>>() { + @Override protected ArrayList<String> initialValue() { return new ArrayList<String>(); } + }; + @Override public CostModelType getCostModelType() { @@ -1872,7 +1885,8 @@ public class OptimizerRuleBased extends Optimizer //NOTE: currently this rule is too conservative (the result variable is assumed to be dense and //most importantly counted twice if this is part of the maximum operation) - double totalMem = Math.max((M+sum), rComputeSumMemoryIntermediates(pn, new HashSet<String>())); + HashMap <String, ArrayList <UIPCandidateHop>> uipCandHopHM = new HashMap <String, ArrayList<UIPCandidateHop>>(); + double totalMem = Math.max((M+sum), rComputeSumMemoryIntermediates(pn, new HashSet<String>(), uipCandHopHM)); //optimization decision if( rHasOnlyInPlaceSafeLeftIndexing(pn, retVars) ) //basic correctness constraint @@ -1893,6 +1907,9 @@ public class OptimizerRuleBased extends Optimizer } } + if(APPLY_REWRITE_UPDATE_INPLACE_INTERMEDIATE && LOG.isDebugEnabled()) + listUIPRes.remove(); + //modify result variable meta data, if rewrite applied if( apply ) { @@ -1904,12 +1921,673 @@ public class OptimizerRuleBased extends Optimizer ((MatrixObject)dat).enableUpdateInPlace(true); } inPlaceResultVars.addAll(retVars); + + if(APPLY_REWRITE_UPDATE_INPLACE_INTERMEDIATE) + { + isUpdateInPlaceApplicable(pn, uipCandHopHM); + + boolean bAnyUIPApplicable = false; + for(Entry<String, ArrayList <UIPCandidateHop>> entry: uipCandHopHM.entrySet()) + { + ArrayList <UIPCandidateHop> uipCandHopList = entry.getValue(); + + if (uipCandHopList != null) { + for (UIPCandidateHop uipCandHop: uipCandHopList) + if(uipCandHop.isLoopApplicable() && uipCandHop.isUpdateInPlace()) + { + uipCandHop.getHop().setUpdateInPlace(true); + bAnyUIPApplicable = true; + + if(LOG.isDebugEnabled()) + listUIPRes.get().add(uipCandHop.getHop().getName()); + } + } + } + if(bAnyUIPApplicable) + try { + //Recompile this block if there is any update in place applicable. + Recompiler.recompileProgramBlockInstructions(pfpb); //TODO: Recompile @ very high level ok? + } + catch(Exception ex){ + throw new DMLRuntimeException(ex); + } + } } - + + if(APPLY_REWRITE_UPDATE_INPLACE_INTERMEDIATE && LOG.isTraceEnabled()) + { + LOG.trace("UpdateInPlace = " + apply + " for lines between " + pn.getBeginLine() + " and " + pn.getEndLine()); + for(Entry<String, ArrayList <UIPCandidateHop>> entry: uipCandHopHM.entrySet()) + { + ArrayList <UIPCandidateHop> uipCandHopList = entry.getValue(); + + if (uipCandHopList != null) { + for (UIPCandidateHop uipCandHop: uipCandHopList) + { + LOG.trace("Matrix Object: Name: " + uipCandHop.getHop().getName() + "<" + uipCandHop.getHop().getBeginLine() + "," + uipCandHop.getHop().getEndLine()+ ">, InLoop:" + + uipCandHop.isLoopApplicable() + ", UIPApplicable:" + uipCandHop.isUpdateInPlace() + ", HopUIPApplicable:" + uipCandHop.getHop().getUpdateInPlace()); + } + } + } + } + LOG.debug(getOptMode()+" OPT: rewrite 'set in-place result indexing' - result="+ apply+" ("+ProgramConverter.serializeStringCollection(inPlaceResultVars)+", M="+toMB(totalMem)+")" ); } + /* + * Algorithm: isUpdateInPlaceApplicable() + * + * Purpose of this algorithm to identify intermediate hops containing matrix objects to be marked as "UpdateInPlace" from ParforProgramBlock. + * First, list of candidates are identified. Then list is pruned based on conditions descibed below. + * + * A.Identification of candidates: + * 1. Candidate's identity defined with name, beginline, endline, and hop. + * 2. Operation of type LeftIndexingOp + * 3. Number of consumers for Hop's first input should be one. + * 4. Matrix Object on which leftindexing operation done has defined outside "Loop" A. + * 5. LeftIndexingOp operation is within a "Loop" A. + * + * Notes: 1. Loop is of type while, for, or parfor with parallelism of one. + * 2. Some conidtions ignored at this point listed below + * 2.1 Unsure of general instructions. It will be hard to identify and iterate. + * 2.2 LeftIndexing outside "loop" + * + * + * B.Pruning the list: + * Candidates are pruned based on any condition met from conditions from 1 to 3 below. + * 0. Identify of candidate is defined with name, begineline, endline, and hop. + * 1. Based on the scope of candidate. If Variable (name) is defined in liveout of loop's statementblock. + * 2. Based on operation type and order + * 2.1 If hop's input variable name is same as candidate's name + * 2.2 Location of hop is before candidate. + * 2.3 Hop's operator type is any of following + * 2.3.1 DataOp (with operation type TransientWrite or TransientRead) + * 2.3.2 ReorgOp (with operation type Reshape or Transpose) + * 2.3.3 FunctionOp + * 3. Location of consumer being affected. + * 3.1 Consumer defined before leftindexing through operation process defined in 2.3 above. + * 3.2 Consumer is being utilized after leftindexing on candidate. + * + * Notes: + * 1. No interleave operations. + * 2. Function with actual operation to be scanned for candiate exclusion list. + * 3. Operattion that does not include updated data through updateinplace. + * + * + * @param pn: OpNode of parfor loop + * @param uipCandHopHM: Hashmap of UIPCandidateHop with name as a key. + * @throws DMLRuntimeException + */ + private void isUpdateInPlaceApplicable(OptNode pn, HashMap <String, ArrayList <UIPCandidateHop>> uipCandHopHM) + throws DMLRuntimeException + { + rIsInLoop(pn, uipCandHopHM, false); + + // Prune candidate list based on non-existance of candidate in the loop + Iterator<Map.Entry<String, ArrayList <UIPCandidateHop>>> uipCandHopHMIter = uipCandHopHM.entrySet().iterator(); + while(uipCandHopHMIter.hasNext()) + { + Map.Entry<String, ArrayList <UIPCandidateHop>> uipCandHopHMentry = uipCandHopHMIter.next(); + ArrayList <UIPCandidateHop> uipCandHopList = uipCandHopHMentry.getValue(); + + if (uipCandHopList != null) { + for (Iterator<UIPCandidateHop> uipCandHopListIter = uipCandHopList.iterator(); uipCandHopListIter.hasNext();) + { + UIPCandidateHop uipCandHop = uipCandHopListIter.next(); + if (!uipCandHop.isLoopApplicable()) //If Loop is not applicable then remove it from the list. + { + uipCandHopListIter.remove(); + if(LOG.isTraceEnabled()) + LOG.trace("Matrix Object: Name: " + uipCandHop.getHop().getName() + "<" + uipCandHop.getHop().getBeginLine() + "," + uipCandHop.getHop().getEndLine()+ + ">, removed from the candidate list as it does not have loop criteria applicable."); + } + } + if(uipCandHopList.isEmpty()) + uipCandHopHMIter.remove(); + } + } + + if(!uipCandHopHM.isEmpty()) + { + // Get consumer list + rResetVisitStatus(pn); + rGetUIPConsumerList(pn, uipCandHopHM); + + // Prune candidate list if consumer is in function call. + uipCandHopHMIter = uipCandHopHM.entrySet().iterator(); + while(uipCandHopHMIter.hasNext()) + { + Map.Entry<String, ArrayList <UIPCandidateHop>> uipCandHopHMentry = uipCandHopHMIter.next(); + ArrayList <UIPCandidateHop> uipCandHopList = uipCandHopHMentry.getValue(); + + if (uipCandHopList != null) { + for (Iterator<UIPCandidateHop> uipCandHopListIter = uipCandHopList.iterator(); uipCandHopListIter.hasNext();) + { + UIPCandidateHop uipCandHop = uipCandHopListIter.next(); + // if one of the consumer is FunctionOp then remove it. + ArrayList<Hop> consHops = uipCandHop.getConsumerHops(); + if(consHops != null) + for (Hop hop: consHops) + { + if(hop instanceof FunctionOp) + { + uipCandHopListIter.remove(); + if(LOG.isTraceEnabled()) + LOG.trace("Matrix Object: Name: " + uipCandHop.getHop().getName() + "<" + uipCandHop.getHop().getBeginLine() + "," + uipCandHop.getHop().getEndLine()+ + ">, removed from the candidate list as one of the consumer is FunctionOp."); + break; + } + } + } + if(uipCandHopList.isEmpty()) + uipCandHopHMIter.remove(); + } + } + + //Validate the consumer list + rResetVisitStatus(pn); + rValidateUIPConsumerList(pn, uipCandHopHM); + } + } + + + + /* + * This will check if candidate LeftIndexingOp are in loop (while, for or parfor). + * + * @param pn: OpNode of parfor loop + * @param uipCandHopHM: Hashmap of UIPCandidateHop with name as a key. + * @throws DMLRuntimeException + */ + private void rIsInLoop(OptNode pn, HashMap <String, ArrayList<UIPCandidateHop>> uipCandHopHM, boolean bInLoop) + throws DMLRuntimeException + { + if(!pn.isLeaf()) + { + ProgramBlock pb = (ProgramBlock) OptTreeConverter.getAbstractPlanMapping().getMappedProg(pn.getID())[1]; + + VariableSet varUpdated = pb.getStatementBlock().variablesUpdated(); + boolean bUIPCandHopUpdated = false; + for(Entry<String, ArrayList <UIPCandidateHop>> entry: uipCandHopHM.entrySet()) + { + String uipCandHopID = entry.getKey(); + + if (varUpdated.containsVariable(uipCandHopID)) + { + bUIPCandHopUpdated = true; + break; + } + } + + // As none of the UIP candidates updated in this DAG, no need for further processing within this DAG + if(!bUIPCandHopUpdated) + return; + + boolean bLoop = false; + if( bInLoop || pb instanceof WhileProgramBlock || + (pb instanceof ParForProgramBlock && ((ParForProgramBlock)pb).getDegreeOfParallelism() == 1) || + (pb instanceof ForProgramBlock && !(pb instanceof ParForProgramBlock))) + bLoop = true; + + for (OptNode optNode: pn.getChilds()) + { + rIsInLoop(optNode, uipCandHopHM, bLoop); + } + } + else if(bInLoop) + { + Hop hop = (Hop) OptTreeConverter.getAbstractPlanMapping().getMappedHop(pn.getID()); + + for(Entry<String, ArrayList <UIPCandidateHop>> entry: uipCandHopHM.entrySet()) + { + ArrayList <UIPCandidateHop> uipCandHopList = entry.getValue(); + + if (uipCandHopList != null) + { + for (UIPCandidateHop uipCandHop: uipCandHopList) + { + //Update if candiate hop defined outside this loop, and leftindexing is within this loop. + if (uipCandHop.getLocation() <= hop.getBeginLine() && uipCandHop.getHop().getBeginLine() <= hop.getEndLine()) + uipCandHop.setIsLoopApplicable(true); + } + } + } + } + + } + + + + /* + * This will get consumer list for candidate LeftIndexingOp. + * + * @param pn: OpNode of parfor loop + * @param uipCandHopHM: Hashmap of UIPCandidateHop with name as a key. + * @throws DMLRuntimeException + */ + private void rGetUIPConsumerList(OptNode pn, HashMap <String, ArrayList<UIPCandidateHop>> uipCandHopHM) + throws DMLRuntimeException + { + if(!pn.isLeaf()) + { + if(pn.getNodeType() == OptNode.NodeType.FUNCCALL) + return; + + ProgramBlock pb = (ProgramBlock) OptTreeConverter.getAbstractPlanMapping().getMappedProg(pn.getID())[1]; + + VariableSet varRead = pb.getStatementBlock().variablesRead(); + boolean bUIPCandHopRead = false; + for(Entry<String, ArrayList <UIPCandidateHop>> entry: uipCandHopHM.entrySet()) + { + String uipCandHopID = entry.getKey(); + + if (varRead.containsVariable(uipCandHopID)) + { + bUIPCandHopRead = true; + break; + } + } + + // As none of the UIP candidates updated in this DAG, no need for further processing within this DAG + if(!bUIPCandHopRead) + return; + + for (OptNode optNode: pn.getChilds()) + rGetUIPConsumerList(optNode, uipCandHopHM); + } + else + { + OptTreePlanMappingAbstract map = OptTreeConverter.getAbstractPlanMapping(); + long ppid = map.getMappedParentID(map.getMappedParentID(pn.getID())); + Object[] o = map.getMappedProg(ppid); + ProgramBlock pb = (ProgramBlock) o[1]; + + Hop hop = (Hop) OptTreeConverter.getAbstractPlanMapping().getMappedHop(pn.getID()); + rGetUIPConsumerList(hop, uipCandHopHM); + + if(pb instanceof IfProgramBlock || pb instanceof WhileProgramBlock || + (pb instanceof ForProgramBlock && !(pb instanceof ParForProgramBlock))) //TODO + rGetUIPConsumerList(pb, uipCandHopHM); + } + } + + + private void rGetUIPConsumerList(ProgramBlock pb, HashMap <String, ArrayList<UIPCandidateHop>> uipCandHopHM) + throws DMLRuntimeException + { + ArrayList<ProgramBlock> childBlocks = null; + ArrayList<ProgramBlock> elseBlocks = null; + if (pb instanceof WhileProgramBlock) + childBlocks = ((WhileProgramBlock)pb).getChildBlocks(); + else if (pb instanceof ForProgramBlock) + childBlocks = ((ForProgramBlock)pb).getChildBlocks(); + else if (pb instanceof IfProgramBlock) + { + childBlocks = ((IfProgramBlock)pb).getChildBlocksIfBody(); + elseBlocks = ((IfProgramBlock)pb).getChildBlocksElseBody(); + } + + if(childBlocks != null) + for (ProgramBlock childBlock: childBlocks) + { + rGetUIPConsumerList(childBlock, uipCandHopHM); + try + { + rGetUIPConsumerList(childBlock.getStatementBlock().get_hops(), uipCandHopHM); + } + catch (Exception e) { + throw new DMLRuntimeException(e); + } + } + if(elseBlocks != null) + for (ProgramBlock childBlock: elseBlocks) + { + rGetUIPConsumerList(childBlock, uipCandHopHM); + try + { + rGetUIPConsumerList(childBlock.getStatementBlock().get_hops(), uipCandHopHM); + } + catch (Exception e) { + throw new DMLRuntimeException(e); + } + } + } + + private void rGetUIPConsumerList(ArrayList<Hop> hops, HashMap <String, ArrayList<UIPCandidateHop>> uipCandHopHM) + throws DMLRuntimeException + { + if(hops != null) + for (Hop hop: hops) + rGetUIPConsumerList(hop, uipCandHopHM); + } + + + private void rGetUIPConsumerList(Hop hop, HashMap <String, ArrayList<UIPCandidateHop>> uipCandHopHM) + throws DMLRuntimeException + { + if(hop.getVisited() != Hop.VisitStatus.DONE) + { + if ((!(!hop.getParent().isEmpty() && hop.getParent().get(0) instanceof LeftIndexingOp)) && + ((hop instanceof DataOp && ((DataOp)hop).getDataOpType() == DataOpTypes.TRANSIENTREAD ) || + (hop instanceof ReorgOp && (((ReorgOp)hop).getOp() == ReOrgOp.RESHAPE || ((ReorgOp)hop).getOp() == ReOrgOp.TRANSPOSE)) || + (hop instanceof FunctionOp))) + { + // If candidate's name is same as input hop. + String uipCandiateID = hop.getName(); + ArrayList <UIPCandidateHop> uipCandHopList = uipCandHopHM.get(uipCandiateID); + + if (uipCandHopList != null) + { + for (UIPCandidateHop uipCandHop: uipCandHopList) + { + // Add consumers for candidate hop. + ArrayList<Hop> consumerHops = uipCandHop.getConsumerHops(); + if(uipCandHop.getConsumerHops() == null) + consumerHops = new ArrayList<Hop>(); + consumerHops.add(getRootHop(hop)); + uipCandHop.setConsumerHops(consumerHops); + } + } + } + + for(Hop hopIn: hop.getInput()) + { + rGetUIPConsumerList(hopIn, uipCandHopHM); + } + + hop.setVisited(Hop.VisitStatus.DONE); + } + } + + + private Hop getRootHop(Hop hop) + { + return (!hop.getParent().isEmpty())?getRootHop(hop.getParent().get(0)):hop; + } + + + private void rResetVisitStatus(OptNode pn) + throws DMLRuntimeException + { + + if(!pn.isLeaf()) + { + if(pn.getNodeType() == OptNode.NodeType.FUNCCALL) + { + Hop hopFunc = (Hop) OptTreeConverter.getAbstractPlanMapping().getMappedHop(pn.getID()); + hopFunc.resetVisitStatus(); + return; + } + ProgramBlock pb = (ProgramBlock) OptTreeConverter.getAbstractPlanMapping().getMappedProg(pn.getID())[1]; + ArrayList<ProgramBlock> childBlocks = null; + ArrayList<ProgramBlock> elseBlocks = null; + if (pb instanceof WhileProgramBlock) + childBlocks = ((WhileProgramBlock)pb).getChildBlocks(); + else if (pb instanceof ForProgramBlock) + childBlocks = ((ForProgramBlock)pb).getChildBlocks(); + else if (pb instanceof IfProgramBlock) { + childBlocks = ((IfProgramBlock)pb).getChildBlocksIfBody(); + elseBlocks = ((IfProgramBlock)pb).getChildBlocksElseBody(); + } + + if(childBlocks != null) + { + for (ProgramBlock childBlock: childBlocks) + { + try + { + Hop.resetVisitStatus(childBlock.getStatementBlock().get_hops()); + } + catch (Exception e) + { + throw new DMLRuntimeException(e); + } + } + } + if(elseBlocks != null) + { + for (ProgramBlock childBlock: elseBlocks) + { + try + { + Hop.resetVisitStatus(childBlock.getStatementBlock().get_hops()); + } + catch (Exception e) + { + throw new DMLRuntimeException(e); + } + } + } + + for (OptNode optNode: pn.getChilds()) + { + rResetVisitStatus(optNode); + } + } + else + { + Hop hop = (Hop) OptTreeConverter.getAbstractPlanMapping().getMappedHop(pn.getID()); + if(hop != null) + { + hop.resetVisitStatus(); + } + } + } + + + + /* + * This will validate candidate's consumer list. + * + * @param pn: OpNode of parfor loop + * @param uipCandHopHM: Hashmap of UIPCandidateHop with name as a key. + * @throws DMLRuntimeException + */ + + private void rValidateUIPConsumerList(OptNode pn, HashMap <String, ArrayList<UIPCandidateHop>> uipCandHopHM) + throws DMLRuntimeException + { + if(!pn.isLeaf()) + { + if(pn.getNodeType() == OptNode.NodeType.FUNCCALL) + { + Hop hop = (Hop) OptTreeConverter.getAbstractPlanMapping().getMappedHop(pn.getID()); + rValidateUIPConsumerList(hop.getInput(), uipCandHopHM); + return; + } + + ProgramBlock pb = (ProgramBlock) OptTreeConverter.getAbstractPlanMapping().getMappedProg(pn.getID())[1]; + + VariableSet varRead = pb.getStatementBlock().variablesRead(); + boolean bUIPCandHopRead = false; + for(Entry<String, ArrayList <UIPCandidateHop>> entry: uipCandHopHM.entrySet()) + { + ArrayList <UIPCandidateHop> uipCandHopList = entry.getValue(); + if (uipCandHopList != null) + { + for (UIPCandidateHop uipCandHop: uipCandHopList) + { + ArrayList<Hop> consumerHops = uipCandHop.getConsumerHops(); + if(consumerHops != null) + { + // If any of consumer's input (or any parent in hierachy of input) matches candiate's name, then + // remove candidate from the list. + for (Hop consumerHop: consumerHops) + { + if (varRead.containsVariable(consumerHop.getName())) + { + bUIPCandHopRead = true; + break; + } + } + } + } + } + } + // As none of the UIP candidates updated in this DAG, no need for further processing within this DAG + if(!bUIPCandHopRead) + return; + + for (OptNode optNode: pn.getChilds()) + rValidateUIPConsumerList(optNode, uipCandHopHM); + } + else + { + OptTreePlanMappingAbstract map = OptTreeConverter.getAbstractPlanMapping(); + long ppid = map.getMappedParentID(map.getMappedParentID(pn.getID())); + Object[] o = map.getMappedProg(ppid); + ProgramBlock pb = (ProgramBlock) o[1]; + + if(pb instanceof IfProgramBlock || pb instanceof WhileProgramBlock || + (pb instanceof ForProgramBlock && !(pb instanceof ParForProgramBlock))) //TODO + rValidateUIPConsumerList(pb, uipCandHopHM); + + long pid = map.getMappedParentID(pn.getID()); + o = map.getMappedProg(pid); + pb = (ProgramBlock) o[1]; + Hop hop = map.getMappedHop(pn.getID()); + rValidateUIPConsumerList(hop, uipCandHopHM, pb.getStatementBlock().variablesRead()); + } + } + + private void rValidateUIPConsumerList(ProgramBlock pb, HashMap <String, ArrayList<UIPCandidateHop>> uipCandHopHM) + throws DMLRuntimeException + { + ArrayList<ProgramBlock> childBlocks = null; + if (pb instanceof WhileProgramBlock) + childBlocks = ((WhileProgramBlock)pb).getChildBlocks(); + else if (pb instanceof ForProgramBlock) + childBlocks = ((ForProgramBlock)pb).getChildBlocks(); + else if (pb instanceof IfProgramBlock) + { + childBlocks = ((IfProgramBlock)pb).getChildBlocksIfBody(); + ArrayList<ProgramBlock> elseBlocks = ((IfProgramBlock)pb).getChildBlocksElseBody(); + if(childBlocks != null && elseBlocks != null) + childBlocks.addAll(elseBlocks); + else if (childBlocks == null) + childBlocks = elseBlocks; + } + + if(childBlocks != null) + for (ProgramBlock childBlock: childBlocks) + { + rValidateUIPConsumerList(childBlock, uipCandHopHM); + try + { + rValidateUIPConsumerList(childBlock.getStatementBlock(), uipCandHopHM); + } + catch (Exception e) { + throw new DMLRuntimeException(e); + } + } + } + + + private void rValidateUIPConsumerList(StatementBlock sb, HashMap <String, ArrayList<UIPCandidateHop>> uipCandHopHM) + throws DMLRuntimeException + { + VariableSet readVariables = sb.variablesRead(); + + for(Entry<String, ArrayList <UIPCandidateHop>> entry: uipCandHopHM.entrySet()) + { + ArrayList <UIPCandidateHop> uipCandHopList = entry.getValue(); + if (uipCandHopList != null) + { + for (UIPCandidateHop uipCandHop: uipCandHopList) + { + ArrayList<Hop> consumerHops = uipCandHop.getConsumerHops(); + if(consumerHops != null) + { + // If consumer has read then remove candidate from the list (set flag to false). + for (Hop consumerHop: consumerHops) + if(readVariables.containsVariable(consumerHop.getName())) + { + uipCandHop.setUpdateInPlace(false); + break; + } + } + } + } + } + } + + + private void rValidateUIPConsumerList(ArrayList<Hop> hops, HashMap <String, ArrayList<UIPCandidateHop>> uipCandHopHM) + throws DMLRuntimeException + { + if(hops != null) + for (Hop hop: hops) + rValidateUIPConsumerList(hop, uipCandHopHM); + } + + + private void rValidateUIPConsumerList(Hop hop, HashMap <String, ArrayList<UIPCandidateHop>> uipCandHopHM) + throws DMLRuntimeException + { + if(hop.getVisited() != Hop.VisitStatus.DONE) + { + for(Entry<String, ArrayList <UIPCandidateHop>> entry: uipCandHopHM.entrySet()) + { + ArrayList <UIPCandidateHop> uipCandHopList = entry.getValue(); + if (uipCandHopList != null) + { + for (UIPCandidateHop uipCandHop: uipCandHopList) + { + ArrayList<Hop> consumerHops = uipCandHop.getConsumerHops(); + if(consumerHops != null) + { + // If consumer has read then remove candidate from the list (set flag to false). + for (Hop consumerHop: consumerHops) + if(hop.getName().equals(consumerHop.getName())) + { + uipCandHop.setUpdateInPlace(false); + break; + } + } + } + } + } + hop.setVisited(Hop.VisitStatus.DONE); + } + } + + private void rValidateUIPConsumerList(Hop hop, HashMap <String, ArrayList<UIPCandidateHop>> uipCandHopHM, VariableSet readVariables) + throws DMLRuntimeException + { + if(hop.getVisited() != Hop.VisitStatus.DONE) + { + for(Entry<String, ArrayList <UIPCandidateHop>> entry: uipCandHopHM.entrySet()) + { + ArrayList <UIPCandidateHop> uipCandHopList = entry.getValue(); + if (uipCandHopList != null) + { + for (UIPCandidateHop uipCandHop: uipCandHopList) + { + ArrayList<Hop> consumerHops = uipCandHop.getConsumerHops(); + if(consumerHops != null) + { + // If consumer has read then remove candidate from the list (set flag to false). + for (Hop consumerHop: consumerHops) + if(readVariables.containsVariable(consumerHop.getName())) + { + uipCandHop.setUpdateInPlace(false); + break; + } + } + } + } + } + hop.setVisited(Hop.VisitStatus.DONE); + } + } + + + public static List<String> getUIPList() + { + return listUIPRes.get(); + } + /** * * @param n @@ -1993,7 +2671,7 @@ public class OptimizerRuleBased extends Optimizer ParForProgramBlock pfpb = (ParForProgramBlock) OptTreeConverter .getAbstractPlanMapping().getMappedProg(pn.getID())[1]; - double M_sumInterm = rComputeSumMemoryIntermediates(pn, inplaceResultVars); + double M_sumInterm = rComputeSumMemoryIntermediates(pn, inplaceResultVars, new HashMap <String, ArrayList <UIPCandidateHop>>()); boolean apply = false; if( (pfpb.getExecMode() == PExecMode.REMOTE_MR_DP || pfpb.getExecMode() == PExecMode.REMOTE_MR) @@ -2013,7 +2691,8 @@ public class OptimizerRuleBased extends Optimizer * @return * @throws DMLRuntimeException */ - protected double rComputeSumMemoryIntermediates( OptNode n, HashSet<String> inplaceResultVars ) + protected double rComputeSumMemoryIntermediates( OptNode n, HashSet<String> inplaceResultVars, + HashMap <String, ArrayList <UIPCandidateHop>> uipCandidateHM ) throws DMLRuntimeException { double sum = 0; @@ -2021,11 +2700,37 @@ public class OptimizerRuleBased extends Optimizer if( !n.isLeaf() ) { for( OptNode cn : n.getChilds() ) - sum += rComputeSumMemoryIntermediates( cn, inplaceResultVars ); + sum += rComputeSumMemoryIntermediates( cn, inplaceResultVars, uipCandidateHM ); } else if( n.getNodeType()== NodeType.HOP ) { Hop h = OptTreeConverter.getAbstractPlanMapping().getMappedHop(n.getID()); + if (h.getDataType() == Expression.DataType.MATRIX && h instanceof LeftIndexingOp && + h.getInput().get(0).getParent().size() == 1) + { + long pid = OptTreeConverter.getAbstractPlanMapping().getMappedParentID(n.getID()); + ProgramBlock pb = (ProgramBlock) OptTreeConverter.getAbstractPlanMapping().getMappedProg(pid)[1]; + + while(!(pb instanceof WhileProgramBlock || pb instanceof ForProgramBlock)) + { + pid = OptTreeConverter.getAbstractPlanMapping().getMappedParentID(pid); + pb = (ProgramBlock) OptTreeConverter.getAbstractPlanMapping().getMappedProg(pid)[1]; + } + + String uipCandiateID = new String(h.getName()); + ArrayList <UIPCandidateHop> uipCandiHopList = uipCandidateHM.get(uipCandiateID); + if(uipCandiHopList == null) + uipCandiHopList = new ArrayList<UIPCandidateHop>(); + uipCandiHopList.add(new UIPCandidateHop(h, pb)); + uipCandidateHM.put(uipCandiateID, uipCandiHopList); + + StatementBlock sb = (StatementBlock) OptTreeConverter.getAbstractPlanMapping().getMappedProg(OptTreeConverter.getAbstractPlanMapping().getMappedParentID(n.getID()))[0]; + if(LOG.isDebugEnabled()) + LOG.debug("Candidate Hop:" + h.getName() + "<" + h.getBeginLine() + "," + h.getEndLine() + ">,<" + + h.getBeginColumn() + "," + h.getEndColumn() + "> PB:" + "<" + pb.getBeginLine() + "," + pb.getEndLine() + ">,<" + + pb.getBeginColumn() + "," + pb.getEndColumn() + "> SB:" + "<" + sb.getBeginLine() + "," + sb.getEndLine() + ">,<" + + sb.getBeginColumn() + "," + sb.getEndColumn() + ">"); + } if( n.getParam(ParamType.OPSTRING).equals(IndexingOp.OPSTRING) && n.getParam(ParamType.DATA_PARTITION_FORMAT) != null ) @@ -3183,5 +3888,74 @@ public class OptimizerRuleBased extends Optimizer return OptimizerUtils.toMB(inB) + "MB"; } + /* + * This class stores information for the candidate hop, such as hop itself, program block. + * When it gets evaluated if Matrix can be marked for "UpdateInPlace", additional properties such + * as location, flag to indicate if its in loop (for, parfor, while), flag to indicate if hop can be marked as "UpdateInPlace". + */ + class UIPCandidateHop { + Hop hop; + int iLocation = -1; + ProgramBlock pb; + Boolean bIsLoopApplicable = false, bUpdateInPlace = true; + ArrayList<Hop> consumerHops = null; + + + UIPCandidateHop(Hop hop, ProgramBlock pb) + { + this.hop = hop; + this.pb = pb; + } + + Hop getHop() + { + return hop; + } + + ProgramBlock getProgramBlock() + { + return pb; + } + + int getLocation() + { + return this.iLocation; + } + + void setLocation(int iLocation) + { + this.iLocation = iLocation; + } + + boolean isLoopApplicable() + { + return(bIsLoopApplicable); + } + + void setIsLoopApplicable(boolean bInWhileLoop) + { + this.bIsLoopApplicable = bInWhileLoop; + } + boolean isUpdateInPlace() + { + return(bUpdateInPlace); + } + + void setUpdateInPlace(boolean bUpdateInPlace) + { + this.bUpdateInPlace = bUpdateInPlace; + } + + ArrayList<Hop> getConsumerHops() + { + return this.consumerHops; + } + + void setConsumerHops(ArrayList<Hop> consumerHops) + { + this.consumerHops = consumerHops; + } + + } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java index 5c1bbd7..bf8d792 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java @@ -22,7 +22,6 @@ package org.apache.sysml.runtime.instructions.cp; import java.io.IOException; import org.apache.commons.lang.StringUtils; - import org.apache.sysml.lops.Lop; import org.apache.sysml.lops.UnaryCP; import org.apache.sysml.parser.Expression.DataType; @@ -47,6 +46,7 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.OutputInfo; import org.apache.sysml.runtime.util.MapReduceTool; import org.apache.sysml.runtime.util.UtilFunctions; +import org.apache.sysml.utils.Statistics; public class VariableCPInstruction extends CPInstruction @@ -97,6 +97,7 @@ public class VariableCPInstruction extends CPInstruction private CPOperand input3; private CPOperand output; private MetaData metadata; + private boolean updateInPlace; // CSV related members (used only in createvar instructions) private FileFormatProperties formatProperties; @@ -183,17 +184,19 @@ public class VariableCPInstruction extends CPInstruction } // This version of the constructor is used only in case of CreateVariable - public VariableCPInstruction (VariableOperationCode op, CPOperand in1, CPOperand in2, CPOperand in3, MetaData md, int _arity, String sopcode, String istr) + public VariableCPInstruction (VariableOperationCode op, CPOperand in1, CPOperand in2, CPOperand in3, MetaData md, boolean updateInPlace, int _arity, String sopcode, String istr) { this(op, in1, in2, in3, (CPOperand)null, _arity, sopcode, istr); metadata = md; + this.updateInPlace = updateInPlace; } // This version of the constructor is used only in case of CreateVariable - public VariableCPInstruction (VariableOperationCode op, CPOperand in1, CPOperand in2, CPOperand in3, MetaData md, int _arity, FileFormatProperties formatProperties, String sopcode, String istr) + public VariableCPInstruction (VariableOperationCode op, CPOperand in1, CPOperand in2, CPOperand in3, MetaData md, boolean updateInPlace, int _arity, FileFormatProperties formatProperties, String sopcode, String istr) { this(op, in1, in2, in3, (CPOperand)null, _arity, sopcode, istr); metadata = md; + this.updateInPlace = updateInPlace; this.formatProperties = formatProperties; } @@ -284,11 +287,11 @@ public class VariableCPInstruction extends CPInstruction * 13 inputs: createvar corresponding to WRITE -- includes properties hasHeader, delim, and sparse * 14 inputs: createvar corresponding to READ -- includes properties hasHeader, delim, fill, and fillValue */ - if ( parts.length < 13 || parts.length > 15 ) + if ( parts.length < 14 || parts.length > 16 ) throw new DMLRuntimeException("Invalid number of operands in createvar instruction: " + str); } else { - if ( parts.length != 5 && parts.length != 10 ) + if ( parts.length != 5 && parts.length != 11 ) throw new DMLRuntimeException("Invalid number of operands in createvar instruction: " + str); } OutputInfo oi = OutputInfo.stringToOutputInfo(fmt); @@ -309,6 +312,9 @@ public class VariableCPInstruction extends CPInstruction throw new DMLRuntimeException("Invalid number of operands in createvar instruction: " + str); } MatrixFormatMetaData iimd = new MatrixFormatMetaData(mc, oi, ii); + boolean updateInPlace = false; + if ( parts.length >= 11 ) + updateInPlace = Boolean.parseBoolean(parts[10]); if ( fmt.equalsIgnoreCase("csv") ) { /* @@ -317,26 +323,26 @@ public class VariableCPInstruction extends CPInstruction * 14 inputs: createvar corresponding to READ -- includes properties hasHeader, delim, fill, and fillValue */ FileFormatProperties fmtProperties = null; - if ( parts.length == 13 ) { - boolean hasHeader = Boolean.parseBoolean(parts[10]); - String delim = parts[11]; - boolean sparse = Boolean.parseBoolean(parts[12]); + if ( parts.length == 14 ) { + boolean hasHeader = Boolean.parseBoolean(parts[11]); + String delim = parts[12]; + boolean sparse = Boolean.parseBoolean(parts[13]); fmtProperties = new CSVFileFormatProperties(hasHeader, delim, sparse) ; } else { - boolean hasHeader = Boolean.parseBoolean(parts[10]); - String delim = parts[11]; - boolean fill = Boolean.parseBoolean(parts[12]); - double fillValue = UtilFunctions.parseToDouble(parts[13]); + boolean hasHeader = Boolean.parseBoolean(parts[11]); + String delim = parts[12]; + boolean fill = Boolean.parseBoolean(parts[13]); + double fillValue = UtilFunctions.parseToDouble(parts[14]); String naStrings = null; - if ( parts.length == 15 ) - naStrings = parts[14]; + if ( parts.length == 16 ) + naStrings = parts[15]; fmtProperties = new CSVFileFormatProperties(hasHeader, delim, fill, fillValue, naStrings) ; } - return new VariableCPInstruction(VariableOperationCode.CreateVariable, in1, in2, in3, iimd, parts.length, fmtProperties, opcode, str); + return new VariableCPInstruction(VariableOperationCode.CreateVariable, in1, in2, in3, iimd, updateInPlace, parts.length, fmtProperties, opcode, str); } else { - return new VariableCPInstruction(VariableOperationCode.CreateVariable, in1, in2, in3, iimd, parts.length, opcode, str); + return new VariableCPInstruction(VariableOperationCode.CreateVariable, in1, in2, in3, iimd, updateInPlace, parts.length, opcode, str); } case AssignVariable: in1 = new CPOperand(parts[1]); @@ -438,8 +444,12 @@ public class VariableCPInstruction extends CPInstruction //is potential for hidden side effects between variables. mobj.setMetaData((MetaData)metadata.clone()); mobj.setFileFormatProperties(formatProperties); - + mobj.enableUpdateInPlace(updateInPlace); ec.setVariable(input1.getName(), mobj); + if(updateInPlace) + Statistics.incrementTotUpdateInPlace(); + else + Statistics.incrementTotNonUpdateInPlace(); } else if ( input1.getDataType() == DataType.SCALAR ){ ScalarObject sobj = null; @@ -932,7 +942,7 @@ public class VariableCPInstruction extends CPInstruction return parseInstruction(str); } - public static Instruction prepareCreateVariableInstruction(String varName, String fileName, boolean fNameOverride, String format, MatrixCharacteristics mc, boolean hasHeader, String delim, boolean sparse) throws DMLRuntimeException, DMLUnsupportedOperationException { + public static Instruction prepareCreateVariableInstruction(String varName, String fileName, boolean fNameOverride, String format, MatrixCharacteristics mc, boolean updateInPlace) throws DMLRuntimeException, DMLUnsupportedOperationException { StringBuilder sb = new StringBuilder(); sb.append(getBasicCreateVarString(varName, fileName, fNameOverride, format)); @@ -946,6 +956,30 @@ public class VariableCPInstruction extends CPInstruction sb.append(mc.getColsPerBlock()); sb.append(Lop.OPERAND_DELIMITOR); sb.append(mc.getNonZeros()); + sb.append(Lop.OPERAND_DELIMITOR); + sb.append(updateInPlace); + + String str = sb.toString(); + + return parseInstruction(str); + } + + public static Instruction prepareCreateVariableInstruction(String varName, String fileName, boolean fNameOverride, String format, MatrixCharacteristics mc, boolean updateInPlace, boolean hasHeader, String delim, boolean sparse) throws DMLRuntimeException, DMLUnsupportedOperationException { + StringBuilder sb = new StringBuilder(); + sb.append(getBasicCreateVarString(varName, fileName, fNameOverride, format)); + + sb.append(Lop.OPERAND_DELIMITOR); + sb.append(mc.getRows()); + sb.append(Lop.OPERAND_DELIMITOR); + sb.append(mc.getCols()); + sb.append(Lop.OPERAND_DELIMITOR); + sb.append(mc.getRowsPerBlock()); + sb.append(Lop.OPERAND_DELIMITOR); + sb.append(mc.getColsPerBlock()); + sb.append(Lop.OPERAND_DELIMITOR); + sb.append(mc.getNonZeros()); + sb.append(Lop.OPERAND_DELIMITOR); + sb.append(updateInPlace); sb.append(Lop.OPERAND_DELIMITOR); sb.append(hasHeader); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/main/java/org/apache/sysml/utils/Explain.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Explain.java b/src/main/java/org/apache/sysml/utils/Explain.java index 16d8c12..1986265 100644 --- a/src/main/java/org/apache/sysml/utils/Explain.java +++ b/src/main/java/org/apache/sysml/utils/Explain.java @@ -28,7 +28,9 @@ import java.util.Map.Entry; import org.apache.sysml.api.DMLException; import org.apache.sysml.hops.FunctionOp; import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.Hop.DataOpTypes; import org.apache.sysml.hops.Hop.VisitStatus; +import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.HopsException; import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.hops.OptimizerUtils; @@ -682,7 +684,12 @@ public class Explain + hop.getDim2() + "," + hop.getRowsInBlock() + "," + hop.getColsInBlock() + "," - + hop.getNnz() + "]"); + + hop.getNnz()); + + if (hop instanceof DataOp && ((DataOp)hop).getDataOpType() == DataOpTypes.TRANSIENTREAD && hop.getUpdateInPlace()) + sb.append("," + hop.getUpdateInPlace()); + + sb.append("]"); //memory estimates sb.append(" [" + showMem(hop.getInputMemEstimate(), false) + "," http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/main/java/org/apache/sysml/utils/Statistics.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java b/src/main/java/org/apache/sysml/utils/Statistics.java index 4f757d6..499b19c 100644 --- a/src/main/java/org/apache/sysml/utils/Statistics.java +++ b/src/main/java/org/apache/sysml/utils/Statistics.java @@ -88,6 +88,9 @@ public class Statistics private static HashMap<String,Long> _cpInstTime = new HashMap<String, Long>(); private static HashMap<String,Long> _cpInstCounts = new HashMap<String, Long>(); + private static AtomicLong lTotUpdateInPlace = new AtomicLong(0); + private static AtomicLong lTotNonUpdateInPlace = new AtomicLong(0); + public static synchronized void setNoOfExecutedMRJobs(int iNoOfExecutedMRJobs) { Statistics.iNoOfExecutedMRJobs = iNoOfExecutedMRJobs; } @@ -144,6 +147,14 @@ public class Statistics iNoOfCompiledSPInst ++; } + public static void incrementTotUpdateInPlace() { + lTotUpdateInPlace.incrementAndGet(); + } + + public static void incrementTotNonUpdateInPlace() { + lTotNonUpdateInPlace.incrementAndGet(); + } + /** * * @param count @@ -539,6 +550,7 @@ public class Statistics sb.append("ParFor optimize time:\t\t" + String.format("%.3f", ((double)getParforOptTime())/1000) + " sec.\n"); sb.append("ParFor initialize time:\t\t" + String.format("%.3f", ((double)getParforInitTime())/1000) + " sec.\n"); sb.append("ParFor result merge time:\t" + String.format("%.3f", ((double)getParforMergeTime())/1000) + " sec.\n"); + sb.append("ParFor total update in-place:\t" + lTotUpdateInPlace + "/" + (lTotUpdateInPlace.get()+lTotNonUpdateInPlace.get()) + "\n"); } sb.append("Total JIT compile time:\t\t" + ((double)getJITCompileTime())/1000 + " sec.\n"); sb.append("Total JVM GC count:\t\t" + getJVMgcCount() + ".\n"); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java index 92e4acad..076e5d0 100644 --- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java @@ -39,7 +39,6 @@ import org.apache.wink.json4j.JSONObject; import org.junit.After; import org.junit.Assert; import org.junit.Before; - import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.api.MLContext; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/test/java/org/apache/sysml/test/integration/functions/updateinplace/UpdateInPlaceTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/updateinplace/UpdateInPlaceTest.java b/src/test/java/org/apache/sysml/test/integration/functions/updateinplace/UpdateInPlaceTest.java new file mode 100644 index 0000000..2732820 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/updateinplace/UpdateInPlaceTest.java @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.updateinplace; + +import java.util.Arrays; +import java.util.List; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.runtime.controlprogram.parfor.opt.OptimizerRuleBased; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; + +public class UpdateInPlaceTest extends AutomatedTestBase +{ + + private final static String TEST_DIR = "functions/updateinplace/"; + private final static String TEST_NAME = "updateinplace"; + private final static String TEST_CLASS_DIR = TEST_DIR + UpdateInPlaceTest.class.getSimpleName() + "/"; + + /* Test cases to test following scenarios + * + * Test scenarios Test case + * ------------------------------------------------------------------------------------ + * + * Positive case:: + * =============== + * + * Candidate UIP applicable testUIP + * + * Interleave Operalap:: + * ===================== + * + * Various loop types:: + * -------------------- + * + * Overlap for Consumer within while loop testUIPNAConsUsed + * Overlap for Consumer outside loop testUIPNAConsUsedOutsideDAG + * Overlap for Consumer within loop(not used) testUIPNAConsUsed + * Overlap for Consumer within for loop testUIPNAConsUsedForLoop + * Overlap for Consumer within inner parfor loop testUIPNAParFor + * Overlap for Consumer inside loop testUIPNAConsUsedInsideDAG + * + * Complex Statement:: + * ------------------- + * + * Overlap for Consumer within complex statement testUIPNAComplexConsUsed + * (Consumer in complex statement) + * Overlap for Consumer within complex statement testUIPNAComplexCandUsed + * (Candidate in complex statement) + * + * Else and Predicate case:: + * ------------------------- + * + * Overlap for Consumer within else clause testUIPNAConsUsedElse + * Overlap with consumer in predicate testUIPNACandInPredicate + * + * Multiple LIX for same object with interleave:: + * ---------------------------------------------- + * + * Overlap for Consumer with multiple lix testUIPNAMultiLIX + * + * + * Function Calls:: + * ================ + * + * Overlap for candidate used in function call testUIPNACandInFuncCall + * Overlap for consumer used in function call testUIPNAConsInFuncCall + * Function call without consumer/candidate testUIPFuncCall + * + */ + + + + //Note: In order to run these tests against ParFor loop, parfor's DEBUG flag needs to be set in the script. + + @Override + public void setUp() + { + TestUtils.clearAssertionInformation(); + OptimizerUtils.ALLOW_DYN_RECOMPILATION = true; + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, null)); + } + + //public void testUIPOverlapStatement(1) + @Test + public void testUIP() + { + List<String> listUIPRes = Arrays.asList("A"); + + runUpdateInPlaceTest(TEST_NAME, 1, listUIPRes); + } + + @Test + public void testUIPNAConsUsed() + { + List<String> listUIPRes = Arrays.asList(); + + runUpdateInPlaceTest(TEST_NAME, 2, listUIPRes); + } + + @Test + public void testUIPNAConsUsedOutsideDAG() + { + List<String> listUIPRes = Arrays.asList(); + + runUpdateInPlaceTest(TEST_NAME, 3, listUIPRes); + } + + @Test + public void testUIPConsNotUsed() + { + List<String> listUIPRes = Arrays.asList("A"); + + runUpdateInPlaceTest(TEST_NAME, 4, listUIPRes); + } + + @Test + public void testUIPNAConsUsedForLoop() + { + List<String> listUIPRes = Arrays.asList(); + + runUpdateInPlaceTest(TEST_NAME, 5, listUIPRes); + } + + @Test + public void testUIPNAComplexConsUsed() + { + List<String> listUIPRes = Arrays.asList(); + + runUpdateInPlaceTest(TEST_NAME, 6, listUIPRes); + } + + @Test + public void testUIPNAComplexCandUsed() + { + List<String> listUIPRes = Arrays.asList(); + + runUpdateInPlaceTest(TEST_NAME, 7, listUIPRes); + } + + @Test + public void testUIPNAConsUsedElse() + { + List<String> listUIPRes = Arrays.asList(); + + runUpdateInPlaceTest(TEST_NAME, 8, listUIPRes); + } + + @Test + public void testUIPNACandInPredicate() + { + List<String> listUIPRes = Arrays.asList(); + + runUpdateInPlaceTest(TEST_NAME, 9, listUIPRes); + } + + @Test + public void testUIPNAMultiLIX() + { + List<String> listUIPRes = Arrays.asList(); + + runUpdateInPlaceTest(TEST_NAME, 10, listUIPRes); + } + + @Test + public void testUIPNAParFor() + { + List<String> listUIPRes = Arrays.asList(); + + runUpdateInPlaceTest(TEST_NAME, 11, listUIPRes); + } + + @Test + public void testUIPNACandInFuncCall() + { + List<String> listUIPRes = Arrays.asList(); + + runUpdateInPlaceTest(TEST_NAME, 12, listUIPRes); + } + + @Test + public void testUIPNAConsInFuncCall() + { + List<String> listUIPRes = Arrays.asList(); + + runUpdateInPlaceTest(TEST_NAME, 13, listUIPRes); + } + + @Test + public void testUIPFuncCall() + { + List<String> listUIPRes = Arrays.asList("A"); + + runUpdateInPlaceTest(TEST_NAME, 14, listUIPRes); + } + + @Test + public void testUIPNAConsUsedInsideDAG() + { + List<String> listUIPRes = Arrays.asList(); + + runUpdateInPlaceTest(TEST_NAME, 15, listUIPRes); + } + + + /** + * + * @param TEST_NAME + * @param iTestNumber + * @param listUIPRes + */ + private void runUpdateInPlaceTest( String TEST_NAME, int iTestNumber, List<String> listUIPExp ) + { + try + { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + + // This is for running the junit test the new way, i.e., construct the arguments directly + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + iTestNumber + ".dml"; + programArgs = new String[]{}; //new String[]{"-args", input("A"), output("B") }; + + runTest(true, false, null, -1); + + List<String> listUIPRes = OptimizerRuleBased.getUIPList(); + int iUIPResCount = 0; + + // If UpdateInPlace list specified in the argument, verify the list. + if (listUIPExp != null) + { + if(listUIPRes != null) + { + for (String strUIPMatName: listUIPExp) + Assert.assertTrue("Expected UpdateInPlace matrix " + strUIPMatName + + " does not exist in the result UpdateInPlace matrix list.", + listUIPRes.contains(strUIPMatName)); + + iUIPResCount = listUIPRes.size(); + } + + Assert.assertTrue("Expected # of UpdateInPlace matrix object/s " + listUIPExp.size() + + " does not match with the # of matrix objects " + iUIPResCount + " from optimization result.", + (iUIPResCount == listUIPExp.size())); + } + else + { + Assert.assertTrue("Expected # of UpdateInPlace matrix object/s " + "0" + + " does not match with the # of matrix objects " + "0" + " from optimization result.", + (listUIPRes == null || listUIPRes.size() == 0)); + } + } + finally{ + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/test/scripts/functions/updateinplace/updateinplace1.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/updateinplace/updateinplace1.dml b/src/test/scripts/functions/updateinplace/updateinplace1.dml new file mode 100644 index 0000000..6d19ab2 --- /dev/null +++ b/src/test/scripts/functions/updateinplace/updateinplace1.dml @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +n = 2; +m = 2; + +parfor (j in 1:m, log=DEBUG){ + A = matrix(3, rows=n, cols=1); + i = 1 + while (i <= n){ + if (1 == 1) + print("i = " + i + " j = " + j + " Sum(A) = " + sum(A)); + A[i,1] = j*7+i; + i = i + 1 + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/test/scripts/functions/updateinplace/updateinplace10.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/updateinplace/updateinplace10.dml b/src/test/scripts/functions/updateinplace/updateinplace10.dml new file mode 100644 index 0000000..936bbfe --- /dev/null +++ b/src/test/scripts/functions/updateinplace/updateinplace10.dml @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +n = 2; +m = 2; + +parfor (j in 1:m, log=DEBUG){ + A = matrix(3, rows=n, cols=3); + B = matrix(2, rows=n, cols=3); + for (i in 1:n){ + print("i = " + i + " j = " + j + " Sum(B) = " + sum(B)); + A[i,1] = j*2+i; + if(1 == 1) + B = A + A[i,2] = j*3+i; + A[i,3] = j*4+i; + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/test/scripts/functions/updateinplace/updateinplace11.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/updateinplace/updateinplace11.dml b/src/test/scripts/functions/updateinplace/updateinplace11.dml new file mode 100644 index 0000000..708ba2a --- /dev/null +++ b/src/test/scripts/functions/updateinplace/updateinplace11.dml @@ -0,0 +1,41 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +n = 2; +m = 2; + +parfor (j in 1:m, log=DEBUG){ + parfor (k in 1:m, log=DEBUG){ + A = matrix(3, rows=n, cols=1); + B = matrix(2, rows=n, cols=1); + C = matrix(4, rows=n, cols=1); + i = 1 + while (i <= n){ + print("i = " + i + " j = " + j + " Sum(B) = " + sum(B)); + if(i < 5) + C = A + else + B = A + A[i,1] = j*7+i; + i = i + 1 + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/test/scripts/functions/updateinplace/updateinplace12.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/updateinplace/updateinplace12.dml b/src/test/scripts/functions/updateinplace/updateinplace12.dml new file mode 100644 index 0000000..3fb9aaf --- /dev/null +++ b/src/test/scripts/functions/updateinplace/updateinplace12.dml @@ -0,0 +1,44 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +n = 2; +m = 2; + +parfor (j in 1:m, log=DEBUG){ + A = matrix(3, rows=n, cols=1); + i = 1 + while (i <= n){ + if (1 == 1) + test = testUIP(A) + A[i,1] = j*7+i; + i = i + 1 + } + + print("Sum from testUIP = " + test); +} + + +testUIP = function (matrix[double] in_m_data) + return (double sumInData) { + if(min(in_m_data) > 20) + X = testUIP(in_m_data); + sumInData = sum(in_m_data); +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/test/scripts/functions/updateinplace/updateinplace13.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/updateinplace/updateinplace13.dml b/src/test/scripts/functions/updateinplace/updateinplace13.dml new file mode 100644 index 0000000..137ee51 --- /dev/null +++ b/src/test/scripts/functions/updateinplace/updateinplace13.dml @@ -0,0 +1,45 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +n = 2; +m = 2; + +parfor (j in 1:m, log=DEBUG){ + A = matrix(3, rows=n, cols=1); + B = matrix(3, rows=n, cols=1); + i = 1 + while (i <= n){ + test = testUIP(B) + if (1 == 1) + B = A + A[i,1] = j*7+i; + i = i + 1 + } + + print("Sum from testUIP = " + test); +} + + +testUIP = function (matrix[double] in_m_data) + return (double sumInData) { + if(1 == 1) {} + sumInData = sum(in_m_data); +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f8f423c3/src/test/scripts/functions/updateinplace/updateinplace14.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/updateinplace/updateinplace14.dml b/src/test/scripts/functions/updateinplace/updateinplace14.dml new file mode 100644 index 0000000..58395e1 --- /dev/null +++ b/src/test/scripts/functions/updateinplace/updateinplace14.dml @@ -0,0 +1,45 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +n = 2; +m = 2; + +parfor (j in 1:m, log=DEBUG){ + A = matrix(3, rows=n, cols=1); + B = matrix(2, rows=n, cols=1); + i = 1 + while (i <= n){ + test = testUIP(B) + A[i,1] = j*7+i; + B[i,1] = j*3+i; + i = i + 1 + } + + print("Sum from testUIP = " + test); + print("Sum(B) = " + sum(B)); +} + + +testUIP = function (matrix[double] in_m_data) + return (double sumInData) { + if(1 == 1) {} + sumInData = sum(in_m_data); +}
