This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new cfbe19062a [SYSTEMDS-3774] Improved test coverage of simplification rewrites cfbe19062a is described below commit cfbe19062ab706c8acfc5bd688c04e20c3e9cbc5 Author: ReneEnjilian <enjilianr...@gmail.com> AuthorDate: Mon Mar 3 16:43:01 2025 +0100 [SYSTEMDS-3774] Improved test coverage of simplification rewrites Closes #2240. --- .../java/org/apache/sysds/hops/OptimizerUtils.java | 2 +- ...RewriteCanonicalizeMatrixMultScalarAddTest.java | 117 ++++++ .../RewriteFuseOrderOperationChainTest.java | 98 +++++ ...ewriteRemoveUnnecessaryBinaryOperationTest.java | 168 +++++++++ .../rewrite/RewriteRemoveUnnecessaryMinusTest.java | 96 +++++ ...RewriteRemoveUnnecessaryReorgOperationTest.java | 106 ++++++ ...iteRemoveUnnecessaryVectorizeOperationTest.java | 106 ++++++ ...iteSimplifyBinaryMatrixScalarOperationTest.java | 125 +++++++ .../RewriteSimplifyBinaryToUnaryOperationTest.java | 132 +++++++ ...writeSimplifyCumsumColOrFullAggregatesTest.java | 96 +++++ ...teSimplifyMultiBinaryToBinaryOperationTest.java | 100 +++++ .../rewrite/RewriteSimplifyOuterSeqExpandTest.java | 104 ++++++ .../RewriteSimplifyReverseOperationTest.java | 100 +++++ .../RewriteSimplifyTransposedAppendTest.java | 108 ++++++ .../RewriteSimplifyUnaryAggReorgOperationTest.java | 95 +++++ .../RewriteSimplifyUnaryPPredOperationTest.java | 409 +++++++++++++++++++++ .../RewriteCanonicalizeMatrixMultScalarAdd.R | 46 +++ .../RewriteCanonicalizeMatrixMultScalarAdd.dml | 37 ++ .../rewrite/RewriteFuseOrderOperationChain.R | 43 +++ .../rewrite/RewriteFuseOrderOperationChain.dml | 29 ++ .../RewriteRemoveUnnecessaryBinaryOperation.R | 51 +++ .../RewriteRemoveUnnecessaryBinaryOperation.dml | 47 +++ .../rewrite/RewriteRemoveUnnecessaryMinus.R | 39 ++ .../rewrite/RewriteRemoveUnnecessaryMinus.dml | 29 ++ .../RewriteRemoveUnnecessaryReorgOperation.R | 44 +++ .../RewriteRemoveUnnecessaryReorgOperation.dml | 35 ++ .../RewriteRemoveUnnecessaryVectorizeOperation.R | 45 +++ .../RewriteRemoveUnnecessaryVectorizeOperation.dml | 37 ++ .../RewriteSimplifyBinaryMatrixScalarOperation.R | 50 +++ .../RewriteSimplifyBinaryMatrixScalarOperation.dml | 42 +++ .../RewriteSimplifyBinaryToUnaryOperation.R | 45 +++ .../RewriteSimplifyBinaryToUnaryOperation.dml | 38 ++ .../RewriteSimplifyCumsumColOrFullAggregates.R | 38 ++ .../RewriteSimplifyCumsumColOrFullAggregates.dml | 29 ++ .../RewriteSimplifyMultiBinaryToBinaryOperation.R | 39 ++ ...RewriteSimplifyMultiBinaryToBinaryOperation.dml | 30 ++ .../rewrite/RewriteSimplifyOuterSeqExpand.R | 45 +++ .../rewrite/RewriteSimplifyOuterSeqExpand.dml | 36 ++ .../rewrite/RewriteSimplifyReverseOperation.R | 38 ++ .../rewrite/RewriteSimplifyReverseOperation.dml | 29 ++ .../rewrite/RewriteSimplifyTransposedAppend.R | 45 +++ .../rewrite/RewriteSimplifyTransposedAppend.dml | 36 ++ .../RewriteSimplifyUnaryAggReorgOperation.R | 38 ++ .../RewriteSimplifyUnaryAggReorgOperation.dml | 29 ++ .../rewrite/RewriteSimplifyUnaryPPredOperation.R | 101 +++++ .../rewrite/RewriteSimplifyUnaryPPredOperation.dml | 127 +++++++ 46 files changed, 3278 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java index a3161c5723..30df23dba2 100644 --- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java +++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java @@ -199,7 +199,7 @@ public class OptimizerUtils public static boolean ALLOW_SUM_PRODUCT_REWRITES2 = true; /** - * Enables additional mmchain optimizations. in the future, this might be merged with + * Enables additional mmchain optimizations. In the future, this might be merged with * ALLOW_SUM_PRODUCT_REWRITES. */ public static boolean ALLOW_ADVANCED_MMCHAIN_REWRITES = false; diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAddTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAddTest.java new file mode 100644 index 0000000000..c9a580d5e3 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAddTest.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteCanonicalizeMatrixMultScalarAddTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteCanonicalizeMatrixMultScalarAdd"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteCanonicalizeMatrixMultScalarAddTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testCanonicalizeMatrixMultScalarAddPosNoRewrite() { + testRewriteCanonicalizeMatrixMultScalarAdd(1, false); + } + + @Test + public void testCanonicalizeMatrixMultScalarAddPosRewrite() { + testRewriteCanonicalizeMatrixMultScalarAdd(1, true); // (z + U%*%V) -> (U%*%V + z) + } + + @Test + public void testCanonicalizeMatrixMultScalarAddNegNoRewrite() { + testRewriteCanonicalizeMatrixMultScalarAdd(2, false); + } + + @Test + public void testCanonicalizeMatrixMultScalarAddNegRewrite() { + testRewriteCanonicalizeMatrixMultScalarAdd(2, true); // (U%*%V - z) -> (U%*%V + (-z)) + } + + private void testRewriteCanonicalizeMatrixMultScalarAdd(int ID, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("U"), input("V"), String.valueOf(ID), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrices + double[][] U = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + double[][] V = getRandomMatrix(rows, cols, -1, 1, 0.60d, 4); + writeInputMatrixWithMTD("U", U, true); + writeInputMatrixWithMTD("V", V, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(ID == 1) { + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.MULT.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.MULT.toString())); + } + else if(ID == 2) { + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.PLUS.toString())); + else + Assert.assertFalse(heavyHittersContainsString(Opcodes.PLUS.toString())); + } + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } + +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFuseOrderOperationChainTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFuseOrderOperationChainTest.java new file mode 100644 index 0000000000..281fc75cd3 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFuseOrderOperationChainTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteFuseOrderOperationChainTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteFuseOrderOperationChain"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteFuseOrderOperationChainTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testFuseOrderOperationChainNoRewrite() { + testRewriteFuseOrderOperationChain(false); + } + + @Test + public void testFuseOrderOperationChainRewrite() { + testRewriteFuseOrderOperationChain(true); + } + + private void testRewriteFuseOrderOperationChain(boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrices + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + long numOrder = Statistics.getCPHeavyHitterCount(Opcodes.SORT.toString()); + if(rewrites) + Assert.assertEquals(numOrder, 1); + else + Assert.assertEquals(numOrder, 2); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperationTest.java new file mode 100644 index 0000000000..b505e846aa --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperationTest.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteRemoveUnnecessaryBinaryOperationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteRemoveUnnecessaryBinaryOperation"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteRemoveUnnecessaryBinaryOperationTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testRemoveUnnecessaryBinaryOperationDivNoRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(1, false); + } + + @Test + public void testRemoveUnnecessaryBinaryOperationDivRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(1, true); // X/1 + } + + @Test + public void testRemoveUnnecessaryBinaryOperationMultRightNoRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(2, false); + } + + @Test + public void testRemoveUnnecessaryBinaryOperationMultRightRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(2, true); // X*1 + } + + @Test + public void testRemoveUnnecessaryBinaryOperationMultLeftNoRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(3, false); + } + + @Test + public void testRemoveUnnecessaryBinaryOperationMultLeftRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(3, true); // 1*X + } + + @Test + public void testRemoveUnnecessaryBinaryOperationMinusNoRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(4, false); + } + + @Test + public void testRemoveUnnecessaryBinaryOperationMinusRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(4, true); // X-0 + } + + @Test + public void testRemoveUnnecessaryBinaryOperationNegMultLeftNoRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(5, false); + } + + @Test + public void testRemoveUnnecessaryBinaryOperationNegMultLeftRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(5, true); // -1*X + } + + @Test + public void testRemoveUnnecessaryBinaryOperationNegMultRightNoRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(6, false); + } + + @Test + public void testRemoveUnnecessaryBinaryOperationNegMultRightRewrite() { + testRewriteRemoveUnnecessaryBinaryOperation(6, true); // X*-1 + } + + private void testRewriteRemoveUnnecessaryBinaryOperation(int ID, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), String.valueOf(ID), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrix + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(ID == 1) { + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.DIV.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.DIV.toString())); + } + else if(ID == 2 || ID == 3) { + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.MULT.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.MULT.toString())); + } + else if(ID == 4) { + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.MINUS.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.MINUS.toString())); + } + else if(ID == 5 || ID == 6) { + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.MINUS.toString()) && + !heavyHittersContainsString(Opcodes.MULT.toString())); + else + Assert.assertTrue(!heavyHittersContainsString(Opcodes.MINUS.toString()) && + heavyHittersContainsString(Opcodes.MULT.toString())); + } + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryMinusTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryMinusTest.java new file mode 100644 index 0000000000..dbd1a31ce4 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryMinusTest.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteRemoveUnnecessaryMinusTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteRemoveUnnecessaryMinus"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteRemoveUnnecessaryMinusTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testRemoveUnnecessaryMinusNoRewrite() { + testRewriteRemoveUnnecessaryMinus(false); + } + + @Test + public void testRemoveUnnecessaryMinusRewrite() { + testRewriteRemoveUnnecessaryMinus(true); + } + + private void testRewriteRemoveUnnecessaryMinus(boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrix + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.MULT.toString(), Opcodes.POW.toString())); + else + Assert.assertTrue((heavyHittersContainsString(Opcodes.MULT.toString(), Opcodes.POW.toString()))); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryReorgOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryReorgOperationTest.java new file mode 100644 index 0000000000..733c9da7e6 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryReorgOperationTest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteRemoveUnnecessaryReorgOperationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteRemoveUnnecessaryReorgOperation"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteRemoveUnnecessaryReorgOperationTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testRemoveUnnecessaryReorgOperationTransposeNoRewrite() { + testRewriteRemoveUnnecessaryReorgOperation(1, false); + } + + @Test + public void testRemoveUnnecessaryReorgOperationTransposeRewrite() { + testRewriteRemoveUnnecessaryReorgOperation(1, true); + } + + @Test + public void testRemoveUnnecessaryReorgOperationReverseNoRewrite() { + testRewriteRemoveUnnecessaryReorgOperation(2, false); + } + + @Test + public void testRemoveUnnecessaryReorgOperationReverseRewrite() { + testRewriteRemoveUnnecessaryReorgOperation(2, true); + } + + private void testRewriteRemoveUnnecessaryReorgOperation(int ID, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), String.valueOf(ID), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrix + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.TRANSPOSE.toString(), Opcodes.REV.toString())); + else + Assert.assertTrue((heavyHittersContainsString(Opcodes.MULT.toString(), Opcodes.REV.toString()))); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperationTest.java new file mode 100644 index 0000000000..d231b0ff93 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperationTest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteRemoveUnnecessaryVectorizeOperationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteRemoveUnnecessaryVectorizeOperation"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteRemoveUnnecessaryVectorizeOperationTest.class.getSimpleName() + "/"; + + private static final int rows = 100; + private static final int cols = 100; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testRemoveUnnecessaryVectorizeOperationLeftNoRewrite() { + testRewriteRemoveUnnecessaryVectorizeOperation(1, false); + } + + @Test + public void testRemoveUnnecessaryVectorizeOperationLeftRewrite() { + testRewriteRemoveUnnecessaryVectorizeOperation(1, true); + } + + @Test + public void testRemoveUnnecessaryVectorizeOperationRightNoRewrite() { + testRewriteRemoveUnnecessaryVectorizeOperation(2, false); + } + + @Test + public void testRemoveUnnecessaryVectorizeOperationRightRewrite() { + testRewriteRemoveUnnecessaryVectorizeOperation(2, true); + } + + private void testRewriteRemoveUnnecessaryVectorizeOperation(int ID, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), String.valueOf(ID), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrix + double[][] X = getRandomMatrix(rows, cols, 1, 2, 1.00d, 5); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.RANDOM.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.RANDOM.toString())); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperationTest.java new file mode 100644 index 0000000000..8772ce923c --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperationTest.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyBinaryMatrixScalarOperationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteSimplifyBinaryMatrixScalarOperation"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteSimplifyBinaryMatrixScalarOperationTest.class.getSimpleName() + "/"; + + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testSimplifyBinaryMatrixScalarOperationMMNoRewrite() { + testRewriteSimplifyBinaryMatrixScalarOperation(1, false); + } + + @Test + public void testSimplifyBinaryMatrixScalarOperationMMRewrite() { + testRewriteSimplifyBinaryMatrixScalarOperation(1, true); //as.scalar(X*Y) -> as.scalar(X) * as.scalar(Y) + } + + @Test + public void testSimplifyBinaryMatrixScalarOperationRightNoRewrite() { + testRewriteSimplifyBinaryMatrixScalarOperation(2, false); + } + + @Test + public void testSimplifyBinaryMatrixScalarOperationRightRewrite() { + testRewriteSimplifyBinaryMatrixScalarOperation(2, true); // as.scalar(X*s) -> as.scalar(X) * s + } + + @Test + public void testSimplifyBinaryMatrixScalarOperationLeftNoRewrite() { + testRewriteSimplifyBinaryMatrixScalarOperation(3, false); + } + + @Test + public void testSimplifyBinaryMatrixScalarOperationLeftRewrite() { + testRewriteSimplifyBinaryMatrixScalarOperation(3, true); // as.scalar(s*X) -> s * as.scalar(X) + } + + private void testRewriteSimplifyBinaryMatrixScalarOperation(int ID, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", String.valueOf(ID), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLScalarFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> rfile = readRScalarFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + long numCastDts = Statistics.getCPHeavyHitterCount(Opcodes.CAST_AS_SCALAR.toString()); + if(ID == 1) { + if(rewrites) + Assert.assertEquals(2, numCastDts); + else + Assert.assertEquals(1, numCastDts); + } + else if(ID == 2) { + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.MULT.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.MULT2.toString())); + } + else if(ID == 3) { + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.NM.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.MULT.toString())); + } + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBinaryToUnaryOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBinaryToUnaryOperationTest.java new file mode 100644 index 0000000000..f91afbbc46 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBinaryToUnaryOperationTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyBinaryToUnaryOperationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteSimplifyBinaryToUnaryOperation"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteSimplifyBinaryToUnaryOperationTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testSimplifyBinaryToUnaryOperationAddNoRewrite() { + testRewriteSimplifyBinaryToUnaryOperation(1, false); + } + + @Test + public void testSimplifyBinaryToUnaryOperationAddRewrite() { + testRewriteSimplifyBinaryToUnaryOperation(1, true); // X+X -> X*2 + } + + @Test + public void testSimplifyBinaryToUnaryOperationMultNoRewrite() { + testRewriteSimplifyBinaryToUnaryOperation(2, false); + } + + @Test + public void testSimplifyBinaryToUnaryOperationMultRewrite() { + testRewriteSimplifyBinaryToUnaryOperation(2, true); // X*X -> X² + } + + @Test + public void testSimplifyBinaryToUnaryOperationSignNoRewrite() { + testRewriteSimplifyBinaryToUnaryOperation(3, false); + } + + @Test + public void testSimplifyBinaryToUnaryOperationSignRewrite() { + testRewriteSimplifyBinaryToUnaryOperation(3, true); // (X>0)-(X<0) -> sign(X) + } + + private void testRewriteSimplifyBinaryToUnaryOperation(int ID, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), String.valueOf(ID), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrix + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(ID == 1) { + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.MULT2.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.PLUS.toString())); + } + else if(ID == 2) { + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.POW2.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.MULT.toString())); + } + else if(ID == 3) { + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.SIGN.toString())); + else + Assert.assertTrue(heavyHittersContainsAllString(Opcodes.GREATER.toString(), Opcodes.LESS.toString(), + Opcodes.MINUS.toString())); + } + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } + +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregatesTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregatesTest.java new file mode 100644 index 0000000000..c77db1c148 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregatesTest.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyCumsumColOrFullAggregatesTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteSimplifyCumsumColOrFullAggregates"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteSimplifyCumsumColOrFullAggregatesTest.class.getSimpleName() + "/"; + + private static final int rows = 10; + private static final int cols = 10; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testSimplifyCumsumColOrFullAggregatesNoRewrite() { + testRewriteSimplifyCumsumColOrFullAggregates(false); + } + + @Test + public void testSimplifyCumsumColOrFullAggregatesRewrite() { + testRewriteSimplifyCumsumColOrFullAggregates(true); //colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1)) + } + + private void testRewriteSimplifyCumsumColOrFullAggregates(boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrix + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertTrue((heavyHittersContainsString(Opcodes.SEQUENCE.toString()))); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.UCUMKP.toString(), Opcodes.UACKP.toString())); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperationTest.java new file mode 100644 index 0000000000..395fcf3d3a --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperationTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyMultiBinaryToBinaryOperationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteSimplifyMultiBinaryToBinaryOperation"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteSimplifyMultiBinaryToBinaryOperationTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testSimplifyMultiBinaryToBinaryOperationNoRewrite() { + testRewriteSimplifyMultiBinaryToBinaryOperation(false); + } + + @Test + public void testSimplifyMultiBinaryToBinaryOperationRewrite() { + testRewriteSimplifyMultiBinaryToBinaryOperation(true); + } + + private void testRewriteSimplifyMultiBinaryToBinaryOperation(boolean rewrites) { + boolean oldFlag1 = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + boolean oldFlag2 = OptimizerUtils.ALLOW_OPERATOR_FUSION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), input("Y"), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + OptimizerUtils.ALLOW_OPERATOR_FUSION = rewrites; + + // create and write matrices + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.60d, 4); + writeInputMatrixWithMTD("X", X, true); + writeInputMatrixWithMTD("Y", Y, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.MINUS1_MULT.toString())); + else + Assert.assertTrue(heavyHittersContainsAllString(Opcodes.MINUS.toString(), Opcodes.MULT.toString())); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag1; + OptimizerUtils.ALLOW_OPERATOR_FUSION = oldFlag2; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyOuterSeqExpandTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyOuterSeqExpandTest.java new file mode 100644 index 0000000000..a9c3e4ffe7 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyOuterSeqExpandTest.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyOuterSeqExpandTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteSimplifyOuterSeqExpand"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteSimplifyOuterSeqExpandTest.class.getSimpleName() + "/"; + + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testSimplifyOuterSeqExpandRightNoRewrite() { + testRewriteSimplifyOuterSeqExpand(1, false); + } + + // outer(v, t(seq(1,m)), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false) + @Test + public void testSimplifyOuterSeqExpandRightRewrite() { + testRewriteSimplifyOuterSeqExpand(1, true); + } + + @Test + public void testSimplifyOuterSeqExpandLeftNoRewrite() { + testRewriteSimplifyOuterSeqExpand(2, false); + } + + // outer(seq(1,m), t(v), "==") -> rexpand(m, max=v, dir=row, ignore=true, cast=false) + @Test + public void testSimplifyOuterSeqExpandLeftRewrite() { + testRewriteSimplifyOuterSeqExpand(2, true); + } + + private void testRewriteSimplifyOuterSeqExpand(int ID, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", String.valueOf(ID), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.REXPAND.toString())); + else + Assert.assertTrue( + (heavyHittersContainsString(Opcodes.SEQUENCE.toString(), Opcodes.TRANSPOSE.toString()))); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } + +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseOperationTest.java new file mode 100644 index 0000000000..ed2f53bc51 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyReverseOperationTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyReverseOperationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteSimplifyReverseOperation"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteSimplifyReverseOperationTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testSimplifyReverseOperationNoRewrite() { + testRewriteSimplifyReverseOperation(false); + } + + @Test + public void testSimplifyReverseOperationRewrite() { + testRewriteSimplifyReverseOperation(true); + } + + private void testRewriteSimplifyReverseOperation(boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrix + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertTrue(heavyHittersContainsString(Opcodes.REV.toString()) && + !heavyHittersContainsAllString(Opcodes.MMULT.toString(), Opcodes.SEQUENCE.toString(), + Opcodes.CTABLEEXPAND.toString())); + else + Assert.assertTrue(heavyHittersContainsAllString(Opcodes.MMULT.toString(), Opcodes.SEQUENCE.toString(), + Opcodes.CTABLEEXPAND.toString())); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } + +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposedAppendTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposedAppendTest.java new file mode 100644 index 0000000000..2c6721bcec --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTransposedAppendTest.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyTransposedAppendTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteSimplifyTransposedAppend"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteSimplifyTransposedAppendTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testSimplifyTransposedAppendTransposeCBindNoRewrite() { + testRewriteSimplifyTransposedAppend(1, false); + } + + @Test + public void testSimplifyTransposedAppendTransposeCBindRewrite() { + testRewriteSimplifyTransposedAppend(1, true); // t(cbind(t(A),t(B))) --> rbind(A,B) + } + + @Test + public void testSimplifyTransposedAppendTransposeRBindNoRewrite() { + testRewriteSimplifyTransposedAppend(2, false); + } + + @Test + public void testSimplifyTransposedAppendTransposeRBindRewrite() { + testRewriteSimplifyTransposedAppend(2, true); // t(rbind(t(A),t(B))) --> cbind(A,B) + } + + private void testRewriteSimplifyTransposedAppend(int ID, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("A"), input("B"), String.valueOf(ID), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrices + double[][] A = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + double[][] B = getRandomMatrix(rows, cols, -1, 1, 0.60d, 3); + writeInputMatrixWithMTD("A", A, true); + writeInputMatrixWithMTD("B", B, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.TRANSPOSE.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.TRANSPOSE.toString())); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyUnaryAggReorgOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyUnaryAggReorgOperationTest.java new file mode 100644 index 0000000000..8c58a21a15 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyUnaryAggReorgOperationTest.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyUnaryAggReorgOperationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteSimplifyUnaryAggReorgOperation"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteSimplifyUnaryAggReorgOperationTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + @Test + public void testSimplifyUnaryAggReorgOperationNoRewrite() { + testRewriteSimplifyUnaryAggReorgOperation(false); + } + + @Test + public void testSimplifyUnaryAggReorgOperationRewrite() { + testRewriteSimplifyUnaryAggReorgOperation(true); // sum(t(X)) -> sum(X) + } + + private void testRewriteSimplifyUnaryAggReorgOperation(boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrix + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLScalarFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> rfile = readRScalarFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.TRANSPOSE.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.TRANSPOSE.toString())); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyUnaryPPredOperationTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyUnaryPPredOperationTest.java new file mode 100644 index 0000000000..940ec77f8c --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyUnaryPPredOperationTest.java @@ -0,0 +1,409 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteSimplifyUnaryPPredOperationTest extends AutomatedTestBase { + + private static final String TEST_NAME = "RewriteSimplifyUnaryPPredOperation"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = + TEST_DIR + RewriteSimplifyUnaryPPredOperationTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"})); + } + + /** + * (1) Rewrites for Less + {abs, round, ceil, floor, sign} + */ + @Test + public void testSimplifyUnaryPPredOperationLessAbsNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(1, false); + } + + @Test + public void testSimplifyUnaryPPredOperationLessAbsRewrite() { + testRewriteSimplifyUnaryPPredOperation(1, true); // abs(X<Y) -> (X<Y) + } + + @Test + public void testSimplifyUnaryPPredOperationLessRoundNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(2, false); + } + + @Test + public void testSimplifyUnaryPPredOperationLessRoundRewrite() { + testRewriteSimplifyUnaryPPredOperation(2, true); // round(X<Y) -> (X<Y) + } + + @Test + public void testSimplifyUnaryPPredOperationLessCeilNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(3, false); + } + + @Test + public void testSimplifyUnaryPPredOperationLessCeilRewrite() { + testRewriteSimplifyUnaryPPredOperation(3, true); // ceil(X<Y) -> (X<Y) + } + + @Test + public void testSimplifyUnaryPPredOperationLessFloorNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(4, false); + } + + @Test + public void testSimplifyUnaryPPredOperationLessFloorRewrite() { + testRewriteSimplifyUnaryPPredOperation(4, true); // floor(X<Y) -> (X<Y) + } + + @Test + public void testSimplifyUnaryPPredOperationLessSignNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(5, false); + } + + @Test + public void testSimplifyUnaryPPredOperationLessSignRewrite() { + testRewriteSimplifyUnaryPPredOperation(5, true); // sign(X<Y) -> (X<Y) + } + + /** + * (2) Rewrites for LessEqual + {abs, round, ceil, floor, sign} + */ + @Test + public void testSimplifyUnaryPPredOperationLessEqualAbsNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(6, false); + } + + @Test + public void testSimplifyUnaryPPredOperationLessEqualAbsRewrite() { + testRewriteSimplifyUnaryPPredOperation(6, true); // abs(X<=Y) -> (X<=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationLessEqualRoundNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(7, false); + } + + @Test + public void testSimplifyUnaryPPredOperationLessEqualRoundRewrite() { + testRewriteSimplifyUnaryPPredOperation(7, true); // round(X<=Y) -> (X<=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationLessEqualCeilNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(8, false); + } + + @Test + public void testSimplifyUnaryPPredOperationLessEqualCeilRewrite() { + testRewriteSimplifyUnaryPPredOperation(8, true); // ceil(X<=Y) -> (X<=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationLessEqualFloorNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(9, false); + } + + @Test + public void testSimplifyUnaryPPredOperationLessEqualFloorRewrite() { + testRewriteSimplifyUnaryPPredOperation(9, true); // floor(X<=Y) -> (X<=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationLessEqualSignNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(10, false); + } + + @Test + public void testSimplifyUnaryPPredOperationLessEqualSignRewrite() { + testRewriteSimplifyUnaryPPredOperation(10, true); // sign(X<=Y) -> (X<=Y) + } + + /** + * (3) Rewrites for Greater + {abs, round, ceil, floor, sign} + */ + @Test + public void testSimplifyUnaryPPredOperationGreaterAbsNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(11, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterAbsRewrite() { + testRewriteSimplifyUnaryPPredOperation(11, true); // abs(X>Y) -> (X>Y) + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterRoundNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(12, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterRoundRewrite() { + testRewriteSimplifyUnaryPPredOperation(12, true); // round(X>Y) -> (X>Y) + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterCeilNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(13, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterCeilRewrite() { + testRewriteSimplifyUnaryPPredOperation(13, true); // ceil(X>Y) -> (X>Y) + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterFloorNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(14, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterFloorRewrite() { + testRewriteSimplifyUnaryPPredOperation(14, true); // floor(X>Y) -> (X>Y) + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterSignNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(15, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterSignRewrite() { + testRewriteSimplifyUnaryPPredOperation(15, true); // sign(X>Y) -> (X>Y) + } + + /** + * (4) Rewrites for GreaterEqual + {abs, round, ceil, floor, sign} + */ + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualAbsNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(16, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualAbsRewrite() { + testRewriteSimplifyUnaryPPredOperation(16, true); // abs(X>=Y) -> (X>=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualRoundNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(17, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualRoundRewrite() { + testRewriteSimplifyUnaryPPredOperation(17, true); // round(X>=Y) -> (X>=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualCeilNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(18, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualCeilRewrite() { + testRewriteSimplifyUnaryPPredOperation(18, true); // ceil(X>=Y) -> (X>=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualFloorNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(19, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualFloorRewrite() { + testRewriteSimplifyUnaryPPredOperation(19, true); // floor(X>=Y) -> (X>=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualSignNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(20, false); + } + + @Test + public void testSimplifyUnaryPPredOperationGreaterEqualSignRewrite() { + testRewriteSimplifyUnaryPPredOperation(20, true); // sign(X>=Y) -> (X>=Y) + } + + /** + * (5) Rewrites for Equal + {abs, round, ceil, floor, sign} + */ + @Test + public void testSimplifyUnaryPPredOperationEqualAbsNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(21, false); + } + + @Test + public void testSimplifyUnaryPPredOperationEqualAbsRewrite() { + testRewriteSimplifyUnaryPPredOperation(21, true); // abs(X==Y) -> (X==Y) + } + + @Test + public void testSimplifyUnaryPPredOperationEqualRoundNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(22, false); + } + + @Test + public void testSimplifyUnaryPPredOperationEqualRoundRewrite() { + testRewriteSimplifyUnaryPPredOperation(22, true); // round(X==Y) -> (X==Y) + } + + @Test + public void testSimplifyUnaryPPredOperationEqualCeilNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(23, false); + } + + @Test + public void testSimplifyUnaryPPredOperationEqualCeilRewrite() { + testRewriteSimplifyUnaryPPredOperation(23, true); // ceil(X==Y) -> (X==Y) + } + + @Test + public void testSimplifyUnaryPPredOperationEqualFloorNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(24, false); + } + + @Test + public void testSimplifyUnaryPPredOperationEqualFloorRewrite() { + testRewriteSimplifyUnaryPPredOperation(24, true); // floor(X==Y) -> (X==Y) + } + + @Test + public void testSimplifyUnaryPPredOperationEqualSignNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(25, false); + } + + @Test + public void testSimplifyUnaryPPredOperationEqualSignRewrite() { + testRewriteSimplifyUnaryPPredOperation(25, true); // sign(X==Y) -> (X==Y) + } + + /** + * (6) Rewrites for NotEqual + {abs, round, ceil, floor, sign} + */ + @Test + public void testSimplifyUnaryPPredOperationNotEqualAbsNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(26, false); + } + + @Test + public void testSimplifyUnaryPPredOperationNotEqualAbsRewrite() { + testRewriteSimplifyUnaryPPredOperation(26, true); // abs(X!=Y) -> (X!=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationNotEqualRoundNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(27, false); + } + + @Test + public void testSimplifyUnaryPPredOperationNotEqualRoundRewrite() { + testRewriteSimplifyUnaryPPredOperation(27, true); // round(X!=Y) -> (X!=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationNotEqualCeilNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(28, false); + } + + @Test + public void testSimplifyUnaryPPredOperationNotEqualCeilRewrite() { + testRewriteSimplifyUnaryPPredOperation(28, true); // ceil(X!=Y) -> (X!=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationNotEqualFloorNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(29, false); + } + + @Test + public void testSimplifyUnaryPPredOperationNotEqualFloorRewrite() { + testRewriteSimplifyUnaryPPredOperation(29, true); // floor(X!=Y) -> (X!=Y) + } + + @Test + public void testSimplifyUnaryPPredOperationNotEqualSignNoRewrite() { + testRewriteSimplifyUnaryPPredOperation(30, false); + } + + @Test + public void testSimplifyUnaryPPredOperationNotEqualSignRewrite() { + testRewriteSimplifyUnaryPPredOperation(30, true); // sign(X!=Y) -> (X!=Y) + } + + private void testRewriteSimplifyUnaryPPredOperation(int ID, boolean rewrites) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(TEST_NAME); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "-args", input("X"), input("Y"), String.valueOf(ID), output("R")}; + fullRScriptName = HOME + TEST_NAME + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(ID), expectedDir()); + + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + // create and write matrices + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.70d, 5); + double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.60d, 3); + writeInputMatrixWithMTD("X", X, true); + writeInputMatrixWithMTD("Y", Y, true); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<MatrixValue.CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("R"); + HashMap<MatrixValue.CellIndex, Double> rfile = readRMatrixFromExpectedDir("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + if(rewrites) + Assert.assertFalse(heavyHittersContainsString(Opcodes.ABS.toString(), Opcodes.ROUND.toString(), + Opcodes.CEIL.toString(), Opcodes.FLOOR.toString(), Opcodes.SIGN.toString())); + else + Assert.assertTrue(heavyHittersContainsString(Opcodes.ABS.toString(), Opcodes.ROUND.toString(), + Opcodes.CEIL.toString(), Opcodes.FLOOR.toString(), Opcodes.SIGN.toString())); + + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } + +} diff --git a/src/test/scripts/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAdd.R b/src/test/scripts/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAdd.R new file mode 100644 index 0000000000..447f2ef747 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAdd.R @@ -0,0 +1,46 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrices U, V from Matrix Market format files +U = as.matrix(readMM(paste(args[1], "U.mtx", sep=""))) +V = as.matrix(readMM(paste(args[1], "V.mtx", sep=""))) +type = as.integer(args[2]) +eps = 0.5 + +# Perform the operations +if( type == 1 ) { + R = (eps + U%*%V) +} else if( type == 2 ) { + R = (U%*%V - eps) +} + + +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAdd.dml b/src/test/scripts/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAdd.dml new file mode 100644 index 0000000000..fd8a1db7de --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteCanonicalizeMatrixMultScalarAdd.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrices U, V, and operation type +U = read($1) +V = read($2) +type = $3 +eps = 0.5 + +# Perform operations +if(type==1){ + R = (eps + U%*%V)*1 # (eps + U%*%V) -> (U%*%V + eps) +} +else if(type==2){ + R = (U%*%V - eps) # (U%*%V - eps) -> (U%*%V + (-eps)) +} + +# Write the result matrix R +write(R, $4) diff --git a/src/test/scripts/functions/rewrite/RewriteFuseOrderOperationChain.R b/src/test/scripts/functions/rewrite/RewriteFuseOrderOperationChain.R new file mode 100644 index 0000000000..b8c4e42986 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteFuseOrderOperationChain.R @@ -0,0 +1,43 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrix +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) + +# Perform operation +# 1) Sort X by column 2 +temp = X[order(X[, 2]), ] + +# 2) Sort the result by column 1 +R = temp[order(temp[, 1]), ] + + +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteFuseOrderOperationChain.dml b/src/test/scripts/functions/rewrite/RewriteFuseOrderOperationChain.dml new file mode 100644 index 0000000000..746b3a796a --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteFuseOrderOperationChain.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrix +X = read($1) + +# Perform operation +R = order(target=order(target=X, by=2), by=1) + +# Write the result matrix R +write(R, $2) diff --git a/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperation.R b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperation.R new file mode 100644 index 0000000000..79ffbe6072 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperation.R @@ -0,0 +1,51 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrix and operation type +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) +type = as.integer(args[2]) + +# Perform operations +if(type==1){ + R = X/1 +} else if(type==2){ + R = X*1 +} else if(type==3){ + R = 1*X +} else if(type==4){ + R = X-0 +} else if(type==5){ + R = -1*X +} else if(type==6){ + R = X * -1 +} + +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperation.dml b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperation.dml new file mode 100644 index 0000000000..6ffe4e543a --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryBinaryOperation.dml @@ -0,0 +1,47 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrix X and operation type +X = read($1) +type = $2 + +# Perform operations +if(type==1){ + R = X/1 +} +else if(type==2){ + R = X*1 +} +else if(type==3){ + R = 1*X +} +else if(type==4){ + R = X-0 +} +else if(type==5){ + R = -1*X +} +else if(type==6){ + R = X * -1 +} + +# Write the result matrix R +write(R, $3) diff --git a/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryMinus.R b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryMinus.R new file mode 100644 index 0000000000..5c4c14ab13 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryMinus.R @@ -0,0 +1,39 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrix +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) + + +# Perform operation +R = -(-X) + +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryMinus.dml b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryMinus.dml new file mode 100644 index 0000000000..fa2262529e --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryMinus.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrix +X = read($1) + +# Perform operation +R = -(-X) # -(-X) -> X + +# Write the result matrix R +write(R, $2) diff --git a/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryReorgOperation.R b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryReorgOperation.R new file mode 100644 index 0000000000..0fec789ada --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryReorgOperation.R @@ -0,0 +1,44 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrix and operation type +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) +type = as.integer(args[2]) + + +# Perform operations +if(type==1){ + R = t(t(X)) +} else if(type==2) { + R = X[nrow(X):1, ][nrow(X):1, ] + +} +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryReorgOperation.dml b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryReorgOperation.dml new file mode 100644 index 0000000000..90774083f2 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryReorgOperation.dml @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrix and operation type +X = read($1) +type = $2 + +# Perform operations +if(type==1){ + R = t(t(X))*1 # t(t(X)) -> X +} +else if(type==2) { + R = rev(rev(X)) # rev(rev(X)) -> X +} + +# Write the result matrix R +write(R, $3) diff --git a/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperation.R b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperation.R new file mode 100644 index 0000000000..dd409e29fc --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperation.R @@ -0,0 +1,45 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrix and type +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) +type = as.integer(args[2]) + +Y = matrix(1, nrow(X), ncol(X)) + +# Perform operations +if(type==1){ + R = Y/X # Left vectorized scalar +} else if (type==2){ + R = X/Y # Right vectorized scalar +} + +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperation.dml b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperation.dml new file mode 100644 index 0000000000..93f3ddb8ff --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteRemoveUnnecessaryVectorizeOperation.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrix X +X = read($1) +Y = matrix(1,nrow(X),ncol(X)) + +type = $2 + +# Perform operations +if(type==1){ + R = Y/X # Left vectorized scalar +} +else if(type==2){ + R = X/Y # Right vectorized scalar +} + +# Write the result matrix R +write(R, $3) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperation.R b/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperation.R new file mode 100644 index 0000000000..2aa799dd18 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperation.R @@ -0,0 +1,50 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read operation type +type = as.integer(args[2]) + +# Create variables +X = matrix(1,1,1) +Y = matrix(2,1,1) +s = 2 + +# Perform the operations +if(type==1){ + R = as.numeric(X*Y) +} else if(type==2){ + R = as.numeric(X*s) +} else if(type==3){ + R = as.numeric(s*X) +} + + +write(R, paste(args[3], "R" ,sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperation.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperation.dml new file mode 100644 index 0000000000..18da3c4f4f --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryMatrixScalarOperation.dml @@ -0,0 +1,42 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read operation type +type = $1 + +# Create variables +X = matrix(1,1,1) +Y = matrix(2,1,1) +s = 2 + +# Perform operations +if(type==1){ + R = as.scalar(X*Y) # -> as.scalar(X) * as.scalar(Y) +} +else if(type==2){ + R = as.scalar(X*s) # -> as.scalar(X) * s +} +else if(type==3){ + R = as.scalar(s*X)*1 # -> s * as.scalar(X) +} + +# Write the result +write(R, $2) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryToUnaryOperation.R b/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryToUnaryOperation.R new file mode 100644 index 0000000000..91398bc3b1 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryToUnaryOperation.R @@ -0,0 +1,45 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrix and operation type +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) +type = as.integer(args[2]) + +# Perform operations +if(type==1){ + R = X+X +} else if(type==2){ + R = X*X +} else if(type==3){ + R = (X>0) - (X<0) +} + +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryToUnaryOperation.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryToUnaryOperation.dml new file mode 100644 index 0000000000..f60899ff77 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyBinaryToUnaryOperation.dml @@ -0,0 +1,38 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrix X and operation type +X = read($1) +type = $2 + +# Perform operations +if(type==1){ + R = X+X # X+X -> X*2 +} +else if(type==2){ + R = X*X # X*X -> X² +} +else if(type==3){ + R = (X>0)-(X<0) # (X>0)-(X<0) -> sign(X) +} + +# Write the result matrix R +write(R, $3) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregates.R b/src/test/scripts/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregates.R new file mode 100644 index 0000000000..e01a6f2808 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregates.R @@ -0,0 +1,38 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrix +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) + +# Perform operation +R = t(as.matrix(colSums(apply(X, 2, cumsum)))) + +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregates.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregates.dml new file mode 100644 index 0000000000..8af853c84d --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyCumsumColOrFullAggregates.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrix +X = read($1) + +# Perform operation +R = colSums(cumsum(X)) + +# Write the result matrix R +write(R, $2) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperation.R b/src/test/scripts/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperation.R new file mode 100644 index 0000000000..c366852af7 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperation.R @@ -0,0 +1,39 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrices X, Y from Matrix Market format files +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) +Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep=""))) + +# Perform the operation +R = 1-(X*Y) + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperation.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperation.dml new file mode 100644 index 0000000000..2ec9e0c5bb --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyMultiBinaryToBinaryOperation.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrices X, Y +X = read($1) +Y = read($2) + +# Perform operation +R = 1-(X*Y) + +# Write the result matrix R +write(R, $3) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.R b/src/test/scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.R new file mode 100644 index 0000000000..bec60ac994 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.R @@ -0,0 +1,45 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read operation type +type = as.integer(args[2]) +m = 5 +v = matrix(1, 6, 1) + +# Perform operations +if(type==1){ + R = outer(as.vector(v), 1:m, "==") +} else if(type==2){ + R = outer(1:m, as.vector(v), "==") + +} + +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.dml new file mode 100644 index 0000000000..1e5d0ef9c6 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyOuterSeqExpand.dml @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read operation type +type = $1 +m = 5 +v = matrix(1, 6, 1) + +# Perform operations +if(type==1){ + R = outer(v, t(seq(1,m)), "==") +} +else if(type==2){ + R = outer(seq(1,m), t(v), "==") +} + +# Write the result matrix R +write(R, $2) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyReverseOperation.R b/src/test/scripts/functions/rewrite/RewriteSimplifyReverseOperation.R new file mode 100644 index 0000000000..1bd054186a --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyReverseOperation.R @@ -0,0 +1,38 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrix X from Matrix Market format files +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) + +# Perform the operation +R = table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyReverseOperation.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyReverseOperation.dml new file mode 100644 index 0000000000..cb3004ff9d --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyReverseOperation.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrix X +X = read($1) + +# Perform operation +R = table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X # Rewrite -> rev(X) + +# Write the result matrix R +write(R, $2) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyTransposedAppend.R b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposedAppend.R new file mode 100644 index 0000000000..c0c59dcfc0 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposedAppend.R @@ -0,0 +1,45 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrices and operation type +A = as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +B = as.matrix(readMM(paste(args[1], "B.mtx", sep=""))) +type = as.integer(args[2]) + + +# Perform operations +if(type==1){ + R = t(cbind(t(A),t(B))) +} else if(type==2) { + R = t(rbind(t(A),t(B))) +} + +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyTransposedAppend.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposedAppend.dml new file mode 100644 index 0000000000..754613071b --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTransposedAppend.dml @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrices A, B, and operation type +A = read($1) +B = read($2) +type = $3 + +# Perform operations +if(type==1){ + R = t(cbind(t(A),t(B))) # -> rbind(A, B) +} +else if(type==2) { + R = t(rbind(t(A),t(B))) # -> cbind(A, B) +} + +# Write the result matrix R +write(R, $4) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryAggReorgOperation.R b/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryAggReorgOperation.R new file mode 100644 index 0000000000..0faf876951 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryAggReorgOperation.R @@ -0,0 +1,38 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrix X from Matrix Market format files +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) + +# Perform the operation +R = sum(t(X)) + +write(R, paste(args[2], "R" ,sep="")) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryAggReorgOperation.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryAggReorgOperation.dml new file mode 100644 index 0000000000..abff755d73 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryAggReorgOperation.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrix X +X = read($1) + +# Perform operation +R = sum(t(X)) + +# Write the result matrix R +write(R, $2) diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryPPredOperation.R b/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryPPredOperation.R new file mode 100644 index 0000000000..f7b27a5db1 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryPPredOperation.R @@ -0,0 +1,101 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) + +# Set options for numeric precision +options(digits=22) + +# Load required libraries +library("Matrix") +library("matrixStats") + +# Read matrices and operation type +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) +Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep=""))) +type = as.integer(args[2]) + +# Perform operations +# (1) Less +if(type==1){ + R = abs(X<Y) +} else if(type==2){ + R = round(X<Y) +} else if(type==3){ + R = ceiling(X<Y) +} else if(type==4){ + R = floor(X<Y) +} else if(type==5){ + R = sign(X<Y) +} else if(type==6){ # (2) Less-Equal + R = abs(X<=Y) +} else if(type==7){ + R = round(X<=Y) +} else if(type==8){ + R = ceiling(X<=Y) +} else if(type==9){ + R = floor(X<=Y) +} else if(type==10){ + R = sign(X<=Y) +} else if(type==11){ # (3) Greater + R = abs(X>Y) +} else if(type==12){ + R = round(X>Y) +} else if(type==13){ + R = ceiling(X>Y) +} else if(type==14){ + R = floor(X>Y) +} else if(type==15){ + R = sign(X>Y) +} else if(type==16){ # (4) Greater-Equal + R = abs(X>=Y) +} else if(type==17){ + R = round(X>=Y) +} else if(type==18){ + R = ceiling(X>=Y) +} else if(type==19){ + R = floor(X>=Y) +} else if(type==20){ + R = sign(X>=Y) +} else if(type==21){ # (5) Equal + R = abs(X==Y) +} else if(type==22){ + R = round(X==Y) +} else if(type==23){ + R = ceiling(X==Y) +} else if(type==24){ + R = floor(X==Y) +} else if(type==25){ + R = sign(X==Y) +} else if(type==26){ # (6) Not-Equal + R = abs(X!=Y) +} else if(type==27){ + R = round(X!=Y) +} else if(type==28){ + R = ceiling(X!=Y) +} else if(type==29){ + R = floor(X!=Y) +} else if(type==30){ + R = sign(X!=Y) +} + +#Write result matrix R +writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep="")); diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryPPredOperation.dml b/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryPPredOperation.dml new file mode 100644 index 0000000000..e016bc65e3 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteSimplifyUnaryPPredOperation.dml @@ -0,0 +1,127 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read matrices X, Y, and operation type +X = read($1) +Y = read($2) +type = $3 + +# Perform operations +# (1) Less +if(type==1){ + R = abs(X<Y) +} +else if(type==2){ + R = round(X<Y) +} +else if(type==3){ + R = ceil(X<Y) +} +else if(type==4){ + R = floor(X<Y) +} +else if(type==5){ + R = sign(X<Y) +} +# (2) Less-Equal +else if(type==6){ + R = abs(X<=Y) +} +else if(type==7){ + R = round(X<=Y) +} +else if(type==8){ + R = ceil(X<=Y) +} +else if(type==9){ + R = floor(X<=Y) +} +else if(type==10){ + R = sign(X<=Y) +} +# (3) Greater +else if(type==11){ + R = abs(X>Y) +} +else if(type==12){ + R = round(X>Y) +} +else if(type==13){ + R = ceil(X>Y) +} +else if(type==14){ + R = floor(X>Y) +} +else if(type==15){ + R = sign(X>Y) +} +# (4) Greater-Equal +else if(type==16){ + R = abs(X>=Y) +} +else if(type==17){ + R = round(X>=Y) +} +else if(type==18){ + R = ceil(X>=Y) +} +else if(type==19){ + R = floor(X>=Y) +} +else if(type==20){ + R = sign(X>=Y) +} +# (5) Equal +else if(type==21){ + R = abs(X==Y) +} +else if(type==22){ + R = round(X==Y) +} +else if(type==23){ + R = ceil(X==Y) +} +else if(type==24){ + R = floor(X==Y) +} +else if(type==25){ + R = sign(X==Y) +} +# (6) Not-Equal +else if(type==26){ + R = abs(X!=Y) +} +else if(type==27){ + R = round(X!=Y) +} +else if(type==28){ + R = ceil(X!=Y) +} +else if(type==29){ + R = floor(X!=Y) +} +else if(type==30){ + R = sign(X!=Y) +} + + +# Write the result matrix R +write(R, $4)