This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 082cf89617 [SYSTEMDS-3804] New rewrite for reverse sequences
082cf89617 is described below
commit 082cf89617e116c49e6337666862c1c85e94584d
Author: Matthias Boehm <[email protected]>
AuthorDate: Mon Dec 9 09:54:07 2024 +0100
[SYSTEMDS-3804] New rewrite for reverse sequences
This patch adds a new rewrite rev(seq(1,n)) -> seq(n,1), a pattern
we recently saw in a script on vectorized time series forecasting.
---
.../RewriteAlgebraicSimplificationStatic.java | 25 ++++++-
.../rewrite/RewriteRemoveUnnecessaryRevTest.java | 81 ++++++++++++++++++++++
.../rewrite/RewriteRemoveUnnecessaryRev.dml | 31 +++++++++
3 files changed, 136 insertions(+), 1 deletion(-)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 8053ddc78a..5a79bdee33 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -153,8 +153,9 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hi = canonicalizeMatrixMultScalarAdd(hi);
//e.g., eps+U%*%t(V) -> U%*%t(V)+eps, U%*%t(V)-eps -> U%*%t(V)+(-eps)
hi = simplifyCTableWithConstMatrixInputs(hi);
//e.g., table(X, matrix(1,...)) -> table(X, 1)
hi = removeUnnecessaryCTable(hop, hi, i);
//e.g., sum(table(X, 1)) -> nrow(X) and sum(table(1, Y)) -> nrow(Y) and
sum(table(X, Y)) -> nrow(X)
- hi = simplifyConstantConjunction(hop, hi, i);
//e.g., a & !a -> FALSE
+ hi = simplifyConstantConjunction(hop, hi, i);
//e.g., a & !a -> FALSE
hi = simplifyReverseOperation(hop, hi, i);
//e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X)
+ hi = simplifyReverseSequence(hop, hi, i);
//e.g., rev(seq(1,n)) -> seq(n,1)
if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
hi = simplifyMultiBinaryToBinaryOperation(hi);
//e.g., 1-X*Y -> X 1-* Y
hi = simplifyDistributiveBinaryOperation(hop, hi,
i);//e.g., (X-Y*X) -> (1-Y)*X
@@ -798,6 +799,28 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
return hi;
}
+
+ private static Hop simplifyReverseSequence( Hop parent, Hop hi, int pos
)
+ {
+ if( HopRewriteUtils.isReorg(hi, ReOrgOp.REV)
+ && HopRewriteUtils.isBasic1NSequence(hi.getInput(0))
+ && hi.getInput(0).getParent().size() == 1) //only
consumer
+ {
+ DataGenOp seq = (DataGenOp) hi.getInput(0);
+ Hop from =
seq.getInput().get(seq.getParamIndex(Statement.SEQ_FROM));
+ Hop to =
seq.getInput().get(seq.getParamIndex(Statement.SEQ_TO));
+
seq.getInput().set(seq.getParamIndex(Statement.SEQ_FROM), to);
+ seq.getInput().set(seq.getParamIndex(Statement.SEQ_TO),
from);
+
seq.getInput().set(seq.getParamIndex(Statement.SEQ_INCR), new LiteralOp(-1));
+
+ HopRewriteUtils.replaceChildReference(parent, hi, seq,
pos);
+ HopRewriteUtils.cleanupUnreferenced(hi, seq);
+ hi = seq;
+ LOG.debug("Applied simplifyReverseSequence (line
"+hi.getBeginLine()+").");
+ }
+
+ return hi;
+ }
private static Hop simplifyMultiBinaryToBinaryOperation( Hop hi )
{
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryRevTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryRevTest.java
new file mode 100644
index 0000000000..1dea49a0b8
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryRevTest.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+
+public class RewriteRemoveUnnecessaryRevTest extends AutomatedTestBase
+{
+ private static final String TEST_NAME1 = "RewriteRemoveUnnecessaryRev";
+
+ private static final String TEST_DIR = "functions/rewrite/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewritePushdownSumBinaryMult.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration( TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+ }
+
+ @Test
+ public void testRemoveSeqRevRewrite() {
+ testRewriteRemoveSeqRev( TEST_NAME1, true );
+ }
+
+ @Test
+ public void testRemoveSeqRevNoRewrite() {
+ testRewriteRemoveSeqRev( TEST_NAME1, false );
+ }
+
+ private void testRewriteRemoveSeqRev( String testname, boolean rewrites
)
+ {
+ boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ int rows = 1001;
+
+ try
+ {
+ TestConfiguration config =
getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[]{ "-stats","-args",
String.valueOf(rows), output("Scalar") };
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewrites;
+
+ runTest(true, false, null, -1);
+
+ //compare scalars
+ int ret =
(int)readDMLScalarFromOutputDir("Scalar").get(new CellIndex(1,1)).doubleValue();
+ Assert.assertEquals(ret, rows*(rows+1)/2);
+ if( rewrites )
+
Assert.assertFalse(heavyHittersContainsString("rev"));
+ }
+ finally {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryRev.dml
b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryRev.dml
new file mode 100644
index 0000000000..484e6833c4
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryRev.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+rows = $1;
+
+# to be rewritten to: seq(rows,1)
+X = rev(seq(1,rows))
+
+while(FALSE){}
+
+R = sum(X);
+write(R, $2)
+