Fix RewriteEMult comparator. Add tests.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/eb0599df Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/eb0599df Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/eb0599df Branch: refs/heads/master Commit: eb0599df4c3bcca15531b85a3d870a26e4653179 Parents: 7d57883 Author: Dylan Hutchison <[email protected]> Authored: Fri Jun 9 11:18:32 2017 -0700 Committer: Dylan Hutchison <[email protected]> Committed: Sun Jun 18 17:43:15 2017 -0700 ---------------------------------------------------------------------- .../org/apache/sysml/hops/OptimizerUtils.java | 9 +- .../sysml/hops/rewrite/ProgramRewriter.java | 3 +- .../apache/sysml/hops/rewrite/RewriteEMult.java | 10 +- .../functions/misc/RewriteEMultChainTest.java | 127 +++++++++ .../ternary/ABATernaryAggregateTest.java | 268 +++++++++++++++++++ .../functions/misc/RewriteEMultChainOp.R | 33 +++ .../functions/misc/RewriteEMultChainOp.dml | 28 ++ .../functions/ternary/ABATernaryAggregateC.R | 32 +++ .../functions/ternary/ABATernaryAggregateC.dml | 30 +++ .../functions/ternary/ABATernaryAggregateRC.R | 33 +++ .../functions/ternary/ABATernaryAggregateRC.dml | 30 +++ 11 files changed, 597 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/main/java/org/apache/sysml/hops/OptimizerUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java index a40e36c..2a76d07 100644 --- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java +++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java @@ -110,8 +110,13 @@ public class OptimizerUtils */ public static boolean ALLOW_CONSTANT_FOLDING = true; - public static boolean ALLOW_ALGEBRAIC_SIMPLIFICATION = true; - public static boolean ALLOW_OPERATOR_FUSION = true; + public static boolean ALLOW_ALGEBRAIC_SIMPLIFICATION = true; + /** + * Enables rewriting chains of element-wise multiplies that contain the same multiplicand more than once, as in + * `A*B*A ==> (A^2)*B`. + */ + public static boolean ALLOW_EMULT_CHAIN_REWRITE = true; + public static boolean ALLOW_OPERATOR_FUSION = true; /** * Enables if-else branch removal for constant predicates (original literals or http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java index 8573dd7..b6aab38 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java @@ -96,7 +96,8 @@ public class ProgramRewriter _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() ); if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) _dagRuleSet.add( new RewriteCommonSubexpressionElimination() ); - _dagRuleSet.add( new RewriteEMult() ); //dependency: cse + if ( OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE ) + _dagRuleSet.add( new RewriteEMult() ); //dependency: cse if( OptimizerUtils.ALLOW_CONSTANT_FOLDING ) _dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java index 47c32a9..2c9e5cb 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java @@ -50,7 +50,6 @@ public class RewriteEMult extends HopRewriteRule { public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException { if( roots == null ) return null; - for( int i=0; i<roots.size(); i++ ) { Hop h = roots.get(i); roots.set(i, rule_RewriteEMult(h)); @@ -83,6 +82,7 @@ public class RewriteEMult extends HopRewriteRule { final Set<BinaryOp> emults = new HashSet<>(); final Multiset<Hop> leaves = HashMultiset.create(); findEMultsAndLeaves(r, emults, leaves); + // 2. Ensure it is profitable to do a rewrite. if (isOptimizable(leaves)) { // 3. Check for foreign parents. @@ -93,8 +93,12 @@ public class RewriteEMult extends HopRewriteRule { if (okay) { // 4. Construct replacement EMults for the leaves final Hop replacement = constructReplacement(leaves); - // 5. Replace root with replacement + if (LOG.isDebugEnabled()) + LOG.debug(String.format( + "Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d", + emults.size(), root.getHopID(), replacement.getHopID())); + replacement.setVisited(); return HopRewriteUtils.replaceHop(root, replacement); } } @@ -141,7 +145,7 @@ public class RewriteEMult extends HopRewriteRule { return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW); } - private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType); + private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType).thenComparing(Object::hashCode); private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) { final ArrayList<Hop> parents = child.getParent(); http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java new file mode 100644 index 0000000..e076c95 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import java.util.HashMap; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +/** + * Test whether `A*B*A` successfully rewrites to `(A^2)*B`. + */ +public class RewriteEMultChainTest extends AutomatedTestBase +{ + private static final String TEST_NAME1 = "RewriteEMultChainOp"; + private static final String TEST_DIR = "functions/misc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteEMultChainTest.class.getSimpleName() + "/"; + + private static final int rows = 123; + private static final int cols = 321; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + } + + @Test + public void testMatrixMultChainOptNoRewritesCP() { + testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP); + } + + @Test + public void testMatrixMultChainOptNoRewritesSP() { + testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK); + } + + @Test + public void testMatrixMultChainOptRewritesCP() { + testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP); + } + + @Test + public void testMatrixMultChainOptRewritesSP() { + testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK); + } + + private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et) + { + RUNTIME_PLATFORM platformOld = rtplatform; + switch( et ){ + case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; + case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break; + } + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if( rtplatform == RUNTIME_PLATFORM.SPARK ) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + boolean rewritesOld = OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE; + OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewrites; + + try + { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{ "-explain", "hops", "-stats", + "-args", input("X"), input("Y"), output("R") }; + fullRScriptName = HOME + testname + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.97d, 7); + double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.9d, 3); + writeInputMatrixWithMTD("X", X, true); + writeInputMatrixWithMTD("Y", Y, true); + + //execute tests + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R"); + HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + //check for presence of power operator, if we did a rewrite + if( rewrites ) { + Assert.assertTrue(heavyHittersContainsSubString("^2")); + } + } + finally { + OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewritesOld; + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java new file mode 100644 index 0000000..198e9f4 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.ternary; + +import java.util.HashMap; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.runtime.instructions.Instruction; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; +import org.apache.sysml.utils.Statistics; +import org.junit.Assert; +import org.junit.Test; + +/** + * Similar to {@link TernaryAggregateTest} except that it tests `sum(A*B*A)`. + * Checks compatibility with {@link org.apache.sysml.hops.rewrite.RewriteEMult}. + */ +public class ABATernaryAggregateTest extends AutomatedTestBase +{ + private final static String TEST_NAME1 = "ABATernaryAggregateRC"; + private final static String TEST_NAME2 = "ABATernaryAggregateC"; + + private final static String TEST_DIR = "functions/ternary/"; + private final static String TEST_CLASS_DIR = TEST_DIR + ABATernaryAggregateTest.class.getSimpleName() + "/"; + private final static double eps = 1e-8; + + private final static int rows = 1111; + private final static int cols = 1011; + + private final static double sparsity1 = 0.7; + private final static double sparsity2 = 0.3; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); + } + + @Test + public void testTernaryAggregateRCDenseVectorCP() { + runTernaryAggregateTest(TEST_NAME1, false, true, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCSparseVectorCP() { + runTernaryAggregateTest(TEST_NAME1, true, true, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCDenseMatrixCP() { + runTernaryAggregateTest(TEST_NAME1, false, false, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCSparseMatrixCP() { + runTernaryAggregateTest(TEST_NAME1, true, false, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCDenseVectorSP() { + runTernaryAggregateTest(TEST_NAME1, false, true, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateRCSparseVectorSP() { + runTernaryAggregateTest(TEST_NAME1, true, true, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateRCDenseMatrixSP() { + runTernaryAggregateTest(TEST_NAME1, false, false, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateRCSparseMatrixSP() { + runTernaryAggregateTest(TEST_NAME1, true, false, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateRCDenseVectorMR() { + runTernaryAggregateTest(TEST_NAME1, false, true, true, ExecType.MR); + } + + @Test + public void testTernaryAggregateRCSparseVectorMR() { + runTernaryAggregateTest(TEST_NAME1, true, true, true, ExecType.MR); + } + + @Test + public void testTernaryAggregateRCDenseMatrixMR() { + runTernaryAggregateTest(TEST_NAME1, false, false, true, ExecType.MR); + } + + @Test + public void testTernaryAggregateRCSparseMatrixMR() { + runTernaryAggregateTest(TEST_NAME1, true, false, true, ExecType.MR); + } + + @Test + public void testTernaryAggregateCDenseVectorCP() { + runTernaryAggregateTest(TEST_NAME2, false, true, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateCSparseVectorCP() { + runTernaryAggregateTest(TEST_NAME2, true, true, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateCDenseMatrixCP() { + runTernaryAggregateTest(TEST_NAME2, false, false, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateCSparseMatrixCP() { + runTernaryAggregateTest(TEST_NAME2, true, false, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateCDenseVectorSP() { + runTernaryAggregateTest(TEST_NAME2, false, true, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateCSparseVectorSP() { + runTernaryAggregateTest(TEST_NAME2, true, true, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateCDenseMatrixSP() { + runTernaryAggregateTest(TEST_NAME2, false, false, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateCSparseMatrixSP() { + runTernaryAggregateTest(TEST_NAME2, true, false, true, ExecType.SPARK); + } + + //additional tests to check default without rewrites + + @Test + public void testTernaryAggregateRCDenseVectorCPNoRewrite() { + runTernaryAggregateTest(TEST_NAME1, false, true, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCSparseVectorCPNoRewrite() { + runTernaryAggregateTest(TEST_NAME1, true, true, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCDenseMatrixCPNoRewrite() { + runTernaryAggregateTest(TEST_NAME1, false, false, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateRCSparseMatrixCPNoRewrite() { + runTernaryAggregateTest(TEST_NAME1, true, false, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateCDenseVectorCPNoRewrite() { + runTernaryAggregateTest(TEST_NAME2, false, true, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateCSparseVectorCPNoRewrite() { + runTernaryAggregateTest(TEST_NAME2, true, true, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateCDenseMatrixCPNoRewrite() { + runTernaryAggregateTest(TEST_NAME2, false, false, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateCSparseMatrixCPNoRewrite() { + runTernaryAggregateTest(TEST_NAME2, true, false, false, ExecType.CP); + } + + + + private void runTernaryAggregateTest(String testname, boolean sparse, boolean vectors, boolean rewrites, ExecType et) + { + //rtplatform for MR + RUNTIME_PLATFORM platformOld = rtplatform; + switch( et ){ + case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; + case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID; break; + } + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if( rtplatform == RUNTIME_PLATFORM.SPARK ) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES, + rewritesOldEmult = OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE; + + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites; + OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewrites; + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{"-explain","hops","-stats","-args", input("A"), output("R")}; + + fullRScriptName = HOME + testname + ".R"; + rCmd = "Rscript" + " " + fullRScriptName + " " + + inputDir() + " " + expectedDir(); + + //generate actual dataset + double sparsity = sparse ? sparsity2 : sparsity1; + double[][] A = getRandomMatrix(vectors ? rows*cols : rows, + vectors ? 1 : cols, 0, 1, sparsity, 17); + writeInputMatrixWithMTD("A", A, true); + + //run test cases + runTest(true, false, null, -1); + runRScript(true); + + //compare output matrices + HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R"); + HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + //check for rewritten patterns in statistics output + if( rewrites && et != ExecType.MR ) { + String opcode = ((et == ExecType.SPARK) ? Instruction.SP_INST_PREFIX : "") + + (((testname.equals(TEST_NAME1) || vectors ) ? "tak+*" : "tack+*")); + Assert.assertTrue(Statistics.getCPHeavyHitterOpCodes().contains(opcode)); + } + } + finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld; + OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewritesOldEmult; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/misc/RewriteEMultChainOp.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.R b/src/test/scripts/functions/misc/RewriteEMultChainOp.R new file mode 100644 index 0000000..6d94cc8 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteEMultChainOp.R @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +args <- commandArgs(TRUE) +options(digits=22) +library("Matrix") +library("matrixStats") + +X = as.matrix(readMM(paste(args[1], "X.mtx", sep=""))) +Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep=""))) + +R = X * Y * X; + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/misc/RewriteEMultChainOp.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.dml b/src/test/scripts/functions/misc/RewriteEMultChainOp.dml new file mode 100644 index 0000000..3992403 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteEMultChainOp.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +X = read($1); +Y = read($2); + +R = X * Y * X; + +write(R, $3); \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/ternary/ABATernaryAggregateC.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateC.R b/src/test/scripts/functions/ternary/ABATernaryAggregateC.R new file mode 100644 index 0000000..9601089 --- /dev/null +++ b/src/test/scripts/functions/ternary/ABATernaryAggregateC.R @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) +options(digits=22) + +library("Matrix") + +A = as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +B = A * 2; + +R = t(as.matrix(colSums(A * B * A))); + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml b/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml new file mode 100644 index 0000000..78285af --- /dev/null +++ b/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = read($1); +B = A * 2; +C = A * 3; + +if(1==1){} + +R = colSums(A * B * A); + +write(R, $2); http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R new file mode 100644 index 0000000..6552c7e --- /dev/null +++ b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) +options(digits=22) + +library("Matrix") + +A = as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +B = A * 2; + +s = sum(A * B * A); +R = as.matrix(s); + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml new file mode 100644 index 0000000..965c8d3 --- /dev/null +++ b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = read($1); +B = A * 2; + +if(1==1){} + +s = sum(A * B * A); +R = as.matrix(s); + +write(R, $2); \ No newline at end of file
