This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new c36a8369e5 [SYSTEMDS-3666] New simplification rewrite 
not-over-comparisons
c36a8369e5 is described below

commit c36a8369e5f394a362ac69eb90e3ab62b50c8db9
Author: ReneEnjilian <[email protected]>
AuthorDate: Sun Mar 17 18:37:21 2024 +0100

    [SYSTEMDS-3666] New simplification rewrite not-over-comparisons
    
    Closes #1988.
---
 .../RewriteAlgebraicSimplificationStatic.java      | 36 ++++++++++
 .../rewrite/RewriteDistributiveMatrixMultTest.java |  2 +-
 ...est.java => RewriteNotOverComparisonsTest.java} | 78 ++++++++++++----------
 .../functions/rewrite/RewriteNotOverComparisons.R  | 45 +++++++++++++
 .../rewrite/RewriteNotOverComparisons.dml          | 35 ++++++++++
 5 files changed, 160 insertions(+), 36 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index b74e93d6cf..a867735e50 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -190,6 +190,8 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = simplifyCumsumColOrFullAggregates(hi);          
//e.g., colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
                        hi = simplifyCumsumReverse(hop, hi, i);              
//e.g., rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)
 
+                       hi = simplifyNotOverComparisons(hop, hi, i);         
//e.g., !(A>B) -> (A<=B)
+                       
                        //hi = removeUnecessaryPPred(hop, hi, i);            
//e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
 
                        //process childs recursively after rewrites (to 
investigate pattern newly created by rewrites)
@@ -1980,6 +1982,40 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
                return hi;
        }
        
+       private static Hop simplifyNotOverComparisons(Hop parent, Hop hi, int 
pos){
+               if(HopRewriteUtils.isUnary(hi, OpOp1.NOT) && hi.getInput(0) 
instanceof BinaryOp
+                       && hi.getInput(0).getParent().size() == 1) //NOT is 
only consumer
+               {
+                       Hop binaryOperator = hi.getInput(0);
+                       Hop A = binaryOperator.getInput(0);
+                       Hop B = binaryOperator.getInput(1);
+                       Hop newHop = null;
+
+                       // !(A>B) -> A<=B
+                       if(HopRewriteUtils.isBinary(binaryOperator, 
OpOp2.GREATER)) {
+                               newHop = HopRewriteUtils.createBinary(A, B, 
OpOp2.LESSEQUAL);
+                       }
+                       // !(A<B) -> A>=B
+                       else if(HopRewriteUtils.isBinary(binaryOperator, 
OpOp2.LESS)) {
+                               newHop = HopRewriteUtils.createBinary(A, B, 
OpOp2.GREATEREQUAL);
+                       }
+                       // !(A==B) -> A!=B, including !(A==0) -> A!=0
+                       else if(HopRewriteUtils.isBinary(binaryOperator, 
OpOp2.EQUAL)) {
+                               newHop = HopRewriteUtils.createBinary(A, B, 
OpOp2.NOTEQUAL);
+                       }
+                       //TODO add remaining cases of comparison operators
+
+                       if(parent != null && newHop != null) {
+                               HopRewriteUtils.replaceChildReference(parent, 
hi, newHop, pos);
+                               HopRewriteUtils.cleanupUnreferenced(hi);
+                               hi = newHop;
+                               LOG.debug("Applied simplifyNotOverComparisons 
(line " + hi.getBeginLine() + ")");
+                       }
+               }
+
+               return hi;
+       }
+       
        /**
         * NOTE: currently disabled since this rewrite is INVALID in the
         * presence of NaNs (because (NaN!=NaN) is true). 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
index 7f40a2bef3..2721afc01f 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
@@ -34,7 +34,7 @@ public class RewriteDistributiveMatrixMultTest extends 
AutomatedTestBase {
        private static final String TEST_NAME1 = 
"RewriteDistributiveMatrixMult";
        private static final String TEST_DIR = "functions/rewrite/";
        private static final String TEST_CLASS_DIR =
-               TEST_DIR + 
RewriteSimplifyRowColSumMVMultTest.class.getSimpleName() + "/";
+               TEST_DIR + 
RewriteDistributiveMatrixMultTest.class.getSimpleName() + "/";
 
        private static final int rows = 500;
        private static final int cols = 500;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNotOverComparisonsTest.java
similarity index 56%
copy from 
src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
copy to 
src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNotOverComparisonsTest.java
index 7f40a2bef3..f0c167236a 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNotOverComparisonsTest.java
@@ -30,55 +30,73 @@ import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.apache.sysds.utils.Statistics;
 
-public class RewriteDistributiveMatrixMultTest extends AutomatedTestBase {
-       private static final String TEST_NAME1 = 
"RewriteDistributiveMatrixMult";
+public class RewriteNotOverComparisonsTest extends AutomatedTestBase {
+
+       private static final String TEST_NAME = "RewriteNotOverComparisons";
        private static final String TEST_DIR = "functions/rewrite/";
-       private static final String TEST_CLASS_DIR =
-               TEST_DIR + 
RewriteSimplifyRowColSumMVMultTest.class.getSimpleName() + "/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteNotOverComparisonsTest.class.getSimpleName() + "/";
 
-       private static final int rows = 500;
-       private static final int cols = 500;
+       private static final int rows = 10;
+       private static final int cols = 10;
        private static final double eps = Math.pow(10, -10);
 
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}));
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
+       }
+
+       @Test
+       public void testNotOverComparisonsGreaterNoRewrite() {
+               testRewriteNotOverComparisons(1, false);
+       }
 
+       @Test
+       public void testNotOverComparisonsGreaterRewrite() {
+               testRewriteNotOverComparisons(1, true);
+       }
+
+       @Test
+       public void testNotOverComparisonsLessNoRewrite() {
+               testRewriteNotOverComparisons(2, false);
        }
 
        @Test
-       public void testDistributiveMatrixMultNoRewrite() {
-               testRewriteDistributiveMatrixMult(TEST_NAME1, false);
+       public void testNotOverComparisonsLessRewrite() {
+               testRewriteNotOverComparisons(2, true);
        }
 
        @Test
-       public void testDistributiveMatrixMultRewrite() {
-               testRewriteDistributiveMatrixMult(TEST_NAME1, true);
+       public void testNotOverComparisonsEqualNoRewrite() {
+               testRewriteNotOverComparisons(3, false);
        }
 
-       private void testRewriteDistributiveMatrixMult(String testname, boolean 
rewrites) {
+       @Test
+       public void testNotOverComparisonsEqualRewrite() {
+               testRewriteNotOverComparisons(3, true);
+       }
+
+       private void testRewriteNotOverComparisons(int ID, boolean rewrites) {
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
                try {
-                       TestConfiguration config = 
getTestConfiguration(testname);
+                       TestConfiguration config = 
getTestConfiguration(TEST_NAME);
                        loadTestConfiguration(config);
 
                        String HOME = SCRIPT_DIR + TEST_DIR;
-                       fullDMLScriptName = HOME + testname + ".dml";
-                       programArgs = new String[] {"-stats", "-args", 
input("A"), input("B"), input("C"), output("R")};
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-stats", "-args",
+                               input("A"), input("B"), String.valueOf(ID), 
output("R")};
 
-                       fullRScriptName = HOME + testname + ".R";
-                       rCmd = getRCmd(inputDir(), expectedDir());
+                       fullRScriptName = HOME + TEST_NAME + ".R";
+                       rCmd = getRCmd(inputDir(), String.valueOf(ID), 
expectedDir());
 
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
-                       //create dense matrices so that rewrites are possible
+
                        double[][] A = getRandomMatrix(rows, cols, -1, 1, 
0.70d, 7);
+                       writeInputMatrixWithMTD("A", A, 65, true);
                        double[][] B = getRandomMatrix(rows, cols, -1, 1, 
0.70d, 6);
-                       double[][] C = getRandomMatrix(rows, cols, -1, 1, 
0.70d, 3);
-                       writeInputMatrixWithMTD("A", A, 174522, true);
-                       writeInputMatrixWithMTD("B", B, 174935, true);
-                       writeInputMatrixWithMTD("C", C, 174848, true);
-
+                       writeInputMatrixWithMTD("B", B, 74, true);
+                       
                        runTest(true, false, null, -1);
                        runRScript(true);
 
@@ -87,21 +105,11 @@ public class RewriteDistributiveMatrixMultTest extends 
AutomatedTestBase {
                        HashMap<CellIndex, Double> rfile = 
readRMatrixFromExpectedDir("R");
                        TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
 
-                       //check matrix mult existence
-                       String ba = "ba+*";
-                       long numMatMul = Statistics.getCPHeavyHitterCount(ba);
-
-                       if(rewrites == true) {
-                               Assert.assertTrue(numMatMul == 1);
-                       }
-                       else {
-                               Assert.assertTrue(numMatMul == 2);
-                       }
-
+                       long count = Statistics.getCPHeavyHitterCount("!");
+                       Assert.assertTrue(count == (rewrites ? 0 : 1));
                }
                finally {
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
                }
-
        }
 }
diff --git a/src/test/scripts/functions/rewrite/RewriteNotOverComparisons.R 
b/src/test/scripts/functions/rewrite/RewriteNotOverComparisons.R
new file mode 100644
index 0000000000..a8e138734e
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteNotOverComparisons.R
@@ -0,0 +1,45 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+
+# Set options for numeric precision
+options(digits=22)
+
+# Load required libraries
+library("Matrix")
+library("matrixStats")
+
+# Read matrices A, B, and C from Matrix Market format files
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B = as.matrix(readMM(paste(args[1], "B.mtx", sep="")))
+type = as.integer(args[2])
+
+if( type == 1 ) {
+  R = !(A > B)
+} else if( type == 2 ) {
+  R = !(A < B)
+} else if( type == 3 ) {
+  R = !(A == 0)
+}
+
+writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep=""))
diff --git a/src/test/scripts/functions/rewrite/RewriteNotOverComparisons.dml 
b/src/test/scripts/functions/rewrite/RewriteNotOverComparisons.dml
new file mode 100644
index 0000000000..c09f7c11e5
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteNotOverComparisons.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+
+# Load matrices A, and B
+A = read($1)
+B = read($2)
+
+if( $3 == 1 )
+  R = !(A > B)
+else if( $3 == 2 )
+  R = !(A < B)
+else if( $3 == 3 )
+  R = !(A == 0)
+
+write(R, $4)

Reply via email to