This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 12d8cd70af [SYSTEMDS-3784] Fix weighted unary-mm rewrite test cases
12d8cd70af is described below

commit 12d8cd70afa2156bda74c0b8e5d6d11d27e75c2a
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Oct 23 15:11:01 2024 +0200

    [SYSTEMDS-3784] Fix weighted unary-mm rewrite test cases
---
 .../RewriteAlgebraicSimplificationDynamic.java     |   7 +-
 .../RewriteSimplifyWeightedUnaryMMTest.java        | 174 +++------------------
 .../rewrite/RewriteSimplifyWeightedUnaryMM.dml     |  72 +--------
 3 files changed, 32 insertions(+), 221 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 5d894170df..396c40d114 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -85,8 +85,11 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
        private static OpOp2[] LOOKUP_VALID_WDIVMM_BINARY = new 
OpOp2[]{OpOp2.MULT, OpOp2.DIV}; 
        
        //valid unary and binary operators for wumm
-       private static OpOp1[] LOOKUP_VALID_WUMM_UNARY = new OpOp1[]{OpOp1.ABS, 
OpOp1.ROUND, OpOp1.CEIL, OpOp1.FLOOR, OpOp1.EXP, OpOp1.LOG, OpOp1.SQRT,  
OpOp1.SIGMOID, OpOp1.SPROP}; 
-       private static OpOp2[] LOOKUP_VALID_WUMM_BINARY = new 
OpOp2[]{OpOp2.MULT, OpOp2.POW}; 
+       private static OpOp1[] LOOKUP_VALID_WUMM_UNARY = new OpOp1[]{
+               OpOp1.ABS, OpOp1.ROUND, OpOp1.CEIL, OpOp1.FLOOR, OpOp1.EXP, 
OpOp1.LOG,
+               OpOp1.SQRT, OpOp1.SIN, OpOp1.COS, OpOp1.SIGMOID, OpOp1.SPROP}; 
+       private static OpOp2[] LOOKUP_VALID_WUMM_BINARY = new OpOp2[]{
+               OpOp2.MULT, OpOp2.POW}; 
        
        @Override
        public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, 
ProgramRewriteStatus state) {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedUnaryMMTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedUnaryMMTest.java
index aab2970913..84f7ebfe04 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedUnaryMMTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyWeightedUnaryMMTest.java
@@ -19,10 +19,16 @@
 
 package org.apache.sysds.test.functions.rewrite;
 
+import java.util.HashMap;
+
 import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Ignore;
 import org.junit.Test;
 
 public class RewriteSimplifyWeightedUnaryMMTest extends AutomatedTestBase {
@@ -31,9 +37,8 @@ public class RewriteSimplifyWeightedUnaryMMTest extends 
AutomatedTestBase {
        private static final String TEST_CLASS_DIR =
                TEST_DIR + 
RewriteSimplifyWeightedUnaryMMTest.class.getSimpleName() + "/";
 
-       private static final int rows = 100;
-       private static final int cols = 100;
-       //private static final double eps = Math.pow(10, -7);
+       private static final int rows = 1123; //larger than blocksize needed
+       private static final int cols = 1245;
 
        @Override
        public void setUp() {
@@ -103,166 +108,28 @@ public class RewriteSimplifyWeightedUnaryMMTest extends 
AutomatedTestBase {
                testRewriteSimplifyWeightedUnaryMM(5, true);    //pattern: 
2*(W*(U%*%t(V)))
        }
 
-       /**
-        * These tests cover the case for the third pattern
-        * W * sop(U%*%t(V), c) or W * sop(U%*%t(V), c), where
-        * sop stands for scalar operation (+, -, *, /) and c represents
-        * some constant scalar.
-        * */
-
-       @Test
-       public void testWeightedUnaryMMAddLeftNoRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(6, false);
-       }
-
-       @Test
-       public void testWeightedUnaryMMAddLeftRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(6, true);    //pattern: W * 
(c + U%*%t(V))
-       }
-
-       @Test
-       public void testWeightedUnaryMMMinusLeftNoRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(7, false);
-       }
-
-       @Test
-       public void testWeightedUnaryMMMinusLeftRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(7, true);    //pattern: W * 
(c - U%*%t(V))
-       }
-
        @Test
        public void testWeightedUnaryMMMultLeftNoRewrite(){
                testRewriteSimplifyWeightedUnaryMM(8, false);
        }
 
        @Test
+       @Ignore //FIXME non-applied rewrite
        public void testWeightedUnaryMMMultLeftRewrite(){
                testRewriteSimplifyWeightedUnaryMM(8, true);    //pattern: W * 
(c * (U%*%t(V)))
        }
 
-       @Test
-       public void testWeightedUnaryMMDivLeftNoRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(9, false);
-       }
-
-       @Test
-       public void testWeightedUnaryMMDivLeftRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(9, true);    //pattern: W * 
(c / (U%*%t(V)))
-       }
-
-       // Same pattern but scalar from right instead of left
-
-       @Test
-       public void testWeightedUnaryMMAddRightNoRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(10, false);
-       }
-
-       @Test
-       public void testWeightedUnaryMMAddRightRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(10, true);   //pattern: W * 
(U%*%t(V) + c)
-       }
-
-       @Test
-       public void testWeightedUnaryMMMinusRightNoRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(11, false);
-       }
-
-       @Test
-       public void testWeightedUnaryMMMinusRightRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(11, true);   //pattern: W * 
(U%*%t(V) - c)
-       }
-
        @Test
        public void testWeightedUnaryMMMulRightNoRewrite(){
                testRewriteSimplifyWeightedUnaryMM(12, false);
        }
 
        @Test
+       @Ignore //FIXME non-applied rewrite
        public void testWeightedUnaryMMMultRightRewrite(){
                testRewriteSimplifyWeightedUnaryMM(12, true);   //pattern: W * 
((U%*%t(V)) * c)
        }
 
-       @Test
-       public void testWeightedUnaryMMDivRightNoRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(13, false);
-       }
-
-       @Test
-       public void testWeightedUnaryMMDivRightRewrite(){
-               testRewriteSimplifyWeightedUnaryMM(13, true);   //pattern: W * 
((U%*%t(V)) / c)
-       }
-
-       /**
-        * Here, we omit the transpose in the dml script. The rewrite should 
catch the missing transpose
-        * and replace V with t(V).
-        **/
-
-       @Test
-       public void testWeightedUnaryMMExpNoTranspose(){
-               testRewriteSimplifyWeightedUnaryMM(14, true);   //pattern: W * 
exp(U%*%V)
-       }
-
-       @Test
-       public void testWeightedUnaryMMAbsNoTranspose(){
-               testRewriteSimplifyWeightedUnaryMM(15, true);   //pattern: W * 
abs(U%*%V)
-       }
-
-       @Test
-       public void testWeightedUnaryMMSinNoTranspose(){
-               testRewriteSimplifyWeightedUnaryMM(16, true);   //pattern: W * 
sin(U%*%V)
-       }
-
-       @Test
-       public void testWeightedUnaryMMScalarRightNoTranspose(){
-               testRewriteSimplifyWeightedUnaryMM(17, true);   //pattern: 
(W*(U%*%V))*2
-       }
-
-       @Test
-       public void testWeightedUnaryMMScalarLeftNoTranspose(){
-               testRewriteSimplifyWeightedUnaryMM(18, true);   //pattern: 
2*(W*(U%*%V))
-       }
-
-       @Test
-       public void testWeightedUnaryMMAddLeftNoTranspose(){
-               testRewriteSimplifyWeightedUnaryMM(19, true);   //pattern: W * 
(c + U%*%V)
-       }
-
-       @Test
-       public void testWeightedUnaryMMMinusLeftNoTranspose(){
-               testRewriteSimplifyWeightedUnaryMM(20, true);   //pattern: W * 
(c - U%*%V)
-       }
-
-       @Test
-       public void testWeightedUnaryMMMultLeftNoTranspose(){
-               testRewriteSimplifyWeightedUnaryMM(21, true);   //pattern: W * 
(c * (U%*%V))
-       }
-
-       @Test
-       public void testWeightedUnaryMMDivLeftNoTranspose(){
-               testRewriteSimplifyWeightedUnaryMM(22, true);   //pattern: W * 
(c / (U%*%V))
-       }
-
-       @Test
-       public void testWeightedUnaryMMAddRightNoTranspose(){
-               testRewriteSimplifyWeightedUnaryMM(23, true);   //pattern: W * 
(U%*%V + c)
-       }
-
-       @Test
-       public void testWeightedUnaryMMMinusRightNoTranspose(){
-               testRewriteSimplifyWeightedUnaryMM(24, true);   //pattern: W * 
(U%*%V - c)
-       }
-
-       @Test
-       public void testWeightedUnaryMMMultRightNoTranspose(){
-               testRewriteSimplifyWeightedUnaryMM(25, true);   //pattern: W * 
((U%*%V) * c)
-       }
-
-       @Test
-       public void testWeightedUnaryMMDivRightNoTranspose(){
-               testRewriteSimplifyWeightedUnaryMM(26, true);   //pattern: W * 
((U%*%V) / c)
-       }
-
-
 
        private void testRewriteSimplifyWeightedUnaryMM(int ID, boolean 
rewrites) {
                boolean oldFlag1 = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
@@ -280,11 +147,13 @@ public class RewriteSimplifyWeightedUnaryMMTest extends 
AutomatedTestBase {
 
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
                        OptimizerUtils.ALLOW_OPERATOR_FUSION = rewrites;
+                       Recompiler.reinitRecompiler();
 
                        //create matrices
-                       double[][] U = getRandomMatrix(rows, cols, -1, 1, 
0.80d, 3);
-                       double[][] V = getRandomMatrix(rows, cols, -1, 1, 
0.70d, 4);
-                       double[][] W = getRandomMatrix(rows, cols, -1, 1, 
0.60d, 5);
+                       int rank = 50;
+                       double[][] U = getRandomMatrix(rows, rank, -1, 1, 
0.80d, 3);
+                       double[][] V = getRandomMatrix(cols, rank, -1, 1, 
0.70d, 4);
+                       double[][] W = getRandomMatrix(rows, cols, -1, 1, 
0.01d, 5);
                        writeInputMatrixWithMTD("U", U, true);
                        writeInputMatrixWithMTD("V", V, true);
                        writeInputMatrixWithMTD("W", W, true);
@@ -293,15 +162,10 @@ public class RewriteSimplifyWeightedUnaryMMTest extends 
AutomatedTestBase {
                        runRScript(true);
 
                        //compare matrices
-// FIXME
-//                     HashMap<MatrixValue.CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("R");
-//                     HashMap<MatrixValue.CellIndex, Double> rfile = 
readRMatrixFromExpectedDir("R");
-//                     TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
-//                     if(rewrites)
-//                             
Assert.assertTrue(heavyHittersContainsString("wumm"));
-//                     else
-//                             
Assert.assertFalse(heavyHittersContainsString("wumm"));
-
+                       HashMap<MatrixValue.CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("R");
+                       HashMap<MatrixValue.CellIndex, Double> rfile = 
readRMatrixFromExpectedDir("R");
+                       TestUtils.compareMatrices(dmlfile, rfile, 1e-8, 
"Stat-DML", "Stat-R");
+                       
Assert.assertTrue(heavyHittersContainsString("wumm")==rewrites);
                }
                finally {
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
oldFlag1;
diff --git 
a/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.dml 
b/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.dml
index 300d2d11ea..bda9da8d06 100644
--- a/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.dml
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyWeightedUnaryMM.dml
@@ -28,83 +28,27 @@ c = 4.0
 
 # Perform operations
 if(type == 1){
-    R = W * exp(U%*%t(V))
+  R = W * exp(U%*%t(V))
 }
 else if(type == 2){
-    R = W * abs(U%*%t(V))
+  R = W * abs(U%*%t(V))
 }
 else if(type == 3){
-    R = W * sin(U%*%t(V))
+  R = W * sin(U%*%t(V))
 }
 else if(type == 4){
-    R = (W*(U%*%t(V)))*2
+  R = (W*(U%*%t(V)))*2
 }
 else if(type == 5){
-    R = 2*(W*(U%*%t(V)))
-}
-else if(type == 6){
-    R = W * (c + U%*%t(V))
-}
-else if(type == 7){
-    R = W * (c - U%*%t(V))
+  R = 2*(W*(U%*%t(V)))
 }
 else if(type == 8){
-    R = W * (c * (U%*%t(V)))
-}
-else if(type == 9){
-    R = W * (c / (U%*%t(V)))
-}
-else if(type == 10){
-    R = W * (U%*%t(V) + c)
-}
-else if(type == 11){
-    R = W * (U%*%t(V) - c)
+  R = W * (c * (U%*%t(V)))
 }
 else if(type == 12){
-    R = W * ((U%*%t(V)) * c)
-}
-else if(type == 13){
-    R = W * ((U%*%t(V)) / c)
-}
-else if(type == 14){
-    R = W * exp(U%*%V)
-}
-else if(type == 15){
-    R = W * abs(U%*%V)
-}
-else if(type == 16){
-    R = W * sin(U%*%V)
-}
-else if(type == 17){
-    R = (W*(U%*%V))*2
-}
-else if(type == 18){
-    R = 2*(W*(U%*%V))
-}
-else if(type == 19){
-    R = W * (c + U%*%V)
-}
-else if(type == 20){
-    R = W * (c - U%*%V)
-}
-else if(type == 21){
-    R = W * (c * (U%*%V))
-}
-else if(type == 22){
-    R = W * (c / (U%*%V))
-}
-else if(type == 23){
-    R = W * (U%*%V + c)
-}
-else if(type == 24){
-    R = W * (U%*%V - c)
-}
-else if(type == 25){
-    R = W * ((U%*%V) * c)
-}
-else if(type == 26){
-    R = W * ((U%*%V) / c)
+  R = W * ((U%*%t(V)) * c)
 }
 
 # Write the result matrix R
 write(R, $5)
+

Reply via email to