This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 4a48a689af [SYSTEMDS-3889] New simplification rewrite for 
matrix-scalar ops
4a48a689af is described below

commit 4a48a689afb56239d3df81926606f720e9501fc4
Author: aarna <aarnatya...@gmail.com>
AuthorDate: Fri Jun 13 15:37:33 2025 +0200

    [SYSTEMDS-3889] New simplification rewrite for matrix-scalar ops
    
    e.g., a-A-b -> (a-b)-A; a+A-b -> (a-b)+A
    Closes #2272.
---
 .../RewriteAlgebraicSimplificationStatic.java      | 35 ++++++++
 ...RewriteSimplifyScalarMatrixPMOperationTest.java | 98 ++++++++++++++++++++++
 .../rewrite/RewriteScalarMinusMatrixMinusScalar.R  | 30 +++++++
 .../RewriteScalarMinusMatrixMinusScalar.dml        | 28 +++++++
 .../rewrite/RewriteScalarPlusMatrixMinusScalar.R   | 30 +++++++
 .../rewrite/RewriteScalarPlusMatrixMinusScalar.dml | 28 +++++++
 6 files changed, 249 insertions(+)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 65c8805c7c..ef5670dda8 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -202,6 +202,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = simplifyNegatedSubtraction(hop, hi, i);         
//e.g., -(B-A)->A-B
                        hi = simplifyTransposeAddition(hop, hi, i);          
//e.g., t(A+s1)+s2 -> t(A)+(s1+s2) + potential constant folding
                        hi = simplifyNotOverComparisons(hop, hi, i);         
//e.g., !(A>B) -> (A<=B)
+                       hi = simplifyMatrixScalarPMOperation(hop, hi, i);    
//e.g., a-A-b -> (a-b)-A; a+A-b -> (a-b)+A
                        //hi = removeUnecessaryPPred(hop, hi, i);            
//e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
 
                        //process childs recursively after rewrites (to 
investigate pattern newly created by rewrites)
@@ -212,6 +213,40 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                hop.setVisited();
        }
 
+       private Hop simplifyMatrixScalarPMOperation(Hop parent, Hop hi, int 
pos) {
+               if (!(hi instanceof BinaryOp))
+                       return hi;
+
+               BinaryOp outer = (BinaryOp) hi;
+               Hop left = outer.getInput(0);
+               Hop right = outer.getInput(1);
+               OpOp2 outerOp = outer.getOp();
+
+               if((outerOp != OpOp2.PLUS && outerOp != OpOp2.MINUS) || !(left 
instanceof BinaryOp))
+                       return hi;
+
+               Hop a = left.getInput(0);
+               Hop A = left.getInput(1);
+               Hop b = right;
+               
+               java.util.function.Predicate<Hop> isScalar = h -> 
h.getDataType().isScalar();
+               if (!isScalar.test(a) || !isScalar.test(b) || A.getDataType() 
!= DataType.MATRIX)
+                       return hi;
+
+               // Determine the scalarOp (between a and b) and matrixOp (with 
A)
+               OpOp2 innerOp = ((BinaryOp)left).getOp();
+               if( innerOp != OpOp2.PLUS && innerOp != OpOp2.MINUS )
+                       return hi;
+               OpOp2 scalarOp = (outerOp == OpOp2.PLUS) ? OpOp2.PLUS : 
OpOp2.MINUS;
+               OpOp2 matrixOp = (innerOp == OpOp2.PLUS) ? OpOp2.PLUS : 
OpOp2.MINUS;
+               Hop scalarCombined = HopRewriteUtils.createBinary(a, b, 
scalarOp);
+               Hop result = HopRewriteUtils.createBinary(scalarCombined, A, 
matrixOp);
+
+               HopRewriteUtils.replaceChildReference(parent, hi, result, pos);
+               LOG.debug("Applied simplifyMatrixScalarPMOperation");
+               return result;
+       }
+
        private static Hop simplifyTransposeAddition(Hop parent, Hop hi, int 
pos) {
                //pattern: t(A+s1)+s2 -> t(A)+(s1+s2), and subsequent constant 
folding
                if (HopRewriteUtils.isBinary(hi, OpOp2.PLUS) 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java
new file mode 100644
index 0000000000..64d3b06544
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.test.functions.rewrite;
+
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class RewriteSimplifyScalarMatrixPMOperationTest extends 
AutomatedTestBase {
+       private static final String TEST_NAME1 = 
"RewriteScalarMinusMatrixMinusScalar";
+       private static final String TEST_NAME2 = 
"RewriteScalarPlusMatrixMinusScalar";
+       private static final String TEST_DIR = "functions/rewrite/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteSimplifyScalarMatrixPMOperationTest.class.getSimpleName() + "/";
+       private static final int rows = 100;
+       private static final int cols = 100;
+       private static final double eps = 1e-6;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"A", "a", "b", 
"R"}));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"A", "a", "b", 
"R"}));
+       }
+
+       @Test
+       public void testScalarMinusMatrixMinusScalarRewriteEnabled() {
+               runRewriteTest(TEST_NAME1, true);
+       }
+
+       @Test
+       public void testScalarMinusMatrixMinusScalarRewriteDisabled() {
+               runRewriteTest(TEST_NAME1, false);
+       }
+
+       @Test
+       public void testScalarPlusMatrixMinusScalarRewriteEnabled() {
+               runRewriteTest(TEST_NAME2, true);
+       }
+
+       @Test
+       public void testScalarPlusMatrixMinusScalarRewriteDisabled() {
+               runRewriteTest(TEST_NAME2, false);
+       }
+
+       private void runRewriteTest(String testName, boolean rewriteEnabled) {
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               try {
+                       TestConfiguration config = 
getTestConfiguration(testName);
+                       loadTestConfiguration(config);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testName + ".dml";
+                       fullRScriptName = HOME + testName + ".R";
+                       programArgs = new String[]{"-stats", "-args", 
input("A"), input("a"), input("b"), output("R")};
+                       rCmd = getRCmd(inputDir(), expectedDir());
+
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewriteEnabled;
+
+                       double[][] A = getRandomMatrix(rows, cols, -100, 100, 
0.9, 3);
+                       double[][] a = getRandomMatrix(1, 1, -10, 10, 1.0, 7);
+                       double[][] b = getRandomMatrix(1, 1, -10, 10, 1.0, 5);
+
+                       writeInputMatrixWithMTD("A", A, true);
+                       writeInputMatrixWithMTD("a", a, true);
+                       writeInputMatrixWithMTD("b", b, true);
+
+                       runTest(true, false, null, -1);
+                       runRScript(true);
+
+                       HashMap<MatrixValue.CellIndex, Double> dml = 
readDMLMatrixFromOutputDir("R");
+                       HashMap<MatrixValue.CellIndex, Double> r = 
readRMatrixFromExpectedDir("R");
+                       TestUtils.compareMatrices(dml, r, eps, "Stat-DML", 
"Stat-R");
+               } finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+       }
+}
diff --git 
a/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R 
b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R
new file mode 100644
index 0000000000..bd9ab23ed2
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+args <- commandArgs(TRUE)
+library("Matrix")
+
+A <- as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+a <- as.numeric(readMM(paste(args[1], "a.mtx", sep="")))
+b <- as.numeric(readMM(paste(args[1], "b.mtx", sep="")))
+
+R <- (a-b)-A
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
diff --git 
a/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.dml 
b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.dml
new file mode 100644
index 0000000000..28cdb61dec
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+A = read($1);
+a = read($2);
+b = read($3);
+
+R = a - A - b;
+
+write(R, $4);
+
diff --git 
a/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R 
b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R
new file mode 100644
index 0000000000..ec2764bb28
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+args <- commandArgs(TRUE)
+library("Matrix")
+
+A <- as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+a <- as.numeric(readMM(paste(args[1], "a.mtx", sep="")))
+b <- as.numeric(readMM(paste(args[1], "b.mtx", sep="")))
+
+R <- (a-b)+A
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
diff --git 
a/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.dml 
b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.dml
new file mode 100644
index 0000000000..5ba04566ef
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+A = read($1);
+a = as.scalar(read($2));
+b = as.scalar(read($3));
+
+# Original form: a + A - b
+R = a + A - b;
+
+write(R, $4);

Reply via email to