[SYSTEMML-753] New loop rewrite/runtime update-in-place left indexing Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/78e161c0 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/78e161c0 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/78e161c0
Branch: refs/heads/master Commit: 78e161c03f740b16ba13585dd4dc65ba4b4243df Parents: 0dfa71d Author: Matthias Boehm <[email protected]> Authored: Sat Jun 4 00:31:34 2016 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jun 4 17:50:25 2016 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/api/DMLScript.java | 6 - .../java/org/apache/sysml/hops/DataGenOp.java | 2 +- src/main/java/org/apache/sysml/hops/DataOp.java | 11 +- src/main/java/org/apache/sysml/hops/Hop.java | 17 ++- .../org/apache/sysml/hops/OptimizerUtils.java | 6 + .../sysml/hops/rewrite/ProgramRewriter.java | 2 + .../RewriteMarkLoopVariablesUpdateInPlace.java | 150 +++++++++++++++++++ .../RewriteSplitDagDataDependentOperators.java | 13 +- .../rewrite/RewriteSplitDagUnknownCSVRead.java | 7 +- src/main/java/org/apache/sysml/lops/Data.java | 4 +- .../org/apache/sysml/lops/OutputParameters.java | 17 ++- .../java/org/apache/sysml/lops/compile/Dag.java | 14 +- .../org/apache/sysml/parser/DMLTranslator.java | 10 +- .../org/apache/sysml/parser/StatementBlock.java | 20 ++- .../runtime/controlprogram/ForProgramBlock.java | 77 +++------- .../controlprogram/FunctionProgramBlock.java | 7 - .../runtime/controlprogram/IfProgramBlock.java | 54 +++---- .../sysml/runtime/controlprogram/Program.java | 8 - .../runtime/controlprogram/ProgramBlock.java | 70 +++++++-- .../controlprogram/WhileProgramBlock.java | 100 ++++++------- .../controlprogram/caching/MatrixObject.java | 23 ++- .../context/ExecutionContext.java | 9 +- .../controlprogram/parfor/ProgramConverter.java | 12 +- .../parfor/opt/OptimizerRuleBased.java | 7 +- .../cp/MatrixIndexingCPInstruction.java | 11 +- .../instructions/cp/VariableCPInstruction.java | 31 ++-- .../spark/AppendGSPInstruction.java | 15 +- .../spark/MatrixIndexingSPInstruction.java | 3 +- .../spark/utils/FrameRDDConverterUtils.java | 3 +- .../sysml/runtime/matrix/data/LibMatrixAgg.java | 3 +- .../runtime/matrix/data/LibMatrixReorg.java | 5 +- .../sysml/runtime/matrix/data/MatrixBlock.java | 33 +++- .../matrix/data/OperationsOnMatrixValues.java | 3 +- .../java/org/apache/sysml/utils/Explain.java | 33 ++-- .../functions/frame/FrameIndexingTest.java | 3 +- .../updateinplace/UpdateInPlaceTest.java | 9 +- 36 files changed, 495 insertions(+), 303 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/api/DMLScript.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/DMLScript.java b/src/main/java/org/apache/sysml/api/DMLScript.java index d952d78..ec71af3 100644 --- a/src/main/java/org/apache/sysml/api/DMLScript.java +++ b/src/main/java/org/apache/sysml/api/DMLScript.java @@ -624,12 +624,6 @@ public class DMLScript //Step 7: generate runtime program Program rtprog = prog.getRuntimeProgram(dmlconf); - if (LOG.isDebugEnabled()) { - LOG.info("********************** Instructions *******************"); - rtprog.printMe(); - LOG.info("*******************************************************"); - } - //Step 8: [optional global data flow optimization] if(OptimizerUtils.isOptLevel(OptimizationLevel.O4_GLOBAL_TIME_MEMORY) ) { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/hops/DataGenOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/DataGenOp.java b/src/main/java/org/apache/sysml/hops/DataGenOp.java index 4c81f77..fc1fd2d 100644 --- a/src/main/java/org/apache/sysml/hops/DataGenOp.java +++ b/src/main/java/org/apache/sysml/hops/DataGenOp.java @@ -167,7 +167,7 @@ public class DataGenOp extends Hop implements MultiThreadedHop (getColsInBlock()>0)?getColsInBlock():ConfigurationManager.getBlocksize(), //actual rand nnz might differ (in cp/mr they are corrected after execution) (_op==DataGenMethod.RAND && et==ExecType.SPARK && getNnz()!=0) ? -1 : getNnz(), - getUpdateInPlace()); + getUpdateType()); setLineNumbers(rnd); setLops(rnd); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/hops/DataOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/DataOp.java b/src/main/java/org/apache/sysml/hops/DataOp.java index d83659c..d8e5519 100644 --- a/src/main/java/org/apache/sysml/hops/DataOp.java +++ b/src/main/java/org/apache/sysml/hops/DataOp.java @@ -30,6 +30,7 @@ import org.apache.sysml.lops.LopProperties.ExecType; import org.apache.sysml.lops.LopsException; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.util.LocalFileUtils; @@ -82,9 +83,9 @@ public class DataOp extends Hop } public DataOp(String l, DataType dt, ValueType vt, DataOpTypes dop, - String fname, long dim1, long dim2, long nnz, boolean updateInPlace, long rowsPerBlock, long colsPerBlock) { + String fname, long dim1, long dim2, long nnz, UpdateType update, long rowsPerBlock, long colsPerBlock) { this(l, dt, vt, dop, fname, dim1, dim2, nnz, rowsPerBlock, colsPerBlock); - setUpdateInPlace(updateInPlace); + setUpdateType(update); } /** @@ -186,11 +187,11 @@ public class DataOp extends Hop _dataop = type; } - public void setOutputParams(long dim1, long dim2, long nnz, boolean updateInPlace, long rowsPerBlock, long colsPerBlock) { + public void setOutputParams(long dim1, long dim2, long nnz, UpdateType update, long rowsPerBlock, long colsPerBlock) { setDim1(dim1); setDim2(dim2); setNnz(nnz); - setUpdateInPlace(updateInPlace); + setUpdateType(update); setRowsInBlock(rowsPerBlock); setColsInBlock(colsPerBlock); } @@ -238,7 +239,7 @@ public class DataOp extends Hop case PERSISTENTREAD: l = new Data(HopsData2Lops.get(_dataop), null, inputLops, getName(), null, getDataType(), getValueType(), false, getInputFormatType()); - l.getOutputParameters().setDimensions(getDim1(), getDim2(), _inRowsInBlock, _inColsInBlock, getNnz(), getUpdateInPlace()); + l.getOutputParameters().setDimensions(getDim1(), getDim2(), _inRowsInBlock, _inColsInBlock, getNnz(), getUpdateType()); break; case PERSISTENTWRITE: http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index e0a06e6..8c4999e 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -38,6 +38,7 @@ import org.apache.sysml.lops.LopProperties.ExecType; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.controlprogram.LocalVariableMap; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter; import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; @@ -82,7 +83,7 @@ public abstract class Hop protected long _rows_in_block = -1; protected long _cols_in_block = -1; protected long _nnz = -1; - protected boolean _updateInPlace = false; + protected UpdateType _updateType = UpdateType.COPY; protected ArrayList<Hop> _parent = new ArrayList<Hop>(); protected ArrayList<Hop> _input = new ArrayList<Hop>(); @@ -841,12 +842,12 @@ public abstract class Hop return _nnz; } - public void setUpdateInPlace(boolean updateInPlace){ - _updateInPlace = updateInPlace; + public void setUpdateType(UpdateType update){ + _updateType = update; } - public boolean getUpdateInPlace(){ - return _updateInPlace; + public UpdateType getUpdateType(){ + return _updateType; } public abstract Lop constructLops() @@ -968,7 +969,7 @@ public abstract class Hop s.append(h.getHopID() + "; "); } - s.append("\n dims [" + _dim1 + "," + _dim2 + "] blk [" + _rows_in_block + "," + _cols_in_block + "] nnz: " + _nnz + " UpdateInPlace: " + _updateInPlace); + s.append("\n dims [" + _dim1 + "," + _dim2 + "] blk [" + _rows_in_block + "," + _cols_in_block + "] nnz: " + _nnz + " UpdateInPlace: " + _updateType); s.append(" MemEstimate = Out " + (_outputMemEstimate/1024/1024) + " MB, In&Out " + (_memEstimate/1024/1024) + " MB" ); LOG.debug(s.toString()); } @@ -994,7 +995,7 @@ public abstract class Hop throws HopsException { lop.getOutputParameters().setDimensions( - getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz(), getUpdateInPlace()); + getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz(), getUpdateType()); } public Lop getLops() { @@ -1863,7 +1864,7 @@ public abstract class Hop _rows_in_block = that._rows_in_block; _cols_in_block = that._cols_in_block; _nnz = that._nnz; - _updateInPlace = that._updateInPlace; + _updateType = that._updateType; //no copy of lops (regenerated) _parent = new ArrayList<Hop>(); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/hops/OptimizerUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java index 48a4a9b..833f5ea 100644 --- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java +++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java @@ -170,6 +170,11 @@ public class OptimizerUtils */ public static boolean ALLOW_SPLIT_HOP_DAGS = true; + /** + * Enables a specific rewrite that enables update in place for loop variables that are + * only read/updated via cp leftindexing. + */ + public static boolean ALLOW_LOOP_UPDATE_IN_PLACE = true; /** @@ -302,6 +307,7 @@ public class OptimizerUtils ALLOW_INTER_PROCEDURAL_ANALYSIS = false; ALLOW_BRANCH_REMOVAL = false; ALLOW_SUM_PRODUCT_REWRITES = false; + ALLOW_LOOP_UPDATE_IN_PLACE = false; break; // opt level 2: memory-based (all advanced rewrites) case 2: http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java index 9837a62..49aa0db 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java @@ -115,6 +115,8 @@ public class ProgramRewriter if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION ) _sbRuleSet.add( new RewriteForLoopVectorization() ); //dependency: reblock (reblockop) _sbRuleSet.add( new RewriteInjectSparkLoopCheckpointing(true) ); //dependency: reblock (blocksizes) + if( OptimizerUtils.ALLOW_LOOP_UPDATE_IN_PLACE ) + _sbRuleSet.add( new RewriteMarkLoopVariablesUpdateInPlace() ); } // DYNAMIC REWRITES (which do require size information) http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java new file mode 100644 index 0000000..ce8c255 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.rewrite; + +import java.util.ArrayList; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.hops.DataOp; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.HopsException; +import org.apache.sysml.hops.LeftIndexingOp; +import org.apache.sysml.parser.ForStatement; +import org.apache.sysml.parser.ForStatementBlock; +import org.apache.sysml.parser.IfStatement; +import org.apache.sysml.parser.IfStatementBlock; +import org.apache.sysml.parser.StatementBlock; +import org.apache.sysml.parser.VariableSet; +import org.apache.sysml.parser.WhileStatement; +import org.apache.sysml.parser.WhileStatementBlock; +import org.apache.sysml.parser.Expression.DataType; + +/** + * Rule: Mark loop variables that are only read/updated through cp left indexing + * for update in-place. + * + */ +public class RewriteMarkLoopVariablesUpdateInPlace extends StatementBlockRewriteRule +{ + @Override + public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status) + throws HopsException + { + ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>(); + + if( DMLScript.rtplatform == RUNTIME_PLATFORM.HADOOP + || DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK ) + { + ret.add(sb); // nothing to do here + return ret; //return original statement block + } + + if( sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock ) //incl parfor + { + ArrayList<String> candidates = new ArrayList<String>(); + VariableSet updated = sb.variablesUpdated(); + + for( String varname : updated.getVariableNames() ) { + if( updated.getVariable(varname).getDataType()==DataType.MATRIX) { + if( sb instanceof WhileStatementBlock ) { + WhileStatement wstmt = (WhileStatement) sb.getStatement(0); + if( rIsApplicableForUpdateInPlace(wstmt.getBody(), varname) ) + candidates.add(varname); + } + else if( sb instanceof ForStatementBlock ) { + ForStatement wstmt = (ForStatement) sb.getStatement(0); + if( rIsApplicableForUpdateInPlace(wstmt.getBody(), varname) ) + candidates.add(varname); + } + } + } + + sb.setUpdateInPlaceVars(candidates); + } + + //return modified statement block + ret.add(sb); + return ret; + } + + /** + * + * @param sbs + * @param varname + * @return + * @throws HopsException + */ + private boolean rIsApplicableForUpdateInPlace( ArrayList<StatementBlock> sbs, String varname ) + throws HopsException + { + //NOTE: no function statement blocks / predicates considered because function call would + //render variable as not applicable and predicates don't allow assignments; further reuse + //of loop candidates as child blocks already processed + + //recursive invocation + boolean ret = true; + for( StatementBlock sb : sbs ) { + if (sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock ) + { + ret &= sb.getUpdateInPlaceVars() + .contains(varname); + } + else if (sb instanceof IfStatementBlock) + { + IfStatementBlock isb = (IfStatementBlock) sb; + IfStatement istmt = (IfStatement)isb.getStatement(0); + ret &= rIsApplicableForUpdateInPlace(istmt.getIfBody(), varname); + if( ret && istmt.getElseBody() != null ) + ret &= rIsApplicableForUpdateInPlace(istmt.getElseBody(), varname); + } + else { + if( sb.get_hops() != null ) + for( Hop hop : sb.get_hops() ) + ret &= isApplicableForUpdateInPlace(hop, varname); + } + + //early abort if not applicable + if( !ret ) break; + } + + return ret; + } + + /** + * + * @param hop + * @param varname + * @return + */ + private boolean isApplicableForUpdateInPlace( Hop hop, String varname ) + { + if( !hop.getName().equals(varname) ) + return true; + + //valid if read/updated by leftindexing + //CP exec type not evaluated here as no lops generated yet + return hop instanceof DataOp + && hop.getInput().get(0) instanceof LeftIndexingOp + && hop.getInput().get(0).getInput().get(0) instanceof DataOp + && hop.getInput().get(0).getInput().get(0).getName().equals(varname) + && hop.getInput().get(0).getInput().get(0).getParent().size()==1; + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java index 6f9ab63..85e7431 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java @@ -42,6 +42,7 @@ import org.apache.sysml.hops.recompile.Recompiler; import org.apache.sysml.parser.DataIdentifier; import org.apache.sysml.parser.StatementBlock; import org.apache.sysml.parser.VariableSet; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysml.runtime.matrix.data.Pair; @@ -111,7 +112,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite long rlen = c.getDim1(); long clen = c.getDim2(); long nnz = c.getNnz(); - boolean updateInPlace = c.getUpdateInPlace(); + UpdateType update = c.getUpdateType(); long brlen = c.getRowsInBlock(); long bclen = c.getColsInBlock(); @@ -122,7 +123,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite //create new transient read DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(), - DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, updateInPlace, brlen, bclen); + DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, update, brlen, bclen); tread.setVisited(VisitStatus.DONE); HopRewriteUtils.copyLineNumbers(c, tread); @@ -152,7 +153,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite //create new transient read DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(), - DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, updateInPlace, brlen, bclen); + DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, update, brlen, bclen); tread.setVisited(VisitStatus.DONE); HopRewriteUtils.copyLineNumbers(c, tread); @@ -173,7 +174,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite DataOp twrite = new DataOp(varname, c.getDataType(), c.getValueType(), c, DataOpTypes.TRANSIENTWRITE, null); twrite.setVisited(VisitStatus.DONE); - twrite.setOutputParams(rlen, clen, nnz, updateInPlace, brlen, bclen); + twrite.setOutputParams(rlen, clen, nnz, update, brlen, bclen); HopRewriteUtils.copyLineNumbers(c, twrite); sb1hops.add(twrite); } @@ -386,13 +387,13 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite Hop c = p.getValue(); DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(), DataOpTypes.TRANSIENTREAD, - null, c.getDim1(), c.getDim2(), c.getNnz(), c.getUpdateInPlace(), c.getRowsInBlock(), c.getColsInBlock()); + null, c.getDim1(), c.getDim2(), c.getNnz(), c.getUpdateType(), c.getRowsInBlock(), c.getColsInBlock()); tread.setVisited(VisitStatus.DONE); HopRewriteUtils.copyLineNumbers(c, tread); DataOp twrite = new DataOp(varname, c.getDataType(), c.getValueType(), c, DataOpTypes.TRANSIENTWRITE, null); twrite.setVisited(VisitStatus.DONE); - twrite.setOutputParams(c.getDim1(), c.getDim2(), c.getNnz(), c.getUpdateInPlace(), c.getRowsInBlock(), c.getColsInBlock()); + twrite.setOutputParams(c.getDim1(), c.getDim2(), c.getNnz(), c.getUpdateType(), c.getRowsInBlock(), c.getColsInBlock()); HopRewriteUtils.copyLineNumbers(c, twrite); //create additional cut by rewriting both hop dags http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java index 78d0de7..4ffee5a 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java @@ -30,6 +30,7 @@ import org.apache.sysml.hops.HopsException; import org.apache.sysml.parser.DataIdentifier; import org.apache.sysml.parser.StatementBlock; import org.apache.sysml.parser.VariableSet; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; /** * Rule: Split Hop DAG after CSV reads with unknown size. This is @@ -72,13 +73,13 @@ public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule long rlen = reblock.getDim1(); long clen = reblock.getDim2(); long nnz = reblock.getNnz(); - boolean updateInPlace = c.getUpdateInPlace(); + UpdateType update = c.getUpdateType(); long brlen = reblock.getRowsInBlock(); long bclen = reblock.getColsInBlock(); //create new transient read DataOp tread = new DataOp(reblock.getName(), reblock.getDataType(), reblock.getValueType(), - DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, updateInPlace, brlen, bclen); + DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, update, brlen, bclen); HopRewriteUtils.copyLineNumbers(reblock, tread); //replace reblock with transient read @@ -94,7 +95,7 @@ public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule //add reblock sub dag to first statement block DataOp twrite = new DataOp(reblock.getName(), reblock.getDataType(), reblock.getValueType(), reblock, DataOpTypes.TRANSIENTWRITE, null); - twrite.setOutputParams(rlen, clen, nnz, updateInPlace, brlen, bclen); + twrite.setOutputParams(rlen, clen, nnz, update, brlen, bclen); HopRewriteUtils.copyLineNumbers(reblock, twrite); sb1hops.add(twrite); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/lops/Data.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/Data.java b/src/main/java/org/apache/sysml/lops/Data.java index 3b936d3..f21dd03 100644 --- a/src/main/java/org/apache/sysml/lops/Data.java +++ b/src/main/java/org/apache/sysml/lops/Data.java @@ -221,7 +221,7 @@ public class Data extends Lop return getID() + ":" + "File_Name: " + this.getOutputParameters().getFile_name() + " " + "Label: " + this.getOutputParameters().getLabel() + " " + "Operation: = " + operation + " " + "Format: " + this.outParams.getFormat() + " Datatype: " + getDataType() + " Valuetype: " + getValueType() + " num_rows = " + this.getOutputParameters().getNumRows() + " num_cols = " + - this.getOutputParameters().getNumCols() + " UpdateInPlace: " + this.getOutputParameters().getUpdateInPlace(); + this.getOutputParameters().getNumCols() + " UpdateInPlace: " + this.getOutputParameters().getUpdateType(); } /** @@ -558,7 +558,7 @@ public class Data extends Lop sb.append( OPERAND_DELIMITOR ); sb.append( oparams.getNnz() ); sb.append( OPERAND_DELIMITOR ); - sb.append( oparams.getUpdateInPlace() ); + sb.append( oparams.getUpdateType().toString().toLowerCase() ); // Format-specific properties if ( oparams.getFormat() == Format.CSV ) { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/lops/OutputParameters.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/OutputParameters.java b/src/main/java/org/apache/sysml/lops/OutputParameters.java index 974a603..86866c1 100644 --- a/src/main/java/org/apache/sysml/lops/OutputParameters.java +++ b/src/main/java/org/apache/sysml/lops/OutputParameters.java @@ -20,6 +20,7 @@ package org.apache.sysml.lops; import org.apache.sysml.hops.HopsException; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; /** * class to maintain output parameters for a lop. @@ -37,7 +38,7 @@ public class OutputParameters private long _num_rows = -1; private long _num_cols = -1; private long _nnz = -1; - private boolean _updateInPlace = false; + private UpdateType _updateType = UpdateType.COPY; private long _num_rows_in_block = -1; private long _num_cols_in_block = -1; private String _file_name = null; @@ -82,8 +83,8 @@ public class OutputParameters } } - public void setDimensions(long rows, long cols, long rows_per_block, long cols_per_block, long nnz, boolean updateInPlace) throws HopsException { - _updateInPlace = updateInPlace; + public void setDimensions(long rows, long cols, long rows_per_block, long cols_per_block, long nnz, UpdateType update) throws HopsException { + _updateType = update; setDimensions(rows, cols, rows_per_block, cols_per_block, nnz); } @@ -133,13 +134,13 @@ public class OutputParameters _nnz = nnz; } - public boolean getUpdateInPlace() { - return _updateInPlace; + public UpdateType getUpdateType() { + return _updateType; } - public void setUpdateInPlace(boolean updateInPlace) + public void setUpdateType(UpdateType update) { - _updateInPlace = updateInPlace; + _updateType = update; } public long getRowsInBlock() { @@ -164,7 +165,7 @@ public class OutputParameters sb.append("rows=" + getNumRows() + Lop.VALUETYPE_PREFIX); sb.append("cols=" + getNumCols() + Lop.VALUETYPE_PREFIX); sb.append("nnz=" + getNnz() + Lop.VALUETYPE_PREFIX); - sb.append("updateInPlace=" + getUpdateInPlace() + Lop.VALUETYPE_PREFIX); + sb.append("updateInPlace=" + getUpdateType().toString() + Lop.VALUETYPE_PREFIX); sb.append("rowsInBlock=" + getRowsInBlock() + Lop.VALUETYPE_PREFIX); sb.append("colsInBlock=" + getColsInBlock() + Lop.VALUETYPE_PREFIX); sb.append("isBlockedRepresentation=" + isBlocked() + Lop.VALUETYPE_PREFIX); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/lops/compile/Dag.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/compile/Dag.java b/src/main/java/org/apache/sysml/lops/compile/Dag.java index 2d24dad..7831751 100644 --- a/src/main/java/org/apache/sysml/lops/compile/Dag.java +++ b/src/main/java/org/apache/sysml/lops/compile/Dag.java @@ -2283,7 +2283,7 @@ public class Dag<N extends Lop> // TODO: ideally, this should be done by having a member variable in Lop // which stores the outputInfo. try { - oparams.setDimensions(oparams.getNumRows(), oparams.getNumCols(), -1, -1, oparams.getNnz(), oparams.getUpdateInPlace()); + oparams.setDimensions(oparams.getNumRows(), oparams.getNumCols(), -1, -1, oparams.getNnz(), oparams.getUpdateType()); } catch(HopsException e) { throw new LopsException(node.printErrorLocation() + "error in getOutputInfo in Dag ", e); } @@ -2398,7 +2398,7 @@ public class Dag<N extends Lop> Instruction createvarInst = VariableCPInstruction.prepareCreateVariableInstruction( oparams.getLabel(), oparams.getFile_name(), true, DataType.MATRIX, OutputInfo.outputInfoToString(OutputInfo.CSVOutputInfo), - new MatrixCharacteristics(oparams.getNumRows(), oparams.getNumCols(), -1, -1, oparams.getNnz()), oparams.getUpdateInPlace(), + new MatrixCharacteristics(oparams.getNumRows(), oparams.getNumCols(), -1, -1, oparams.getNnz()), oparams.getUpdateType(), false, delimLop.getStringValue(), true ); @@ -2440,7 +2440,7 @@ public class Dag<N extends Lop> true, node.getDataType(), OutputInfo.outputInfoToString(getOutputInfo(node, false)), new MatrixCharacteristics(oparams.getNumRows(), oparams.getNumCols(), rpb, cpb, oparams.getNnz()), - oparams.getUpdateInPlace() + oparams.getUpdateType() ); createvarInst.setLocation(node); @@ -2472,7 +2472,7 @@ public class Dag<N extends Lop> true, fnOut.getDataType(), OutputInfo.outputInfoToString(getOutputInfo(fnOut, false)), new MatrixCharacteristics(fnOutParams.getNumRows(), fnOutParams.getNumCols(), (int)fnOutParams.getRowsInBlock(), (int)fnOutParams.getColsInBlock(), fnOutParams.getNnz()), - oparams.getUpdateInPlace() + oparams.getUpdateType() ); if (node._beginLine != 0) @@ -2588,7 +2588,7 @@ public class Dag<N extends Lop> true, node.getDataType(), OutputInfo.outputInfoToString(out.getOutInfo()), new MatrixCharacteristics(oparams.getNumRows(), oparams.getNumCols(), rpb, cpb, oparams.getNnz()), - oparams.getUpdateInPlace() + oparams.getUpdateType() ); createvarInst.setLocation(node); @@ -2698,7 +2698,7 @@ public class Dag<N extends Lop> false, node.getDataType(), OutputInfo.outputInfoToString(getOutputInfo(node, false)), new MatrixCharacteristics(oparams.getNumRows(), oparams.getNumCols(), rpb, cpb, oparams.getNnz()), - oparams.getUpdateInPlace() + oparams.getUpdateType() ); //NOTE: no instruction patching because final write from cp instruction @@ -2726,7 +2726,7 @@ public class Dag<N extends Lop> false, node.getDataType(), OutputInfo.outputInfoToString(getOutputInfo(node, false)), new MatrixCharacteristics(oparams.getNumRows(), oparams.getNumCols(), rpb, cpb, oparams.getNnz()), - oparams.getUpdateInPlace() + oparams.getUpdateType() ); // remove the variable CPInstruction currInstr = CPInstructionParser.parseSingleInstruction( http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index 61ba4c8..a3b8b52 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -997,7 +997,7 @@ public class DMLTranslator ae.setInputFormatType(Expression.convertFormatType(formatName)); if (ae.getDataType() == DataType.SCALAR ) { - ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateInPlace(), -1, -1); + ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), -1, -1); } else { switch(ae.getInputFormatType()) { @@ -1005,12 +1005,12 @@ public class DMLTranslator case MM: case CSV: // write output in textcell format - ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateInPlace(), -1, -1); + ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), -1, -1); break; case BINARY: // write output in binary block format - ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateInPlace(), ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize()); + ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), ConfigurationManager.getBlocksize(), ConfigurationManager.getBlocksize()); break; default: @@ -1063,7 +1063,7 @@ public class DMLTranslator Integer statementId = liveOutToTemp.get(target.getName()); if ((statementId != null) && (statementId.intValue() == i)) { DataOp transientwrite = new DataOp(target.getName(), target.getDataType(), target.getValueType(), ae, DataOpTypes.TRANSIENTWRITE, null); - transientwrite.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateInPlace(), ae.getRowsInBlock(), ae.getColsInBlock()); + transientwrite.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), ae.getRowsInBlock(), ae.getColsInBlock()); transientwrite.setAllPositions(target.getBeginLine(), target.getBeginColumn(), target.getEndLine(), target.getEndLine()); updatedLiveOut.addVariable(target.getName(), target); output.add(transientwrite); @@ -1094,7 +1094,7 @@ public class DMLTranslator Integer statementId = liveOutToTemp.get(target.getName()); if ((statementId != null) && (statementId.intValue() == i)) { DataOp transientwrite = new DataOp(target.getName(), target.getDataType(), target.getValueType(), ae, DataOpTypes.TRANSIENTWRITE, null); - transientwrite.setOutputParams(origDim1, origDim2, ae.getNnz(), ae.getUpdateInPlace(), ae.getRowsInBlock(), ae.getColsInBlock()); + transientwrite.setOutputParams(origDim1, origDim2, ae.getNnz(), ae.getUpdateType(), ae.getRowsInBlock(), ae.getColsInBlock()); transientwrite.setAllPositions(target.getBeginLine(), target.getBeginColumn(), target.getEndLine(), target.getEndColumn()); updatedLiveOut.addVariable(target.getName(), target); output.add(transientwrite); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/parser/StatementBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/StatementBlock.java b/src/main/java/org/apache/sysml/parser/StatementBlock.java index c67fbc3..bc5ab4e 100644 --- a/src/main/java/org/apache/sysml/parser/StatementBlock.java +++ b/src/main/java/org/apache/sysml/parser/StatementBlock.java @@ -53,9 +53,10 @@ public class StatementBlock extends LiveVariableAnalysis HashMap<String,ConstIdentifier> _constVarsIn; HashMap<String,ConstIdentifier> _constVarsOut; + private ArrayList<String> _updateInPlaceVars = null; private boolean _requiresRecompile = false; - public StatementBlock(){ + public StatementBlock() { _dmlProg = null; _statements = new ArrayList<Statement>(); _read = new VariableSet(); @@ -66,6 +67,7 @@ public class StatementBlock extends LiveVariableAnalysis _initialized = true; _constVarsIn = new HashMap<String,ConstIdentifier>(); _constVarsOut = new HashMap<String,ConstIdentifier>(); + _updateInPlaceVars = new ArrayList<String>(); } public void setDMLProg(DMLProgram dmlProg){ @@ -1048,20 +1050,24 @@ public class StatementBlock extends LiveVariableAnalysis ///////// - // materialized hops recompilation flags + // materialized hops recompilation / updateinplace flags //// - public void updateRecompilationFlag() - throws HopsException - { + public void updateRecompilationFlag() throws HopsException { _requiresRecompile = ConfigurationManager.isDynamicRecompilation() && Recompiler.requiresRecompilation(get_hops()); } - public boolean requiresRecompilation() - { + public boolean requiresRecompilation() { return _requiresRecompile; } + public ArrayList<String> getUpdateInPlaceVars() { + return _updateInPlaceVars; + } + + public void setUpdateInPlaceVars( ArrayList<String> vars ) { + _updateInPlaceVars = vars; + } } // end class http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/runtime/controlprogram/ForProgramBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/ForProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/ForProgramBlock.java index 535f318..ec2f298 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/ForProgramBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/ForProgramBlock.java @@ -28,6 +28,7 @@ import org.apache.sysml.parser.ForStatementBlock; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.DMLScriptException; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.instructions.Instruction; import org.apache.sysml.runtime.instructions.cp.Data; @@ -37,8 +38,7 @@ import org.apache.sysml.runtime.util.UtilFunctions; import org.apache.sysml.yarn.DMLAppMasterUtils; public class ForProgramBlock extends ProgramBlock -{ - +{ protected ArrayList<Instruction> _fromInstructions; protected ArrayList<Instruction> _toInstructions; protected ArrayList<Instruction> _incrementInstructions; @@ -47,34 +47,6 @@ public class ForProgramBlock extends ProgramBlock protected ArrayList<ProgramBlock> _childBlocks; protected String[] _iterablePredicateVars; //from,to,where constants/internal vars not captured via instructions - - public void printMe() - { - LOG.debug("***** current for block predicate inst: *****"); - LOG.debug("FROM:"); - for (Instruction cp : _fromInstructions){ - cp.printMe(); - } - LOG.debug("TO:"); - for (Instruction cp : _toInstructions){ - cp.printMe(); - } - LOG.debug("INCREMENT:"); - for (Instruction cp : _incrementInstructions){ - cp.printMe(); - } - - LOG.debug("***** children block inst: *****"); - for (ProgramBlock pb : this._childBlocks){ - pb.printMe(); - } - - LOG.debug("***** current block inst exit: *****"); - for (Instruction i : this._exitInstructions) { - i.printMe(); - } - - } public ForProgramBlock(Program prog, String[] iterPredVars) throws DMLRuntimeException @@ -86,70 +58,59 @@ public class ForProgramBlock extends ProgramBlock _iterablePredicateVars = iterPredVars; } - public ArrayList<Instruction> getFromInstructions() - { + public ArrayList<Instruction> getFromInstructions() { return _fromInstructions; } - public void setFromInstructions(ArrayList<Instruction> instructions) - { + public void setFromInstructions(ArrayList<Instruction> instructions) { _fromInstructions = instructions; } - public ArrayList<Instruction> getToInstructions() - { + public ArrayList<Instruction> getToInstructions() { return _toInstructions; } - public void setToInstructions(ArrayList<Instruction> instructions) - { + public void setToInstructions(ArrayList<Instruction> instructions) { _toInstructions = instructions; } - public ArrayList<Instruction> getIncrementInstructions() - { + public ArrayList<Instruction> getIncrementInstructions() { return _incrementInstructions; } - public void setIncrementInstructions(ArrayList<Instruction> instructions) - { + public void setIncrementInstructions(ArrayList<Instruction> instructions) { _incrementInstructions = instructions; } - public void addExitInstruction(Instruction inst){ + public void addExitInstruction(Instruction inst) { _exitInstructions.add(inst); } - public ArrayList<Instruction> getExitInstructions(){ + public ArrayList<Instruction> getExitInstructions() { return _exitInstructions; } - public void setExitInstructions(ArrayList<Instruction> inst){ + public void setExitInstructions(ArrayList<Instruction> inst) { _exitInstructions = inst; } - public void addProgramBlock(ProgramBlock childBlock) { _childBlocks.add(childBlock); } - public ArrayList<ProgramBlock> getChildBlocks() - { + public ArrayList<ProgramBlock> getChildBlocks() { return _childBlocks; } - public void setChildBlocks(ArrayList<ProgramBlock> pbs) - { + public void setChildBlocks(ArrayList<ProgramBlock> pbs) { _childBlocks = pbs; } - public String[] getIterablePredicateVars() - { + public String[] getIterablePredicateVars() { return _iterablePredicateVars; } - public void setIterablePredicateVars(String[] iterPredVars) - { + public void setIterablePredicateVars(String[] iterPredVars) { _iterablePredicateVars = iterPredVars; } @@ -173,6 +134,9 @@ public class ForProgramBlock extends ProgramBlock // execute for loop try { + // prepare update in-place variables + UpdateType[] flags = prepareUpdateInPlaceVariables(ec); + // run for loop body for each instance of predicate sequence SequenceIterator seqIter = new SequenceIterator(iterVarName, from, to, incr); for( IntObject iterVar : seqIter ) @@ -186,6 +150,9 @@ public class ForProgramBlock extends ProgramBlock _childBlocks.get(i).execute(ec); } } + + // reset update-in-place variables + resetUpdateInPlaceVariableFlags(ec, flags); } catch (DMLScriptException e) { //propagate stop call @@ -275,7 +242,7 @@ public class ForProgramBlock extends ProgramBlock return ret; } - + public String printBlockErrorLocation(){ return "ERROR: Runtime error in for program block generated from for statement block between lines " + _beginLine + " and " + _endLine + " -- "; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java index cf45313..728aeff 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/FunctionProgramBlock.java @@ -153,13 +153,6 @@ public class FunctionProgramBlock extends ProgramBlock return _recompileOnce; } - public void printMe() { - - for (ProgramBlock pb : this._childBlocks){ - pb.printMe(); - } - } - public String printBlockErrorLocation(){ return "ERROR: Runtime error in function program block generated from function statement block between lines " + _beginLine + " and " + _endLine + " -- "; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/runtime/controlprogram/IfProgramBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/IfProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/IfProgramBlock.java index 27b285c..0d77364 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/IfProgramBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/IfProgramBlock.java @@ -61,48 +61,30 @@ public class IfProgramBlock extends ProgramBlock _exitInstructions = new ArrayList<Instruction>(); } - public ArrayList<ProgramBlock> getChildBlocksIfBody() - { return _childBlocksIfBody; } + public ArrayList<ProgramBlock> getChildBlocksIfBody() { + return _childBlocksIfBody; + } - public void setChildBlocksIfBody(ArrayList<ProgramBlock> blocks) - { _childBlocksIfBody = blocks; } + public void setChildBlocksIfBody(ArrayList<ProgramBlock> blocks) { + _childBlocksIfBody = blocks; + } - public void addProgramBlockIfBody(ProgramBlock pb) - { _childBlocksIfBody.add(pb); } + public void addProgramBlockIfBody(ProgramBlock pb) { + _childBlocksIfBody.add(pb); + } - public ArrayList<ProgramBlock> getChildBlocksElseBody() - { return _childBlocksElseBody; } + public ArrayList<ProgramBlock> getChildBlocksElseBody() { + return _childBlocksElseBody; + } - public void setChildBlocksElseBody(ArrayList<ProgramBlock> blocks) - { _childBlocksElseBody = blocks; } - - public void addProgramBlockElseBody(ProgramBlock pb) - { _childBlocksElseBody.add(pb); } - - public void printMe() { - - LOG.debug("***** if current block predicate inst: *****"); - for (Instruction cp : _predicate){ - cp.printMe(); - } - - LOG.debug("***** children block inst --- if body : *****"); - for (ProgramBlock pb : this._childBlocksIfBody){ - pb.printMe(); - } + public void setChildBlocksElseBody(ArrayList<ProgramBlock> blocks) { + _childBlocksElseBody = blocks; + } - LOG.debug("***** children block inst --- else body : *****"); - for (ProgramBlock pb: this._childBlocksElseBody){ - pb.printMe(); - } - - LOG.debug("***** current block inst exit: *****"); - for (Instruction i : this._exitInstructions) { - i.printMe(); - } + public void addProgramBlockElseBody(ProgramBlock pb) { + _childBlocksElseBody.add(pb); } - - + public void setExitInstructions2(ArrayList<Instruction> exitInstructions){ _exitInstructions = exitInstructions; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/runtime/controlprogram/Program.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/Program.java b/src/main/java/org/apache/sysml/runtime/controlprogram/Program.java index 24ccaef..7a35ec7 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/Program.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/Program.java @@ -154,12 +154,4 @@ public class Program ec.clearDebugProgramCounters(); } - - - public void printMe() { - - for (ProgramBlock pb : this._programBlocks) { - pb.printMe(); - } - } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java index a8d324e..40be5b1 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java @@ -27,12 +27,14 @@ import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.MLContextProxy; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.recompile.Recompiler; import org.apache.sysml.parser.StatementBlock; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.DMLScriptException; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.instructions.Instruction; import org.apache.sysml.runtime.instructions.cp.BooleanObject; @@ -49,8 +51,7 @@ import org.apache.sysml.yarn.DMLAppMasterUtils; public class ProgramBlock -{ - +{ protected static final Log LOG = LogFactory.getLog(ProgramBlock.class.getName()); private static final boolean CHECK_MATRIX_SPARSITY = false; @@ -338,6 +339,64 @@ public class ProgramBlock } } + + /** + * + * @param ec + * @return + * @throws DMLRuntimeException + */ + protected UpdateType[] prepareUpdateInPlaceVariables(ExecutionContext ec) + throws DMLRuntimeException + { + if( _sb == null || _sb.getUpdateInPlaceVars().isEmpty() ) + return null; + + ArrayList<String> varnames = _sb.getUpdateInPlaceVars(); + UpdateType[] flags = new UpdateType[varnames.size()]; + for( int i=0; i<flags.length; i++ ) + if( ec.getVariable(varnames.get(i)) != null ) { + String varname = varnames.get(i); + MatrixObject mo = ec.getMatrixObject(varname); + flags[i] = mo.getUpdateType(); + //create deep copy if required and if it fits in thread-local mem budget + if( flags[i]==UpdateType.COPY && OptimizerUtils.getLocalMemBudget()/2 > + OptimizerUtils.estimateSizeExactSparsity(mo.getMatrixCharacteristics())) { + MatrixObject moNew = new MatrixObject(mo); + MatrixBlock mbVar = mo.acquireRead(); + moNew.acquireModify( !mbVar.isInSparseFormat() ? new MatrixBlock(mbVar) : + new MatrixBlock(mbVar, MatrixBlock.DEFAULT_INPLACE_SPARSEBLOCK, true) ); + mo.release(); + moNew.release(); + moNew.setUpdateType(UpdateType.INPLACE); + ec.setVariable(varname, moNew); + } + } + + return flags; + } + + /** + * + * @param ec + * @param flags + * @throws DMLRuntimeException + */ + protected void resetUpdateInPlaceVariableFlags(ExecutionContext ec, UpdateType[] flags) + throws DMLRuntimeException + { + if( flags == null ) + return; + + //reset update-in-place flag to pre-loop status + ArrayList<String> varnames = _sb.getUpdateInPlaceVars(); + for( int i=0; i<varnames.size(); i++ ) + if( ec.getVariable(varnames.get(i)) != null && flags[i] !=null ) { + MatrixObject mo = ec.getMatrixObject(varnames.get(i)); + mo.setUpdateType(flags[i]); + } + } + /** * * @param inst @@ -348,13 +407,6 @@ public class ProgramBlock return ( inst instanceof VariableCPInstruction && ((VariableCPInstruction)inst).isRemoveVariable() ); } - public void printMe() { - //System.out.println("***** INSTRUCTION BLOCK *****"); - for (Instruction i : this._inst) { - i.printMe(); - } - } - /** * * @param lastInst http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/runtime/controlprogram/WhileProgramBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/WhileProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/WhileProgramBlock.java index c84a7c3..de4684f 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/WhileProgramBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/WhileProgramBlock.java @@ -27,6 +27,7 @@ import org.apache.sysml.parser.WhileStatementBlock; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.DMLScriptException; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.instructions.Instruction; import org.apache.sysml.runtime.instructions.Instruction.INSTRUCTION_TYPE; @@ -57,44 +58,27 @@ public class WhileProgramBlock extends ProgramBlock _childBlocks = new ArrayList<ProgramBlock>(); } - public void printMe() { - - LOG.debug("***** while current block predicate inst: *****"); - for (Instruction cp : _predicate){ - cp.printMe(); - } - - for (ProgramBlock pb : this._childBlocks){ - pb.printMe(); - } - - LOG.debug("***** current block inst exit: *****"); - for (Instruction i : this._exitInstructions) { - i.printMe(); - } - } - - - - public void addProgramBlock(ProgramBlock childBlock) { _childBlocks.add(childBlock); } - public void setExitInstructions2(ArrayList<Instruction> exitInstructions) - { _exitInstructions = exitInstructions; } + public void setExitInstructions2(ArrayList<Instruction> exitInstructions) { + _exitInstructions = exitInstructions; + } - public void setExitInstructions1(ArrayList<Instruction> predicate) - { _predicate = predicate; } + public void setExitInstructions1(ArrayList<Instruction> predicate) { + _predicate = predicate; + } - public void addExitInstruction(Instruction inst) - { _exitInstructions.add(inst); } + public void addExitInstruction(Instruction inst) { + _exitInstructions.add(inst); + } - public ArrayList<Instruction> getPredicate() - { return _predicate; } + public ArrayList<Instruction> getPredicate() { + return _predicate; + } - public void setPredicate( ArrayList<Instruction> predicate ) - { + public void setPredicate( ArrayList<Instruction> predicate ) { _predicate = predicate; //update result var if non-empty predicate (otherwise, @@ -103,14 +87,17 @@ public class WhileProgramBlock extends ProgramBlock _predicateResultVar = findPredicateResultVar(); } - public String getPredicateResultVar() - { return _predicateResultVar; } + public String getPredicateResultVar() { + return _predicateResultVar; + } - public void setPredicateResultVar(String resultVar) - { _predicateResultVar = resultVar; } + public void setPredicateResultVar(String resultVar) { + _predicateResultVar = resultVar; + } - public ArrayList<Instruction> getExitInstructions() - { return _exitInstructions; } + public ArrayList<Instruction> getExitInstructions() { + return _exitInstructions; + } private BooleanObject executePredicate(ExecutionContext ec) throws DMLRuntimeException @@ -157,51 +144,49 @@ public class WhileProgramBlock extends ProgramBlock result = new BooleanObject( scalarResult.getBooleanValue() ); //auto casting } } - catch(Exception ex) - { - LOG.trace("\nWhile predicate variables: "+ ec.getVariables().toString()); - throw new DMLRuntimeException(this.printBlockErrorLocation() + "Failed to evaluate the WHILE predicate.", ex); + catch(Exception ex) { + throw new DMLRuntimeException(this.printBlockErrorLocation() + "Failed to evaluate the while predicate.", ex); } //(guaranteed to be non-null, see executePredicate/getScalarInput) return result; } - public void execute(ExecutionContext ec) throws DMLRuntimeException { - - BooleanObject predResult = executePredicate(ec); - + public void execute(ExecutionContext ec) throws DMLRuntimeException + { //execute while loop try { - while(predResult.getBooleanValue()) + // prepare update in-place variables + UpdateType[] flags = prepareUpdateInPlaceVariables(ec); + + //run loop body until predicate becomes false + while( executePredicate(ec).getBooleanValue() ) { //execute all child blocks for (int i=0 ; i < _childBlocks.size() ; i++) { ec.updateDebugState(i); _childBlocks.get(i).execute(ec); } - - predResult = executePredicate(ec); } + + // reset update-in-place variables + resetUpdateInPlaceVariableFlags(ec, flags); } - catch(DMLScriptException e) - { + catch (DMLScriptException e) { + //propagate stop call throw e; } - catch(Exception e) - { - LOG.trace("\nWhile predicate variables: "+ ec.getVariables().toString()); - throw new DMLRuntimeException(this.printBlockErrorLocation() + "Error evaluating while program block.", e); + catch (Exception e) { + throw new DMLRuntimeException(printBlockErrorLocation() + "Error evaluating while program block", e); } //execute exit instructions try { executeInstructions(_exitInstructions, ec); } - catch(Exception e) - { - throw new DMLRuntimeException(this.printBlockErrorLocation() + "Error executing while exit instructions.", e); + catch(Exception e) { + throw new DMLRuntimeException(printBlockErrorLocation() + "Error executing while exit instructions.", e); } } @@ -209,8 +194,7 @@ public class WhileProgramBlock extends ProgramBlock return _childBlocks; } - public void setChildBlocks(ArrayList<ProgramBlock> childs) - { + public void setChildBlocks(ArrayList<ProgramBlock> childs) { _childBlocks = childs; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java index ca6f1c7..90d1f3f 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/MatrixObject.java @@ -63,8 +63,17 @@ public class MatrixObject extends CacheableData<MatrixBlock> { private static final long serialVersionUID = 6374712373206495637L; + public enum UpdateType { + COPY, + INPLACE, + INPLACE_PINNED; + public boolean isInPlace() { + return (this != COPY); + } + } + //additional matrix-specific flags - private boolean _updateInPlaceFlag = false; //flag if in-place update + private UpdateType _updateType = UpdateType.COPY; //information relevant to partitioned matrices. private boolean _partitioned = false; //indicates if obj partitioned @@ -105,7 +114,7 @@ public class MatrixObject extends CacheableData<MatrixBlock> _metaData = new MatrixFormatMetaData(new MatrixCharacteristics(metaOld.getMatrixCharacteristics()), metaOld.getOutputInfo(), metaOld.getInputInfo()); - _updateInPlaceFlag = mo._updateInPlaceFlag; + _updateType = mo._updateType; _partitioned = mo._partitioned; _partitionFormat = mo._partitionFormat; _partitionSize = mo._partitionSize; @@ -117,16 +126,16 @@ public class MatrixObject extends CacheableData<MatrixBlock> * * @param flag */ - public void enableUpdateInPlace(boolean flag) { - _updateInPlaceFlag = flag; + public void setUpdateType(UpdateType flag) { + _updateType = flag; } /** * * @return */ - public boolean isUpdateInPlaceEnabled() { - return _updateInPlaceFlag; + public UpdateType getUpdateType() { + return _updateType; } @Override @@ -531,7 +540,7 @@ public class MatrixObject extends CacheableData<MatrixBlock> @Override protected boolean isBelowCachingThreshold() { return super.isBelowCachingThreshold() - || isUpdateInPlaceEnabled(); //pinned result variable + || getUpdateType() == UpdateType.INPLACE_PINNED; } @Override http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java index b31541e..3536cd8 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java @@ -36,6 +36,7 @@ import org.apache.sysml.runtime.controlprogram.caching.CacheException; import org.apache.sysml.runtime.controlprogram.caching.CacheableData; import org.apache.sysml.runtime.controlprogram.caching.FrameObject; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.instructions.Instruction; import org.apache.sysml.runtime.instructions.cp.BooleanObject; import org.apache.sysml.runtime.instructions.cp.Data; @@ -315,13 +316,13 @@ public class ExecutionContext * @param inplace * @throws DMLRuntimeException */ - public void setMatrixOutput(String varName, MatrixBlock outputData, boolean inplace) + public void setMatrixOutput(String varName, MatrixBlock outputData, UpdateType flag) throws DMLRuntimeException { - if( inplace ) //modify metadata to prevent output serialization - { + if( flag.isInPlace() ) { + //modify metadata to carry update status MatrixObject sores = (MatrixObject) this.getVariable (varName); - sores.enableUpdateInPlace( true ); + sores.setUpdateType( flag ); } //default case http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/ProgramConverter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/ProgramConverter.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/ProgramConverter.java index 4b78589..adf0144 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/ProgramConverter.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/ProgramConverter.java @@ -58,6 +58,7 @@ import org.apache.sysml.runtime.controlprogram.WhileProgramBlock; import org.apache.sysml.runtime.controlprogram.ParForProgramBlock.PDataPartitionFormat; import org.apache.sysml.runtime.controlprogram.ParForProgramBlock.PExecMode; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; @@ -170,7 +171,7 @@ public class ProgramConverter //(each worker requires its own copy of the empty matrix object) for( String var : cpec.getVariables().keySet() ) { Data dat = cpec.getVariables().get(var); - if( dat instanceof MatrixObject && ((MatrixObject)dat).isUpdateInPlaceEnabled() ) { + if( dat instanceof MatrixObject && ((MatrixObject)dat).getUpdateType().isInPlace() ) { MatrixObject mo = (MatrixObject)dat; MatrixObject moNew = new MatrixObject(mo); if( mo.getNnz() != 0 ){ @@ -721,6 +722,7 @@ public class ProgramConverter ret.setLiveOut( sb.liveOut() ); ret.setUpdatedVariables( sb.variablesUpdated() ); ret.setReadVariables( sb.variablesRead() ); + ret.setUpdateInPlaceVars( sb.getUpdateInPlaceVars() ); //shallow copy child statements ret.setStatements( sb.getStatements() ); @@ -774,6 +776,7 @@ public class ProgramConverter ret.setLiveOut( sb.liveOut() ); ret.setUpdatedVariables( sb.variablesUpdated() ); ret.setReadVariables( sb.variablesRead() ); + ret.setUpdateInPlaceVars( sb.getUpdateInPlaceVars() ); //shallow copy child statements ret.setStatements( sb.getStatements() ); @@ -1002,7 +1005,6 @@ public class ProgramConverter MatrixCharacteristics mc = md.getMatrixCharacteristics(); value = mo.getFileName(); PDataPartitionFormat partFormat = (mo.getPartitionFormat()!=null) ? mo.getPartitionFormat() : PDataPartitionFormat.NONE; - boolean inplace = mo.isUpdateInPlaceEnabled(); matrixMetaData = new String[9]; matrixMetaData[0] = String.valueOf( mc.getRows() ); matrixMetaData[1] = String.valueOf( mc.getCols() ); @@ -1012,7 +1014,7 @@ public class ProgramConverter matrixMetaData[5] = InputInfo.inputInfoToString( md.getInputInfo() ); matrixMetaData[6] = OutputInfo.outputInfoToString( md.getOutputInfo() ); matrixMetaData[7] = String.valueOf( partFormat ); - matrixMetaData[8] = String.valueOf( inplace ); + matrixMetaData[8] = String.valueOf( mo.getUpdateType() ); break; default: throw new DMLRuntimeException("Unable to serialize datatype "+datatype); @@ -2161,14 +2163,14 @@ public class ProgramConverter InputInfo iin = InputInfo.stringToInputInfo( st.nextToken() ); OutputInfo oin = OutputInfo.stringToOutputInfo( st.nextToken() ); PDataPartitionFormat partFormat = PDataPartitionFormat.valueOf( st.nextToken() ); - boolean inplace = Boolean.parseBoolean( st.nextToken() ); + UpdateType inplace = UpdateType.valueOf( st.nextToken() ); MatrixCharacteristics mc = new MatrixCharacteristics(rows, cols, brows, bcols, nnz); MatrixFormatMetaData md = new MatrixFormatMetaData( mc, oin, iin ); mo.setMetaData( md ); mo.setVarName( name ); if( partFormat!=PDataPartitionFormat.NONE ) mo.setPartitioned( partFormat, -1 ); //TODO once we support BLOCKWISE_N we should support it here as well - mo.enableUpdateInPlace(inplace); + mo.setUpdateType(inplace); dat = mo; break; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java index 4ac54ce..2821d8f 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java @@ -84,6 +84,7 @@ import org.apache.sysml.runtime.controlprogram.ParForProgramBlock.PResultMerge; import org.apache.sysml.runtime.controlprogram.ParForProgramBlock.PTaskPartitioner; import org.apache.sysml.runtime.controlprogram.WhileProgramBlock; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter; @@ -1919,7 +1920,7 @@ public class OptimizerRuleBased extends Optimizer for( String var : retVars ){ Data dat = vars.get(var); if( dat instanceof MatrixObject ) - ((MatrixObject)dat).enableUpdateInPlace(true); + ((MatrixObject)dat).setUpdateType(UpdateType.INPLACE_PINNED); } inPlaceResultVars.addAll(retVars); @@ -1936,7 +1937,7 @@ public class OptimizerRuleBased extends Optimizer for (UIPCandidateHop uipCandHop: uipCandHopList) if(uipCandHop.isIntermediate() && uipCandHop.isLoopApplicable() && uipCandHop.isUpdateInPlace()) { - uipCandHop.getHop().setUpdateInPlace(true); + uipCandHop.getHop().setUpdateType(UpdateType.INPLACE_PINNED); bAnyUIPApplicable = true; if(LOG.isDebugEnabled()) @@ -1971,7 +1972,7 @@ public class OptimizerRuleBased extends Optimizer if(uipCandHop.getHop() != null) { LOG.trace("Matrix Object: Name: " + uipCandHop.getHop().getName() + "<" + uipCandHop.getHop().getBeginLine() + "," + uipCandHop.getHop().getEndLine()+ ">, InLoop:" - + uipCandHop.isLoopApplicable() + ", UIPApplicable:" + uipCandHop.isUpdateInPlace() + ", HopUIPApplicable:" + uipCandHop.getHop().getUpdateInPlace()); + + uipCandHop.isLoopApplicable() + ", UIPApplicable:" + uipCandHop.isUpdateInPlace() + ", HopUIPApplicable:" + uipCandHop.getHop().getUpdateType()); LOG.trace("Explain Candidate HOP after recompile"); LOG.trace(Explain.explain(uipCandHop.getHop())); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java index efa704c..0d21626 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java @@ -24,6 +24,7 @@ import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.operators.Operator; @@ -77,10 +78,10 @@ public final class MatrixIndexingCPInstruction extends IndexingCPInstruction //left indexing else if ( opcode.equalsIgnoreCase("leftIndex")) { - boolean inplace = mo.isUpdateInPlaceEnabled(); + UpdateType updateType = mo.getUpdateType(); if(DMLScript.STATISTICS) { - if(inplace) + if( updateType.isInPlace() ) Statistics.incrementTotalLixUIP(); Statistics.incrementTotalLix(); } @@ -91,7 +92,7 @@ public final class MatrixIndexingCPInstruction extends IndexingCPInstruction if(input2.getDataType() == DataType.MATRIX) //MATRIX<-MATRIX { MatrixBlock rhsMatBlock = ec.getMatrixInput(input2.getName()); - resultBlock = matBlock.leftIndexingOperations(rhsMatBlock, ixrange, new MatrixBlock(), inplace); + resultBlock = matBlock.leftIndexingOperations(rhsMatBlock, ixrange, new MatrixBlock(), updateType); ec.releaseMatrixInput(input2.getName()); } else //MATRIX<-SCALAR @@ -100,7 +101,7 @@ public final class MatrixIndexingCPInstruction extends IndexingCPInstruction throw new DMLRuntimeException("Invalid index range of scalar leftindexing: "+ixrange.toString()+"." ); ScalarObject scalar = ec.getScalarInput(input2.getName(), ValueType.DOUBLE, input2.isLiteral()); resultBlock = (MatrixBlock) matBlock.leftIndexingOperations(scalar, - (int)ixrange.rowStart, (int)ixrange.colStart, new MatrixBlock(), inplace); + (int)ixrange.rowStart, (int)ixrange.colStart, new MatrixBlock(), updateType); } //unpin lhs input @@ -111,7 +112,7 @@ public final class MatrixIndexingCPInstruction extends IndexingCPInstruction resultBlock.examSparsity(); //unpin output - ec.setMatrixOutput(output.getName(), resultBlock, inplace); + ec.setMatrixOutput(output.getName(), resultBlock, updateType); } else throw new DMLRuntimeException("Invalid opcode (" + opcode +") encountered in MatrixIndexingCPInstruction."); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java index 8b71307..0fc2706 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/VariableCPInstruction.java @@ -30,6 +30,7 @@ import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.caching.FrameObject; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter; import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; @@ -101,7 +102,7 @@ public class VariableCPInstruction extends CPInstruction private CPOperand input3; private CPOperand output; private MetaData metadata; - private boolean _updateInPlace; + private UpdateType _updateType; // Frame related members private String _schema; @@ -195,20 +196,20 @@ public class VariableCPInstruction extends CPInstruction } // This version of the constructor is used only in case of CreateVariable - public VariableCPInstruction (VariableOperationCode op, CPOperand in1, CPOperand in2, CPOperand in3, MetaData md, boolean updateInPlace, int _arity, String schema, String sopcode, String istr) + public VariableCPInstruction (VariableOperationCode op, CPOperand in1, CPOperand in2, CPOperand in3, MetaData md, UpdateType updateType, int _arity, String schema, String sopcode, String istr) { this(op, in1, in2, in3, (CPOperand)null, _arity, sopcode, istr); metadata = md; - _updateInPlace = updateInPlace; + _updateType = updateType; _schema = schema; } // This version of the constructor is used only in case of CreateVariable - public VariableCPInstruction (VariableOperationCode op, CPOperand in1, CPOperand in2, CPOperand in3, MetaData md, boolean updateInPlace, int _arity, FileFormatProperties formatProperties, String schema, String sopcode, String istr) + public VariableCPInstruction (VariableOperationCode op, CPOperand in1, CPOperand in2, CPOperand in3, MetaData md, UpdateType updateType, int _arity, FileFormatProperties formatProperties, String schema, String sopcode, String istr) { this(op, in1, in2, in3, (CPOperand)null, _arity, sopcode, istr); metadata = md; - _updateInPlace = updateInPlace; + _updateType = updateType; _formatProperties = formatProperties; _schema = schema; } @@ -325,9 +326,9 @@ public class VariableCPInstruction extends CPInstruction throw new DMLRuntimeException("Invalid number of operands in createvar instruction: " + str); } MatrixFormatMetaData iimd = new MatrixFormatMetaData(mc, oi, ii); - boolean updateInPlace = false; + UpdateType updateType = UpdateType.COPY; if ( parts.length >= 12 ) - updateInPlace = Boolean.parseBoolean(parts[11]); + updateType = UpdateType.valueOf(parts[11].toUpperCase()); //handle frame schema String schema = (dt==DataType.FRAME && parts.length>=13) ? parts[parts.length-1] : null; @@ -353,10 +354,10 @@ public class VariableCPInstruction extends CPInstruction naStrings = parts[16]; fmtProperties = new CSVFileFormatProperties(hasHeader, delim, fill, fillValue, naStrings) ; } - return new VariableCPInstruction(VariableOperationCode.CreateVariable, in1, in2, in3, iimd, updateInPlace, parts.length, fmtProperties, schema, opcode, str); + return new VariableCPInstruction(VariableOperationCode.CreateVariable, in1, in2, in3, iimd, updateType, parts.length, fmtProperties, schema, opcode, str); } else { - return new VariableCPInstruction(VariableOperationCode.CreateVariable, in1, in2, in3, iimd, updateInPlace, parts.length, schema, opcode, str); + return new VariableCPInstruction(VariableOperationCode.CreateVariable, in1, in2, in3, iimd, updateType, parts.length, schema, opcode, str); } case AssignVariable: in1 = new CPOperand(parts[1]); @@ -457,9 +458,9 @@ public class VariableCPInstruction extends CPInstruction //is potential for hidden side effects between variables. mobj.setMetaData((MetaData)metadata.clone()); mobj.setFileFormatProperties(_formatProperties); - mobj.enableUpdateInPlace(_updateInPlace); + mobj.setUpdateType(_updateType); ec.setVariable(input1.getName(), mobj); - if(DMLScript.STATISTICS && _updateInPlace) + if(DMLScript.STATISTICS && _updateType.isInPlace()) Statistics.incrementTotalUIPVar(); } else if( input1.getDataType() == DataType.FRAME ) { @@ -990,7 +991,7 @@ public class VariableCPInstruction extends CPInstruction return parseInstruction(str); } - public static Instruction prepareCreateVariableInstruction(String varName, String fileName, boolean fNameOverride, DataType dt, String format, MatrixCharacteristics mc, boolean updateInPlace) throws DMLRuntimeException { + public static Instruction prepareCreateVariableInstruction(String varName, String fileName, boolean fNameOverride, DataType dt, String format, MatrixCharacteristics mc, UpdateType update) throws DMLRuntimeException { StringBuilder sb = new StringBuilder(); sb.append(getBasicCreateVarString(varName, fileName, fNameOverride, dt, format)); @@ -1005,14 +1006,14 @@ public class VariableCPInstruction extends CPInstruction sb.append(Lop.OPERAND_DELIMITOR); sb.append(mc.getNonZeros()); sb.append(Lop.OPERAND_DELIMITOR); - sb.append(updateInPlace); + sb.append(update.toString().toLowerCase()); String str = sb.toString(); return parseInstruction(str); } - public static Instruction prepareCreateVariableInstruction(String varName, String fileName, boolean fNameOverride, DataType dt, String format, MatrixCharacteristics mc, boolean updateInPlace, boolean hasHeader, String delim, boolean sparse) throws DMLRuntimeException { + public static Instruction prepareCreateVariableInstruction(String varName, String fileName, boolean fNameOverride, DataType dt, String format, MatrixCharacteristics mc, UpdateType update, boolean hasHeader, String delim, boolean sparse) throws DMLRuntimeException { StringBuilder sb = new StringBuilder(); sb.append(getBasicCreateVarString(varName, fileName, fNameOverride, dt, format)); @@ -1027,7 +1028,7 @@ public class VariableCPInstruction extends CPInstruction sb.append(Lop.OPERAND_DELIMITOR); sb.append(mc.getNonZeros()); sb.append(Lop.OPERAND_DELIMITOR); - sb.append(updateInPlace); + sb.append(update.toString().toLowerCase()); sb.append(Lop.OPERAND_DELIMITOR); sb.append(hasHeader); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendGSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendGSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendGSPInstruction.java index 4badf24..3da1a79 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendGSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendGSPInstruction.java @@ -29,6 +29,7 @@ import org.apache.spark.api.java.function.PairFunction; import scala.Tuple2; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.functionobjects.OffsetColumnIndex; @@ -147,7 +148,7 @@ public class AppendGSPInstruction extends BinarySPInstruction { // This case occurs for last block of LHS matrix MatrixBlock tmp = new MatrixBlock(secondBlk.getNumRows(), secondBlk.getNumColumns(), true); - firstBlk = tmp.leftIndexingOperations(firstBlk, 0, firstBlk.getNumRows()-1, 0, firstBlk.getNumColumns()-1, new MatrixBlock(), true); + firstBlk = tmp.leftIndexingOperations(firstBlk, 0, firstBlk.getNumRows()-1, 0, firstBlk.getNumColumns()-1, new MatrixBlock(), UpdateType.INPLACE_PINNED); } //merge with sort since blocks might be in any order @@ -199,19 +200,19 @@ public class AppendGSPInstruction extends BinarySPInstruction if(cutAt >= in.getNumColumns()) { // The block is too small to be cut MatrixBlock firstBlk = new MatrixBlock(in.getNumRows(), lblen1, true); - firstBlk = firstBlk.leftIndexingOperations(in, 0, in.getNumRows()-1, lblen1-in.getNumColumns(), lblen1-1, new MatrixBlock(), true); + firstBlk = firstBlk.leftIndexingOperations(in, 0, in.getNumRows()-1, lblen1-in.getNumColumns(), lblen1-1, new MatrixBlock(), UpdateType.INPLACE_PINNED); retVal.add(new Tuple2<MatrixIndexes, MatrixBlock>(firstIndex, firstBlk)); } else { // Since merge requires the dimensions matching, shifting = slicing + left indexing MatrixBlock firstSlicedBlk = in.sliceOperations(0, in.getNumRows()-1, 0, cutAt-1, new MatrixBlock()); MatrixBlock firstBlk = new MatrixBlock(in.getNumRows(), lblen1, true); - firstBlk = firstBlk.leftIndexingOperations(firstSlicedBlk, 0, in.getNumRows()-1, _shiftBy, _blen-1, new MatrixBlock(), true); + firstBlk = firstBlk.leftIndexingOperations(firstSlicedBlk, 0, in.getNumRows()-1, _shiftBy, _blen-1, new MatrixBlock(), UpdateType.INPLACE_PINNED); MatrixBlock secondSlicedBlk = in.sliceOperations(0, in.getNumRows()-1, cutAt, in.getNumColumns()-1, new MatrixBlock()); int llen2 = UtilFunctions.computeBlockSize(_outlen, secondIndex.getColumnIndex(), _blen); MatrixBlock secondBlk = new MatrixBlock(in.getNumRows(), llen2, true); - secondBlk = secondBlk.leftIndexingOperations(secondSlicedBlk, 0, in.getNumRows()-1, 0, secondSlicedBlk.getNumColumns()-1, new MatrixBlock(), true); + secondBlk = secondBlk.leftIndexingOperations(secondSlicedBlk, 0, in.getNumRows()-1, 0, secondSlicedBlk.getNumColumns()-1, new MatrixBlock(), UpdateType.INPLACE_PINNED); retVal.add(new Tuple2<MatrixIndexes, MatrixBlock>(firstIndex, firstBlk)); retVal.add(new Tuple2<MatrixIndexes, MatrixBlock>(secondIndex, secondBlk)); @@ -226,19 +227,19 @@ public class AppendGSPInstruction extends BinarySPInstruction if(cutAt >= in.getNumRows()) { // The block is too small to be cut MatrixBlock firstBlk = new MatrixBlock(lblen1, in.getNumColumns(), true); - firstBlk = firstBlk.leftIndexingOperations(in, lblen1-in.getNumRows(), lblen1-1, 0, in.getNumColumns()-1, new MatrixBlock(), true); + firstBlk = firstBlk.leftIndexingOperations(in, lblen1-in.getNumRows(), lblen1-1, 0, in.getNumColumns()-1, new MatrixBlock(), UpdateType.INPLACE_PINNED); retVal.add(new Tuple2<MatrixIndexes, MatrixBlock>(firstIndex, firstBlk)); } else { // Since merge requires the dimensions matching, shifting = slicing + left indexing MatrixBlock firstSlicedBlk = in.sliceOperations(0, cutAt-1, 0, in.getNumColumns()-1, new MatrixBlock()); MatrixBlock firstBlk = new MatrixBlock(lblen1, in.getNumColumns(), true); - firstBlk = firstBlk.leftIndexingOperations(firstSlicedBlk, _shiftBy, _blen-1, 0, in.getNumColumns()-1, new MatrixBlock(), true); + firstBlk = firstBlk.leftIndexingOperations(firstSlicedBlk, _shiftBy, _blen-1, 0, in.getNumColumns()-1, new MatrixBlock(), UpdateType.INPLACE_PINNED); MatrixBlock secondSlicedBlk = in.sliceOperations(cutAt, in.getNumRows()-1, 0, in.getNumColumns()-1, new MatrixBlock()); int lblen2 = UtilFunctions.computeBlockSize(_outlen, secondIndex.getRowIndex(), _blen); MatrixBlock secondBlk = new MatrixBlock(lblen2, in.getNumColumns(), true); - secondBlk = secondBlk.leftIndexingOperations(secondSlicedBlk, 0, secondSlicedBlk.getNumRows()-1, 0, in.getNumColumns()-1, new MatrixBlock(), true); + secondBlk = secondBlk.leftIndexingOperations(secondSlicedBlk, 0, secondSlicedBlk.getNumRows()-1, 0, in.getNumColumns()-1, new MatrixBlock(), UpdateType.INPLACE_PINNED); retVal.add(new Tuple2<MatrixIndexes, MatrixBlock>(firstIndex, firstBlk)); retVal.add(new Tuple2<MatrixIndexes, MatrixBlock>(secondIndex, secondBlk)); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/runtime/instructions/spark/MatrixIndexingSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/MatrixIndexingSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/MatrixIndexingSPInstruction.java index 9aa553a..4248999 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/MatrixIndexingSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/MatrixIndexingSPInstruction.java @@ -30,6 +30,7 @@ import scala.Tuple2; import org.apache.sysml.hops.AggBinaryOp.SparkAggType; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.instructions.InstructionUtils; @@ -390,7 +391,7 @@ public class MatrixIndexingSPInstruction extends UnarySPInstruction int lhs_lru = UtilFunctions.computeCellInBlock(lhs_ru, _brlen); int lhs_lcl = UtilFunctions.computeCellInBlock(lhs_cl, _bclen); int lhs_lcu = UtilFunctions.computeCellInBlock(lhs_cu, _bclen); - MatrixBlock ret = arg._2.leftIndexingOperations(slicedRHSMatBlock, lhs_lrl, lhs_lru, lhs_lcl, lhs_lcu, new MatrixBlock(), false); + MatrixBlock ret = arg._2.leftIndexingOperations(slicedRHSMatBlock, lhs_lrl, lhs_lru, lhs_lcl, lhs_lcu, new MatrixBlock(), UpdateType.COPY); return new Tuple2<MatrixIndexes, MatrixBlock>(arg._1, ret); } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/78e161c0/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/FrameRDDConverterUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/FrameRDDConverterUtils.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/FrameRDDConverterUtils.java index 5f319b9..3805164 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/FrameRDDConverterUtils.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/FrameRDDConverterUtils.java @@ -39,6 +39,7 @@ import scala.Tuple2; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysml.runtime.instructions.spark.data.SerLongWritable; import org.apache.sysml.runtime.instructions.spark.data.SerText; import org.apache.sysml.runtime.instructions.spark.functions.ConvertFrameBlockToIJVLines; @@ -769,7 +770,7 @@ public class FrameRDDConverterUtils MatrixIndexes matrixIndexes = new MatrixIndexes(UtilFunctions.computeBlockIndex(begRow+1, _brlenMatrix),UtilFunctions.computeBlockIndex(lColId+1, _bclenMatrix)); MatrixBlock matrixBlocktmp = DataConverter.convertToMatrixBlock(tmpFrame); - MatrixBlock matrixBlock = matrixBlocktmp.leftIndexingOperations(matrixBlocktmp, (int)begRowMat, (int)endRowMat, (int)begColMat, (int)endColMat, new MatrixBlock(), true); + MatrixBlock matrixBlock = matrixBlocktmp.leftIndexingOperations(matrixBlocktmp, (int)begRowMat, (int)endRowMat, (int)begColMat, (int)endColMat, new MatrixBlock(), UpdateType.INPLACE_PINNED); ret.add(new Tuple2<MatrixIndexes, MatrixBlock>(matrixIndexes, matrixBlock)); lColId = endCol+1;
