This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 3ce16d058f [SYSTEMDS-3864] Additional trace simplification rewrites 3ce16d058f is described below commit 3ce16d058f4bdb0bf4bec9cb9cc79458ae7519b6 Author: aarna <aarnatya...@gmail.com> AuthorDate: Tue Apr 22 09:41:48 2025 +0200 [SYSTEMDS-3864] Additional trace simplification rewrites Closes #2254. --- .../java/org/apache/sysds/hops/AggBinaryOp.java | 1 - .../RewriteAlgebraicSimplificationStatic.java | 42 +++++++++- .../rewrite/RewriteSimplifyTraceSumTest.java | 89 +++++++++++++++++++++ .../rewrite/RewriteSimplifyTraceTransposeTest.java | 90 ++++++++++++++++++++++ .../functions/rewrite/RewriteSimplifyTraceSum.R | 39 ++++++++++ .../functions/rewrite/RewriteSimplifyTraceSum.dml | 34 ++++++++ .../rewrite/RewriteSimplifyTraceTranspose.R | 31 ++++++++ .../rewrite/RewriteSimplifyTraceTranspose.dml | 31 ++++++++ 8 files changed, 355 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java index 5f9c6b41b3..0be3143206 100644 --- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java @@ -43,7 +43,6 @@ import org.apache.sysds.lops.MatMultCP; import org.apache.sysds.lops.PMMJ; import org.apache.sysds.lops.PMapMult; import org.apache.sysds.lops.Transform; -import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; import org.apache.sysds.runtime.matrix.data.MatrixBlock; diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index c46bc62400..f59d334d17 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -176,6 +176,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule if(OptimizerUtils.ALLOW_OPERATOR_FUSION) hi = fuseBinarySubDAGToUnaryOperation(hop, hi, i); //e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) -> selp(X) hi = simplifyTraceMatrixMult(hop, hi, i); //e.g., trace(X%*%Y)->sum(X*t(Y)); + hi = simplifyTraceSum(hop, hi, i); //e.g. , trace(A+B)->trace(A)+trace(B); + hi = simplifyTraceTranspose(hop, hi, i); //e.g. , trace(t(A))->trace(A) hi = simplifySlicedMatrixMult(hop, hi, i); //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1]; hi = simplifyListIndexing(hi); //e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1] hi = simplifyScalarIndexing(hop, hi, i); //e.g., as.scalar(X[i,1])->X[i,1] w/ scalar output @@ -201,7 +203,6 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hi = simplifyNotOverComparisons(hop, hi, i); //e.g., !(A>B) -> (A<=B) //hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X)) - //process childs recursively after rewrites (to investigate pattern newly created by rewrites) if( !descendFirst ) rule_AlgebraicSimplification(hi, descendFirst); @@ -1603,6 +1604,45 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule return hi; } + private static Hop simplifyTraceSum(Hop parent, Hop hi, int pos) { + if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == AggOp.TRACE) { + Hop hi2 = hi.getInput().get(0); + if (HopRewriteUtils.isBinary(hi2, OpOp2.PLUS) && hi2.getParent().size() == 1) { + Hop left = hi2.getInput().get(0); + Hop right = hi2.getInput().get(1); + + // Create trace nodes + AggUnaryOp traceLeft = HopRewriteUtils.createAggUnaryOp(left, AggOp.TRACE, Direction.RowCol); + AggUnaryOp traceRight = HopRewriteUtils.createAggUnaryOp(right, AggOp.TRACE, Direction.RowCol); + + // Add them + BinaryOp sum = HopRewriteUtils.createBinary(traceLeft, traceRight, OpOp2.PLUS); + + // Replace in DAG + HopRewriteUtils.replaceChildReference(parent, hi, sum, pos); + HopRewriteUtils.cleanupUnreferenced(hi, hi2); + + LOG.debug("Applied simplifyTraceSum rewrite"); + return sum; + } + } + return hi; + } + + private static Hop simplifyTraceTranspose(Hop parent, Hop hi, int pos) { + // Check if the current Hop is a trace operation + if ( HopRewriteUtils.isAggUnaryOp(hi, AggOp.TRACE) ) { + Hop input = hi.getInput().get(0); + + // Check if input is a transpose and it is only consumer + if (HopRewriteUtils.isReorg(input, ReOrgOp.TRANS) && input.getParent().size() == 1) { + HopRewriteUtils.replaceChildReference(hi, input, input.getInput(0)); + LOG.debug("Applied simplifyTraceTranspose rewrite"); + } + } + return hi; + } + private static Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int pos) { //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1] diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceSumTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceSumTest.java new file mode 100644 index 0000000000..e561b8e002 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceSumTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyTraceSumTest extends AutomatedTestBase { + private static final String TEST_NAME = "RewriteSimplifyTraceSum"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyTraceSumTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = 1e-10; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"})); + } + + @Test + public void testSimplifyTraceSumRewrite() { + runTraceRewriteTest(TEST_NAME, true); + } + + @Test + public void testSimplifyTraceSumNoRewrite() { + runTraceRewriteTest(TEST_NAME, false); + } + + private void runTraceRewriteTest(String testname, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + fullRScriptName = HOME + testname + ".R"; + + programArgs = new String[]{"-explain", "-stats", "-args", input("A"), input("B"), output("R")}; + rCmd = getRCmd(inputDir(), expectedDir()); + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + double[][] A = getRandomMatrix(rows, cols, -1, 1, 0.70d, 7); + double[][] B = getRandomMatrix(cols, rows, -1, 1, 0.70d, 6); + writeInputMatrixWithMTD("A", A, true); + writeInputMatrixWithMTD("B", B, true); + // Run SystemDS and R scripts + runTest(true, false, null, -1); + runRScript(true); + + // Compare DML and R outputs + HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLScalarFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> rfile = readRScalarFromExpectedDir("R"); + + // Ensure they're equal (within tolerance) + TestUtils.compareMatrices(dmlfile, rfile, eps, "DMLResult", "RResult"); + Assert.assertEquals(rewrites?2:1, Statistics.getCPHeavyHitterCount("uaktrace")); + } finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceTransposeTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceTransposeTest.java new file mode 100644 index 0000000000..80abce0319 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceTransposeTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyTraceTransposeTest extends AutomatedTestBase { + private static final String TEST_NAME = "RewriteSimplifyTraceTranspose"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyTraceTransposeTest.class.getSimpleName() + "/"; + + private static final int rows = 100; + private static final int cols = 100; + private static final double eps = 1e-6; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"})); + } + + @Test + public void testRewriteEnabled() { + runRewriteTest(true); + } + + @Test + public void testRewriteDisabled() { + runRewriteTest(false); + } + + private void runRewriteTest(boolean rewriteEnabled) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + fullRScriptName = HOME + TEST_NAME + ".R"; + programArgs = new String[]{"-stats", "-args", input("A"), output("R")}; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewriteEnabled; + double[][] A = getRandomMatrix(rows, cols, -1, 1, 0.70d, 7); + writeInputMatrixWithMTD("A", A, true); + runTest(true, false, null, -1); + runRScript(true); + + // Read DML scalar output + HashMap<MatrixValue.CellIndex, Double> dmlMap = readDMLScalarFromOutputDir("R"); + double dmlTrace = dmlMap.get(new MatrixValue.CellIndex(1, 1)); + + // Read R scalar output + HashMap<MatrixValue.CellIndex, Double> rMap = readRScalarFromExpectedDir("R"); + double rTrace = rMap.get(new MatrixValue.CellIndex(1, 1)); + + // Compare the scalar values within the given tolerance + Assert.assertEquals("Trace result mismatch", rTrace, dmlTrace, eps); + Assert.assertTrue(heavyHittersContainsString("r'")!=rewriteEnabled); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.R b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.R new file mode 100644 index 0000000000..82abad71be --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.R @@ -0,0 +1,39 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +library("Matrix") +library("matrixStats") + +A = as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +B = as.matrix(readMM(paste(args[1], "B.mtx", sep=""))) + +# Perform the matrix operation +R = sum(diag(A))+sum(diag(B)) + +# Write the result scalar R +write(R, paste(args[2], "R" ,sep="")) + + + diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.dml new file mode 100644 index 0000000000..9eaf4fcb84 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.dml @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +# Load matrices A, B +A = read($1) +B = read($2) + +# Perform the operation +R = trace(A+B) + +# Write the result R +write(R, $3) + + + + + diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.R b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.R new file mode 100644 index 0000000000..3bbb28f649 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.R @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +args <- commandArgs(TRUE) + +library("Matrix") +library("matrixStats") + +A = as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) + +R <- sum(diag(t(A))) + +# Write the result scalar R +write(R, paste(args[2], "R" ,sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.dml new file mode 100644 index 0000000000..2b2b3e6dd0 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +# Read input matrix A +A = read($1); + +# Compute trace of transpose +result = trace(t(A)); + +# Write scalar result to output +write(result, $2); + + +