This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 9e649c8254 [SYSTEMDS-3884] Additional rewrites subtraction and addition 9e649c8254 is described below commit 9e649c8254b2d20ff8bcd8f66c76e0aeff47e1d0 Author: aarna <aarnatya...@gmail.com> AuthorDate: Tue May 13 08:00:02 2025 +0200 [SYSTEMDS-3884] Additional rewrites subtraction and addition -(B-A)->A-B t(A+1)+2 -> t(A)+1+2 -> t(A)+3 Closes #2258. --- .../RewriteAlgebraicSimplificationStatic.java | 103 ++++++++++++++++++++- .../RewriteSimplifyNegatedSubtractionTest.java | 90 ++++++++++++++++++ .../RewriteSimplifyTransposeAdditionTest.java | 93 +++++++++++++++++++ .../functions/rewrite/RewriteNegatedSubtraction.R | 31 +++++++ .../rewrite/RewriteNegatedSubtraction.dml | 27 ++++++ .../rewrite/RewriteSimplifyTransposeAddition.R | 30 ++++++ .../rewrite/RewriteSimplifyTransposeAddition.dml | 26 ++++++ 7 files changed, 399 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index f59d334d17..b8bf05184a 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -199,7 +199,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hi = simplifyBinaryComparisonChain(hop, hi, i); //e.g., outer(v1,v2,"==")==1 -> outer(v1,v2,"=="), outer(v1,v2,"==")==0 -> outer(v1,v2,"!="), hi = simplifyCumsumColOrFullAggregates(hi); //e.g., colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1)) hi = simplifyCumsumReverse(hop, hi, i); //e.g., rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X) - + hi = simplifyNegatedSubtraction(hop, hi, i); //e.g., -(B-A)->A-B + hi = simplifyTransposeAddition(hop, hi, i); //e.g., t(A+1)+2 -> t(A)+1+2 -> t(A)+3 hi = simplifyNotOverComparisons(hop, hi, i); //e.g., !(A>B) -> (A<=B) //hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X)) @@ -211,6 +212,106 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hop.setVisited(); } + private static Hop simplifyTransposeAddition(Hop parent, Hop hi, int pos) { + if (!(hi instanceof BinaryOp) + || ((BinaryOp)hi).getOp() != OpOp2.PLUS + || hi.getDataType() != DataType.MATRIX) + return hi; + + BinaryOp bop = (BinaryOp)hi; + + ReorgOp tSide = null; + LiteralOp litSide = null; + Hop in0 = bop.getInput().get(0), in1 = bop.getInput().get(1); + if (in0 instanceof ReorgOp && ((ReorgOp)in0).getOp() == ReOrgOp.TRANS + && in1 instanceof LiteralOp) { + tSide = (ReorgOp)in0; + litSide = (LiteralOp)in1; + } + else if (in1 instanceof ReorgOp && ((ReorgOp)in1).getOp() == ReOrgOp.TRANS + && in0 instanceof LiteralOp) { + tSide = (ReorgOp)in1; + litSide = (LiteralOp)in0; + } + else + return hi; + + //check if only consumer + if (tSide.getParent().size() > 1) { + return hi; + } + + Hop inner = tSide.getInput().get(0); + if (!(inner instanceof BinaryOp) + || ((BinaryOp)inner).getOp() != OpOp2.PLUS + || inner.getDataType() != DataType.MATRIX) + return hi; + + BinaryOp ib = (BinaryOp)inner; + + Hop X = null; + LiteralOp lit1 = null; + Hop i0 = ib.getInput().get(0), i1 = ib.getInput().get(1); + if (i0 instanceof LiteralOp) { + lit1 = (LiteralOp)i0; + X = i1; + } + else if (i1 instanceof LiteralOp) { + lit1 = (LiteralOp)i1; + X = i0; + } + else + return hi; + + double c = lit1.getDoubleValue() + litSide.getDoubleValue(); + + ReorgOp newT = HopRewriteUtils.createTranspose(X); + newT.setDim1(tSide.getDim1()); + newT.setDim2(tSide.getDim2()); + + LiteralOp newLit = new LiteralOp(c); + newLit.setDim1(1); + newLit.setDim2(1); + + //creating new binaryOp + BinaryOp newPlus = HopRewriteUtils.createBinary(newT, newLit, OpOp2.PLUS); + newPlus.setDim1(bop.getDim1()); + newPlus.setDim2(bop.getDim2()); + + HopRewriteUtils.replaceChildReference(parent, bop, newPlus, pos); + HopRewriteUtils.cleanupUnreferenced(bop, tSide, ib, litSide); + + LOG.debug("Applied simplifyTransposeAddition (line " + hi.getBeginLine() + ")."); + + return newPlus; + } + + private static Hop simplifyNegatedSubtraction(Hop parent, Hop hi, int pos) { + if (hi instanceof BinaryOp + && ((BinaryOp) hi).getOp() == OpOp2.MINUS + && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 0) + && hi.getParent().size() == 1 + && hi.getInput().get(1) instanceof BinaryOp + && ((BinaryOp) hi.getInput().get(1)).getOp() == OpOp2.MINUS + && hi.getInput().get(1).getParent().size() == 1) + { + Hop innerMinus = hi.getInput().get(1); + Hop B = innerMinus.getInput().get(0); + Hop A = innerMinus.getInput().get(1); + + BinaryOp newHop = HopRewriteUtils.createBinary(A, B, OpOp2.MINUS); + + HopRewriteUtils.copyLineNumbers(hi, newHop); + HopRewriteUtils.replaceChildReference(parent, hi, newHop, pos); + HopRewriteUtils.cleanupUnreferenced(hi); + hi = newHop; + + LOG.debug("Applied simplifyNegatedSubtraction (line " + hi.getBeginLine() + ")."); + } + return hi; + } + + private static Hop removeUnnecessaryVectorizeOperation(Hop hi) { //applies to all binary matrix operations, if one input is unnecessarily vectorized diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyNegatedSubtractionTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyNegatedSubtractionTest.java new file mode 100644 index 0000000000..da5876e343 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyNegatedSubtractionTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; +import org.junit.Test; +import org.junit.Assert; +import java.util.HashMap; + +public class RewriteSimplifyNegatedSubtractionTest extends AutomatedTestBase { + private static final String TEST_NAME = "RewriteNegatedSubtraction"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyNegatedSubtractionTest.class.getSimpleName() + "/"; + private static final int rows = 100; + private static final int cols = 100; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, + new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"})); + } + + @Test + public void testRewriteEnabled() { + runRewriteTest(true); + } + + @Test + public void testRewriteDisabled() { + runRewriteTest(false); + } + + private void runRewriteTest(boolean rewriteEnabled) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + fullRScriptName = HOME + TEST_NAME + ".R"; + programArgs = new String[]{"-stats", "-args", input("A"), input("B"), output("R")}; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewriteEnabled; + + // Generate input matrices + double[][] A = getRandomMatrix(rows, cols, -10, 10, 0.7, 3); + double[][] B = getRandomMatrix(rows, cols, -10, 10, 0.7, 7); + writeInputMatrixWithMTD("A", A, true); + writeInputMatrixWithMTD("B", B, true); + + // Run DML script + runTest(true, false, null, -1); + runRScript(true); + + HashMap<MatrixValue.CellIndex, Double> dml = readDMLMatrixFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> r = readRMatrixFromExpectedDir("R"); + + Assert.assertEquals("DML and R outputs do not match", r, dml); + if( rewriteEnabled ) + Assert.assertEquals(1, Statistics.getCPHeavyHitterCount("-")); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposeAdditionTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposeAdditionTest.java new file mode 100644 index 0000000000..9247e07d4f --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposeAdditionTest.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysds.test.functions.rewrite; + +import java.util.HashMap; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; +import org.junit.Assert; +import org.junit.Test; + +public class RewriteSimplifyTransposeAdditionTest extends AutomatedTestBase { + private static final String TEST_NAME = "RewriteSimplifyTransposeAddition"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyTransposeAdditionTest.class.getSimpleName() + "/"; + + private static final int rows = 100; + private static final int cols = 100; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"})); + } + + @Test + public void testRewriteEnabled() { + runRewriteTest(true); + } + + @Test + public void testRewriteDisabled() { + runRewriteTest(false); + } + + private void runRewriteTest(boolean rewriteEnabled) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + fullRScriptName = HOME + TEST_NAME + ".R"; + + // DML script parameters + programArgs = new String[]{"-stats", "-args", input("A"), output("R")}; + rCmd = getRCmd(inputDir(), expectedDir()); + + // Set optimizer flags + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewriteEnabled; + + // Generate input matrix + double[][] A = getRandomMatrix(rows, cols, -10, 10, 0.7, 3); + writeInputMatrixWithMTD("A", A, true); + + // Run DML and R scripts + runTest(true, false, null, -1); + runRScript(true); + + // Compare output matrices + HashMap<CellIndex, Double> dml = readDMLMatrixFromOutputDir("R"); + HashMap<CellIndex, Double> r = readRMatrixFromExpectedDir("R"); + + Assert.assertEquals("DML and R outputs do not match", r, dml); + if( rewriteEnabled ) + Assert.assertEquals(1, Statistics.getCPHeavyHitterCount("+")); + } + finally { + // Reset optimizer flags + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteNegatedSubtraction.R b/src/test/scripts/functions/rewrite/RewriteNegatedSubtraction.R new file mode 100644 index 0000000000..26492f9ec8 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteNegatedSubtraction.R @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +library("Matrix") + +args <- commandArgs(TRUE) + +A <- as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +B <- as.matrix(readMM(paste(args[1], "B.mtx", sep=""))) + +R <- A - B + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")) + diff --git a/src/test/scripts/functions/rewrite/RewriteNegatedSubtraction.dml b/src/test/scripts/functions/rewrite/RewriteNegatedSubtraction.dml new file mode 100644 index 0000000000..a25e40f2db --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteNegatedSubtraction.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +A = read($1); +B = read($2); + +# Expression that will be rewritten +R = 0 - (B - A); + +write(R, $3); diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.R b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.R new file mode 100644 index 0000000000..6bc82690aa --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.R @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +args <- commandArgs(TRUE) +library("Matrix") + +A = as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) + +# Compute t(A)+3 +R <- t(A)+3 + +# Write the result matrix +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")) \ No newline at end of file diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.dml new file mode 100644 index 0000000000..d27a471238 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposeAddition.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +A = read($1); + +# Compute t(A+1)+2 which should be rewritten to t(A)+3 +result = t(A+1)+2; + +write(result, $2); \ No newline at end of file