Repository: incubator-systemml Updated Branches: refs/heads/master 16e7b1c88 -> ce84288f0
[SYSTEMML-766][SYSTEMML-774] Cleanup imports/tests and simplifications Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/ce84288f Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/ce84288f Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/ce84288f Branch: refs/heads/master Commit: ce84288f0b981a1c1b7e4264b27d124eff126ae6 Parents: 16e7b1c Author: Matthias Boehm <[email protected]> Authored: Sat Jul 16 18:06:36 2016 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jul 16 18:06:36 2016 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/lops/PlusMult.java | 7 +-- .../instructions/CPInstructionParser.java | 1 - .../instructions/SPInstructionParser.java | 3 -- .../instructions/cp/PlusMultCPInstruction.java | 10 +--- .../spark/PlusMultSPInstruction.java | 9 ---- .../sysml/runtime/matrix/data/LibMatrixAgg.java | 7 ++- .../runtime/matrix/data/LibMatrixDatagen.java | 8 +-- .../runtime/matrix/data/LibMatrixMult.java | 38 +++++--------- .../sysml/runtime/matrix/data/MatrixBlock.java | 29 +++++------ .../misc/RewriteFuseBinaryOpChainTest.java | 55 +++++++++----------- 10 files changed, 58 insertions(+), 109 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/main/java/org/apache/sysml/lops/PlusMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/PlusMult.java b/src/main/java/org/apache/sysml/lops/PlusMult.java index 2dc16e9..65e6440 100644 --- a/src/main/java/org/apache/sysml/lops/PlusMult.java +++ b/src/main/java/org/apache/sysml/lops/PlusMult.java @@ -25,7 +25,6 @@ import org.apache.sysml.lops.LopProperties.ExecType; import org.apache.sysml.lops.compile.JobType; import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.parser.Expression.ValueType; -import org.apache.sysml.parser.Expression.*; /** @@ -100,8 +99,4 @@ public class PlusMult extends Lop return sb.toString(); } - - - - -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java index c91ad8c..2df3615 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java @@ -64,7 +64,6 @@ import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysml.runtime.instructions.cp.CPInstruction.CPINSTRUCTION_TYPE; import org.apache.sysml.runtime.instructions.cpfile.MatrixIndexingCPFileInstruction; import org.apache.sysml.runtime.instructions.cpfile.ParameterizedBuiltinCPFileInstruction; -import org.apache.sysml.runtime.matrix.operators.BinaryOperator; public class CPInstructionParser extends InstructionParser { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java index a9a34f5..11e0ea0 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java @@ -34,9 +34,6 @@ import org.apache.sysml.lops.WeightedSquaredLossR; import org.apache.sysml.lops.WeightedUnaryMM; import org.apache.sysml.lops.WeightedUnaryMMR; import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.instructions.cp.ArithmeticBinaryCPInstruction; -import org.apache.sysml.runtime.instructions.cp.PlusMultCPInstruction; -import org.apache.sysml.runtime.instructions.cp.CPInstruction.CPINSTRUCTION_TYPE; import org.apache.sysml.runtime.instructions.spark.AggregateTernarySPInstruction; import org.apache.sysml.runtime.instructions.spark.AggregateUnarySPInstruction; import org.apache.sysml.runtime.instructions.spark.AppendGAlignedSPInstruction; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java index 8b01cb7..7cb75e5 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java @@ -1,19 +1,13 @@ package org.apache.sysml.runtime.instructions.cp; -import org.apache.sysml.parser.Expression.DataType; -import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.functionobjects.CM; import org.apache.sysml.runtime.functionobjects.MinusMultiply; import org.apache.sysml.runtime.functionobjects.PlusMultiply; import org.apache.sysml.runtime.functionobjects.ValueFunctionWithConstant; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.operators.BinaryOperator; -import org.apache.sysml.runtime.matrix.operators.CMOperator; -import org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes; public class PlusMultCPInstruction extends ArithmeticBinaryCPInstruction { public PlusMultCPInstruction(BinaryOperator op, CPOperand in1, CPOperand in2, @@ -41,8 +35,7 @@ public class PlusMultCPInstruction extends ArithmeticBinaryCPInstruction { @Override public void processInstruction( ExecutionContext ec ) throws DMLRuntimeException - { - + { String output_name = output.getName(); //get all the inputs @@ -50,7 +43,6 @@ public class PlusMultCPInstruction extends ArithmeticBinaryCPInstruction { MatrixBlock matrix2 = ec.getMatrixInput(input2.getName()); ScalarObject lambda = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()); - //execution ((ValueFunctionWithConstant) ((BinaryOperator)_optr).fn).setConstant(lambda.getDoubleValue()); MatrixBlock out = (MatrixBlock) matrix1.binaryOperations((BinaryOperator) _optr, matrix2, new MatrixBlock()); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java index 89de821..4b73679 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java @@ -19,8 +19,6 @@ package org.apache.sysml.runtime.instructions.spark; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.sysml.parser.Expression.DataType; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; @@ -29,15 +27,8 @@ import org.apache.sysml.runtime.functionobjects.PlusMultiply; import org.apache.sysml.runtime.functionobjects.ValueFunctionWithConstant; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.instructions.cp.CPOperand; -import org.apache.sysml.runtime.instructions.cp.PlusMultCPInstruction; import org.apache.sysml.runtime.instructions.cp.ScalarObject; -import org.apache.sysml.runtime.instructions.spark.functions.MatrixMatrixBinaryOpFunction; -import org.apache.sysml.runtime.instructions.spark.functions.ReplicateVectorFunction; -import org.apache.sysml.runtime.matrix.MatrixCharacteristics; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; -import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.matrix.operators.BinaryOperator; -import org.apache.sysml.runtime.matrix.operators.Operator; public class PlusMultSPInstruction extends ArithmeticBinarySPInstruction { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java index 1dc6a3c..cf689b8 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java @@ -543,9 +543,6 @@ public class LibMatrixAgg public static void groupedAggregate(MatrixBlock groups, MatrixBlock target, MatrixBlock weights, MatrixBlock result, int numGroups, Operator op, int k) throws DMLRuntimeException { - //preprocessing - result.sparse = false; // Do not need to check for isThreadSafe, because dense is assumed to be thread safe - //fall back to sequential version if necessary boolean rowVector = (target.getNumRows()==1 && target.getNumColumns()>1); if( k <= 1 || (long)target.rlen*target.clen < PAR_NUMCELL_THRESHOLD || rowVector || target.clen==1) { @@ -556,7 +553,9 @@ public class LibMatrixAgg if( !(op instanceof CMOperator || op instanceof AggregateOperator) ) { throw new DMLRuntimeException("Invalid operator (" + op + ") encountered while processing groupedAggregate."); } - + + //preprocessing (no need to check isThreadSafe) + result.sparse = false; result.allocateDenseBlock(); //core multi-threaded grouped aggregate computation http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDatagen.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDatagen.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDatagen.java index 19b2dff..c59b3e7 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDatagen.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDatagen.java @@ -392,18 +392,14 @@ public class LibMatrixDatagen boolean lsparse = MatrixBlock.evalSparseFormatInMemory( rows, cols, estnnz ); //fallback to sequential if single rowblock or too few cells or if MatrixBlock is not thread safe - if( k<=1 || (rows <= rpb && lsparse) || (long)rows*cols < PAR_NUMCELL_THRESHOLD) { + if( k<=1 || (rows <= rpb && lsparse) || (long)rows*cols < PAR_NUMCELL_THRESHOLD + || !MatrixBlock.isThreadSafe(lsparse) ) { generateRandomMatrix(out, rgen, nnzInBlocks, bigrand, bSeed); return; } out.reset(rows, cols, lsparse); - if (!out.isThreadSafe()) { - generateRandomMatrix(out, rgen, nnzInBlocks, bigrand, bSeed); - return; - } - //special case shortcuts for efficiency if ( rgen._pdf.equalsIgnoreCase(RAND_PDF_UNIFORM)) { if ( min == 0.0 && max == 0.0 ) { //all zeros http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java index 27438d9..57ce557 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java @@ -300,9 +300,6 @@ public class LibMatrixMult ret.examSparsity(); //turn empty dense into sparse return; } - - //pre-processing - ret.sparse = false; // MatrixBlock is assumed to be thread safe if dense //check too high additional memory requirements (fallback to sequential) //check too small workload in terms of flops (fallback to sequential too) @@ -315,7 +312,8 @@ public class LibMatrixMult //Timing time = new Timing(true); - //pre-processing + //pre-processing (no need to check isThreadSafe) + ret.sparse = false; ret.allocateDenseBlock(); //core matrix mult chain computation @@ -399,9 +397,6 @@ public class LibMatrixMult return; } - // pre-processing - ret.sparse = false; // MatrixBlock is assumed to be thread safe if dense - //check no parallelization benefit (fallback to sequential) //check too small workload in terms of flops (fallback to sequential too) if( ret.rlen == 1 @@ -414,8 +409,9 @@ public class LibMatrixMult //Timing time = new Timing(true); - //pre-processing + //pre-processing (no need to check isThreadSafe) m1 = prepMatrixMultTransposeSelfInput(m1, leftTranspose); + ret.sparse = false; ret.allocateDenseBlock(); //core multi-threaded matrix mult computation @@ -500,9 +496,6 @@ public class LibMatrixMult if( pm1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) return; - //pre-processing - ret1.sparse = false; // MatrixBlock is assumed to be thread safe if dense - //check no parallelization benefit (fallback to sequential) if (pm1.rlen == 1) { matrixMultPermute(pm1, m2, ret1, ret2); @@ -512,6 +505,7 @@ public class LibMatrixMult //Timing time = new Timing(true); //allocate first output block (second allocated if needed) + ret1.sparse = false; // no need to check isThreadSafe ret1.allocateDenseBlock(); try @@ -598,7 +592,8 @@ public class LibMatrixMult } //check no parallelization benefit (fallback to sequential) - if (mX.rlen == 1 || !ret.isThreadSafe()) { + //no need to check isThreadSafe (scalar output) + if( mX.rlen == 1 ) { matrixMultWSLoss(mX, mU, mV, mW, ret, wt); return; } @@ -684,11 +679,8 @@ public class LibMatrixMult return; } - //pre-processing - ret.sparse = mW.sparse; - //check no parallelization benefit (fallback to sequential) - if (mW.rlen == 1 || !ret.isThreadSafe()) { + if (mW.rlen == 1 || !MatrixBlock.isThreadSafe(mW.sparse)) { matrixMultWSigmoid(mW, mU, mV, ret, wt); return; } @@ -696,6 +688,7 @@ public class LibMatrixMult //Timing time = new Timing(true); //pre-processing + ret.sparse = mW.sparse; ret.allocateDenseOrSparseBlock(); try @@ -905,15 +898,10 @@ public class LibMatrixMult //Timing time = new Timing(true); - //pre-processing + //pre-processing (no need to check isThreadSafe) ret.sparse = false; ret.allocateDenseBlock(); - if (!ret.isThreadSafe()){ - matrixMultWCeMM(mW, mU, mV, eps, ret, wt); - return; - } - try { ExecutorService pool = Executors.newFixedThreadPool(k); @@ -993,11 +981,8 @@ public class LibMatrixMult return; } - //pre-processing - ret.sparse = mW.sparse; - //check no parallelization benefit (fallback to sequential) - if (mW.rlen == 1 || !ret.isThreadSafe()) { + if (mW.rlen == 1 || !MatrixBlock.isThreadSafe(mW.sparse)) { matrixMultWuMM(mW, mU, mV, ret, wt, fn); return; } @@ -1005,6 +990,7 @@ public class LibMatrixMult //Timing time = new Timing(true); //pre-processing + ret.sparse = mW.sparse; ret.allocateDenseOrSparseBlock(); try http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java index 720aed1..68cb43a 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java @@ -6134,26 +6134,25 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab } /** - * Whether concurrent modification operations are allowed - * This method is to be used by methods that attempt to do a task concurrently, - * like in {@link LibMatrixDatagen#generateRandomMatrix(MatrixBlock, RandomMatrixGenerator, long[], Well1024a, long, int)} + * Indicates if concurrent modifications of disjoint rows are thread-safe. + * * @return */ public boolean isThreadSafe() { - if (sparse){ - if (sparseBlock == null){ - // It is assumed that MCSR is the only safe sparse block implementation available. - return DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR; - } - else { - return sparseBlock.isThreadSafe(); - } - } - else { - return true; - } + return !sparse || (sparseBlock != null) ? sparseBlock.isThreadSafe() : + DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR; //only MCSR thread-safe } + /** + * Indicates if concurrent modifications of disjoint rows are thread-safe. + * + * @param sparse + * @return + */ + public static boolean isThreadSafe(boolean sparse) { + return !sparse || DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR; //only MCSR thread-safe + } + public void print() { System.out.println("sparse = "+sparse); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java index e010083..7fec6b0 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java @@ -27,6 +27,7 @@ import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.runtime.instructions.Instruction; import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; @@ -39,7 +40,6 @@ import org.apache.sysml.utils.Statistics; */ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase { - private static final String TEST_NAME1 = "RewriteFuseBinaryOpChainTest1"; private static final String TEST_NAME2 = "RewriteFuseBinaryOpChainTest2"; @@ -51,71 +51,59 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase private static final double eps = Math.pow(10, -10); @Override - public void setUp() - { + public void setUp() { TestUtils.clearAssertionInformation(); addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); } @Test - public void testFuseBinaryPlusNoRewrite() - { + public void testFuseBinaryPlusNoRewrite() { testFuseBinaryChain( TEST_NAME1, false, ExecType.CP ); } @Test - public void testFuseBinaryPlusRewrite() - { + public void testFuseBinaryPlusRewrite() { testFuseBinaryChain( TEST_NAME1, true, ExecType.CP); } + @Test - public void testFuseBinaryMinusNoRewrite() - { + public void testFuseBinaryMinusNoRewrite() { testFuseBinaryChain( TEST_NAME2, false, ExecType.CP ); } @Test - public void testFuseBinaryMinusRewrite() - { + public void testFuseBinaryMinusRewrite() { testFuseBinaryChain( TEST_NAME2, true, ExecType.CP ); } - - @Test - public void testSpFuseBinaryPlusNoRewrite() - { + public void testSpFuseBinaryPlusNoRewrite() { testFuseBinaryChain( TEST_NAME1, false, ExecType.SPARK ); } - @Test - public void testSpFuseBinaryPlusRewrite() - { + public void testSpFuseBinaryPlusRewrite() { testFuseBinaryChain( TEST_NAME1, true, ExecType.SPARK ); } - @Test - public void testSpFuseBinaryMinusNoRewrite() - { + public void testSpFuseBinaryMinusNoRewrite() { testFuseBinaryChain( TEST_NAME2, false, ExecType.SPARK ); } @Test - public void testSpFuseBinaryMinusRewrite() - { + public void testSpFuseBinaryMinusRewrite() { testFuseBinaryChain( TEST_NAME2, true, ExecType.SPARK ); } /** * - * @param condition - * @param branchRemoval - * @param IPA + * @param testname + * @param rewrites + * @param instType */ private void testFuseBinaryChain( String testname, boolean rewrites, ExecType instType ) { @@ -123,21 +111,21 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase switch( instType ){ case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; - default: rtplatform = RUNTIME_PLATFORM.HYBRID; break; + default: rtplatform = RUNTIME_PLATFORM.SINGLE_NODE; break; } boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; if( rtplatform == RUNTIME_PLATFORM.SPARK ) DMLScript.USE_LOCAL_SPARK_CONFIG = true; + boolean rewritesOld = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + try - { - + { TestConfiguration config = getTestConfiguration(testname); loadTestConfiguration(config); - String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + testname + ".dml"; programArgs = new String[]{"-explain", "-stats","-args", output("S") }; @@ -152,6 +140,13 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("S"); HashMap<CellIndex, Double> rfile = readRMatrixFromFS("S"); Assert.assertTrue(TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R")); + + //check for applies rewrites + if( rewrites ) { + String prefix = (instType==ExecType.SPARK) ? Instruction.SP_INST_PREFIX : ""; + Assert.assertTrue("Rewrite not applied.",Statistics.getCPHeavyHitterOpCodes() + .contains(testname.equals(TEST_NAME1) ? prefix+"+*" : prefix+"-*" )); + } } finally {
