This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 082cf89617 [SYSTEMDS-3804] New rewrite for reverse sequences
082cf89617 is described below

commit 082cf89617e116c49e6337666862c1c85e94584d
Author: Matthias Boehm <[email protected]>
AuthorDate: Mon Dec 9 09:54:07 2024 +0100

    [SYSTEMDS-3804] New rewrite for reverse sequences
    
    This patch adds a new rewrite rev(seq(1,n)) -> seq(n,1), a pattern
    we recently saw in a script on vectorized time series forecasting.
---
 .../RewriteAlgebraicSimplificationStatic.java      | 25 ++++++-
 .../rewrite/RewriteRemoveUnnecessaryRevTest.java   | 81 ++++++++++++++++++++++
 .../rewrite/RewriteRemoveUnnecessaryRev.dml        | 31 +++++++++
 3 files changed, 136 insertions(+), 1 deletion(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 8053ddc78a..5a79bdee33 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -153,8 +153,9 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = canonicalizeMatrixMultScalarAdd(hi);            
//e.g., eps+U%*%t(V) -> U%*%t(V)+eps, U%*%t(V)-eps -> U%*%t(V)+(-eps) 
                        hi = simplifyCTableWithConstMatrixInputs(hi);        
//e.g., table(X, matrix(1,...)) -> table(X, 1)
                        hi = removeUnnecessaryCTable(hop, hi, i);            
//e.g., sum(table(X, 1)) -> nrow(X) and sum(table(1, Y)) -> nrow(Y) and 
sum(table(X, Y)) -> nrow(X)
-                       hi = simplifyConstantConjunction(hop, hi, i);       
//e.g., a & !a -> FALSE 
+                       hi = simplifyConstantConjunction(hop, hi, i);        
//e.g., a & !a -> FALSE 
                        hi = simplifyReverseOperation(hop, hi, i);           
//e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X)
+                       hi = simplifyReverseSequence(hop, hi, i);            
//e.g., rev(seq(1,n)) -> seq(n,1)
                        if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
                                hi = simplifyMultiBinaryToBinaryOperation(hi);  
     //e.g., 1-X*Y -> X 1-* Y
                        hi = simplifyDistributiveBinaryOperation(hop, hi, 
i);//e.g., (X-Y*X) -> (1-Y)*X
@@ -798,6 +799,28 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
 
                return hi;
        }
+       
+       private static Hop simplifyReverseSequence( Hop parent, Hop hi, int pos 
)
+       {
+               if( HopRewriteUtils.isReorg(hi, ReOrgOp.REV) 
+                       && HopRewriteUtils.isBasic1NSequence(hi.getInput(0))
+                       && hi.getInput(0).getParent().size() == 1) //only 
consumer
+               {
+                       DataGenOp seq = (DataGenOp) hi.getInput(0);
+                       Hop from = 
seq.getInput().get(seq.getParamIndex(Statement.SEQ_FROM));
+                       Hop to = 
seq.getInput().get(seq.getParamIndex(Statement.SEQ_TO));
+                       
seq.getInput().set(seq.getParamIndex(Statement.SEQ_FROM), to);
+                       seq.getInput().set(seq.getParamIndex(Statement.SEQ_TO), 
from);
+                       
seq.getInput().set(seq.getParamIndex(Statement.SEQ_INCR), new LiteralOp(-1));
+                       
+                       HopRewriteUtils.replaceChildReference(parent, hi, seq, 
pos);
+                       HopRewriteUtils.cleanupUnreferenced(hi, seq);
+                       hi = seq;
+                       LOG.debug("Applied simplifyReverseSequence (line 
"+hi.getBeginLine()+").");
+               }
+
+               return hi;
+       }
 
        private static Hop simplifyMultiBinaryToBinaryOperation( Hop hi )
        {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryRevTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryRevTest.java
new file mode 100644
index 0000000000..1dea49a0b8
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryRevTest.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+
+public class RewriteRemoveUnnecessaryRevTest extends AutomatedTestBase 
+{
+       private static final String TEST_NAME1 = "RewriteRemoveUnnecessaryRev";
+
+       private static final String TEST_DIR = "functions/rewrite/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewritePushdownSumBinaryMult.class.getSimpleName() + "/";
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+       }
+       
+       @Test
+       public void testRemoveSeqRevRewrite() {
+               testRewriteRemoveSeqRev( TEST_NAME1, true );
+       }
+       
+       @Test
+       public void testRemoveSeqRevNoRewrite() {
+               testRewriteRemoveSeqRev( TEST_NAME1, false );
+       }
+
+       private void testRewriteRemoveSeqRev( String testname, boolean rewrites 
)
+       {
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               int rows = 1001;
+               
+               try
+               {
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[]{ "-stats","-args", 
String.valueOf(rows), output("Scalar") };
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
+
+                       runTest(true, false, null, -1); 
+                       
+                       //compare scalars 
+                       int ret = 
(int)readDMLScalarFromOutputDir("Scalar").get(new CellIndex(1,1)).doubleValue();
+                       Assert.assertEquals(ret, rows*(rows+1)/2);
+                       if( rewrites )
+                               
Assert.assertFalse(heavyHittersContainsString("rev"));
+               }
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryRev.dml 
b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryRev.dml
new file mode 100644
index 0000000000..484e6833c4
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryRev.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+rows = $1;
+
+# to be rewritten to: seq(rows,1)
+X = rev(seq(1,rows))
+
+while(FALSE){}
+
+R = sum(X);
+write(R, $2)
+

Reply via email to