This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 7de36573fc [SYSTEMDS-3785] Fix rewrite test for simplify bushy binary
ops
7de36573fc is described below
commit 7de36573fc6e6e22957145cbb40cfc402d5978f8
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Oct 24 19:54:48 2024 +0200
[SYSTEMDS-3785] Fix rewrite test for simplify bushy binary ops
This patch resolves a remaining FIXME after improved rewrite code
coverage by fixing the expressions and other rewrite configs so the
test actually triggers the existing rewrite.
---
.../java/org/apache/sysds/hops/OptimizerUtils.java | 1 +
.../apache/sysds/hops/rewrite/ProgramRewriter.java | 3 +-
.../RewriteAlgebraicSimplificationStatic.java | 4 +--
.../RewriteSimplifyBushyBinaryOperationTest.java | 38 ++++++++++++----------
4 files changed, 25 insertions(+), 21 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index de8e7809ca..6338ff7a70 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -195,6 +195,7 @@ public class OptimizerUtils
* all sum-product related rewrites.
*/
public static boolean ALLOW_SUM_PRODUCT_REWRITES = true;
+ public static boolean ALLOW_SUM_PRODUCT_REWRITES2 = true;
/**
* Enables additional mmchain optimizations. in the future, this might
be merged with
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index cd440a6bcf..03633d06a8 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -126,7 +126,8 @@ public class ProgramRewriter{
}
if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) {
_dagRuleSet.add( new
RewriteMatrixMultChainOptimization() ); //dependency: cse
- _dagRuleSet.add( new
RewriteElementwiseMultChainOptimization() ); //dependency: cse
+ if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 )
+ _dagRuleSet.add( new
RewriteElementwiseMultChainOptimization()); //dependency: cse
}
if(OptimizerUtils.ALLOW_ADVANCED_MMCHAIN_REWRITES){
_dagRuleSet.add( new
RewriteMatrixMultChainOptimizationTranspose() ); //dependency: cse
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 76691d6480..a18a2b7466 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -855,8 +855,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
}
/**
- * (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
- * (X+(Y+(Z%*%v))) -> (X+Y)+(Z%*%v)
+ * t(Z)%*%(X*(Y*(Z%*%v))) -> t(Z)%*%(X*Y)*(Z%*%v)
+ * t(Z)%*%(X+(Y+(Z%*%v))) -> t(Z)%*%((X+Y)+(Z%*%v))
*
* Note: Restriction ba() at leaf and root instead of data at leaf to
not reorganize too
* eagerly, which would loose additional rewrite potential. This
rewrite has two goals
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBushyBinaryOperationTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBushyBinaryOperationTest.java
index 105dfa8cbc..fb1bcc3630 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBushyBinaryOperationTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBushyBinaryOperationTest.java
@@ -25,6 +25,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
import org.junit.Test;
import java.util.HashMap;
@@ -37,7 +38,7 @@ public class RewriteSimplifyBushyBinaryOperationTest extends
AutomatedTestBase {
TEST_DIR +
RewriteSimplifyBushyBinaryOperationTest.class.getSimpleName() + "/";
private static final int rows = 500;
- private static final int cols = 500;
+ private static final int cols = 100;
private static final double eps = Math.pow(10, -10);
@Override
@@ -46,28 +47,28 @@ public class RewriteSimplifyBushyBinaryOperationTest
extends AutomatedTestBase {
addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
}
+ //pattern: t(Z)%*%(X*(Y*(Z%*%v))) -> t(Z)%*%((X*Y)*(Z%*%v))
@Test
public void testBushyBinaryOperationMultNoRewrite() {
testSimplifyBushyBinaryOperation(1, false);
}
@Test
- public void testBushyBinaryOperationMultRewrite() { //pattern:
(X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
+ public void testBushyBinaryOperationMultRewrite() {
testSimplifyBushyBinaryOperation(1, true);
}
+ //pattern: t(Z)%*%(X+(Y+(Z%*%v))) -> t(Z)%*%((X+Y)+(Z%*%v))
@Test
public void testBushyBinaryOperationAddNoRewrite() {
testSimplifyBushyBinaryOperation(2, false);
}
@Test
- public void testBushyBinaryOperationAddtRewrite() { //pattern:
(X+(Y+(Z%*%v))) -> (X+Y)+(Z%*%v)
+ public void testBushyBinaryOperationAddtRewrite() {
testSimplifyBushyBinaryOperation(2, true);
}
-
-
private void testSimplifyBushyBinaryOperation(int ID, boolean rewrites)
{
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
try {
@@ -76,19 +77,21 @@ public class RewriteSimplifyBushyBinaryOperationTest
extends AutomatedTestBase {
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-stats", "-args",
input("X"), input("Y"), input("Z"), input("v"), String.valueOf(ID),
output("R")};
+ programArgs = new String[] {"-stats", "-explain",
"-args",
+ input("X"), input("Y"), input("Z"), input("v"),
String.valueOf(ID), output("R")};
fullRScriptName = HOME + TEST_NAME + ".R";
rCmd = getRCmd(inputDir(), String.valueOf(ID),
expectedDir());
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewrites;
- //OptimizerUtils.ALLOW_OPERATOR_FUSION = rewrites;
- //OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
-
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 = false;
//disable nary mult
+ OptimizerUtils.ALLOW_OPERATOR_FUSION = false; //disable
emult reordering
+ //TODO improved phase ordering
+
//create matrices
- double[][] X = getRandomMatrix(rows, cols, -1, 1,
0.60d, 3);
- double[][] Y = getRandomMatrix(rows, cols, -1, 1,
0.60d, 5);
+ double[][] X = getRandomMatrix(rows, 1, -1, 1, 0.60d,
3);
+ double[][] Y = getRandomMatrix(rows, 1, -1, 1, 0.60d,
5);
double[][] Z = getRandomMatrix(rows, cols, -1, 1,
0.60d, 6);
- double[][] v = getRandomMatrix(rows, cols, -1, 1,
0.60d, 8);
+ double[][] v = getRandomMatrix(cols, 1, -1, 1, 0.60d,
8);
writeInputMatrixWithMTD("X", X, true);
writeInputMatrixWithMTD("Y", Y, true);
writeInputMatrixWithMTD("Z", Z, true);
@@ -101,15 +104,14 @@ public class RewriteSimplifyBushyBinaryOperationTest
extends AutomatedTestBase {
HashMap<MatrixValue.CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("R");
HashMap<MatrixValue.CellIndex, Double> rfile =
readRMatrixFromExpectedDir("R");
TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
-
- /**
- * The rewrite in RewriteAlgebraicSimplificationStatic
is not entered. Hence, we fail
- * the assertions for this rewrite so that we can
revisit this issue later.
- */
- //FIXME
+
+ if( ID == 1 && rewrites ) //check mmchain, enabled by
bushy join
+
Assert.assertTrue(heavyHittersContainsString("mmchain"));
}
finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ OptimizerUtils.ALLOW_OPERATOR_FUSION = true;
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 = true;
Recompiler.reinitRecompiler();
}
}