Repository: systemml Updated Branches: refs/heads/master afbedf3bf -> 1e1210b9e
[SYSTEMML-2277,224,423] New rewrite for hoisting loop-invariant ops This patch introduces a new optional (still disabled) rewrite for code motion, i.e., hoisting loop-invariant operations from while, for, or parfor loops. These loop-invariant operations are defined as reads of variables used read-only in the loop, and operations that have only loop-invariant inputs (modulo some special cases such as rand without seed). Furthermore, this also includes a cleanup of various rewrites that deal with the creation of transient reads and writes. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/1e1210b9 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/1e1210b9 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/1e1210b9 Branch: refs/heads/master Commit: 1e1210b9ebdb68e76ad20ee08f132ef32483f829 Parents: afbedf3 Author: Matthias Boehm <[email protected]> Authored: Mon Apr 23 19:13:12 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Mon Apr 23 19:13:12 2018 -0700 ---------------------------------------------------------------------- .../org/apache/sysml/hops/OptimizerUtils.java | 6 + .../sysml/hops/rewrite/HopRewriteUtils.java | 39 +++- .../sysml/hops/rewrite/ProgramRewriter.java | 2 + .../RewriteHoistLoopInvariantOperations.java | 209 +++++++++++++++++++ .../RewriteInjectSparkLoopCheckpointing.java | 5 +- .../RewriteSplitDagDataDependentOperators.java | 41 +--- .../rewrite/RewriteSplitDagUnknownCSVRead.java | 20 +- .../hops/rewrite/StatementBlockRewriteRule.java | 16 +- .../org/apache/sysml/parser/StatementBlock.java | 4 +- .../RewriteHoistingLoopInvariantOpsTest.java | 127 +++++++++++ .../functions/misc/RewriteCodeMotionFor.R | 37 ++++ .../functions/misc/RewriteCodeMotionFor.dml | 31 +++ .../functions/misc/RewriteCodeMotionWhile.R | 39 ++++ .../functions/misc/RewriteCodeMotionWhile.dml | 33 +++ .../functions/misc/ZPackageSuite.java | 1 + 15 files changed, 541 insertions(+), 69 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/main/java/org/apache/sysml/hops/OptimizerUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java index 2d76759..e9af001 100644 --- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java +++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java @@ -182,6 +182,12 @@ public class OptimizerUtils */ public static boolean ALLOW_LOOP_UPDATE_IN_PLACE = true; + /** + * Enables a specific rewrite for code motion, i.e., hoisting loop invariant code + * out of while, for, and parfor loops. + */ + public static boolean ALLOW_CODE_MOTION = false; + /** * Specifies a multiplier computing the degree of parallelism of parallel http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index c6c42ae..8abe90b 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -499,7 +499,31 @@ public class HopRewriteUtils public static Hop getDataGenOpConstantValue(Hop hop) { return ((DataGenOp) hop).getConstantValue(); - } + } + + public static DataOp createTransientRead(String name, Hop h) { + //note: different constructor necessary for formattype + DataOp tread = new DataOp(name, h.getDataType(), h.getValueType(), + DataOpTypes.TRANSIENTREAD, null, h.getDim1(), h.getDim2(), h.getNnz(), + h.getUpdateType(), h.getRowsInBlock(), h.getColsInBlock()); + tread.setVisited(); + copyLineNumbers(h, tread); + return tread; + } + + public static DataOp createTransientWrite(String name, Hop in) { + return createDataOp(name, in, DataOpTypes.TRANSIENTWRITE); + } + + public static DataOp createDataOp(String name, Hop in, DataOpTypes type) { + DataOp dop = new DataOp(name, in.getDataType(), + in.getValueType(), in, type, null); + dop.setVisited(); + dop.setOutputParams(in.getDim1(), in.getDim2(), in.getNnz(), + in.getUpdateType(), in.getRowsInBlock(), in.getColsInBlock()); + copyLineNumbers(in, dop); + return dop; + } public static ReorgOp createTranspose(Hop input) { return createReorg(input, ReOrgOp.TRANS); @@ -684,14 +708,6 @@ public class HopRewriteUtils return ternOp; } - public static DataOp createDataOp(String name, Hop input, DataOpTypes type) { - DataOp dop = new DataOp(name, input.getDataType(), input.getValueType(), input, type, null); - dop.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock()); - copyLineNumbers(input, dop); - dop.refreshSizeInformation(); - return dop; - } - public static void setOutputParameters( Hop hop, long rlen, long clen, int brlen, int bclen, long nnz ) { hop.setDim1( rlen ); hop.setDim2( clen ); @@ -1295,6 +1311,11 @@ public class HopRewriteUtils || sb instanceof ForStatementBlock); //incl parfor } + public static boolean isLoopStatementBlock(StatementBlock sb) { + return sb instanceof WhileStatementBlock + || sb instanceof ForStatementBlock; //incl parfor + } + public static long getMaxNrowInput(Hop hop) { return getMaxInputDim(hop, true); } http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java index eb7d23c..2963e9d 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java @@ -112,6 +112,8 @@ public class ProgramRewriter if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION ) _sbRuleSet.add( new RewriteForLoopVectorization() ); //dependency: reblock (reblockop) _sbRuleSet.add( new RewriteInjectSparkLoopCheckpointing(true) ); //dependency: reblock (blocksizes) + if( OptimizerUtils.ALLOW_CODE_MOTION ) + _sbRuleSet.add( new RewriteHoistLoopInvariantOperations() ); //dependency: vectorize, but before inplace if( OptimizerUtils.ALLOW_LOOP_UPDATE_IN_PLACE ) _sbRuleSet.add( new RewriteMarkLoopVariablesUpdateInPlace() ); } http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/main/java/org/apache/sysml/hops/rewrite/RewriteHoistLoopInvariantOperations.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteHoistLoopInvariantOperations.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteHoistLoopInvariantOperations.java new file mode 100644 index 0000000..3e77486 --- /dev/null +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteHoistLoopInvariantOperations.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.hops.rewrite; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.sysml.hops.DataOp; +import org.apache.sysml.hops.FunctionOp; +import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.Hop.DataGenMethod; +import org.apache.sysml.hops.Hop.DataOpTypes; +import org.apache.sysml.hops.LiteralOp; +import org.apache.sysml.hops.recompile.Recompiler; +import org.apache.sysml.parser.DataIdentifier; +import org.apache.sysml.parser.ForStatement; +import org.apache.sysml.parser.ForStatementBlock; +import org.apache.sysml.parser.IfStatementBlock; +import org.apache.sysml.parser.StatementBlock; +import org.apache.sysml.parser.VariableSet; +import org.apache.sysml.parser.WhileStatement; +import org.apache.sysml.parser.WhileStatementBlock; + +/** + * Rule: Simplify program structure by hoisting loop-invariant operations + * out of while, for, or parfor loops. + */ +public class RewriteHoistLoopInvariantOperations extends StatementBlockRewriteRule +{ + private final boolean _sideEffectFreeFuns; + + public RewriteHoistLoopInvariantOperations() { + this(false); + } + + public RewriteHoistLoopInvariantOperations(boolean noSideEffects) { + _sideEffectFreeFuns = noSideEffects; + } + + @Override + public boolean createsSplitDag() { + return true; + } + + @Override + public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) { + //early abort if possible + if( sb == null || !HopRewriteUtils.isLoopStatementBlock(sb) ) + return Arrays.asList(sb); //rewrite only applies to loops + + //step 1: determine read-only variables + Set<String> candInputs = sb.variablesRead().getVariableNames().stream() + .filter(v -> !sb.variablesUpdated().containsVariable(v)) + .collect(Collectors.toSet()); + + //step 2: collect loop-invariant operations along with their tmp names + Map<String, Hop> invariantOps = new HashMap<>(); + collectOperations(sb, candInputs, invariantOps); + + //step 3: create new statement block for all temporary intermediates + return invariantOps.isEmpty() ? Arrays.asList(sb) : + Arrays.asList(createStatementBlock(sb, invariantOps), sb); + } + + @Override + public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus state) { + return sbs; + } + + private void collectOperations(StatementBlock sb, Set<String> candInputs, Map<String, Hop> invariantOps) { + + if( sb instanceof WhileStatementBlock ) { + WhileStatement wstmt = (WhileStatement) sb.getStatement(0); + for( StatementBlock csb : wstmt.getBody() ) + collectOperations(csb, candInputs, invariantOps); + } + else if( sb instanceof ForStatementBlock ) { + ForStatement fstmt = (ForStatement) sb.getStatement(0); + for( StatementBlock csb : fstmt.getBody() ) + collectOperations(csb, candInputs, invariantOps); + } + else if( sb instanceof IfStatementBlock ) { + //note: for now we do not pull loop-invariant code out of + //if statement blocks because these operations are conditionally + //executed, so unconditional execution might be counter productive + } + else if( sb.getHops() != null ) { + //step a: bottom-up flagging of loop-invariant operations + //(these are defined operations whose inputs are read only + //variables or other loop-invariant operations) + Hop.resetVisitStatus(sb.getHops()); + HashSet<Long> memo = new HashSet<>(); + for( Hop hop : sb.getHops() ) + rTagLoopInvariantOperations(hop, candInputs, memo); + + //step b: copy hop sub dag and replace it via tread + Hop.resetVisitStatus(sb.getHops()); + for( Hop hop : sb.getHops() ) + rCollectAndReplaceOperations(hop, candInputs, memo, invariantOps); + + if( !memo.isEmpty() ) { + LOG.debug("Applied hoistLoopInvariantOperations (lines " + +sb.getBeginLine()+"-"+sb.getEndLine()+"): "+memo.size()+"."); + } + } + } + + private void rTagLoopInvariantOperations(Hop hop, Set<String> candInputs, Set<Long> memo) { + if( hop.isVisited() ) + return; + + //process inputs first (depth first) + for( Hop c : hop.getInput() ) + rTagLoopInvariantOperations(c, candInputs, memo); + + //flag operation if all inputs are loop invariant + boolean invariant = !HopRewriteUtils.isDataGenOp(hop, DataGenMethod.RAND) + && (!(hop instanceof FunctionOp) || _sideEffectFreeFuns) + && !HopRewriteUtils.isData(hop, DataOpTypes.TRANSIENTREAD) + && !HopRewriteUtils.isData(hop, DataOpTypes.TRANSIENTWRITE); + for( Hop c : hop.getInput() ) { + invariant &= (candInputs.contains(c.getName()) + || memo.contains(c.getHopID()) || c instanceof LiteralOp); + } + if( invariant ) + memo.add(hop.getHopID()); + + hop.setVisited(); + } + + private void rCollectAndReplaceOperations(Hop hop, Set<String> candInputs, Set<Long> memo, Map<String, Hop> invariantOps) { + if( hop.isVisited() ) + return; + + //replace amenable inputs or process recursively + //(without iterators due to parent-child modifications) + for( int i=0; i<hop.getInput().size(); i++ ) { + Hop c = hop.getInput().get(i); + if( memo.contains(c.getHopID()) ) { + String tmpName = createCutVarName(false); + Hop tmp = Recompiler.deepCopyHopsDag(c); + tmp.getParent().clear(); + invariantOps.put(tmpName, tmp); + + //create read and replace all parent references + DataOp tread = HopRewriteUtils.createTransientRead(tmpName, c); + List<Hop> parents = new ArrayList<>(c.getParent()); + for( Hop p : parents ) + HopRewriteUtils.replaceChildReference(p, c, tread); + } + else { + rCollectAndReplaceOperations(c, candInputs, memo, invariantOps); + } + } + + hop.setVisited(); + } + + private StatementBlock createStatementBlock(StatementBlock sb, Map<String, Hop> invariantOps) { + //create empty last-level statement block + StatementBlock ret = new StatementBlock(); + ret.setDMLProg(sb.getDMLProg()); + ret.setParseInfo(sb); + ret.setLiveIn(new VariableSet(sb.liveIn())); + ret.setLiveOut(new VariableSet(sb.liveIn())); + + //append hops with custom + ArrayList<Hop> hops = new ArrayList<>(); + for( Entry<String, Hop> e : invariantOps.entrySet() ) { + Hop h = e.getValue(); + DataOp twrite = HopRewriteUtils.createTransientWrite(e.getKey(), h); + hops.add(twrite); + //update live variable analysis + DataIdentifier diVar = new DataIdentifier(e.getKey()); + diVar.setDimensions(h.getDim1(), h.getDim2()); + diVar.setBlockDimensions(h.getRowsInBlock(), h.getColsInBlock()); + diVar.setDataType(h.getDataType()); + diVar.setValueType(h.getValueType()); + ret.liveOut().addVariable(e.getKey(), diVar); + sb.liveIn().addVariable(e.getKey(), diVar); + } + ret.setHops(hops); + return ret; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java index 6c3ad76..853a02d 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteInjectSparkLoopCheckpointing.java @@ -102,10 +102,9 @@ public class RewriteInjectSparkLoopCheckpointing extends StatementBlockRewriteRu long dim1 = (dat instanceof IndexedIdentifier) ? ((IndexedIdentifier)dat).getOrigDim1() : dat.getDim1(); long dim2 = (dat instanceof IndexedIdentifier) ? ((IndexedIdentifier)dat).getOrigDim2() : dat.getDim2(); DataOp tread = new DataOp(var, DataType.MATRIX, ValueType.DOUBLE, DataOpTypes.TRANSIENTREAD, - dat.getFilename(), dim1, dim2, dat.getNnz(), blocksize, blocksize); + dat.getFilename(), dim1, dim2, dat.getNnz(), blocksize, blocksize); tread.setRequiresCheckpoint(true); - DataOp twrite = new DataOp(var, DataType.MATRIX, ValueType.DOUBLE, tread, DataOpTypes.TRANSIENTWRITE, null); - HopRewriteUtils.setOutputParameters(twrite, dim1, dim2, blocksize, blocksize, dat.getNnz()); + DataOp twrite = HopRewriteUtils.createTransientWrite(var, tread); hops.add(twrite); livein.addVariable(var, read.getVariable(var)); liveout.addVariable(var, read.getVariable(var)); http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java index a758ee0..afbf483 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagDataDependentOperators.java @@ -45,8 +45,6 @@ import org.apache.sysml.hops.recompile.Recompiler; import org.apache.sysml.parser.DataIdentifier; import org.apache.sysml.parser.StatementBlock; import org.apache.sysml.parser.VariableSet; -import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; -import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysml.runtime.matrix.data.Pair; /** @@ -68,10 +66,6 @@ import org.apache.sysml.runtime.matrix.data.Pair; */ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewriteRule { - private static final String SB_CUT_PREFIX = "_sbcvar"; - private static final String FUN_CUT_PREFIX = "_funvar"; - private static IDSequence _seq = new IDSequence(); - @Override public boolean createsSplitDag() { return true; @@ -123,8 +117,6 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite String varname = null; long rlen = c.getDim1(); long clen = c.getDim2(); - long nnz = c.getNnz(); - UpdateType update = c.getUpdateType(); int brlen = c.getRowsInBlock(); int bclen = c.getColsInBlock(); @@ -134,10 +126,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite varname = twrite.getName(); //create new transient read - DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(), - DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, update, brlen, bclen); - tread.setVisited(); - HopRewriteUtils.copyLineNumbers(c, tread); + DataOp tread = HopRewriteUtils.createTransientRead(varname, c); //replace data-dependent operator with transient read ArrayList<Hop> parents = new ArrayList<>(c.getParent()); @@ -160,10 +149,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite varname = createCutVarName(false); //create new transient read - DataOp tread = new DataOp(varname, c.getDataType(), c.getValueType(), - DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, update, brlen, bclen); - tread.setVisited(); - HopRewriteUtils.copyLineNumbers(c, tread); + DataOp tread = HopRewriteUtils.createTransientRead(varname, c); //replace data-dependent operator with transient read ArrayList<Hop> parents = new ArrayList<>(c.getParent()); @@ -175,11 +161,7 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite } //add data-dependent operator sub dag to first statement block - DataOp twrite = new DataOp(varname, c.getDataType(), - c.getValueType(), c, DataOpTypes.TRANSIENTWRITE, null); - twrite.setVisited(); - twrite.setOutputParams(rlen, clen, nnz, update, brlen, bclen); - HopRewriteUtils.copyLineNumbers(c, twrite); + DataOp twrite = HopRewriteUtils.createTransientWrite(varname, c); sb1hops.add(twrite); } @@ -364,16 +346,10 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite if( tread == null ) { String varname = createCutVarName(false); - tread = new DataOp(varname, c.getDataType(), c.getValueType(), DataOpTypes.TRANSIENTREAD, null, - c.getDim1(), c.getDim2(), c.getNnz(), c.getUpdateType(), c.getRowsInBlock(), c.getColsInBlock()); - tread.setVisited(); - HopRewriteUtils.copyLineNumbers(c, tread); + tread = HopRewriteUtils.createTransientRead(varname, c); reuseTRead.put(c.getHopID(), tread); - DataOp twrite = new DataOp(varname, c.getDataType(), c.getValueType(), c, DataOpTypes.TRANSIENTWRITE, null); - twrite.setVisited(); - twrite.setOutputParams(c.getDim1(), c.getDim2(), c.getNnz(), c.getUpdateType(), c.getRowsInBlock(), c.getColsInBlock()); - HopRewriteUtils.copyLineNumbers(c, twrite); + DataOp twrite = HopRewriteUtils.createTransientWrite(varname, c); //update live in and out of new statement block (for piggybacking) DataIdentifier diVar = new DataIdentifier(varname); @@ -484,11 +460,4 @@ public class RewriteSplitDagDataDependentOperators extends StatementBlockRewrite public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) { return sbs; } - - public static String createCutVarName(boolean fun) { - return fun ? - FUN_CUT_PREFIX + _seq.getNextID() : - SB_CUT_PREFIX + _seq.getNextID(); - - } } http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java index 30631d6..a4c31d9 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteSplitDagUnknownCSVRead.java @@ -34,7 +34,6 @@ import org.apache.sysml.hops.LiteralOp; import org.apache.sysml.parser.DataIdentifier; import org.apache.sysml.parser.StatementBlock; import org.apache.sysml.parser.VariableSet; -import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType; /** * Rule: Split Hop DAG after CSV reads with unknown size. This is @@ -81,13 +80,6 @@ public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule ArrayList<Hop> sb1hops = new ArrayList<>(); for( Hop reblock : cand ) { - long rlen = reblock.getDim1(); - long clen = reblock.getDim2(); - long nnz = reblock.getNnz(); - UpdateType update = reblock.getUpdateType(); - int brlen = reblock.getRowsInBlock(); - int bclen = reblock.getColsInBlock(); - //replace reblock inputs to avoid dangling references across dags //(otherwise, for instance, literal ops are shared across dags) for( int i=0; i<reblock.getInput().size(); i++ ) @@ -96,9 +88,7 @@ public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule new LiteralOp((LiteralOp)reblock.getInput().get(i))); //create new transient read - DataOp tread = new DataOp(reblock.getName(), reblock.getDataType(), reblock.getValueType(), - DataOpTypes.TRANSIENTREAD, null, rlen, clen, nnz, update, brlen, bclen); - HopRewriteUtils.copyLineNumbers(reblock, tread); + DataOp tread = HopRewriteUtils.createTransientRead(reblock.getName(), reblock); //replace reblock with transient read ArrayList<Hop> parents = new ArrayList<>(reblock.getParent()); @@ -108,10 +98,7 @@ public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule } //add reblock sub dag to first statement block - DataOp twrite = new DataOp(reblock.getName(), reblock.getDataType(), reblock.getValueType(), - reblock, DataOpTypes.TRANSIENTWRITE, null); - twrite.setOutputParams(rlen, clen, nnz, update, brlen, bclen); - HopRewriteUtils.copyLineNumbers(reblock, twrite); + DataOp twrite = HopRewriteUtils.createTransientWrite(reblock.getName(), reblock); sb1hops.add(twrite); //update live in and out of new statement block (for piggybacking) @@ -128,8 +115,7 @@ public class RewriteSplitDagUnknownCSVRead extends StatementBlockRewriteRule ret.add(sb); //statement block with remaining hops sb.setSplitDag(true); //avoid later merge by other rewrites } - catch(Exception ex) - { + catch(Exception ex) { throw new HopsException("Failed to split hops dag for csv read with unknown size.", ex); } LOG.debug("Applied splitDagUnknownCSVRead."); http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java b/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java index fe8d111..9f4b619 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/StatementBlockRewriteRule.java @@ -25,6 +25,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.parser.StatementBlock; +import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; /** * Base class for all hop rewrites in order to enable generic @@ -34,6 +35,17 @@ import org.apache.sysml.parser.StatementBlock; public abstract class StatementBlockRewriteRule { protected static final Log LOG = LogFactory.getLog(StatementBlockRewriteRule.class.getName()); + + private static final String SB_CUT_PREFIX = "_sbcvar"; + private static final String FUN_CUT_PREFIX = "_funvar"; + private static IDSequence _seq = new IDSequence(); + + public static String createCutVarName(boolean fun) { + return fun ? + FUN_CUT_PREFIX + _seq.getNextID() : + SB_CUT_PREFIX + _seq.getNextID(); + + } /** * Indicates if the rewrite potentially splits dags, which is used @@ -52,7 +64,7 @@ public abstract class StatementBlockRewriteRule * @param sate program rewrite status * @return list of statement blocks */ - public abstract List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus sate); + public abstract List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state); /** * Handle a list of statement blocks. Specific type constraints have to be ensured @@ -63,5 +75,5 @@ public abstract class StatementBlockRewriteRule * @param sate program rewrite status * @return list of statement blocks */ - public abstract List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate); + public abstract List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus state); } http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/main/java/org/apache/sysml/parser/StatementBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/StatementBlock.java b/src/main/java/org/apache/sysml/parser/StatementBlock.java index 2957482..190a481 100644 --- a/src/main/java/org/apache/sysml/parser/StatementBlock.java +++ b/src/main/java/org/apache/sysml/parser/StatementBlock.java @@ -30,7 +30,7 @@ import org.apache.sysml.api.DMLScript; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.recompile.Recompiler; -import org.apache.sysml.hops.rewrite.RewriteSplitDagDataDependentOperators; +import org.apache.sysml.hops.rewrite.StatementBlockRewriteRule; import org.apache.sysml.lops.Lop; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.FormatType; @@ -537,7 +537,7 @@ public class StatementBlock extends LiveVariableAnalysis implements ParseInfo for( ParameterExpression pexpr : fexpr.getParamExprs() ) pexpr.setExpr(rHoistFunctionCallsFromExpressions(pexpr.getExpr(), false, tmp)); if( !root ) { //core hoisting - String varname = RewriteSplitDagDataDependentOperators.createCutVarName(true); + String varname = StatementBlockRewriteRule.createCutVarName(true); DataIdentifier di = new DataIdentifier(varname); di.setDataType(fexpr.getDataType()); di.setValueType(fexpr.getValueType()); http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteHoistingLoopInvariantOpsTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteHoistingLoopInvariantOpsTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteHoistingLoopInvariantOpsTest.java new file mode 100644 index 0000000..2a28ae7 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteHoistingLoopInvariantOpsTest.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import java.util.HashMap; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; +import org.apache.sysml.utils.Statistics; +import org.junit.Assert; +import org.junit.Test; + +public class RewriteHoistingLoopInvariantOpsTest extends AutomatedTestBase +{ + private static final String TEST_NAME1 = "RewriteCodeMotionFor"; + private static final String TEST_NAME2 = "RewriteCodeMotionWhile"; + + private static final String TEST_DIR = "functions/misc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteHoistingLoopInvariantOpsTest.class.getSimpleName() + "/"; + + private static final int rows = 265; + private static final int cols = 132; + private static final int iters = 10; + private static final double sparsity = 0.1; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}) ); + addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"R"}) ); + } + + @Test + public void testCodeMotionForCP() { + testRewriteCodeMotion(TEST_NAME1, false, ExecType.CP); + } + + @Test + public void testCodeMotionForRewriteCP() { + testRewriteCodeMotion(TEST_NAME1, true, ExecType.CP); + } + + @Test + public void testCodeMotionWhileCP() { + testRewriteCodeMotion(TEST_NAME2, false, ExecType.CP); + } + + @Test + public void testCodeMotionWhileRewriteCP() { + testRewriteCodeMotion(TEST_NAME2, true, ExecType.CP); + } + + private void testRewriteCodeMotion(String testname, boolean rewrites, ExecType et) + { + RUNTIME_PLATFORM platformOld = rtplatform; + switch( et ){ + case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; + case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break; + } + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if( rtplatform == RUNTIME_PLATFORM.SPARK || rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK ) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + boolean rewritesOld = OptimizerUtils.ALLOW_CODE_MOTION; + OptimizerUtils.ALLOW_CODE_MOTION = rewrites; + + try + { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[] { "-explain", "hops", "-stats", "-args", + input("X"), String.valueOf(iters), output("R") }; + fullRScriptName = HOME + testname + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(iters), expectedDir()); + + double[][] X = getRandomMatrix(rows, cols, -1, 1, sparsity, 7); + writeInputMatrixWithMTD("X", X, true); + + //execute tests + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R"); + HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + //check applied code motion rewrites (moved sum and - from 10 to 1) + Assert.assertEquals(rewrites?1:10, Statistics.getCPHeavyHitterCount("uak+")); + Assert.assertEquals(rewrites?1:10, Statistics.getCPHeavyHitterCount("-")); + } + finally { + OptimizerUtils.ALLOW_CODE_MOTION = rewritesOld; + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/test/scripts/functions/misc/RewriteCodeMotionFor.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteCodeMotionFor.R b/src/test/scripts/functions/misc/RewriteCodeMotionFor.R new file mode 100644 index 0000000..5d21bd1 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteCodeMotionFor.R @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") + +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) + +R = matrix(0, 1, 1); +for(i in 1:as.integer(args[2])) { + t1 = X - sum(X); + t2 = X + max(X/i); + R = R + min(t1 * t2); +} + +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); + http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/test/scripts/functions/misc/RewriteCodeMotionFor.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteCodeMotionFor.dml b/src/test/scripts/functions/misc/RewriteCodeMotionFor.dml new file mode 100644 index 0000000..e1acd85 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteCodeMotionFor.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1); + +R = matrix(0, 1, 1); +for(i in 1:$2) { + t1 = X - sum(X); + t2 = X + max(X/i); + R = R + min(t1 * t2); +} + +write(R, $3); http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/test/scripts/functions/misc/RewriteCodeMotionWhile.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteCodeMotionWhile.R b/src/test/scripts/functions/misc/RewriteCodeMotionWhile.R new file mode 100644 index 0000000..1cfe05d --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteCodeMotionWhile.R @@ -0,0 +1,39 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") + +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) + +R = matrix(0, 1, 1); +i = 1; +while( i <= as.integer(args[2]) ) { + t1 = X - sum(X); + t2 = X + max(X/i); + R = R + min(t1 * t2); + i = i + 1; +} + +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); + http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/test/scripts/functions/misc/RewriteCodeMotionWhile.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteCodeMotionWhile.dml b/src/test/scripts/functions/misc/RewriteCodeMotionWhile.dml new file mode 100644 index 0000000..2e3f349 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteCodeMotionWhile.dml @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1); + +R = matrix(0, 1, 1); +i = 1; +while( i <= $2 ) { + t1 = X - sum(X); + t2 = X + max(X/i); + R = R + min(t1 * t2); + i += 1; +} + +write(R, $3); http://git-wip-us.apache.org/repos/asf/systemml/blob/1e1210b9/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java index b75b07a..6166e3d 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java @@ -61,6 +61,7 @@ import org.junit.runners.Suite; RewriteFoldRCBindTest.class, RewriteFuseBinaryOpChainTest.class, RewriteFusedRandTest.class, + RewriteHoistingLoopInvariantOpsTest.class, RewriteIndexingVectorizationTest.class, RewriteLoopVectorization.class, RewriteMatrixMultChainOptTest.class,
