Review comments 3
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/04f692df Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/04f692df Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/04f692df Branch: refs/heads/master Commit: 04f692dfcb25a032044dabb7064241073f959300 Parents: de469d2 Author: Dylan Hutchison <[email protected]> Authored: Sun Jun 18 16:54:51 2017 -0700 Committer: Dylan Hutchison <[email protected]> Committed: Sun Jun 18 17:43:54 2017 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/AggUnaryOp.java | 4 +- .../sysml/hops/rewrite/ProgramRewriter.java | 2 +- ...ementwiseMultChainOptimizationChainTest.java | 127 ------------------- ...iteElementwiseMultChainOptimizationTest.java | 127 +++++++++++++++++++ .../functions/misc/ZPackageSuite.java | 2 +- 5 files changed, 132 insertions(+), 130 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/main/java/org/apache/sysml/hops/AggUnaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java index a207831..8e681c1 100644 --- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java +++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java @@ -647,7 +647,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop handled = true; } else if (input11 instanceof BinaryOp ) { BinaryOp b11 = (BinaryOp)input11; - switch (b11.getOp()) { + switch( b11.getOp() ) { case MULT: // A*B*C case in1 = input11.getInput().get(0).constructLops(); in2 = input11.getInput().get(1).constructLops(); @@ -664,6 +664,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop handled = true; } break; + default: break; } } else if( input12 instanceof BinaryOp ) { BinaryOp b12 = (BinaryOp)input12; @@ -683,6 +684,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop handled = true; } break; + default: break; } } http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java index 1053850..7ee3ccb 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java @@ -97,7 +97,7 @@ public class ProgramRewriter if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) _dagRuleSet.add( new RewriteCommonSubexpressionElimination() ); if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) - _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse + _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse if( OptimizerUtils.ALLOW_CONSTANT_FOLDING ) _dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java deleted file mode 100644 index e490750..0000000 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.test.integration.functions.misc; - -import java.util.HashMap; - -import org.apache.sysml.api.DMLScript; -import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; -import org.apache.sysml.hops.OptimizerUtils; -import org.apache.sysml.lops.LopProperties.ExecType; -import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; -import org.apache.sysml.test.integration.AutomatedTestBase; -import org.apache.sysml.test.integration.TestConfiguration; -import org.apache.sysml.test.utils.TestUtils; -import org.junit.Assert; -import org.junit.Test; - -/** - * Test whether `2*X*3*Y*4*X` successfully rewrites to `Y*(X^2)*24`. - */ -public class RewriteElementwiseMultChainOptimizationChainTest extends AutomatedTestBase -{ - private static final String TEST_NAME1 = "RewriteEMultChainOpXYX"; - private static final String TEST_DIR = "functions/misc/"; - private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationChainTest.class.getSimpleName() + "/"; - - private static final int rows = 123; - private static final int cols = 321; - private static final double eps = Math.pow(10, -10); - - @Override - public void setUp() { - TestUtils.clearAssertionInformation(); - addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); - } - - @Test - public void testMatrixMultChainOptNoRewritesCP() { - testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP); - } - - @Test - public void testMatrixMultChainOptNoRewritesSP() { - testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK); - } - - @Test - public void testMatrixMultChainOptRewritesCP() { - testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP); - } - - @Test - public void testMatrixMultChainOptRewritesSP() { - testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK); - } - - private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et) - { - RUNTIME_PLATFORM platformOld = rtplatform; - switch( et ){ - case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; - case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; - default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break; - } - - boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; - if( rtplatform == RUNTIME_PLATFORM.SPARK ) - DMLScript.USE_LOCAL_SPARK_CONFIG = true; - - boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES; - OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites; - - try - { - TestConfiguration config = getTestConfiguration(testname); - loadTestConfiguration(config); - - String HOME = SCRIPT_DIR + TEST_DIR; - fullDMLScriptName = HOME + testname + ".dml"; - programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") }; - fullRScriptName = HOME + testname + ".R"; - rCmd = getRCmd(inputDir(), expectedDir()); - - double Xsparsity = 0.8, Ysparsity = 0.6; - double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7); - double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3); - writeInputMatrixWithMTD("X", X, true); - writeInputMatrixWithMTD("Y", Y, true); - - //execute tests - runTest(true, false, null, -1); - runRScript(true); - - //compare matrices - HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R"); - HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R"); - TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); - - //check for presence of power operator, if we did a rewrite - if( rewrites ) { - Assert.assertTrue(heavyHittersContainsSubString("^2")); - } - } - finally { - OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld; - rtplatform = platformOld; - DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; - } - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java new file mode 100644 index 0000000..91cb4e0 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import java.util.HashMap; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +/** + * Test whether `2*X*3*Y*4*X` successfully rewrites to `Y*(X^2)*24`. + */ +public class RewriteElementwiseMultChainOptimizationTest extends AutomatedTestBase +{ + private static final String TEST_NAME1 = "RewriteEMultChainOpXYX"; + private static final String TEST_DIR = "functions/misc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationTest.class.getSimpleName() + "/"; + + private static final int rows = 123; + private static final int cols = 321; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + } + + @Test + public void testMatrixMultChainOptNoRewritesCP() { + testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP); + } + + @Test + public void testMatrixMultChainOptNoRewritesSP() { + testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK); + } + + @Test + public void testMatrixMultChainOptRewritesCP() { + testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP); + } + + @Test + public void testMatrixMultChainOptRewritesSP() { + testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK); + } + + private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et) + { + RUNTIME_PLATFORM platformOld = rtplatform; + switch( et ){ + case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; + case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break; + } + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if( rtplatform == RUNTIME_PLATFORM.SPARK ) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES; + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites; + + try + { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") }; + fullRScriptName = HOME + testname + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + double Xsparsity = 0.8, Ysparsity = 0.6; + double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7); + double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3); + writeInputMatrixWithMTD("X", X, true); + writeInputMatrixWithMTD("Y", Y, true); + + //execute tests + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R"); + HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R"); + TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R"); + + //check for presence of power operator, if we did a rewrite + if( rewrites ) { + Assert.assertTrue(heavyHittersContainsSubString("^2")); + } + } + finally { + OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld; + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java index deea784..860cdbe 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java @@ -50,7 +50,7 @@ import org.junit.runners.Suite; ReadAfterWriteTest.class, RewriteCSETransposeScalarTest.class, RewriteCTableToRExpandTest.class, - RewriteElementwiseMultChainOptimizationChainTest.class, + RewriteElementwiseMultChainOptimizationTest.class, RewriteEliminateAggregatesTest.class, RewriteFuseBinaryOpChainTest.class, RewriteFusedRandTest.class,
