This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 4a48a689af [SYSTEMDS-3889] New simplification rewrite for matrix-scalar ops 4a48a689af is described below commit 4a48a689afb56239d3df81926606f720e9501fc4 Author: aarna <aarnatya...@gmail.com> AuthorDate: Fri Jun 13 15:37:33 2025 +0200 [SYSTEMDS-3889] New simplification rewrite for matrix-scalar ops e.g., a-A-b -> (a-b)-A; a+A-b -> (a-b)+A Closes #2272. --- .../RewriteAlgebraicSimplificationStatic.java | 35 ++++++++ ...RewriteSimplifyScalarMatrixPMOperationTest.java | 98 ++++++++++++++++++++++ .../rewrite/RewriteScalarMinusMatrixMinusScalar.R | 30 +++++++ .../RewriteScalarMinusMatrixMinusScalar.dml | 28 +++++++ .../rewrite/RewriteScalarPlusMatrixMinusScalar.R | 30 +++++++ .../rewrite/RewriteScalarPlusMatrixMinusScalar.dml | 28 +++++++ 6 files changed, 249 insertions(+) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 65c8805c7c..ef5670dda8 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -202,6 +202,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hi = simplifyNegatedSubtraction(hop, hi, i); //e.g., -(B-A)->A-B hi = simplifyTransposeAddition(hop, hi, i); //e.g., t(A+s1)+s2 -> t(A)+(s1+s2) + potential constant folding hi = simplifyNotOverComparisons(hop, hi, i); //e.g., !(A>B) -> (A<=B) + hi = simplifyMatrixScalarPMOperation(hop, hi, i); //e.g., a-A-b -> (a-b)-A; a+A-b -> (a-b)+A //hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X)) //process childs recursively after rewrites (to investigate pattern newly created by rewrites) @@ -212,6 +213,40 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hop.setVisited(); } + private Hop simplifyMatrixScalarPMOperation(Hop parent, Hop hi, int pos) { + if (!(hi instanceof BinaryOp)) + return hi; + + BinaryOp outer = (BinaryOp) hi; + Hop left = outer.getInput(0); + Hop right = outer.getInput(1); + OpOp2 outerOp = outer.getOp(); + + if((outerOp != OpOp2.PLUS && outerOp != OpOp2.MINUS) || !(left instanceof BinaryOp)) + return hi; + + Hop a = left.getInput(0); + Hop A = left.getInput(1); + Hop b = right; + + java.util.function.Predicate<Hop> isScalar = h -> h.getDataType().isScalar(); + if (!isScalar.test(a) || !isScalar.test(b) || A.getDataType() != DataType.MATRIX) + return hi; + + // Determine the scalarOp (between a and b) and matrixOp (with A) + OpOp2 innerOp = ((BinaryOp)left).getOp(); + if( innerOp != OpOp2.PLUS && innerOp != OpOp2.MINUS ) + return hi; + OpOp2 scalarOp = (outerOp == OpOp2.PLUS) ? OpOp2.PLUS : OpOp2.MINUS; + OpOp2 matrixOp = (innerOp == OpOp2.PLUS) ? OpOp2.PLUS : OpOp2.MINUS; + Hop scalarCombined = HopRewriteUtils.createBinary(a, b, scalarOp); + Hop result = HopRewriteUtils.createBinary(scalarCombined, A, matrixOp); + + HopRewriteUtils.replaceChildReference(parent, hi, result, pos); + LOG.debug("Applied simplifyMatrixScalarPMOperation"); + return result; + } + private static Hop simplifyTransposeAddition(Hop parent, Hop hi, int pos) { //pattern: t(A+s1)+s2 -> t(A)+(s1+s2), and subsequent constant folding if (HopRewriteUtils.isBinary(hi, OpOp2.PLUS) diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java new file mode 100644 index 0000000000..64d3b06544 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyScalarMatrixPMOperationTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyScalarMatrixPMOperationTest extends AutomatedTestBase { + private static final String TEST_NAME1 = "RewriteScalarMinusMatrixMinusScalar"; + private static final String TEST_NAME2 = "RewriteScalarPlusMatrixMinusScalar"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyScalarMatrixPMOperationTest.class.getSimpleName() + "/"; + private static final int rows = 100; + private static final int cols = 100; + private static final double eps = 1e-6; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"A", "a", "b", "R"})); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"A", "a", "b", "R"})); + } + + @Test + public void testScalarMinusMatrixMinusScalarRewriteEnabled() { + runRewriteTest(TEST_NAME1, true); + } + + @Test + public void testScalarMinusMatrixMinusScalarRewriteDisabled() { + runRewriteTest(TEST_NAME1, false); + } + + @Test + public void testScalarPlusMatrixMinusScalarRewriteEnabled() { + runRewriteTest(TEST_NAME2, true); + } + + @Test + public void testScalarPlusMatrixMinusScalarRewriteDisabled() { + runRewriteTest(TEST_NAME2, false); + } + + private void runRewriteTest(String testName, boolean rewriteEnabled) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(testName); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testName + ".dml"; + fullRScriptName = HOME + testName + ".R"; + programArgs = new String[]{"-stats", "-args", input("A"), input("a"), input("b"), output("R")}; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewriteEnabled; + + double[][] A = getRandomMatrix(rows, cols, -100, 100, 0.9, 3); + double[][] a = getRandomMatrix(1, 1, -10, 10, 1.0, 7); + double[][] b = getRandomMatrix(1, 1, -10, 10, 1.0, 5); + + writeInputMatrixWithMTD("A", A, true); + writeInputMatrixWithMTD("a", a, true); + writeInputMatrixWithMTD("b", b, true); + + runTest(true, false, null, -1); + runRScript(true); + + HashMap<MatrixValue.CellIndex, Double> dml = readDMLMatrixFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> r = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dml, r, eps, "Stat-DML", "Stat-R"); + } finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R new file mode 100644 index 0000000000..bd9ab23ed2 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.R @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +args <- commandArgs(TRUE) +library("Matrix") + +A <- as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +a <- as.numeric(readMM(paste(args[1], "a.mtx", sep=""))) +b <- as.numeric(readMM(paste(args[1], "b.mtx", sep=""))) + +R <- (a-b)-A + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.dml b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.dml new file mode 100644 index 0000000000..28cdb61dec --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteScalarMinusMatrixMinusScalar.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +A = read($1); +a = read($2); +b = read($3); + +R = a - A - b; + +write(R, $4); + diff --git a/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R new file mode 100644 index 0000000000..ec2764bb28 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.R @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +args <- commandArgs(TRUE) +library("Matrix") + +A <- as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +a <- as.numeric(readMM(paste(args[1], "a.mtx", sep=""))) +b <- as.numeric(readMM(paste(args[1], "b.mtx", sep=""))) + +R <- (a-b)+A + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.dml b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.dml new file mode 100644 index 0000000000..5ba04566ef --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteScalarPlusMatrixMinusScalar.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +A = read($1); +a = as.scalar(read($2)); +b = as.scalar(read($3)); + +# Original form: a + A - b +R = a + A - b; + +write(R, $4);