Repository: systemml
Updated Branches:
  refs/heads/master 0c4a3611c -> 7d007e7b2


[SYSTEMML-2490] Improved rewrite for update-in-place in for/while loops

This patch generalizes the existing update-in-place loop rewrite to
allow update-in-place for cases where correct access to the updated
matrix is forced by existing data dependencies.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/7d007e7b
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/7d007e7b
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/7d007e7b

Branch: refs/heads/master
Commit: 7d007e7b216b4b161fa385b460f90f2d1845b4db
Parents: 0c4a361
Author: Matthias Boehm <[email protected]>
Authored: Fri Sep 28 16:50:21 2018 +0200
Committer: Matthias Boehm <[email protected]>
Committed: Fri Sep 28 16:50:21 2018 +0200

----------------------------------------------------------------------
 .../RewriteMarkLoopVariablesUpdateInPlace.java  | 58 +++++++++++++++++---
 1 file changed, 49 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/7d007e7b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java
index 4032358..ba37e06 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.java
@@ -27,6 +27,7 @@ import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.hops.DataOp;
 import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.Hop.DataOpTypes;
 import org.apache.sysml.hops.Hop.OpOp1;
 import org.apache.sysml.hops.LeftIndexingOp;
 import org.apache.sysml.hops.UnaryOp;
@@ -117,8 +118,9 @@ public class RewriteMarkLoopVariablesUpdateInPlace extends 
StatementBlockRewrite
                        }
                        else {
                                if( sb.getHops() != null )
-                                       for( Hop hop : sb.getHops() ) 
-                                               ret &= 
isApplicableForUpdateInPlace(hop, varname);
+                                       if( 
!isApplicableForUpdateInPlace(sb.getHops(), varname) )
+                                               for( Hop hop : sb.getHops() ) 
+                                                       ret &= 
isApplicableForUpdateInPlace(hop, varname);
                        }
                        
                        //early abort if not applicable
@@ -128,18 +130,14 @@ public class RewriteMarkLoopVariablesUpdateInPlace 
extends StatementBlockRewrite
                return ret;
        }
        
-       private static boolean isApplicableForUpdateInPlace( Hop hop, String 
varname )
-       {
+       private static boolean isApplicableForUpdateInPlace(Hop hop, String 
varname) {
+               //NOTE: single-root-level validity check
                if( !hop.getName().equals(varname) )
                        return true;
        
                //valid if read/updated by leftindexing 
                //CP exec type not evaluated here as no lops generated yet 
-               boolean validLix = hop instanceof DataOp 
-                       && hop.isMatrix() && hop.getInput().get(0).isMatrix()
-                       && hop.getInput().get(0) instanceof LeftIndexingOp
-                       && hop.getInput().get(0).getInput().get(0) instanceof 
DataOp
-                       && 
hop.getInput().get(0).getInput().get(0).getName().equals(varname);
+               boolean validLix = probeLixRoot(hop, varname);
                
                //valid if only safe consumers of left indexing input
                if( validLix ) {
@@ -153,6 +151,48 @@ public class RewriteMarkLoopVariablesUpdateInPlace extends 
StatementBlockRewrite
                return validLix;
        }
        
+       private static boolean isApplicableForUpdateInPlace(ArrayList<Hop> 
hops, String varname) {
+               //NOTE: additional DAG-level validity check
+               
+               // check single LIX update which is direct root-child to 
varname assignment
+               Hop bLix = null;
+               for( Hop hop : hops ) {
+                       if( probeLixRoot(hop, varname) ) {
+                               if( bLix != null ) return false; //invalid
+                               bLix = hop.getInput().get(0);
+                       }
+               }
+               
+               // check all other roots independent of varname
+               boolean valid = true;
+               Hop.resetVisitStatus(hops);
+               for( Hop hop : hops )
+                       if( hop.getInput().get(0) != bLix )
+                               valid &= rProbeOtherRoot(hop, varname);
+               Hop.resetVisitStatus(hops);
+               
+               return valid;
+       }
+       
+       private static boolean probeLixRoot(Hop root, String varname) {
+               return root instanceof DataOp 
+                       && root.isMatrix() && root.getInput().get(0).isMatrix()
+                       && root.getInput().get(0) instanceof LeftIndexingOp
+                       && root.getInput().get(0).getInput().get(0) instanceof 
DataOp
+                       && 
root.getInput().get(0).getInput().get(0).getName().equals(varname);
+       }
+       
+       private static boolean rProbeOtherRoot(Hop hop, String varname) {
+               if( hop.isVisited() )
+                       return false;
+               boolean valid = !(hop instanceof LeftIndexingOp)
+                       && !(HopRewriteUtils.isData(hop, 
DataOpTypes.TRANSIENTREAD) && hop.getName().equals(varname));
+               for( Hop c : hop.getInput() )
+                       valid &= rProbeOtherRoot(c, varname);
+               hop.setVisited();
+               return valid;
+       }
+       
        @Override
        public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> 
sbs, ProgramRewriteStatus sate) {
                return sbs;

Reply via email to