This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 147519e495 [SYSTEMDS-3664] New simplification rewrite rev(seq())
147519e495 is described below

commit 147519e49558493e58e14f359befd56fcf74ffda
Author: aarna <aarnatya...@gmail.com>
AuthorDate: Sun Mar 16 18:08:49 2025 +0100

    [SYSTEMDS-3664] New simplification rewrite rev(seq())
    
    This patch introduces a new simplification rewrite for reversing a
    sequence rev(seq(1,n)) --> seq(n,1).
    
    Closes #2242.
---
 .../RewriteAlgebraicSimplificationStatic.java      |  54 ++++++++++
 .../RewriteSimplifyReverseSequenceStepTest.java    | 109 +++++++++++++++++++++
 .../rewrite/RewriteSimplifyReverseSequenceStep.dml |  35 +++++++
 3 files changed, 198 insertions(+)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 5d867bf0ff..c46bc62400 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -156,6 +156,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = simplifyConstantConjunction(hop, hi, i);        
//e.g., a & !a -> FALSE 
                        hi = simplifyReverseOperation(hop, hi, i);           
//e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X)
                        hi = simplifyReverseSequence(hop, hi, i);            
//e.g., rev(seq(1,n)) -> seq(n,1)
+                       hi = simplifyReverseSequenceStep(hop, hi, i);        
//e.g., rev(seq(1,n,2)) -> rev(n,1,-2)
                        if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
                                hi = simplifyMultiBinaryToBinaryOperation(hi);  
     //e.g., 1-X*Y -> X 1-* Y
                        hi = simplifyDistributiveBinaryOperation(hop, hi, 
i);//e.g., (X-Y*X) -> (1-Y)*X
@@ -824,6 +825,59 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
 
                return hi;
        }
+       
+       private static Hop simplifyReverseSequenceStep(Hop parent, Hop hi, int 
pos) {
+               if (HopRewriteUtils.isReorg(hi, ReOrgOp.REV)
+                               && hi.getInput(0) instanceof DataGenOp
+                               && ((DataGenOp) hi.getInput(0)).getOp() == 
OpOpDG.SEQ
+                               && hi.getInput(0).getParent().size() == 1) // 
only one consumer
+               {
+                       DataGenOp seq = (DataGenOp) hi.getInput(0);
+                       Hop from = 
seq.getInput().get(seq.getParamIndex(Statement.SEQ_FROM));
+                       Hop to = 
seq.getInput().get(seq.getParamIndex(Statement.SEQ_TO));
+                       Hop incr = 
seq.getInput().get(seq.getParamIndex(Statement.SEQ_INCR));
+
+                       if (from instanceof LiteralOp && to instanceof 
LiteralOp && incr instanceof LiteralOp) {
+                               double fromVal = ((LiteralOp) 
from).getDoubleValue();
+                               double toVal = ((LiteralOp) 
to).getDoubleValue();
+                               double incrVal = ((LiteralOp) 
incr).getDoubleValue();
+
+                               // Skip if increment is zero (invalid sequence)
+                               if (Math.abs(incrVal) < 1e-10)
+                                       return hi;
+
+                               boolean isValidDirection = false;
+
+                               // Checking direction compatibility
+                               if ((incrVal > 0 && fromVal <= toVal) || 
(incrVal < 0 && fromVal >= toVal)) {
+                                       isValidDirection = true;
+                               }
+
+                               if (isValidDirection) {
+                                       // Calculate the number of elements and 
the last element
+                                       int numValues = 
(int)Math.floor(Math.abs((toVal - fromVal) / incrVal)) + 1;
+                                       double lastVal = fromVal + (numValues - 
1) * incrVal;
+
+                                       // Create a new sequence based on 
actual last value
+                                       LiteralOp newFrom = new 
LiteralOp(lastVal);
+                                       LiteralOp newTo = new 
LiteralOp(fromVal);
+                                       LiteralOp newIncr = new 
LiteralOp(-incrVal);
+
+                                       // Replace the parameters
+                                       
seq.getInput().set(seq.getParamIndex(Statement.SEQ_FROM), newFrom);
+                                       
seq.getInput().set(seq.getParamIndex(Statement.SEQ_TO), newTo);
+                                       
seq.getInput().set(seq.getParamIndex(Statement.SEQ_INCR), newIncr);
+
+                                       // Replace the old sequence with the 
new one
+                                       
HopRewriteUtils.replaceChildReference(parent, hi, seq, pos);
+                                       HopRewriteUtils.cleanupUnreferenced(hi, 
seq);
+                                       hi = seq;
+                                       LOG.debug("Applied 
simplifyReverseSequenceStep (line " + hi.getBeginLine() + ").");
+                               }
+                       }
+               }
+               return hi;
+       }
 
        private static Hop simplifyMultiBinaryToBinaryOperation( Hop hi )
        {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseSequenceStepTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseSequenceStepTest.java
new file mode 100644
index 0000000000..7d176c7aa8
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseSequenceStepTest.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class RewriteSimplifyReverseSequenceStepTest extends AutomatedTestBase {
+       private static final String TEST_NAME1 = 
"RewriteSimplifyReverseSequenceStep";
+
+       private static final String TEST_DIR = "functions/rewrite/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteSimplifyReverseSequenceStepTest.class.getSimpleName() + "/";
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"}));
+       }
+
+       @Test
+       public void testRewriteReverseSeqStep() {
+               testRewriteReverseSeq(TEST_NAME1, true);
+       }
+
+       @Test
+       public void testNoRewriteReverseSeqStep() {
+               testRewriteReverseSeq(TEST_NAME1, false);
+       }
+
+       private void testRewriteReverseSeq(String testname, boolean rewrites) {
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               int rows = 10;
+
+               try {
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[]{"-stats", "-args", 
String.valueOf(rows), output("Scalar")};
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
+
+                       runTest(true, false, null, -1);
+
+                       // Calculate expected sums for each sequence
+                       double sum1 = calculateSum(0, rows-1, 1);       // A1 = 
rev(seq(0, rows-1, 1))
+                       double sum2 = calculateSum(0, rows, 2);   // A2 = 
rev(seq(0, rows, 2))
+                       double sum3 = calculateSum(2, rows, 2);   // A3 = 
rev(seq(2, rows, 2))
+                       double sum4 = calculateSum(0, 100, 5);     // A4 = 
rev(seq(0, 100, 5))
+                       double sum5 = calculateSum(15, 5, -0.5);                
// A5 = rev(seq(15, 5, -0.5))
+
+                       double expected = sum1 + sum2 + sum3 + sum4 + sum5;
+
+                       double ret = 
readDMLScalarFromOutputDir("Scalar").get(new MatrixValue.CellIndex(1, 
1)).doubleValue();
+
+                       Assert.assertEquals("Incorrect sum computed", expected, 
ret, 1e-10);
+
+                       if (rewrites) {
+                               // With bidirectional rewrite, REV operations 
should be removed
+                               Assert.assertFalse("Rewrite should have removed 
REV operation!",
+                                               
heavyHittersContainsString("rev"));
+                       }
+               }
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+       }
+
+       // Helper method to calculate sum of a sequence
+       private double calculateSum(double from, double to, double incr) {
+               double sum = 0;
+               int n = 0;
+
+               if ((incr > 0 && from <= to) || (incr < 0 && from >= to)) {
+                       // Calculate number of elements in the sequence
+                       n = (int)Math.floor(Math.abs((to - from) / incr)) + 1;
+
+                       // Calculate the last element in the sequence
+                       double last = from + (n - 1) * incr;
+
+                       // Use arithmetic sequence sum formula: n * (first + 
last) / 2
+                       sum = n * (from + last) / 2;
+               }
+
+               return sum;
+       }
+}
diff --git 
a/src/test/scripts/functions/rewrite/RewriteSimplifyReverseSequenceStep.dml 
b/src/test/scripts/functions/rewrite/RewriteSimplifyReverseSequenceStep.dml
new file mode 100644
index 0000000000..e8f3314c26
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyReverseSequenceStep.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+rows = as.integer($1)
+
+# Original test sequences (positive increments)
+A1 = rev(seq(0, rows-1, 1))    # Should become seq(rows-1, 0, -1)
+A2 = rev(seq(0, rows, 2))      # Should become seq(rows, 0, -2)
+A3 = rev(seq(2, rows, 2))      # Should become seq(lastVal, 2, -2) where 
lastVal is the last value in the sequence
+A4 = rev(seq(0, 100, 5))       # Should become seq(100, 0, -5)
+A5 = rev(seq(15, 5, -0.5))        # Should become seq(5, 15, 0.5)
+
+# Sum all sequences
+R = sum(A1) + sum(A2) + sum(A3) + sum(A4) + sum(A5)
+
+# Output
+write(R, $2)
\ No newline at end of file

Reply via email to