This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 147519e495 [SYSTEMDS-3664] New simplification rewrite rev(seq()) 147519e495 is described below commit 147519e49558493e58e14f359befd56fcf74ffda Author: aarna <aarnatya...@gmail.com> AuthorDate: Sun Mar 16 18:08:49 2025 +0100 [SYSTEMDS-3664] New simplification rewrite rev(seq()) This patch introduces a new simplification rewrite for reversing a sequence rev(seq(1,n)) --> seq(n,1). Closes #2242. --- .../RewriteAlgebraicSimplificationStatic.java | 54 ++++++++++ .../RewriteSimplifyReverseSequenceStepTest.java | 109 +++++++++++++++++++++ .../rewrite/RewriteSimplifyReverseSequenceStep.dml | 35 +++++++ 3 files changed, 198 insertions(+) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 5d867bf0ff..c46bc62400 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -156,6 +156,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hi = simplifyConstantConjunction(hop, hi, i); //e.g., a & !a -> FALSE hi = simplifyReverseOperation(hop, hi, i); //e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X) hi = simplifyReverseSequence(hop, hi, i); //e.g., rev(seq(1,n)) -> seq(n,1) + hi = simplifyReverseSequenceStep(hop, hi, i); //e.g., rev(seq(1,n,2)) -> rev(n,1,-2) if(OptimizerUtils.ALLOW_OPERATOR_FUSION) hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y hi = simplifyDistributiveBinaryOperation(hop, hi, i);//e.g., (X-Y*X) -> (1-Y)*X @@ -824,6 +825,59 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule return hi; } + + private static Hop simplifyReverseSequenceStep(Hop parent, Hop hi, int pos) { + if (HopRewriteUtils.isReorg(hi, ReOrgOp.REV) + && hi.getInput(0) instanceof DataGenOp + && ((DataGenOp) hi.getInput(0)).getOp() == OpOpDG.SEQ + && hi.getInput(0).getParent().size() == 1) // only one consumer + { + DataGenOp seq = (DataGenOp) hi.getInput(0); + Hop from = seq.getInput().get(seq.getParamIndex(Statement.SEQ_FROM)); + Hop to = seq.getInput().get(seq.getParamIndex(Statement.SEQ_TO)); + Hop incr = seq.getInput().get(seq.getParamIndex(Statement.SEQ_INCR)); + + if (from instanceof LiteralOp && to instanceof LiteralOp && incr instanceof LiteralOp) { + double fromVal = ((LiteralOp) from).getDoubleValue(); + double toVal = ((LiteralOp) to).getDoubleValue(); + double incrVal = ((LiteralOp) incr).getDoubleValue(); + + // Skip if increment is zero (invalid sequence) + if (Math.abs(incrVal) < 1e-10) + return hi; + + boolean isValidDirection = false; + + // Checking direction compatibility + if ((incrVal > 0 && fromVal <= toVal) || (incrVal < 0 && fromVal >= toVal)) { + isValidDirection = true; + } + + if (isValidDirection) { + // Calculate the number of elements and the last element + int numValues = (int)Math.floor(Math.abs((toVal - fromVal) / incrVal)) + 1; + double lastVal = fromVal + (numValues - 1) * incrVal; + + // Create a new sequence based on actual last value + LiteralOp newFrom = new LiteralOp(lastVal); + LiteralOp newTo = new LiteralOp(fromVal); + LiteralOp newIncr = new LiteralOp(-incrVal); + + // Replace the parameters + seq.getInput().set(seq.getParamIndex(Statement.SEQ_FROM), newFrom); + seq.getInput().set(seq.getParamIndex(Statement.SEQ_TO), newTo); + seq.getInput().set(seq.getParamIndex(Statement.SEQ_INCR), newIncr); + + // Replace the old sequence with the new one + HopRewriteUtils.replaceChildReference(parent, hi, seq, pos); + HopRewriteUtils.cleanupUnreferenced(hi, seq); + hi = seq; + LOG.debug("Applied simplifyReverseSequenceStep (line " + hi.getBeginLine() + ")."); + } + } + } + return hi; + } private static Hop simplifyMultiBinaryToBinaryOperation( Hop hi ) { diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseSequenceStepTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseSequenceStepTest.java new file mode 100644 index 0000000000..7d176c7aa8 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseSequenceStepTest.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +public class RewriteSimplifyReverseSequenceStepTest extends AutomatedTestBase { + private static final String TEST_NAME1 = "RewriteSimplifyReverseSequenceStep"; + + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteSimplifyReverseSequenceStepTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"})); + } + + @Test + public void testRewriteReverseSeqStep() { + testRewriteReverseSeq(TEST_NAME1, true); + } + + @Test + public void testNoRewriteReverseSeqStep() { + testRewriteReverseSeq(TEST_NAME1, false); + } + + private void testRewriteReverseSeq(String testname, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + int rows = 10; + + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{"-stats", "-args", String.valueOf(rows), output("Scalar")}; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + runTest(true, false, null, -1); + + // Calculate expected sums for each sequence + double sum1 = calculateSum(0, rows-1, 1); // A1 = rev(seq(0, rows-1, 1)) + double sum2 = calculateSum(0, rows, 2); // A2 = rev(seq(0, rows, 2)) + double sum3 = calculateSum(2, rows, 2); // A3 = rev(seq(2, rows, 2)) + double sum4 = calculateSum(0, 100, 5); // A4 = rev(seq(0, 100, 5)) + double sum5 = calculateSum(15, 5, -0.5); // A5 = rev(seq(15, 5, -0.5)) + + double expected = sum1 + sum2 + sum3 + sum4 + sum5; + + double ret = readDMLScalarFromOutputDir("Scalar").get(new MatrixValue.CellIndex(1, 1)).doubleValue(); + + Assert.assertEquals("Incorrect sum computed", expected, ret, 1e-10); + + if (rewrites) { + // With bidirectional rewrite, REV operations should be removed + Assert.assertFalse("Rewrite should have removed REV operation!", + heavyHittersContainsString("rev")); + } + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } + + // Helper method to calculate sum of a sequence + private double calculateSum(double from, double to, double incr) { + double sum = 0; + int n = 0; + + if ((incr > 0 && from <= to) || (incr < 0 && from >= to)) { + // Calculate number of elements in the sequence + n = (int)Math.floor(Math.abs((to - from) / incr)) + 1; + + // Calculate the last element in the sequence + double last = from + (n - 1) * incr; + + // Use arithmetic sequence sum formula: n * (first + last) / 2 + sum = n * (from + last) / 2; + } + + return sum; + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyReverseSequenceStep.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyReverseSequenceStep.dml new file mode 100644 index 0000000000..e8f3314c26 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyReverseSequenceStep.dml @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +rows = as.integer($1) + +# Original test sequences (positive increments) +A1 = rev(seq(0, rows-1, 1)) # Should become seq(rows-1, 0, -1) +A2 = rev(seq(0, rows, 2)) # Should become seq(rows, 0, -2) +A3 = rev(seq(2, rows, 2)) # Should become seq(lastVal, 2, -2) where lastVal is the last value in the sequence +A4 = rev(seq(0, 100, 5)) # Should become seq(100, 0, -5) +A5 = rev(seq(15, 5, -0.5)) # Should become seq(5, 15, 0.5) + +# Sum all sequences +R = sum(A1) + sum(A2) + sum(A3) + sum(A4) + sum(A5) + +# Output +write(R, $2) \ No newline at end of file