This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 9e649c8254 [SYSTEMDS-3884] Additional rewrites subtraction and addition
9e649c8254 is described below
commit 9e649c8254b2d20ff8bcd8f66c76e0aeff47e1d0
Author: aarna <[email protected]>
AuthorDate: Tue May 13 08:00:02 2025 +0200
[SYSTEMDS-3884] Additional rewrites subtraction and addition
-(B-A)->A-B
t(A+1)+2 -> t(A)+1+2 -> t(A)+3
Closes #2258.
---
.../RewriteAlgebraicSimplificationStatic.java | 103 ++++++++++++++++++++-
.../RewriteSimplifyNegatedSubtractionTest.java | 90 ++++++++++++++++++
.../RewriteSimplifyTransposeAdditionTest.java | 93 +++++++++++++++++++
.../functions/rewrite/RewriteNegatedSubtraction.R | 31 +++++++
.../rewrite/RewriteNegatedSubtraction.dml | 27 ++++++
.../rewrite/RewriteSimplifyTransposeAddition.R | 30 ++++++
.../rewrite/RewriteSimplifyTransposeAddition.dml | 26 ++++++
7 files changed, 399 insertions(+), 1 deletion(-)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index f59d334d17..b8bf05184a 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -199,7 +199,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hi = simplifyBinaryComparisonChain(hop, hi, i);
//e.g., outer(v1,v2,"==")==1 -> outer(v1,v2,"=="), outer(v1,v2,"==")==0 ->
outer(v1,v2,"!="),
hi = simplifyCumsumColOrFullAggregates(hi);
//e.g., colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
hi = simplifyCumsumReverse(hop, hi, i);
//e.g., rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)
-
+ hi = simplifyNegatedSubtraction(hop, hi, i);
//e.g., -(B-A)->A-B
+ hi = simplifyTransposeAddition(hop, hi, i);
//e.g., t(A+1)+2 -> t(A)+1+2 -> t(A)+3
hi = simplifyNotOverComparisons(hop, hi, i);
//e.g., !(A>B) -> (A<=B)
//hi = removeUnecessaryPPred(hop, hi, i);
//e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
@@ -211,6 +212,106 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hop.setVisited();
}
+ private static Hop simplifyTransposeAddition(Hop parent, Hop hi, int
pos) {
+ if (!(hi instanceof BinaryOp)
+ || ((BinaryOp)hi).getOp() != OpOp2.PLUS
+ || hi.getDataType() != DataType.MATRIX)
+ return hi;
+
+ BinaryOp bop = (BinaryOp)hi;
+
+ ReorgOp tSide = null;
+ LiteralOp litSide = null;
+ Hop in0 = bop.getInput().get(0), in1 = bop.getInput().get(1);
+ if (in0 instanceof ReorgOp && ((ReorgOp)in0).getOp() ==
ReOrgOp.TRANS
+ && in1 instanceof LiteralOp) {
+ tSide = (ReorgOp)in0;
+ litSide = (LiteralOp)in1;
+ }
+ else if (in1 instanceof ReorgOp && ((ReorgOp)in1).getOp() ==
ReOrgOp.TRANS
+ && in0 instanceof LiteralOp) {
+ tSide = (ReorgOp)in1;
+ litSide = (LiteralOp)in0;
+ }
+ else
+ return hi;
+
+ //check if only consumer
+ if (tSide.getParent().size() > 1) {
+ return hi;
+ }
+
+ Hop inner = tSide.getInput().get(0);
+ if (!(inner instanceof BinaryOp)
+ || ((BinaryOp)inner).getOp() != OpOp2.PLUS
+ || inner.getDataType() != DataType.MATRIX)
+ return hi;
+
+ BinaryOp ib = (BinaryOp)inner;
+
+ Hop X = null;
+ LiteralOp lit1 = null;
+ Hop i0 = ib.getInput().get(0), i1 = ib.getInput().get(1);
+ if (i0 instanceof LiteralOp) {
+ lit1 = (LiteralOp)i0;
+ X = i1;
+ }
+ else if (i1 instanceof LiteralOp) {
+ lit1 = (LiteralOp)i1;
+ X = i0;
+ }
+ else
+ return hi;
+
+ double c = lit1.getDoubleValue() + litSide.getDoubleValue();
+
+ ReorgOp newT = HopRewriteUtils.createTranspose(X);
+ newT.setDim1(tSide.getDim1());
+ newT.setDim2(tSide.getDim2());
+
+ LiteralOp newLit = new LiteralOp(c);
+ newLit.setDim1(1);
+ newLit.setDim2(1);
+
+ //creating new binaryOp
+ BinaryOp newPlus = HopRewriteUtils.createBinary(newT, newLit,
OpOp2.PLUS);
+ newPlus.setDim1(bop.getDim1());
+ newPlus.setDim2(bop.getDim2());
+
+ HopRewriteUtils.replaceChildReference(parent, bop, newPlus,
pos);
+ HopRewriteUtils.cleanupUnreferenced(bop, tSide, ib, litSide);
+
+ LOG.debug("Applied simplifyTransposeAddition (line " +
hi.getBeginLine() + ").");
+
+ return newPlus;
+ }
+
+ private static Hop simplifyNegatedSubtraction(Hop parent, Hop hi, int
pos) {
+ if (hi instanceof BinaryOp
+ && ((BinaryOp) hi).getOp() == OpOp2.MINUS
+ &&
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 0)
+ && hi.getParent().size() == 1
+ && hi.getInput().get(1) instanceof BinaryOp
+ && ((BinaryOp) hi.getInput().get(1)).getOp() ==
OpOp2.MINUS
+ && hi.getInput().get(1).getParent().size() == 1)
+ {
+ Hop innerMinus = hi.getInput().get(1);
+ Hop B = innerMinus.getInput().get(0);
+ Hop A = innerMinus.getInput().get(1);
+
+ BinaryOp newHop = HopRewriteUtils.createBinary(A, B,
OpOp2.MINUS);
+
+ HopRewriteUtils.copyLineNumbers(hi, newHop);
+ HopRewriteUtils.replaceChildReference(parent, hi,
newHop, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi);
+ hi = newHop;
+
+ LOG.debug("Applied simplifyNegatedSubtraction (line " +
hi.getBeginLine() + ").");
+ }
+ return hi;
+ }
+
+
private static Hop removeUnnecessaryVectorizeOperation(Hop hi)
{
//applies to all binary matrix operations, if one input is
unnecessarily vectorized
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyNegatedSubtractionTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyNegatedSubtractionTest.java
new file mode 100644
index 0000000000..da5876e343
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyNegatedSubtractionTest.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.test.functions.rewrite;
+
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Test;
+import org.junit.Assert;
+import java.util.HashMap;
+
+public class RewriteSimplifyNegatedSubtractionTest extends AutomatedTestBase {
+ private static final String TEST_NAME = "RewriteNegatedSubtraction";
+ private static final String TEST_DIR = "functions/rewrite/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewriteSimplifyNegatedSubtractionTest.class.getSimpleName() + "/";
+ private static final int rows = 100;
+ private static final int cols = 100;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME,
+ new TestConfiguration(TEST_CLASS_DIR,
TEST_NAME, new String[]{"R"}));
+ }
+
+ @Test
+ public void testRewriteEnabled() {
+ runRewriteTest(true);
+ }
+
+ @Test
+ public void testRewriteDisabled() {
+ runRewriteTest(false);
+ }
+
+ private void runRewriteTest(boolean rewriteEnabled) {
+ boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ try {
+ TestConfiguration config =
getTestConfiguration(TEST_NAME);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ fullRScriptName = HOME + TEST_NAME + ".R";
+ programArgs = new String[]{"-stats", "-args",
input("A"), input("B"), output("R")};
+ rCmd = getRCmd(inputDir(), expectedDir());
+
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewriteEnabled;
+
+ // Generate input matrices
+ double[][] A = getRandomMatrix(rows, cols, -10, 10,
0.7, 3);
+ double[][] B = getRandomMatrix(rows, cols, -10, 10,
0.7, 7);
+ writeInputMatrixWithMTD("A", A, true);
+ writeInputMatrixWithMTD("B", B, true);
+
+ // Run DML script
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ HashMap<MatrixValue.CellIndex, Double> dml =
readDMLMatrixFromOutputDir("R");
+ HashMap<MatrixValue.CellIndex, Double> r =
readRMatrixFromExpectedDir("R");
+
+ Assert.assertEquals("DML and R outputs do not match",
r, dml);
+ if( rewriteEnabled )
+ Assert.assertEquals(1,
Statistics.getCPHeavyHitterCount("-"));
+ }
+ finally {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposeAdditionTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposeAdditionTest.java
new file mode 100644
index 0000000000..9247e07d4f
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposeAdditionTest.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.test.functions.rewrite;
+
+import java.util.HashMap;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class RewriteSimplifyTransposeAdditionTest extends AutomatedTestBase {
+ private static final String TEST_NAME =
"RewriteSimplifyTransposeAddition";
+ private static final String TEST_DIR = "functions/rewrite/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewriteSimplifyTransposeAdditionTest.class.getSimpleName() + "/";
+
+ private static final int rows = 100;
+ private static final int cols = 100;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"}));
+ }
+
+ @Test
+ public void testRewriteEnabled() {
+ runRewriteTest(true);
+ }
+
+ @Test
+ public void testRewriteDisabled() {
+ runRewriteTest(false);
+ }
+
+ private void runRewriteTest(boolean rewriteEnabled) {
+ boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ try {
+ TestConfiguration config =
getTestConfiguration(TEST_NAME);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ fullRScriptName = HOME + TEST_NAME + ".R";
+
+ // DML script parameters
+ programArgs = new String[]{"-stats", "-args",
input("A"), output("R")};
+ rCmd = getRCmd(inputDir(), expectedDir());
+
+ // Set optimizer flags
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewriteEnabled;
+
+ // Generate input matrix
+ double[][] A = getRandomMatrix(rows, cols, -10, 10,
0.7, 3);
+ writeInputMatrixWithMTD("A", A, true);
+
+ // Run DML and R scripts
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ // Compare output matrices
+ HashMap<CellIndex, Double> dml =
readDMLMatrixFromOutputDir("R");
+ HashMap<CellIndex, Double> r =
readRMatrixFromExpectedDir("R");
+
+ Assert.assertEquals("DML and R outputs do not match",
r, dml);
+ if( rewriteEnabled )
+ Assert.assertEquals(1,
Statistics.getCPHeavyHitterCount("+"));
+ }
+ finally {
+ // Reset optimizer flags
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/rewrite/RewriteNegatedSubtraction.R
b/src/test/scripts/functions/rewrite/RewriteNegatedSubtraction.R
new file mode 100644
index 0000000000..26492f9ec8
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteNegatedSubtraction.R
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+library("Matrix")
+
+args <- commandArgs(TRUE)
+
+A <- as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B <- as.matrix(readMM(paste(args[1], "B.mtx", sep="")))
+
+R <- A - B
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
+
diff --git a/src/test/scripts/functions/rewrite/RewriteNegatedSubtraction.dml
b/src/test/scripts/functions/rewrite/RewriteNegatedSubtraction.dml
new file mode 100644
index 0000000000..a25e40f2db
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteNegatedSubtraction.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+A = read($1);
+B = read($2);
+
+# Expression that will be rewritten
+R = 0 - (B - A);
+
+write(R, $3);
diff --git
a/src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.R
b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.R
new file mode 100644
index 0000000000..6bc82690aa
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.R
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+args <- commandArgs(TRUE)
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+
+# Compute t(A)+3
+R <- t(A)+3
+
+# Write the result matrix
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
\ No newline at end of file
diff --git
a/src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.dml
b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.dml
new file mode 100644
index 0000000000..d27a471238
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+A = read($1);
+
+# Compute t(A+1)+2 which should be rewritten to t(A)+3
+result = t(A+1)+2;
+
+write(result, $2);
\ No newline at end of file