This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 091144d669 [SYSTEMDS-3798] Fix generality of loop vectorization rewrite
091144d669 is described below
commit 091144d669a803cf03ca953f0a866f5bc967246c
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Dec 13 18:18:59 2024 +0100
[SYSTEMDS-3798] Fix generality of loop vectorization rewrite
---
.../hops/rewrite/RewriteForLoopVectorization.java | 38 +++++++++++++++-------
.../test/functions/vect/AutoVectorizationTest.java | 24 +++++---------
2 files changed, 35 insertions(+), 27 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java
index 0c09c2efb4..ad06ac2359 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.java
@@ -138,13 +138,11 @@ public class RewriteForLoopVectorization extends
StatementBlockRewriteRule
&& right.getInput(0) instanceof
IndexingOp )
{
IndexingOp ix =
(IndexingOp)right.getInput(0);
- if( ix.isRowLowerEqualsUpper() &&
ix.getInput(1) instanceof DataOp
- &&
ix.getInput(1).getName().equals(itervar) ){
+ if( checkItervarIndexing(ix, itervar,
true) ){
leftScalar = true;
rowIx = true;
}
- else if( ix.isColLowerEqualsUpper() &&
ix.getInput(3) instanceof DataOp
- &&
ix.getInput(3).getName().equals(itervar) ){
+ else if( checkItervarIndexing(ix,
itervar, false) ){
leftScalar = true;
rowIx = false;
}
@@ -157,13 +155,11 @@ public class RewriteForLoopVectorization extends
StatementBlockRewriteRule
&& left.getInput(0) instanceof
IndexingOp )
{
IndexingOp ix =
(IndexingOp)left.getInput(0);
- if( ix.isRowLowerEqualsUpper() &&
ix.getInput(1) instanceof DataOp
- &&
ix.getInput(1).getName().equals(itervar) ){
+ if( checkItervarIndexing(ix, itervar,
true) ){
rightScalar = true;
rowIx = true;
}
- else if( ix.isColLowerEqualsUpper() &&
ix.getInput(3) instanceof DataOp
- &&
ix.getInput(3).getName().equals(itervar) ){
+ else if( checkItervarIndexing(ix,
itervar, false) ){
rightScalar = true;
rowIx = false;
}
@@ -231,8 +227,14 @@ public class RewriteForLoopVectorization extends
StatementBlockRewriteRule
&& root.getName().equals(left.getName())
&& right instanceof IndexingOp &&
right.isScalar())
{
- leftScalar = true;
- rowIx = true; //row and col
+ if(
checkItervarIndexing((IndexingOp)right, itervar, true) ){
+ leftScalar = true;
+ rowIx = true;
+ }
+ else if(
checkItervarIndexing((IndexingOp)right, itervar, false) ){
+ leftScalar = true;
+ rowIx = false;
+ }
}
//check for right scalar plus
else if( HopRewriteUtils.isValidOp(bop.getOp(),
MAP_SCALAR_AGGREGATE_SOURCE_OPS)
@@ -240,8 +242,14 @@ public class RewriteForLoopVectorization extends
StatementBlockRewriteRule
&&
root.getName().equals(right.getName())
&& left instanceof IndexingOp &&
left.isScalar())
{
- rightScalar = true;
- rowIx = true; //row and col
+ if(
checkItervarIndexing((IndexingOp)left, itervar, true) ){
+ rightScalar = true;
+ rowIx = true;
+ }
+ else if(
checkItervarIndexing((IndexingOp)left, itervar, false) ){
+ rightScalar = true;
+ rowIx = false;
+ }
}
}
}
@@ -461,6 +469,12 @@ public class RewriteForLoopVectorization extends
StatementBlockRewriteRule
return ret;
}
+ private static boolean checkItervarIndexing(IndexingOp ix, String
itervar, boolean row) {
+ return ix.isRowLowerEqualsUpper()
+ && ix.getInput(row?1:3) instanceof DataOp
+ && ix.getInput(row?1:3).getName().equals(itervar);
+ }
+
private static boolean[] checkLeftAndRightIndexing(LeftIndexingOp lix,
IndexingOp rix, String itervar) {
boolean[] ret = new boolean[2]; //apply, rowIx
diff --git
a/src/test/java/org/apache/sysds/test/functions/vect/AutoVectorizationTest.java
b/src/test/java/org/apache/sysds/test/functions/vect/AutoVectorizationTest.java
index 0b0b301224..7771c96a8a 100644
---
a/src/test/java/org/apache/sysds/test/functions/vect/AutoVectorizationTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/vect/AutoVectorizationTest.java
@@ -213,42 +213,36 @@ public class AutoVectorizationTest extends
AutomatedTestBase
runVectorizationTest( TEST_NAME24 );
}
- /**
- *
- * @param cfc
- * @param vt
- */
private void runVectorizationTest( String testName )
{
String TEST_NAME = testName;
try
- {
+ {
TestConfiguration config =
getTestConfiguration(TEST_NAME);
loadTestConfiguration(config);
- String HOME = SCRIPT_DIR + TEST_DIR;
+ String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[]{"-explain","-args",
input("A"), output("R") };
fullRScriptName = HOME + TEST_NAME + ".R";
- rCmd = getRCmd(inputDir(), expectedDir());
+ rCmd = getRCmd(inputDir(), expectedDir());
//generate input
double[][] A = getRandomMatrix(rows, cols, 0, 1, 1.0,
7);
writeInputMatrixWithMTD("A", A, true);
//run tests
- runTest(true, false, null, -1);
- runRScript(true);
-
- //compare results
- HashMap<CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("R");
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ //compare results
+ HashMap<CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("R");
HashMap<CellIndex, Double> rfile =
readRMatrixFromExpectedDir("R");
TestUtils.compareMatrices(dmlfile, rfile, 1e-14, "DML",
"R");
}
- catch(Exception ex)
- {
+ catch(Exception ex) {
throw new RuntimeException(ex);
}
}