This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new c36a8369e5 [SYSTEMDS-3666] New simplification rewrite
not-over-comparisons
c36a8369e5 is described below
commit c36a8369e5f394a362ac69eb90e3ab62b50c8db9
Author: ReneEnjilian <[email protected]>
AuthorDate: Sun Mar 17 18:37:21 2024 +0100
[SYSTEMDS-3666] New simplification rewrite not-over-comparisons
Closes #1988.
---
.../RewriteAlgebraicSimplificationStatic.java | 36 ++++++++++
.../rewrite/RewriteDistributiveMatrixMultTest.java | 2 +-
...est.java => RewriteNotOverComparisonsTest.java} | 78 ++++++++++++----------
.../functions/rewrite/RewriteNotOverComparisons.R | 45 +++++++++++++
.../rewrite/RewriteNotOverComparisons.dml | 35 ++++++++++
5 files changed, 160 insertions(+), 36 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index b74e93d6cf..a867735e50 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -190,6 +190,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hi = simplifyCumsumColOrFullAggregates(hi);
//e.g., colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
hi = simplifyCumsumReverse(hop, hi, i);
//e.g., rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)
+ hi = simplifyNotOverComparisons(hop, hi, i);
//e.g., !(A>B) -> (A<=B)
+
//hi = removeUnecessaryPPred(hop, hi, i);
//e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
//process childs recursively after rewrites (to
investigate pattern newly created by rewrites)
@@ -1980,6 +1982,40 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
return hi;
}
+ private static Hop simplifyNotOverComparisons(Hop parent, Hop hi, int
pos){
+ if(HopRewriteUtils.isUnary(hi, OpOp1.NOT) && hi.getInput(0)
instanceof BinaryOp
+ && hi.getInput(0).getParent().size() == 1) //NOT is
only consumer
+ {
+ Hop binaryOperator = hi.getInput(0);
+ Hop A = binaryOperator.getInput(0);
+ Hop B = binaryOperator.getInput(1);
+ Hop newHop = null;
+
+ // !(A>B) -> A<=B
+ if(HopRewriteUtils.isBinary(binaryOperator,
OpOp2.GREATER)) {
+ newHop = HopRewriteUtils.createBinary(A, B,
OpOp2.LESSEQUAL);
+ }
+ // !(A<B) -> A>=B
+ else if(HopRewriteUtils.isBinary(binaryOperator,
OpOp2.LESS)) {
+ newHop = HopRewriteUtils.createBinary(A, B,
OpOp2.GREATEREQUAL);
+ }
+ // !(A==B) -> A!=B, including !(A==0) -> A!=0
+ else if(HopRewriteUtils.isBinary(binaryOperator,
OpOp2.EQUAL)) {
+ newHop = HopRewriteUtils.createBinary(A, B,
OpOp2.NOTEQUAL);
+ }
+ //TODO add remaining cases of comparison operators
+
+ if(parent != null && newHop != null) {
+ HopRewriteUtils.replaceChildReference(parent,
hi, newHop, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi);
+ hi = newHop;
+ LOG.debug("Applied simplifyNotOverComparisons
(line " + hi.getBeginLine() + ")");
+ }
+ }
+
+ return hi;
+ }
+
/**
* NOTE: currently disabled since this rewrite is INVALID in the
* presence of NaNs (because (NaN!=NaN) is true).
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
index 7f40a2bef3..2721afc01f 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
@@ -34,7 +34,7 @@ public class RewriteDistributiveMatrixMultTest extends
AutomatedTestBase {
private static final String TEST_NAME1 =
"RewriteDistributiveMatrixMult";
private static final String TEST_DIR = "functions/rewrite/";
private static final String TEST_CLASS_DIR =
- TEST_DIR +
RewriteSimplifyRowColSumMVMultTest.class.getSimpleName() + "/";
+ TEST_DIR +
RewriteDistributiveMatrixMultTest.class.getSimpleName() + "/";
private static final int rows = 500;
private static final int cols = 500;
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNotOverComparisonsTest.java
similarity index 56%
copy from
src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
copy to
src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNotOverComparisonsTest.java
index 7f40a2bef3..f0c167236a 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDistributiveMatrixMultTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNotOverComparisonsTest.java
@@ -30,55 +30,73 @@ import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.apache.sysds.utils.Statistics;
-public class RewriteDistributiveMatrixMultTest extends AutomatedTestBase {
- private static final String TEST_NAME1 =
"RewriteDistributiveMatrixMult";
+public class RewriteNotOverComparisonsTest extends AutomatedTestBase {
+
+ private static final String TEST_NAME = "RewriteNotOverComparisons";
private static final String TEST_DIR = "functions/rewrite/";
- private static final String TEST_CLASS_DIR =
- TEST_DIR +
RewriteSimplifyRowColSumMVMultTest.class.getSimpleName() + "/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewriteNotOverComparisonsTest.class.getSimpleName() + "/";
- private static final int rows = 500;
- private static final int cols = 500;
+ private static final int rows = 10;
+ private static final int cols = 10;
private static final double eps = Math.pow(10, -10);
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}));
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
+ }
+
+ @Test
+ public void testNotOverComparisonsGreaterNoRewrite() {
+ testRewriteNotOverComparisons(1, false);
+ }
+ @Test
+ public void testNotOverComparisonsGreaterRewrite() {
+ testRewriteNotOverComparisons(1, true);
+ }
+
+ @Test
+ public void testNotOverComparisonsLessNoRewrite() {
+ testRewriteNotOverComparisons(2, false);
}
@Test
- public void testDistributiveMatrixMultNoRewrite() {
- testRewriteDistributiveMatrixMult(TEST_NAME1, false);
+ public void testNotOverComparisonsLessRewrite() {
+ testRewriteNotOverComparisons(2, true);
}
@Test
- public void testDistributiveMatrixMultRewrite() {
- testRewriteDistributiveMatrixMult(TEST_NAME1, true);
+ public void testNotOverComparisonsEqualNoRewrite() {
+ testRewriteNotOverComparisons(3, false);
}
- private void testRewriteDistributiveMatrixMult(String testname, boolean
rewrites) {
+ @Test
+ public void testNotOverComparisonsEqualRewrite() {
+ testRewriteNotOverComparisons(3, true);
+ }
+
+ private void testRewriteNotOverComparisons(int ID, boolean rewrites) {
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
try {
- TestConfiguration config =
getTestConfiguration(testname);
+ TestConfiguration config =
getTestConfiguration(TEST_NAME);
loadTestConfiguration(config);
String HOME = SCRIPT_DIR + TEST_DIR;
- fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[] {"-stats", "-args",
input("A"), input("B"), input("C"), output("R")};
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "-args",
+ input("A"), input("B"), String.valueOf(ID),
output("R")};
- fullRScriptName = HOME + testname + ".R";
- rCmd = getRCmd(inputDir(), expectedDir());
+ fullRScriptName = HOME + TEST_NAME + ".R";
+ rCmd = getRCmd(inputDir(), String.valueOf(ID),
expectedDir());
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewrites;
- //create dense matrices so that rewrites are possible
+
double[][] A = getRandomMatrix(rows, cols, -1, 1,
0.70d, 7);
+ writeInputMatrixWithMTD("A", A, 65, true);
double[][] B = getRandomMatrix(rows, cols, -1, 1,
0.70d, 6);
- double[][] C = getRandomMatrix(rows, cols, -1, 1,
0.70d, 3);
- writeInputMatrixWithMTD("A", A, 174522, true);
- writeInputMatrixWithMTD("B", B, 174935, true);
- writeInputMatrixWithMTD("C", C, 174848, true);
-
+ writeInputMatrixWithMTD("B", B, 74, true);
+
runTest(true, false, null, -1);
runRScript(true);
@@ -87,21 +105,11 @@ public class RewriteDistributiveMatrixMultTest extends
AutomatedTestBase {
HashMap<CellIndex, Double> rfile =
readRMatrixFromExpectedDir("R");
TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
- //check matrix mult existence
- String ba = "ba+*";
- long numMatMul = Statistics.getCPHeavyHitterCount(ba);
-
- if(rewrites == true) {
- Assert.assertTrue(numMatMul == 1);
- }
- else {
- Assert.assertTrue(numMatMul == 2);
- }
-
+ long count = Statistics.getCPHeavyHitterCount("!");
+ Assert.assertTrue(count == (rewrites ? 0 : 1));
}
finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
}
-
}
}
diff --git a/src/test/scripts/functions/rewrite/RewriteNotOverComparisons.R
b/src/test/scripts/functions/rewrite/RewriteNotOverComparisons.R
new file mode 100644
index 0000000000..a8e138734e
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteNotOverComparisons.R
@@ -0,0 +1,45 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+
+# Set options for numeric precision
+options(digits=22)
+
+# Load required libraries
+library("Matrix")
+library("matrixStats")
+
+# Read matrices A, B, and C from Matrix Market format files
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B = as.matrix(readMM(paste(args[1], "B.mtx", sep="")))
+type = as.integer(args[2])
+
+if( type == 1 ) {
+ R = !(A > B)
+} else if( type == 2 ) {
+ R = !(A < B)
+} else if( type == 3 ) {
+ R = !(A == 0)
+}
+
+writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep=""))
diff --git a/src/test/scripts/functions/rewrite/RewriteNotOverComparisons.dml
b/src/test/scripts/functions/rewrite/RewriteNotOverComparisons.dml
new file mode 100644
index 0000000000..c09f7c11e5
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteNotOverComparisons.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+
+# Load matrices A, and B
+A = read($1)
+B = read($2)
+
+if( $3 == 1 )
+ R = !(A > B)
+else if( $3 == 2 )
+ R = !(A < B)
+else if( $3 == 3 )
+ R = !(A == 0)
+
+write(R, $4)