Fix RewriteEMult comparator. Add tests.

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/eb0599df
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/eb0599df
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/eb0599df

Branch: refs/heads/master
Commit: eb0599df4c3bcca15531b85a3d870a26e4653179
Parents: 7d57883
Author: Dylan Hutchison <[email protected]>
Authored: Fri Jun 9 11:18:32 2017 -0700
Committer: Dylan Hutchison <[email protected]>
Committed: Sun Jun 18 17:43:15 2017 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/hops/OptimizerUtils.java   |   9 +-
 .../sysml/hops/rewrite/ProgramRewriter.java     |   3 +-
 .../apache/sysml/hops/rewrite/RewriteEMult.java |  10 +-
 .../functions/misc/RewriteEMultChainTest.java   | 127 +++++++++
 .../ternary/ABATernaryAggregateTest.java        | 268 +++++++++++++++++++
 .../functions/misc/RewriteEMultChainOp.R        |  33 +++
 .../functions/misc/RewriteEMultChainOp.dml      |  28 ++
 .../functions/ternary/ABATernaryAggregateC.R    |  32 +++
 .../functions/ternary/ABATernaryAggregateC.dml  |  30 +++
 .../functions/ternary/ABATernaryAggregateRC.R   |  33 +++
 .../functions/ternary/ABATernaryAggregateRC.dml |  30 +++
 11 files changed, 597 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java 
b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
index a40e36c..2a76d07 100644
--- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
@@ -110,8 +110,13 @@ public class OptimizerUtils
         */
        public static boolean ALLOW_CONSTANT_FOLDING = true;
        
-       public static boolean ALLOW_ALGEBRAIC_SIMPLIFICATION = true; 
-       public static boolean ALLOW_OPERATOR_FUSION = true; 
+       public static boolean ALLOW_ALGEBRAIC_SIMPLIFICATION = true;
+       /**
+        * Enables rewriting chains of element-wise multiplies that contain the 
same multiplicand more than once, as in
+        * `A*B*A ==> (A^2)*B`.
+        */
+       public static boolean ALLOW_EMULT_CHAIN_REWRITE = true;
+       public static boolean ALLOW_OPERATOR_FUSION = true;
        
        /**
         * Enables if-else branch removal for constant predicates (original 
literals or 

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 8573dd7..b6aab38 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -96,7 +96,8 @@ public class ProgramRewriter
                        _dagRuleSet.add(     new 
RewriteRemoveUnnecessaryCasts()             );         
                        if( 
OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
                                _dagRuleSet.add( new 
RewriteCommonSubexpressionElimination()     );
-                       _dagRuleSet.add( new RewriteEMult()                     
             ); //dependency: cse
+                       if ( OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE )
+                               _dagRuleSet.add( new RewriteEMult()             
                 ); //dependency: cse
                        if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
                                _dagRuleSet.add( new RewriteConstantFolding()   
                 ); //dependency: cse
                        if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
index 47c32a9..2c9e5cb 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -50,7 +50,6 @@ public class RewriteEMult extends HopRewriteRule {
        public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, 
ProgramRewriteStatus state) throws HopsException {
                if( roots == null )
                        return null;
-
                for( int i=0; i<roots.size(); i++ ) {
                        Hop h = roots.get(i);
                        roots.set(i, rule_RewriteEMult(h));
@@ -83,6 +82,7 @@ public class RewriteEMult extends HopRewriteRule {
                        final Set<BinaryOp> emults = new HashSet<>();
                        final Multiset<Hop> leaves = HashMultiset.create();
                        findEMultsAndLeaves(r, emults, leaves);
+
                        // 2. Ensure it is profitable to do a rewrite.
                        if (isOptimizable(leaves)) {
                                // 3. Check for foreign parents.
@@ -93,8 +93,12 @@ public class RewriteEMult extends HopRewriteRule {
                                if (okay) {
                                        // 4. Construct replacement EMults for 
the leaves
                                        final Hop replacement = 
constructReplacement(leaves);
-
                                        // 5. Replace root with replacement
+                                       if (LOG.isDebugEnabled())
+                                               LOG.debug(String.format(
+                                                               "Element-wise 
multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d",
+                                                               emults.size(), 
root.getHopID(), replacement.getHopID()));
+                                       replacement.setVisited();
                                        return HopRewriteUtils.replaceHop(root, 
replacement);
                                }
                        }
@@ -141,7 +145,7 @@ public class RewriteEMult extends HopRewriteRule {
                return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), 
Hop.OpOp2.POW);
        }
 
-       private static Comparator<Hop> compareByDataType = 
Comparator.comparing(Hop::getDataType);
+       private static Comparator<Hop> compareByDataType = 
Comparator.comparing(Hop::getDataType).thenComparing(Object::hashCode);
 
        private static boolean checkForeignParent(final Set<BinaryOp> emults, 
final BinaryOp child) {
                final ArrayList<Hop> parents = child.getParent();

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
new file mode 100644
index 0000000..e076c95
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test whether `A*B*A` successfully rewrites to `(A^2)*B`.
+ */
+public class RewriteEMultChainTest extends AutomatedTestBase
+{
+       private static final String TEST_NAME1 = "RewriteEMultChainOp";
+       private static final String TEST_DIR = "functions/misc/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteEMultChainTest.class.getSimpleName() + "/";
+       
+       private static final int rows = 123;
+       private static final int cols = 321;
+       private static final double eps = Math.pow(10, -10);
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+       }
+
+       @Test
+       public void testMatrixMultChainOptNoRewritesCP() {
+               testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
+       }
+       
+       @Test
+       public void testMatrixMultChainOptNoRewritesSP() {
+               testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testMatrixMultChainOptRewritesCP() {
+               testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
+       }
+       
+       @Test
+       public void testMatrixMultChainOptRewritesSP() {
+               testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
+       }
+
+       private void testRewriteMatrixMultChainOp(String testname, boolean 
rewrites, ExecType et)
+       {       
+               RUNTIME_PLATFORM platformOld = rtplatform;
+               switch( et ){
+                       case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+                       case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+                       default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; 
break;
+               }
+               
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               if( rtplatform == RUNTIME_PLATFORM.SPARK )
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               
+               boolean rewritesOld = OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE;
+               OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewrites;
+               
+               try
+               {
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[]{ "-explain", "hops", 
"-stats", 
+                               "-args", input("X"), input("Y"), output("R") };
+                       fullRScriptName = HOME + testname + ".R";
+                       rCmd = getRCmd(inputDir(), expectedDir());              
        
+
+                       double[][] X = getRandomMatrix(rows, cols, -1, 1, 
0.97d, 7);
+                       double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.9d, 
3);
+                       writeInputMatrixWithMTD("X", X, true);
+                       writeInputMatrixWithMTD("Y", Y, true);
+                       
+                       //execute tests
+                       runTest(true, false, null, -1); 
+                       runRScript(true); 
+                       
+                       //compare matrices 
+                       HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS("R");
+                       HashMap<CellIndex, Double> rfile  = 
readRMatrixFromFS("R");
+                       TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
+                       
+                       //check for presence of power operator, if we did a 
rewrite
+                       if( rewrites ) {
+                               
Assert.assertTrue(heavyHittersContainsSubString("^2"));
+                       }
+               }
+               finally {
+                       OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewritesOld;
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               }
+       }       
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
new file mode 100644
index 0000000..198e9f4
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.ternary;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.instructions.Instruction;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.apache.sysml.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Similar to {@link TernaryAggregateTest} except that it tests `sum(A*B*A)`.
+ * Checks compatibility with {@link 
org.apache.sysml.hops.rewrite.RewriteEMult}.
+ */
+public class ABATernaryAggregateTest extends AutomatedTestBase
+{
+       private final static String TEST_NAME1 = "ABATernaryAggregateRC";
+       private final static String TEST_NAME2 = "ABATernaryAggregateC";
+       
+       private final static String TEST_DIR = "functions/ternary/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
ABATernaryAggregateTest.class.getSimpleName() + "/";
+       private final static double eps = 1e-8;
+       
+       private final static int rows = 1111;
+       private final static int cols = 1011;
+       
+       private final static double sparsity1 = 0.7;
+       private final static double sparsity2 = 0.3;
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); 
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); 
+       }
+
+       @Test
+       public void testTernaryAggregateRCDenseVectorCP() {
+               runTernaryAggregateTest(TEST_NAME1, false, true, true, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testTernaryAggregateRCSparseVectorCP() {
+               runTernaryAggregateTest(TEST_NAME1, true, true, true, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testTernaryAggregateRCDenseMatrixCP() {
+               runTernaryAggregateTest(TEST_NAME1, false, false, true, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testTernaryAggregateRCSparseMatrixCP() {
+               runTernaryAggregateTest(TEST_NAME1, true, false, true, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testTernaryAggregateRCDenseVectorSP() {
+               runTernaryAggregateTest(TEST_NAME1, false, true, true, 
ExecType.SPARK);
+       }
+       
+       @Test
+       public void testTernaryAggregateRCSparseVectorSP() {
+               runTernaryAggregateTest(TEST_NAME1, true, true, true, 
ExecType.SPARK);
+       }
+       
+       @Test
+       public void testTernaryAggregateRCDenseMatrixSP() {
+               runTernaryAggregateTest(TEST_NAME1, false, false, true, 
ExecType.SPARK);
+       }
+       
+       @Test
+       public void testTernaryAggregateRCSparseMatrixSP() {
+               runTernaryAggregateTest(TEST_NAME1, true, false, true, 
ExecType.SPARK);
+       }
+       
+       @Test
+       public void testTernaryAggregateRCDenseVectorMR() {
+               runTernaryAggregateTest(TEST_NAME1, false, true, true, 
ExecType.MR);
+       }
+       
+       @Test
+       public void testTernaryAggregateRCSparseVectorMR() {
+               runTernaryAggregateTest(TEST_NAME1, true, true, true, 
ExecType.MR);
+       }
+       
+       @Test
+       public void testTernaryAggregateRCDenseMatrixMR() {
+               runTernaryAggregateTest(TEST_NAME1, false, false, true, 
ExecType.MR);
+       }
+       
+       @Test
+       public void testTernaryAggregateRCSparseMatrixMR() {
+               runTernaryAggregateTest(TEST_NAME1, true, false, true, 
ExecType.MR);
+       }
+       
+       @Test
+       public void testTernaryAggregateCDenseVectorCP() {
+               runTernaryAggregateTest(TEST_NAME2, false, true, true, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testTernaryAggregateCSparseVectorCP() {
+               runTernaryAggregateTest(TEST_NAME2, true, true, true, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testTernaryAggregateCDenseMatrixCP() {
+               runTernaryAggregateTest(TEST_NAME2, false, false, true, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testTernaryAggregateCSparseMatrixCP() {
+               runTernaryAggregateTest(TEST_NAME2, true, false, true, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testTernaryAggregateCDenseVectorSP() {
+               runTernaryAggregateTest(TEST_NAME2, false, true, true, 
ExecType.SPARK);
+       }
+       
+       @Test
+       public void testTernaryAggregateCSparseVectorSP() {
+               runTernaryAggregateTest(TEST_NAME2, true, true, true, 
ExecType.SPARK);
+       }
+       
+       @Test
+       public void testTernaryAggregateCDenseMatrixSP() {
+               runTernaryAggregateTest(TEST_NAME2, false, false, true, 
ExecType.SPARK);
+       }
+       
+       @Test
+       public void testTernaryAggregateCSparseMatrixSP() {
+               runTernaryAggregateTest(TEST_NAME2, true, false, true, 
ExecType.SPARK);
+       }
+       
+       //additional tests to check default without rewrites
+       
+       @Test
+       public void testTernaryAggregateRCDenseVectorCPNoRewrite() {
+               runTernaryAggregateTest(TEST_NAME1, false, true, false, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testTernaryAggregateRCSparseVectorCPNoRewrite() {
+               runTernaryAggregateTest(TEST_NAME1, true, true, false, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testTernaryAggregateRCDenseMatrixCPNoRewrite() {
+               runTernaryAggregateTest(TEST_NAME1, false, false, false, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testTernaryAggregateRCSparseMatrixCPNoRewrite() {
+               runTernaryAggregateTest(TEST_NAME1, true, false, false, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testTernaryAggregateCDenseVectorCPNoRewrite() {
+               runTernaryAggregateTest(TEST_NAME2, false, true, false, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testTernaryAggregateCSparseVectorCPNoRewrite() {
+               runTernaryAggregateTest(TEST_NAME2, true, true, false, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testTernaryAggregateCDenseMatrixCPNoRewrite() {
+               runTernaryAggregateTest(TEST_NAME2, false, false, false, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testTernaryAggregateCSparseMatrixCPNoRewrite() {
+               runTernaryAggregateTest(TEST_NAME2, true, false, false, 
ExecType.CP);
+       }
+       
+       
+       
+       private void runTernaryAggregateTest(String testname, boolean sparse, 
boolean vectors, boolean rewrites, ExecType et)
+       {
+               //rtplatform for MR
+               RUNTIME_PLATFORM platformOld = rtplatform;
+               switch( et ){
+                       case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+                       case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+                       default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
+               }
+       
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               if( rtplatform == RUNTIME_PLATFORM.SPARK )
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+       
+               boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES,
+                               rewritesOldEmult = 
OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE;
+               
+               try {
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+                       
+                       OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
+                       OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewrites;
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new 
String[]{"-explain","hops","-stats","-args", input("A"), output("R")};
+                       
+                       fullRScriptName = HOME + testname + ".R";
+                       rCmd = "Rscript" + " " + fullRScriptName + " " + 
+                               inputDir() + " " + expectedDir();
+       
+                       //generate actual dataset
+                       double sparsity = sparse ? sparsity2 : sparsity1;
+                       double[][] A = getRandomMatrix(vectors ? rows*cols : 
rows, 
+                                       vectors ? 1 : cols, 0, 1, sparsity, 
17); 
+                       writeInputMatrixWithMTD("A", A, true);
+                       
+                       //run test cases
+                       runTest(true, false, null, -1); 
+                       runRScript(true); 
+                       
+                       //compare output matrices 
+                       HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS("R");
+                       HashMap<CellIndex, Double> rfile  = 
readRMatrixFromFS("R");
+                       TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
+                       
+                       //check for rewritten patterns in statistics output
+                       if( rewrites && et != ExecType.MR ) {
+                               String opcode = ((et == ExecType.SPARK) ? 
Instruction.SP_INST_PREFIX : "") + 
+                                       (((testname.equals(TEST_NAME1) || 
vectors ) ? "tak+*" : "tack+*"));
+                               
Assert.assertTrue(Statistics.getCPHeavyHitterOpCodes().contains(opcode));
+                       }
+               }
+               finally {
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+                       OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
+                       OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = 
rewritesOldEmult;
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/misc/RewriteEMultChainOp.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.R 
b/src/test/scripts/functions/misc/RewriteEMultChainOp.R
new file mode 100644
index 0000000..6d94cc8
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOp.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
+Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
+
+R = X * Y * X;
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.dml 
b/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
new file mode 100644
index 0000000..3992403
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($1);
+Y = read($2);
+
+R = X * Y * X;
+
+write(R, $3);
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/ternary/ABATernaryAggregateC.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateC.R 
b/src/test/scripts/functions/ternary/ABATernaryAggregateC.R
new file mode 100644
index 0000000..9601089
--- /dev/null
+++ b/src/test/scripts/functions/ternary/ABATernaryAggregateC.R
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B = A * 2;
+
+R = t(as.matrix(colSums(A * B * A)));
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml 
b/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml
new file mode 100644
index 0000000..78285af
--- /dev/null
+++ b/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+B = A * 2;
+C = A * 3;
+
+if(1==1){}
+
+R = colSums(A * B * A);
+
+write(R, $2);

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R 
b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R
new file mode 100644
index 0000000..6552c7e
--- /dev/null
+++ b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B = A * 2;
+
+s = sum(A * B * A);
+R = as.matrix(s);
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml 
b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml
new file mode 100644
index 0000000..965c8d3
--- /dev/null
+++ b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+B = A * 2;
+
+if(1==1){}
+
+s = sum(A * B * A);
+R = as.matrix(s);
+
+write(R, $2);
\ No newline at end of file

Reply via email to