This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 3ce16d058f [SYSTEMDS-3864] Additional trace simplification rewrites
3ce16d058f is described below

commit 3ce16d058f4bdb0bf4bec9cb9cc79458ae7519b6
Author: aarna <aarnatya...@gmail.com>
AuthorDate: Tue Apr 22 09:41:48 2025 +0200

    [SYSTEMDS-3864] Additional trace simplification rewrites
    
    Closes #2254.
---
 .../java/org/apache/sysds/hops/AggBinaryOp.java    |  1 -
 .../RewriteAlgebraicSimplificationStatic.java      | 42 +++++++++-
 .../rewrite/RewriteSimplifyTraceSumTest.java       | 89 +++++++++++++++++++++
 .../rewrite/RewriteSimplifyTraceTransposeTest.java | 90 ++++++++++++++++++++++
 .../functions/rewrite/RewriteSimplifyTraceSum.R    | 39 ++++++++++
 .../functions/rewrite/RewriteSimplifyTraceSum.dml  | 34 ++++++++
 .../rewrite/RewriteSimplifyTraceTranspose.R        | 31 ++++++++
 .../rewrite/RewriteSimplifyTraceTranspose.dml      | 31 ++++++++
 8 files changed, 355 insertions(+), 2 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java 
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index 5f9c6b41b3..0be3143206 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -43,7 +43,6 @@ import org.apache.sysds.lops.MatMultCP;
 import org.apache.sysds.lops.PMMJ;
 import org.apache.sysds.lops.PMapMult;
 import org.apache.sysds.lops.Transform;
-import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index c46bc62400..f59d334d17 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -176,6 +176,8 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
                                hi = fuseBinarySubDAGToUnaryOperation(hop, hi, 
i);   //e.g., X*(1-X)-> sprop(X) || 1/(1+exp(-X)) -> sigmoid(X) || X*(X>0) -> 
selp(X)
                        hi = simplifyTraceMatrixMult(hop, hi, i);            
//e.g., trace(X%*%Y)->sum(X*t(Y));
+                       hi = simplifyTraceSum(hop, hi, i);                   
//e.g. , trace(A+B)->trace(A)+trace(B);
+                       hi = simplifyTraceTranspose(hop, hi, i);             
//e.g. , trace(t(A))->trace(A)
                        hi = simplifySlicedMatrixMult(hop, hi, i);           
//e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1];
                        hi = simplifyListIndexing(hi);                       
//e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1]
                        hi = simplifyScalarIndexing(hop, hi, i);             
//e.g., as.scalar(X[i,1])->X[i,1] w/ scalar output
@@ -201,7 +203,6 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = simplifyNotOverComparisons(hop, hi, i);         
//e.g., !(A>B) -> (A<=B)
                        //hi = removeUnecessaryPPred(hop, hi, i);            
//e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
 
-
                        //process childs recursively after rewrites (to 
investigate pattern newly created by rewrites)
                        if( !descendFirst )
                                rule_AlgebraicSimplification(hi, descendFirst);
@@ -1603,6 +1604,45 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
                return hi;
        }
 
+       private static Hop simplifyTraceSum(Hop parent, Hop hi, int pos) {
+               if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() == 
AggOp.TRACE) {
+                       Hop hi2 = hi.getInput().get(0);
+                       if (HopRewriteUtils.isBinary(hi2, OpOp2.PLUS) && 
hi2.getParent().size() == 1) {
+                               Hop left = hi2.getInput().get(0);
+                               Hop right = hi2.getInput().get(1);
+
+                               // Create trace nodes
+                               AggUnaryOp traceLeft = 
HopRewriteUtils.createAggUnaryOp(left, AggOp.TRACE, Direction.RowCol);
+                               AggUnaryOp traceRight = 
HopRewriteUtils.createAggUnaryOp(right, AggOp.TRACE, Direction.RowCol);
+
+                               // Add them
+                               BinaryOp sum = 
HopRewriteUtils.createBinary(traceLeft, traceRight, OpOp2.PLUS);
+
+                               // Replace in DAG
+                               HopRewriteUtils.replaceChildReference(parent, 
hi, sum, pos);
+                               HopRewriteUtils.cleanupUnreferenced(hi, hi2);
+
+                               LOG.debug("Applied simplifyTraceSum rewrite");
+                               return sum;
+                       }
+               }
+               return hi;
+       }
+
+       private static Hop simplifyTraceTranspose(Hop parent, Hop hi, int pos) {
+               // Check if the current Hop is a trace operation
+               if ( HopRewriteUtils.isAggUnaryOp(hi, AggOp.TRACE) ) {
+                       Hop input = hi.getInput().get(0);
+
+                       // Check if input is a transpose and it is only consumer
+                       if (HopRewriteUtils.isReorg(input, ReOrgOp.TRANS) && 
input.getParent().size() == 1) {
+                               HopRewriteUtils.replaceChildReference(hi, 
input, input.getInput(0));
+                               LOG.debug("Applied simplifyTraceTranspose 
rewrite");
+                       }
+               }
+               return hi;
+       }
+
        private static Hop simplifySlicedMatrixMult(Hop parent, Hop hi, int pos)
        {
                //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1] 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceSumTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceSumTest.java
new file mode 100644
index 0000000000..e561b8e002
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceSumTest.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.test.functions.rewrite;
+
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class RewriteSimplifyTraceSumTest extends AutomatedTestBase {
+       private static final String TEST_NAME = "RewriteSimplifyTraceSum";
+       private static final String TEST_DIR = "functions/rewrite/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteSimplifyTraceSumTest.class.getSimpleName() + "/";
+
+       private static final int rows = 500;
+       private static final int cols = 500;
+       private static final double eps = 1e-10;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"}));
+       }
+
+       @Test
+       public void testSimplifyTraceSumRewrite() {
+               runTraceRewriteTest(TEST_NAME, true);
+       }
+
+       @Test
+       public void testSimplifyTraceSumNoRewrite() {
+               runTraceRewriteTest(TEST_NAME, false);
+       }
+
+       private void runTraceRewriteTest(String testname, boolean rewrites) {
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               try {
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       fullRScriptName = HOME + testname + ".R";
+
+                       programArgs = new String[]{"-explain", "-stats", 
"-args", input("A"), input("B"), output("R")};
+                       rCmd = getRCmd(inputDir(), expectedDir());
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
+                       double[][] A = getRandomMatrix(rows, cols, -1, 1, 
0.70d, 7);
+                       double[][] B = getRandomMatrix(cols, rows, -1, 1, 
0.70d, 6);
+                       writeInputMatrixWithMTD("A", A, true);
+                       writeInputMatrixWithMTD("B", B, true);
+                       // Run SystemDS and R scripts
+                       runTest(true, false, null, -1);
+                       runRScript(true);
+
+                       // Compare DML and R outputs
+                       HashMap<MatrixValue.CellIndex, Double> dmlfile = 
readDMLScalarFromOutputDir("R");
+                       HashMap<MatrixValue.CellIndex, Double> rfile = 
readRScalarFromExpectedDir("R");
+
+                       // Ensure they're equal (within tolerance)
+                       TestUtils.compareMatrices(dmlfile, rfile, eps, 
"DMLResult", "RResult");
+                       Assert.assertEquals(rewrites?2:1, 
Statistics.getCPHeavyHitterCount("uaktrace"));
+               } finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceTransposeTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceTransposeTest.java
new file mode 100644
index 0000000000..80abce0319
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceTransposeTest.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.test.functions.rewrite;
+
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class RewriteSimplifyTraceTransposeTest extends AutomatedTestBase {
+       private static final String TEST_NAME = "RewriteSimplifyTraceTranspose";
+       private static final String TEST_DIR = "functions/rewrite/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteSimplifyTraceTransposeTest.class.getSimpleName() + "/";
+
+       private static final int rows = 100;
+       private static final int cols = 100;
+       private static final double eps = 1e-6;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"}));
+       }
+
+       @Test
+       public void testRewriteEnabled() {
+               runRewriteTest(true);
+       }
+
+       @Test
+       public void testRewriteDisabled() {
+               runRewriteTest(false);
+       }
+
+       private void runRewriteTest(boolean rewriteEnabled) {
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               try {
+                       TestConfiguration config = 
getTestConfiguration(TEST_NAME);
+                       loadTestConfiguration(config);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       fullRScriptName = HOME + TEST_NAME + ".R";
+                       programArgs = new String[]{"-stats", "-args", 
input("A"), output("R")};
+                       rCmd = getRCmd(inputDir(), expectedDir());
+
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewriteEnabled;
+                       double[][] A = getRandomMatrix(rows, cols, -1, 1, 
0.70d, 7);
+                       writeInputMatrixWithMTD("A", A, true);
+                       runTest(true, false, null, -1);
+                       runRScript(true);
+
+                       // Read DML scalar output
+                       HashMap<MatrixValue.CellIndex, Double> dmlMap = 
readDMLScalarFromOutputDir("R");
+                       double dmlTrace = dmlMap.get(new 
MatrixValue.CellIndex(1, 1));
+
+                       // Read R scalar output
+                       HashMap<MatrixValue.CellIndex, Double> rMap = 
readRScalarFromExpectedDir("R");
+                       double rTrace = rMap.get(new MatrixValue.CellIndex(1, 
1));
+
+                       // Compare the scalar values within the given tolerance
+                       Assert.assertEquals("Trace result mismatch", rTrace, 
dmlTrace, eps);
+                       
Assert.assertTrue(heavyHittersContainsString("r'")!=rewriteEnabled);
+               } 
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.R 
b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.R
new file mode 100644
index 0000000000..82abad71be
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.R
@@ -0,0 +1,39 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+args <- commandArgs(TRUE)
+
+# Set options for numeric precision
+options(digits=22)
+
+library("Matrix")
+library("matrixStats")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B = as.matrix(readMM(paste(args[1], "B.mtx", sep="")))
+
+# Perform the matrix operation
+R = sum(diag(A))+sum(diag(B))
+
+# Write the result scalar R
+write(R, paste(args[2], "R" ,sep=""))
+
+
+
diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.dml 
b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.dml
new file mode 100644
index 0000000000..9eaf4fcb84
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceSum.dml
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+# Load matrices A, B
+A = read($1)
+B = read($2)
+
+# Perform the operation
+R = trace(A+B)
+
+# Write the result R
+write(R, $3)
+
+
+
+
+
diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.R 
b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.R
new file mode 100644
index 0000000000..3bbb28f649
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.R
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+args <- commandArgs(TRUE)
+
+library("Matrix")
+library("matrixStats")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+
+R <- sum(diag(t(A)))
+
+# Write the result scalar R
+write(R, paste(args[2], "R" ,sep=""))
diff --git 
a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.dml 
b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.dml
new file mode 100644
index 0000000000..2b2b3e6dd0
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceTranspose.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+# Read input matrix A
+A = read($1);
+
+# Compute trace of transpose
+result = trace(t(A));
+
+# Write scalar result to output
+write(result, $2);
+
+
+

Reply via email to