[SYSTEMML-1369] New loop vectorization rewrite for indexed copies

As explained in https://issues.apache.org/jira/browse/SYSTEMML-1369,
this patch introduces an auto loop vectorization rewrite for indexed
copies. For example, we now rewrite the following loop

parfor (i in 1:ncol(labels))
   topics[id, i] = labels[1, i];
   
to a simple left indexing of topics[id, 1:ncol(labels)] = labels[1,
1:ncol(labels)]. This applies to for and parfor loops. Furthermore, this
patch also fixes size update issues of the existing loop vectorization
rewrites vectorizeElementwiseBinary and vectorizeElementwiseUnary.

On a scenario of an 1K x 1K dense topics matrix, an outer loop over id,
and regular for loops, the runtime improved from 33s (3438s without
update-in-place) to 0.2s. Note that on large out-of-core matrices, the
improvements are even larger because we cannot apply update-in-place
there. 


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/ed3a1588
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/ed3a1588
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/ed3a1588

Branch: refs/heads/master
Commit: ed3a1588287e22ab3e8a3e9e971b505a9157ba59
Parents: cd6685a
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Thu Mar 2 23:01:28 2017 -0800
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Thu Mar 2 23:01:28 2017 -0800

----------------------------------------------------------------------
 .../rewrite/RewriteForLoopVectorization.java    | 177 ++++++++++++-------
 1 file changed, 114 insertions(+), 63 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ed3a1588/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java
index 991dedd..273436e 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java
@@ -76,10 +76,20 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                                          || csb instanceof IfStatementBlock 
                                          || csb instanceof ForStatementBlock ) 
)
                                {
-                                       //auto vectorization pattern
-                                       sb = vectorizeScalarAggregate(sb, csb, 
from, to, incr, iterVar);           //e.g., for(i){s = s + as.scalar(X[i,2])}
+                                       //AUTO VECTORIZATION PATTERNS
+                                       //Note: unnecessary row or column 
indexing then later removed via hop rewrites
+                                       
+                                       //e.g., for(i in a:b){s = s + 
as.scalar(X[i,2])} -> s = sum(X[a:b,2])
+                                       sb = vectorizeScalarAggregate(sb, csb, 
from, to, incr, iterVar);  
+                                       
+                                       //e.g., for(i in a:b){X[i,2] = Y[i,1] + 
Z[i,3]} -> X[a:b,2] = Y[a:b,1] + Z[a:b,3];
                                        sb = vectorizeElementwiseBinary(sb, 
csb, from, to, incr, iterVar);
+                                       
+                                       //e.g., for(i in a:b){X[i,2] = 
abs(Y[i,1])} -> X[a:b,2] = abs(Y[a:b,1]);
                                        sb = vectorizeElementwiseUnary(sb, csb, 
from, to, incr, iterVar);
+                               
+                                       //e.g., for(i in a:b){X[7,i] = Y[1,i]} 
-> X[7,a:b] = Y[1,a:b];
+                                       sb = vectorizeIndexedCopy(sb, csb, 
from, to, incr, iterVar);
                                }       
                        }       
                }       
@@ -91,19 +101,6 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                return ret;
        }
        
-       /**
-        * Note: unnecessary row or column indexing then later removed via
-        * dynamic rewrites
-        * 
-        * @param sb statement block?
-        * @param csb statement boock?
-        * @param from high-level operator?
-        * @param to high-level operator?
-        * @param increment high-level operator?
-        * @param itervar ?
-        * @return statement block
-        * @throws HopsException if HopsException occurs
-        */
        private StatementBlock vectorizeScalarAggregate( StatementBlock sb, 
StatementBlock csb, Hop from, Hop to, Hop increment, String itervar ) 
                throws HopsException
        {
@@ -206,19 +203,6 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                return ret;
        }
        
-       /**
-        * Note: unnecessary row or column indexing then later removed via
-        * dynamic rewrites
-        * 
-        * @param sb ?
-        * @param csb ?
-        * @param from ?
-        * @param to ?
-        * @param increment ?
-        * @param itervar ?
-        * @return statement block
-        * @throws HopsException if HopsException occurs
-        */
        private StatementBlock vectorizeElementwiseBinary( StatementBlock sb, 
StatementBlock csb, Hop from, Hop to, Hop increment, String itervar ) 
                throws HopsException
        {
@@ -291,8 +275,9 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                        HopRewriteUtils.replaceChildReference(rix0, 
rix0.getInput().get(index2-1), to, index2-1);
                        HopRewriteUtils.replaceChildReference(rix1, 
rix1.getInput().get(index1-1), from, index1-1);
                        HopRewriteUtils.replaceChildReference(rix1, 
rix1.getInput().get(index2-1), to, index2-1);
+                       updateLeftAndRightIndexingSizes(rowIx, lix, rix0, rix1);
                        bop.refreshSizeInformation();
-                       lix.refreshSizeInformation();
+                       lix.refreshSizeInformation(); //after bop update
                        
                        ret = csb;
                        //ret.liveIn().removeVariable(itervar);
@@ -302,19 +287,6 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                return ret;
        }
        
-       /**
-        * Note: unnecessary row or column indexing then later removed via
-        * dynamic rewrites
-        * 
-        * @param sb ?
-        * @param csb ?
-        * @param from ?
-        * @param to ?
-        * @param increment ?
-        * @param itervar ?
-        * @return statement block
-        * @throws HopsException if HopsException occurs
-        */
        private StatementBlock vectorizeElementwiseUnary( StatementBlock sb, 
StatementBlock csb, Hop from, Hop to, Hop increment, String itervar )
                throws HopsException
        {
@@ -342,30 +314,16 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                                        && lixrhs.getInput().get(0) instanceof 
IndexingOp
                                        && 
lixrhs.getInput().get(0).getInput().get(0) instanceof DataOp )
                                {
-                                       IndexingOp rix = (IndexingOp) 
lixrhs.getInput().get(0);
-                                       //check for rowwise
-                                       if(    lix.getRowLowerEqualsUpper() && 
rix.getRowLowerEqualsUpper() 
-                                               && 
lix.getInput().get(2).getName().equals(itervar)
-                                               && 
rix.getInput().get(1).getName().equals(itervar) )
-                                       {
-                                               apply = true;
-                                               rowIx = true;
-                                       }
-                                       //check for colwise
-                                       if(    lix.getColLowerEqualsUpper() && 
rix.getColLowerEqualsUpper() 
-                                               && 
lix.getInput().get(4).getName().equals(itervar)
-                                               && 
rix.getInput().get(3).getName().equals(itervar) )
-                                       {
-                                               apply = true;
-                                               rowIx = false;
-                                       }
+                                       boolean[] tmp = 
checkLeftAndRightIndexing(lix, 
+                                                       (IndexingOp) 
lixrhs.getInput().get(0), itervar);
+                                       apply = tmp[0];
+                                       rowIx = tmp[1];
                                }
                        }
                }       
                
                //apply rewrite if possible
-               if( apply ) 
-               {
+               if( apply ) {
                        Hop root = csb.get_hops().get(0);
                        LeftIndexingOp lix = (LeftIndexingOp) 
root.getInput().get(0);
                        UnaryOp uop = (UnaryOp) lix.getInput().get(1);
@@ -378,14 +336,107 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                        //modify right indexing
                        HopRewriteUtils.replaceChildReference(rix, 
rix.getInput().get(index1-1), from, index1-1);
                        HopRewriteUtils.replaceChildReference(rix, 
rix.getInput().get(index2-1), to, index2-1);
+                       updateLeftAndRightIndexingSizes(rowIx, lix, rix);
                        uop.refreshSizeInformation();
-                       lix.refreshSizeInformation();
+                       lix.refreshSizeInformation(); //after uop update
                        
                        ret = csb;
-                       //ret.liveIn().removeVariable(itervar);
                        LOG.debug("Applied vectorizeElementwiseUnaryForLoop.");
                }
                
                return ret;
        }
+       
+       private StatementBlock vectorizeIndexedCopy( StatementBlock sb, 
StatementBlock csb, Hop from, Hop to, Hop increment, String itervar )
+               throws HopsException
+       {
+               StatementBlock ret = sb;
+               
+               //check supported increment values
+               if( !(increment instanceof LiteralOp && 
((LiteralOp)increment).getDoubleValue()==1.0) ) {
+                       return ret;
+               }
+                       
+               //check for applicability
+               boolean apply = false;
+               boolean rowIx = false; //row or col
+               if( csb.get_hops()!=null && csb.get_hops().size()==1 )
+               {
+                       Hop root = csb.get_hops().get(0);
+                       
+                       if( root.getDataType()==DataType.MATRIX && 
root.getInput().get(0) instanceof LeftIndexingOp )
+                       {
+                               LeftIndexingOp lix = (LeftIndexingOp) 
root.getInput().get(0);
+                               Hop lixlhs = lix.getInput().get(0);
+                               Hop lixrhs = lix.getInput().get(1);
+                               
+                               if( lixlhs instanceof DataOp && lixrhs 
instanceof IndexingOp
+                                       && lixrhs.getInput().get(0) instanceof 
DataOp )
+                               {
+                                       boolean[] tmp = 
checkLeftAndRightIndexing(lix, (IndexingOp)lixrhs, itervar);
+                                       apply = tmp[0];
+                                       rowIx = tmp[1];
+                               }
+                       }
+               }       
+               
+               //apply rewrite if possible
+               if( apply ) {
+                       Hop root = csb.get_hops().get(0);
+                       LeftIndexingOp lix = (LeftIndexingOp) 
root.getInput().get(0);
+                       IndexingOp rix = (IndexingOp) lix.getInput().get(1);
+                       int index1 = rowIx ? 2 : 4;
+                       int index2 = rowIx ? 3 : 5;
+                       //modify left indexing bounds
+                       HopRewriteUtils.replaceChildReference(lix, 
lix.getInput().get(index1), from, index1);
+                       HopRewriteUtils.replaceChildReference(lix, 
lix.getInput().get(index2), to, index2);
+                       //modify right indexing
+                       HopRewriteUtils.replaceChildReference(rix, 
rix.getInput().get(index1-1), from, index1-1);
+                       HopRewriteUtils.replaceChildReference(rix, 
rix.getInput().get(index2-1), to, index2-1);
+                       updateLeftAndRightIndexingSizes(rowIx, lix, rix);
+                       
+                       ret = csb;
+                       LOG.debug("Applied vectorizeIndexedCopy.");
+               }
+               
+               return ret;
+       }
+       
+       private static boolean[] checkLeftAndRightIndexing(LeftIndexingOp lix, 
IndexingOp rix, String itervar) {
+               boolean[] ret = new boolean[2]; //apply, rowIx
+               
+               //check for rowwise
+               if(    lix.getRowLowerEqualsUpper() && 
rix.getRowLowerEqualsUpper() 
+                       && lix.getInput().get(2).getName().equals(itervar)
+                       && rix.getInput().get(1).getName().equals(itervar) ) {
+                       ret[0] = true;
+                       ret[1] = true;
+               }
+               //check for colwise
+               if(    lix.getColLowerEqualsUpper() && 
rix.getColLowerEqualsUpper() 
+                       && lix.getInput().get(4).getName().equals(itervar)
+                       && rix.getInput().get(3).getName().equals(itervar) ) {
+                       ret[0] = true;
+                       ret[1] = false;
+               }
+               
+               return ret;
+       } 
+       
+       private static void updateLeftAndRightIndexingSizes(boolean rowIx, 
LeftIndexingOp lix, IndexingOp... rix) {
+               //unset special flags
+               if( rowIx ) {
+                       lix.setRowLowerEqualsUpper(false);
+                       for( IndexingOp rixi : rix )
+                               rixi.setRowLowerEqualsUpper(false);
+               }
+               else {
+                       lix.setColLowerEqualsUpper(false);
+                       for( IndexingOp rixi : rix )
+                               rixi.setColLowerEqualsUpper(false);
+               }
+               for( IndexingOp rixi : rix )
+                       rixi.refreshSizeInformation();
+               lix.refreshSizeInformation();
+       }
 }

Reply via email to