This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 9e649c8254 [SYSTEMDS-3884] Additional rewrites subtraction and addition
9e649c8254 is described below

commit 9e649c8254b2d20ff8bcd8f66c76e0aeff47e1d0
Author: aarna <aarnatya...@gmail.com>
AuthorDate: Tue May 13 08:00:02 2025 +0200

    [SYSTEMDS-3884] Additional rewrites subtraction and addition
    
    -(B-A)->A-B
    t(A+1)+2 -> t(A)+1+2 -> t(A)+3
    
    Closes #2258.
---
 .../RewriteAlgebraicSimplificationStatic.java      | 103 ++++++++++++++++++++-
 .../RewriteSimplifyNegatedSubtractionTest.java     |  90 ++++++++++++++++++
 .../RewriteSimplifyTransposeAdditionTest.java      |  93 +++++++++++++++++++
 .../functions/rewrite/RewriteNegatedSubtraction.R  |  31 +++++++
 .../rewrite/RewriteNegatedSubtraction.dml          |  27 ++++++
 .../rewrite/RewriteSimplifyTransposeAddition.R     |  30 ++++++
 .../rewrite/RewriteSimplifyTransposeAddition.dml   |  26 ++++++
 7 files changed, 399 insertions(+), 1 deletion(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index f59d334d17..b8bf05184a 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -199,7 +199,8 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = simplifyBinaryComparisonChain(hop, hi, i);      
//e.g., outer(v1,v2,"==")==1 -> outer(v1,v2,"=="), outer(v1,v2,"==")==0 -> 
outer(v1,v2,"!="),
                        hi = simplifyCumsumColOrFullAggregates(hi);          
//e.g., colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
                        hi = simplifyCumsumReverse(hop, hi, i);              
//e.g., rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)
-
+                       hi = simplifyNegatedSubtraction(hop, hi, i);         
//e.g., -(B-A)->A-B
+                       hi = simplifyTransposeAddition(hop, hi, i);          
//e.g., t(A+1)+2 -> t(A)+1+2 -> t(A)+3
                        hi = simplifyNotOverComparisons(hop, hi, i);         
//e.g., !(A>B) -> (A<=B)
                        //hi = removeUnecessaryPPred(hop, hi, i);            
//e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
 
@@ -211,6 +212,106 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                hop.setVisited();
        }
 
+       private static Hop simplifyTransposeAddition(Hop parent, Hop hi, int 
pos) {
+               if (!(hi instanceof BinaryOp)
+                               || ((BinaryOp)hi).getOp() != OpOp2.PLUS
+                               || hi.getDataType() != DataType.MATRIX)
+                       return hi;
+
+               BinaryOp bop = (BinaryOp)hi;
+
+               ReorgOp tSide = null;
+               LiteralOp litSide = null;
+               Hop in0 = bop.getInput().get(0), in1 = bop.getInput().get(1);
+               if (in0 instanceof ReorgOp && ((ReorgOp)in0).getOp() == 
ReOrgOp.TRANS
+                               && in1 instanceof LiteralOp) {
+                       tSide = (ReorgOp)in0;
+                       litSide = (LiteralOp)in1;
+               }
+               else if (in1 instanceof ReorgOp && ((ReorgOp)in1).getOp() == 
ReOrgOp.TRANS
+                               && in0 instanceof LiteralOp) {
+                       tSide = (ReorgOp)in1;
+                       litSide = (LiteralOp)in0;
+               }
+               else
+                       return hi;
+
+               //check if only consumer
+               if (tSide.getParent().size() > 1) {
+                       return hi;
+               }
+
+               Hop inner = tSide.getInput().get(0);
+               if (!(inner instanceof BinaryOp)
+                               || ((BinaryOp)inner).getOp() != OpOp2.PLUS
+                               || inner.getDataType() != DataType.MATRIX)
+                       return hi;
+
+               BinaryOp ib = (BinaryOp)inner;
+
+               Hop X = null;
+               LiteralOp lit1 = null;
+               Hop i0 = ib.getInput().get(0), i1 = ib.getInput().get(1);
+               if (i0 instanceof LiteralOp) {
+                       lit1 = (LiteralOp)i0;
+                       X = i1;
+               }
+               else if (i1 instanceof LiteralOp) {
+                       lit1 = (LiteralOp)i1;
+                       X = i0;
+               }
+               else
+                       return hi;
+
+               double c = lit1.getDoubleValue() + litSide.getDoubleValue();
+
+               ReorgOp newT = HopRewriteUtils.createTranspose(X);
+               newT.setDim1(tSide.getDim1());
+               newT.setDim2(tSide.getDim2());
+
+               LiteralOp newLit = new LiteralOp(c);
+               newLit.setDim1(1);
+               newLit.setDim2(1);
+
+               //creating new binaryOp
+               BinaryOp newPlus = HopRewriteUtils.createBinary(newT, newLit, 
OpOp2.PLUS);
+               newPlus.setDim1(bop.getDim1());
+               newPlus.setDim2(bop.getDim2());
+
+               HopRewriteUtils.replaceChildReference(parent, bop, newPlus, 
pos);
+               HopRewriteUtils.cleanupUnreferenced(bop, tSide, ib, litSide);
+
+               LOG.debug("Applied simplifyTransposeAddition (line " + 
hi.getBeginLine() + ").");
+
+               return newPlus;
+       }
+
+       private static Hop simplifyNegatedSubtraction(Hop parent, Hop hi, int 
pos) {
+               if (hi instanceof BinaryOp
+                               && ((BinaryOp) hi).getOp() == OpOp2.MINUS
+                               && 
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 0)
+                               && hi.getParent().size() == 1
+                               && hi.getInput().get(1) instanceof BinaryOp
+                               && ((BinaryOp) hi.getInput().get(1)).getOp() == 
OpOp2.MINUS
+                               && hi.getInput().get(1).getParent().size() == 1)
+               {
+                       Hop innerMinus = hi.getInput().get(1);
+                       Hop B = innerMinus.getInput().get(0);
+                       Hop A = innerMinus.getInput().get(1);
+
+                       BinaryOp newHop = HopRewriteUtils.createBinary(A, B, 
OpOp2.MINUS);
+
+                       HopRewriteUtils.copyLineNumbers(hi, newHop);
+                       HopRewriteUtils.replaceChildReference(parent, hi, 
newHop, pos);
+                       HopRewriteUtils.cleanupUnreferenced(hi);
+                       hi = newHop;
+
+                       LOG.debug("Applied simplifyNegatedSubtraction (line " + 
hi.getBeginLine() + ").");
+               }
+               return hi;
+       }
+
+
        private static Hop removeUnnecessaryVectorizeOperation(Hop hi)
        {
                //applies to all binary matrix operations, if one input is 
unnecessarily vectorized
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyNegatedSubtractionTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyNegatedSubtractionTest.java
new file mode 100644
index 0000000000..da5876e343
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyNegatedSubtractionTest.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.test.functions.rewrite;
+
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Test;
+import org.junit.Assert;
+import java.util.HashMap;
+
+public class RewriteSimplifyNegatedSubtractionTest extends AutomatedTestBase {
+       private static final String TEST_NAME = "RewriteNegatedSubtraction";
+       private static final String TEST_DIR = "functions/rewrite/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteSimplifyNegatedSubtractionTest.class.getSimpleName() + "/";
+       private static final int rows = 100;
+       private static final int cols = 100;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME,
+                               new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAME, new String[]{"R"}));
+       }
+
+       @Test
+       public void testRewriteEnabled() {
+               runRewriteTest(true);
+       }
+
+       @Test
+       public void testRewriteDisabled() {
+               runRewriteTest(false);
+       }
+
+       private void runRewriteTest(boolean rewriteEnabled) {
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               try {
+                       TestConfiguration config = 
getTestConfiguration(TEST_NAME);
+                       loadTestConfiguration(config);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       fullRScriptName = HOME + TEST_NAME + ".R";
+                       programArgs = new String[]{"-stats", "-args", 
input("A"), input("B"), output("R")};
+                       rCmd = getRCmd(inputDir(), expectedDir());
+
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewriteEnabled;
+
+                       // Generate input matrices
+                       double[][] A = getRandomMatrix(rows, cols, -10, 10, 
0.7, 3);
+                       double[][] B = getRandomMatrix(rows, cols, -10, 10, 
0.7, 7);
+                       writeInputMatrixWithMTD("A", A, true);
+                       writeInputMatrixWithMTD("B", B, true);
+
+                       // Run DML script
+                       runTest(true, false, null, -1);
+                       runRScript(true);
+
+                       HashMap<MatrixValue.CellIndex, Double> dml = 
readDMLMatrixFromOutputDir("R");
+                       HashMap<MatrixValue.CellIndex, Double> r = 
readRMatrixFromExpectedDir("R");
+
+                       Assert.assertEquals("DML and R outputs do not match", 
r, dml);
+                       if( rewriteEnabled )
+                               Assert.assertEquals(1, 
Statistics.getCPHeavyHitterCount("-"));
+               }
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposeAdditionTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposeAdditionTest.java
new file mode 100644
index 0000000000..9247e07d4f
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposeAdditionTest.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.test.functions.rewrite;
+
+import java.util.HashMap;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class RewriteSimplifyTransposeAdditionTest extends AutomatedTestBase {
+       private static final String TEST_NAME = 
"RewriteSimplifyTransposeAddition";
+       private static final String TEST_DIR = "functions/rewrite/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteSimplifyTransposeAdditionTest.class.getSimpleName() + "/";
+
+       private static final int rows = 100;
+       private static final int cols = 100;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"}));
+       }
+
+       @Test
+       public void testRewriteEnabled() {
+               runRewriteTest(true);
+       }
+
+       @Test
+       public void testRewriteDisabled() {
+               runRewriteTest(false);
+       }
+
+       private void runRewriteTest(boolean rewriteEnabled) {
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               try {
+                       TestConfiguration config = 
getTestConfiguration(TEST_NAME);
+                       loadTestConfiguration(config);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       fullRScriptName = HOME + TEST_NAME + ".R";
+
+                       // DML script parameters
+                       programArgs = new String[]{"-stats", "-args", 
input("A"), output("R")};
+                       rCmd = getRCmd(inputDir(), expectedDir());
+
+                       // Set optimizer flags
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewriteEnabled;
+
+                       // Generate input matrix
+                       double[][] A = getRandomMatrix(rows, cols, -10, 10, 
0.7, 3);
+                       writeInputMatrixWithMTD("A", A, true);
+
+                       // Run DML and R scripts
+                       runTest(true, false, null, -1);
+                       runRScript(true);
+
+                       // Compare output matrices
+                       HashMap<CellIndex, Double> dml = 
readDMLMatrixFromOutputDir("R");
+                       HashMap<CellIndex, Double> r = 
readRMatrixFromExpectedDir("R");
+
+                       Assert.assertEquals("DML and R outputs do not match", 
r, dml);
+                       if( rewriteEnabled )
+                               Assert.assertEquals(1, 
Statistics.getCPHeavyHitterCount("+"));
+               }
+               finally {
+                       // Reset optimizer flags
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/rewrite/RewriteNegatedSubtraction.R 
b/src/test/scripts/functions/rewrite/RewriteNegatedSubtraction.R
new file mode 100644
index 0000000000..26492f9ec8
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteNegatedSubtraction.R
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+library("Matrix")
+
+args <- commandArgs(TRUE)
+
+A <- as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B <- as.matrix(readMM(paste(args[1], "B.mtx", sep="")))
+
+R <- A - B
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
+
diff --git a/src/test/scripts/functions/rewrite/RewriteNegatedSubtraction.dml 
b/src/test/scripts/functions/rewrite/RewriteNegatedSubtraction.dml
new file mode 100644
index 0000000000..a25e40f2db
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteNegatedSubtraction.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+A = read($1);
+B = read($2);
+
+# Expression that will be rewritten
+R = 0 - (B - A);
+
+write(R, $3);
diff --git 
a/src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.R 
b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.R
new file mode 100644
index 0000000000..6bc82690aa
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.R
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+args <- commandArgs(TRUE)
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+
+# Compute t(A)+3
+R <- t(A)+3
+
+# Write the result matrix
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
\ No newline at end of file
diff --git 
a/src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.dml 
b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.dml
new file mode 100644
index 0000000000..d27a471238
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+A = read($1);
+
+# Compute t(A+1)+2 which should be rewritten to t(A)+3
+result = t(A+1)+2;
+
+write(result, $2);
\ No newline at end of file

Reply via email to