This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new cef8a6f271 [SYSTEMDS-3798,3807] Improved loop vectorization rewrite, 
code coverage
cef8a6f271 is described below

commit cef8a6f271ea2b5e9f815a657c4dc13b5e780290
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Dec 13 15:07:31 2024 +0100

    [SYSTEMDS-3798,3807] Improved loop vectorization rewrite, code coverage
---
 .github/workflows/javaTests.yml                    |  3 +-
 .../hops/rewrite/RewriteForLoopVectorization.java  | 77 ++++++++++++++++++++++
 .../test/functions/rewrite/RewriteIfElseTest.java  |  1 -
 .../rewrite/RewriteLoopVectorization.java          |  2 -
 4 files changed, 79 insertions(+), 4 deletions(-)

diff --git a/.github/workflows/javaTests.yml b/.github/workflows/javaTests.yml
index c2cab87c22..b4341c544c 100644
--- a/.github/workflows/javaTests.yml
+++ b/.github/workflows/javaTests.yml
@@ -86,7 +86,8 @@ jobs:
           "**.functions.transform.**","**.functions.unique.**",
           
"**.functions.unary.matrix.**,**.functions.linearization.**,**.functions.jmlc.**"
         ]
-        java: [11]
+        java: ['11']
+        javadist: ['adopt']
     name: ${{ matrix.tests }}
     steps:
     - name: Checkout Repository
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java
index 1d2223dcf9..0c09c2efb4 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java
@@ -88,6 +88,9 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                                        //e.g., for(i in a:b){s = s + 
as.scalar(X[i,2])} -> s = sum(X[a:b,2])
                                        sb = vectorizeScalarAggregate(sb, csb, 
from, to, incr, iterVar);
                                        
+                                       //e.g., for(i in a:b){s = s + X[i,2]} 
-> s = sum(X[a:b,2])
+                                       sb = vectorizeScalarAggregate2(sb, csb, 
from, to, incr, iterVar);
+                                       
                                        //e.g., for(i in a:b){X[i,2] = Y[i,1] + 
Z[i,3]} -> X[a:b,2] = Y[a:b,1] + Z[a:b,3];
                                        sb = vectorizeElementwiseBinary(sb, 
csb, from, to, incr, iterVar);
                                        
@@ -205,6 +208,80 @@ public class RewriteForLoopVectorization extends 
StatementBlockRewriteRule
                return ret;
        }
        
+       private static StatementBlock vectorizeScalarAggregate2( StatementBlock 
sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar ) 
+       {
+               StatementBlock ret = sb;
+               
+               //check for applicability
+               boolean leftScalar = false;
+               boolean rightScalar = false;
+               boolean rowIx = false; //row or col
+               
+               if( csb.getHops()!=null && csb.getHops().size()==1 ) {
+                       Hop root = csb.getHops().get(0);
+                       
+                       if( root.getDataType()==DataType.SCALAR && 
root.getInput(0) instanceof BinaryOp ) {
+                               BinaryOp bop = (BinaryOp) root.getInput(0);
+                               Hop left = bop.getInput(0);
+                               Hop right = bop.getInput(1);
+                               
+                               //check for left scalar plus
+                               if( HopRewriteUtils.isValidOp(bop.getOp(), 
MAP_SCALAR_AGGREGATE_SOURCE_OPS) 
+                                       && left instanceof DataOp && 
left.getDataType() == DataType.SCALAR
+                                       && root.getName().equals(left.getName())
+                                       && right instanceof IndexingOp && 
right.isScalar())
+                               {
+                                       leftScalar = true;
+                                       rowIx = true; //row and col
+                               }
+                               //check for right scalar plus
+                               else if( HopRewriteUtils.isValidOp(bop.getOp(), 
MAP_SCALAR_AGGREGATE_SOURCE_OPS)  
+                                       && right instanceof DataOp && 
right.getDataType() == DataType.SCALAR
+                                       && 
root.getName().equals(right.getName()) 
+                                       && left instanceof IndexingOp && 
left.isScalar())
+                               {
+                                       rightScalar = true;
+                                       rowIx = true; //row and col
+                               }
+                       }
+               }
+               
+               //apply rewrite if possible
+               if( leftScalar || rightScalar ) {
+                       Hop root = csb.getHops().get(0);
+                       BinaryOp bop = (BinaryOp) root.getInput(0);
+                       Hop ix = bop.getInput().get( leftScalar?1:0 );
+                       int aggOpPos = 
HopRewriteUtils.getValidOpPos(bop.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS);
+                       AggOp aggOp = MAP_SCALAR_AGGREGATE_TARGET_OPS[aggOpPos];
+                       
+                       //replace cast with sum
+                       AggUnaryOp newSum = 
HopRewriteUtils.createAggUnaryOp(ix, aggOp, Direction.RowCol);
+                       HopRewriteUtils.removeChildReference(bop, ix);
+                       HopRewriteUtils.addChildReference(bop, newSum, 
leftScalar?1:0 );
+                       
+                       //modify indexing expression according to loop 
predicate from-to
+                       //NOTE: any redundant index operations are removed via 
dynamic algebraic simplification rewrites
+                       int index1 = rowIx ? 1 : 3;
+                       int index2 = rowIx ? 2 : 4;
+                       HopRewriteUtils.replaceChildReference(ix, 
ix.getInput().get(index1), from, index1);
+                       HopRewriteUtils.replaceChildReference(ix, 
ix.getInput().get(index2), to, index2);
+                       
+                       //update indexing size information
+                       if( rowIx )
+                               ((IndexingOp)ix).setRowLowerEqualsUpper(false);
+                       else
+                               ((IndexingOp)ix).setColLowerEqualsUpper(false);
+                       ix.setDataType(DataType.MATRIX);
+                       ix.refreshSizeInformation();
+                       Hop.resetVisitStatus(csb.getHops(), true);
+                       
+                       ret = csb;
+                       LOG.debug("Applied vectorizeScalarSumForLoop2.");
+               }
+               
+               return ret;
+       }
+       
        private static StatementBlock vectorizeElementwiseBinary( 
StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String 
itervar ) 
        {
                StatementBlock ret = sb;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteIfElseTest.java 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteIfElseTest.java
index 087fc49d98..1e7abfb03b 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteIfElseTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteIfElseTest.java
@@ -23,7 +23,6 @@ import java.util.HashMap;
 
 import org.junit.Assert;
 import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.common.Types.ExecType;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java
index 927b0fd666..d9358fef30 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java
@@ -22,7 +22,6 @@ package org.apache.sysds.test.functions.rewrite;
 import java.util.HashMap;
 
 import org.junit.Assert;
-import org.junit.Ignore;
 import org.junit.Test;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
@@ -58,7 +57,6 @@ public class RewriteLoopVectorization extends 
AutomatedTestBase
        }
        
        @Test
-       @Ignore //FIXME: extend loop vectorization rewrite
        public void testLoopVectorizationSumRewrite() {
                testRewriteLoopVectorizationSum( TEST_NAME1, true );
        }

Reply via email to