Repository: systemml Updated Branches: refs/heads/master 0c4a3611c -> 7d007e7b2
[SYSTEMML-2490] Improved rewrite for update-in-place in for/while loops This patch generalizes the existing update-in-place loop rewrite to allow update-in-place for cases where correct access to the updated matrix is forced by existing data dependencies. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/7d007e7b Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/7d007e7b Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/7d007e7b Branch: refs/heads/master Commit: 7d007e7b216b4b161fa385b460f90f2d1845b4db Parents: 0c4a361 Author: Matthias Boehm <[email protected]> Authored: Fri Sep 28 16:50:21 2018 +0200 Committer: Matthias Boehm <[email protected]> Committed: Fri Sep 28 16:50:21 2018 +0200 ---------------------------------------------------------------------- .../RewriteMarkLoopVariablesUpdateInPlace.java | 58 +++++++++++++++++--- 1 file changed, 49 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/7d007e7b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java index 4032358..ba37e06 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java @@ -27,6 +27,7 @@ import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.DataOp; import org.apache.sysml.hops.Hop; +import org.apache.sysml.hops.Hop.DataOpTypes; import org.apache.sysml.hops.Hop.OpOp1; import org.apache.sysml.hops.LeftIndexingOp; import org.apache.sysml.hops.UnaryOp; @@ -117,8 +118,9 @@ public class RewriteMarkLoopVariablesUpdateInPlace extends StatementBlockRewrite } else { if( sb.getHops() != null ) - for( Hop hop : sb.getHops() ) - ret &= isApplicableForUpdateInPlace(hop, varname); + if( !isApplicableForUpdateInPlace(sb.getHops(), varname) ) + for( Hop hop : sb.getHops() ) + ret &= isApplicableForUpdateInPlace(hop, varname); } //early abort if not applicable @@ -128,18 +130,14 @@ public class RewriteMarkLoopVariablesUpdateInPlace extends StatementBlockRewrite return ret; } - private static boolean isApplicableForUpdateInPlace( Hop hop, String varname ) - { + private static boolean isApplicableForUpdateInPlace(Hop hop, String varname) { + //NOTE: single-root-level validity check if( !hop.getName().equals(varname) ) return true; //valid if read/updated by leftindexing //CP exec type not evaluated here as no lops generated yet - boolean validLix = hop instanceof DataOp - && hop.isMatrix() && hop.getInput().get(0).isMatrix() - && hop.getInput().get(0) instanceof LeftIndexingOp - && hop.getInput().get(0).getInput().get(0) instanceof DataOp - && hop.getInput().get(0).getInput().get(0).getName().equals(varname); + boolean validLix = probeLixRoot(hop, varname); //valid if only safe consumers of left indexing input if( validLix ) { @@ -153,6 +151,48 @@ public class RewriteMarkLoopVariablesUpdateInPlace extends StatementBlockRewrite return validLix; } + private static boolean isApplicableForUpdateInPlace(ArrayList<Hop> hops, String varname) { + //NOTE: additional DAG-level validity check + + // check single LIX update which is direct root-child to varname assignment + Hop bLix = null; + for( Hop hop : hops ) { + if( probeLixRoot(hop, varname) ) { + if( bLix != null ) return false; //invalid + bLix = hop.getInput().get(0); + } + } + + // check all other roots independent of varname + boolean valid = true; + Hop.resetVisitStatus(hops); + for( Hop hop : hops ) + if( hop.getInput().get(0) != bLix ) + valid &= rProbeOtherRoot(hop, varname); + Hop.resetVisitStatus(hops); + + return valid; + } + + private static boolean probeLixRoot(Hop root, String varname) { + return root instanceof DataOp + && root.isMatrix() && root.getInput().get(0).isMatrix() + && root.getInput().get(0) instanceof LeftIndexingOp + && root.getInput().get(0).getInput().get(0) instanceof DataOp + && root.getInput().get(0).getInput().get(0).getName().equals(varname); + } + + private static boolean rProbeOtherRoot(Hop hop, String varname) { + if( hop.isVisited() ) + return false; + boolean valid = !(hop instanceof LeftIndexingOp) + && !(HopRewriteUtils.isData(hop, DataOpTypes.TRANSIENTREAD) && hop.getName().equals(varname)); + for( Hop c : hop.getInput() ) + valid &= rProbeOtherRoot(c, varname); + hop.setVisited(); + return valid; + } + @Override public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) { return sbs;
