Repository: incubator-systemml
Updated Branches:
  refs/heads/master 16e7b1c88 -> ce84288f0


[SYSTEMML-766][SYSTEMML-774] Cleanup imports/tests and simplifications

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/ce84288f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/ce84288f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/ce84288f

Branch: refs/heads/master
Commit: ce84288f0b981a1c1b7e4264b27d124eff126ae6
Parents: 16e7b1c
Author: Matthias Boehm <[email protected]>
Authored: Sat Jul 16 18:06:36 2016 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Jul 16 18:06:36 2016 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/lops/PlusMult.java    |  7 +--
 .../instructions/CPInstructionParser.java       |  1 -
 .../instructions/SPInstructionParser.java       |  3 --
 .../instructions/cp/PlusMultCPInstruction.java  | 10 +---
 .../spark/PlusMultSPInstruction.java            |  9 ----
 .../sysml/runtime/matrix/data/LibMatrixAgg.java |  7 ++-
 .../runtime/matrix/data/LibMatrixDatagen.java   |  8 +--
 .../runtime/matrix/data/LibMatrixMult.java      | 38 +++++---------
 .../sysml/runtime/matrix/data/MatrixBlock.java  | 29 +++++------
 .../misc/RewriteFuseBinaryOpChainTest.java      | 55 +++++++++-----------
 10 files changed, 58 insertions(+), 109 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/main/java/org/apache/sysml/lops/PlusMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/PlusMult.java 
b/src/main/java/org/apache/sysml/lops/PlusMult.java
index 2dc16e9..65e6440 100644
--- a/src/main/java/org/apache/sysml/lops/PlusMult.java
+++ b/src/main/java/org/apache/sysml/lops/PlusMult.java
@@ -25,7 +25,6 @@ import org.apache.sysml.lops.LopProperties.ExecType;
 import org.apache.sysml.lops.compile.JobType;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.parser.Expression.ValueType;
-import org.apache.sysml.parser.Expression.*;
 
 
 /**
@@ -100,8 +99,4 @@ public class PlusMult extends Lop
                
                return sb.toString();
        }
-       
-       
-       
-       
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
index c91ad8c..2df3615 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
@@ -64,7 +64,6 @@ import 
org.apache.sysml.runtime.instructions.cp.VariableCPInstruction;
 import 
org.apache.sysml.runtime.instructions.cp.CPInstruction.CPINSTRUCTION_TYPE;
 import 
org.apache.sysml.runtime.instructions.cpfile.MatrixIndexingCPFileInstruction;
 import 
org.apache.sysml.runtime.instructions.cpfile.ParameterizedBuiltinCPFileInstruction;
-import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
 
 public class CPInstructionParser extends InstructionParser 
 {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
index a9a34f5..11e0ea0 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
@@ -34,9 +34,6 @@ import org.apache.sysml.lops.WeightedSquaredLossR;
 import org.apache.sysml.lops.WeightedUnaryMM;
 import org.apache.sysml.lops.WeightedUnaryMMR;
 import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.instructions.cp.ArithmeticBinaryCPInstruction;
-import org.apache.sysml.runtime.instructions.cp.PlusMultCPInstruction;
-import 
org.apache.sysml.runtime.instructions.cp.CPInstruction.CPINSTRUCTION_TYPE;
 import 
org.apache.sysml.runtime.instructions.spark.AggregateTernarySPInstruction;
 import org.apache.sysml.runtime.instructions.spark.AggregateUnarySPInstruction;
 import org.apache.sysml.runtime.instructions.spark.AppendGAlignedSPInstruction;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java
index 8b01cb7..7cb75e5 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java
@@ -1,19 +1,13 @@
 package org.apache.sysml.runtime.instructions.cp;
 
-import org.apache.sysml.parser.Expression.DataType;
-import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysml.runtime.functionobjects.CM;
 import org.apache.sysml.runtime.functionobjects.MinusMultiply;
 import org.apache.sysml.runtime.functionobjects.PlusMultiply;
 import org.apache.sysml.runtime.functionobjects.ValueFunctionWithConstant;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
-import org.apache.sysml.runtime.matrix.operators.CMOperator;
-import 
org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
 
 public class PlusMultCPInstruction extends ArithmeticBinaryCPInstruction {
        public PlusMultCPInstruction(BinaryOperator op, CPOperand in1, 
CPOperand in2, 
@@ -41,8 +35,7 @@ public class PlusMultCPInstruction extends 
ArithmeticBinaryCPInstruction {
        @Override
        public void processInstruction( ExecutionContext ec )
                throws DMLRuntimeException
-       {
-               
+       {               
                String output_name = output.getName();
 
                //get all the inputs
@@ -50,7 +43,6 @@ public class PlusMultCPInstruction extends 
ArithmeticBinaryCPInstruction {
                MatrixBlock matrix2 = ec.getMatrixInput(input2.getName());
                ScalarObject lambda = ec.getScalarInput(input3.getName(), 
input3.getValueType(), input3.isLiteral()); 
                
-               
                //execution
                ((ValueFunctionWithConstant) 
((BinaryOperator)_optr).fn).setConstant(lambda.getDoubleValue());
                MatrixBlock out = (MatrixBlock) 
matrix1.binaryOperations((BinaryOperator) _optr, matrix2, new MatrixBlock());

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java
index 89de821..4b73679 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java
@@ -19,8 +19,6 @@
 
 package org.apache.sysml.runtime.instructions.spark;
 
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
@@ -29,15 +27,8 @@ import org.apache.sysml.runtime.functionobjects.PlusMultiply;
 import org.apache.sysml.runtime.functionobjects.ValueFunctionWithConstant;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
-import org.apache.sysml.runtime.instructions.cp.PlusMultCPInstruction;
 import org.apache.sysml.runtime.instructions.cp.ScalarObject;
-import 
org.apache.sysml.runtime.instructions.spark.functions.MatrixMatrixBinaryOpFunction;
-import 
org.apache.sysml.runtime.instructions.spark.functions.ReplicateVectorFunction;
-import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
-import org.apache.sysml.runtime.matrix.operators.Operator;
 
 public class PlusMultSPInstruction extends  ArithmeticBinarySPInstruction
 {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
index 1dc6a3c..cf689b8 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
@@ -543,9 +543,6 @@ public class LibMatrixAgg
        public static void groupedAggregate(MatrixBlock groups, MatrixBlock 
target, MatrixBlock weights, MatrixBlock result, int numGroups, Operator op, 
int k) 
                throws DMLRuntimeException
        {
-               //preprocessing
-               result.sparse = false;  // Do not need to check for 
isThreadSafe, because dense is assumed to be thread safe
-               
                //fall back to sequential version if necessary
                boolean rowVector = (target.getNumRows()==1 && 
target.getNumColumns()>1);
                if( k <= 1 || (long)target.rlen*target.clen < 
PAR_NUMCELL_THRESHOLD || rowVector || target.clen==1) {
@@ -556,7 +553,9 @@ public class LibMatrixAgg
                if( !(op instanceof CMOperator || op instanceof 
AggregateOperator) ) {
                        throw new DMLRuntimeException("Invalid operator (" + op 
+ ") encountered while processing groupedAggregate.");
                }
-               
+       
+               //preprocessing (no need to check isThreadSafe)
+               result.sparse = false;
                result.allocateDenseBlock();
                
                //core multi-threaded grouped aggregate computation

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDatagen.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDatagen.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDatagen.java
index 19b2dff..c59b3e7 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDatagen.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDatagen.java
@@ -392,18 +392,14 @@ public class LibMatrixDatagen
                boolean lsparse = MatrixBlock.evalSparseFormatInMemory( rows, 
cols, estnnz );
                
                //fallback to sequential if single rowblock or too few cells or 
if MatrixBlock is not thread safe
-               if( k<=1 || (rows <= rpb && lsparse) || (long)rows*cols < 
PAR_NUMCELL_THRESHOLD) {
+               if( k<=1 || (rows <= rpb && lsparse) || (long)rows*cols < 
PAR_NUMCELL_THRESHOLD 
+                       || !MatrixBlock.isThreadSafe(lsparse) ) {
                        generateRandomMatrix(out, rgen, nnzInBlocks, bigrand, 
bSeed);
                        return;
                }
 
                out.reset(rows, cols, lsparse);
                
-               if (!out.isThreadSafe()) {
-                       generateRandomMatrix(out, rgen, nnzInBlocks, bigrand, 
bSeed);
-                       return;
-               }
-               
                //special case shortcuts for efficiency
                if ( rgen._pdf.equalsIgnoreCase(RAND_PDF_UNIFORM)) {
                        if ( min == 0.0 && max == 0.0 ) { //all zeros

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
index 27438d9..57ce557 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixMult.java
@@ -300,9 +300,6 @@ public class LibMatrixMult
                        ret.examSparsity(); //turn empty dense into sparse
                        return;
                }
-
-               //pre-processing
-               ret.sparse = false; // MatrixBlock is assumed to be thread safe 
if dense
                
                //check too high additional memory requirements (fallback to 
sequential)
                //check too small workload in terms of flops (fallback to 
sequential too)
@@ -315,7 +312,8 @@ public class LibMatrixMult
                
                //Timing time = new Timing(true);
                                
-               //pre-processing
+               //pre-processing (no need to check isThreadSafe)
+               ret.sparse = false;
                ret.allocateDenseBlock();
                
                //core matrix mult chain computation
@@ -399,9 +397,6 @@ public class LibMatrixMult
                        return;
                }
                
-               // pre-processing
-               ret.sparse = false;     // MatrixBlock is assumed to be thread 
safe if dense
-               
                //check no parallelization benefit (fallback to sequential)
                //check too small workload in terms of flops (fallback to 
sequential too)
                if( ret.rlen == 1 
@@ -414,8 +409,9 @@ public class LibMatrixMult
                
                //Timing time = new Timing(true);
                
-               //pre-processing
+               //pre-processing (no need to check isThreadSafe)
                m1 = prepMatrixMultTransposeSelfInput(m1, leftTranspose);
+               ret.sparse = false;     
                ret.allocateDenseBlock();
        
                //core multi-threaded matrix mult computation
@@ -500,9 +496,6 @@ public class LibMatrixMult
                if( pm1.isEmptyBlock(false) || m2.isEmptyBlock(false) )
                        return;
 
-               //pre-processing
-               ret1.sparse = false;    // MatrixBlock is assumed to be thread 
safe if dense
-
                //check no parallelization benefit (fallback to sequential)
                if (pm1.rlen == 1) {
                        matrixMultPermute(pm1, m2, ret1, ret2);
@@ -512,6 +505,7 @@ public class LibMatrixMult
                //Timing time = new Timing(true);
                
                //allocate first output block (second allocated if needed)
+               ret1.sparse = false;      // no need to check isThreadSafe
                ret1.allocateDenseBlock();
                
                try
@@ -598,7 +592,8 @@ public class LibMatrixMult
                }
                
                //check no parallelization benefit (fallback to sequential)
-               if (mX.rlen == 1 || !ret.isThreadSafe()) {
+               //no need to check isThreadSafe (scalar output)
+               if( mX.rlen == 1 ) {
                        matrixMultWSLoss(mX, mU, mV, mW, ret, wt);
                        return;
                }
@@ -684,11 +679,8 @@ public class LibMatrixMult
                        return; 
                }
 
-               //pre-processing
-               ret.sparse = mW.sparse;
-               
                //check no parallelization benefit (fallback to sequential)
-               if (mW.rlen == 1 || !ret.isThreadSafe()) {
+               if (mW.rlen == 1 || !MatrixBlock.isThreadSafe(mW.sparse)) {
                        matrixMultWSigmoid(mW, mU, mV, ret, wt);
                        return;
                }
@@ -696,6 +688,7 @@ public class LibMatrixMult
                //Timing time = new Timing(true);
 
                //pre-processing
+               ret.sparse = mW.sparse;
                ret.allocateDenseOrSparseBlock();
                
                try 
@@ -905,15 +898,10 @@ public class LibMatrixMult
 
                //Timing time = new Timing(true);
 
-               //pre-processing
+               //pre-processing (no need to check isThreadSafe)
                ret.sparse = false;
                ret.allocateDenseBlock();
                
-               if (!ret.isThreadSafe()){
-                       matrixMultWCeMM(mW, mU, mV, eps, ret, wt);
-                       return;
-               }
-               
                try 
                {                       
                        ExecutorService pool = Executors.newFixedThreadPool(k);
@@ -993,11 +981,8 @@ public class LibMatrixMult
                        return; 
                }
                
-               //pre-processing
-               ret.sparse = mW.sparse;
-               
                //check no parallelization benefit (fallback to sequential)
-               if (mW.rlen == 1 || !ret.isThreadSafe()) {
+               if (mW.rlen == 1 || !MatrixBlock.isThreadSafe(mW.sparse)) {
                        matrixMultWuMM(mW, mU, mV, ret, wt, fn);
                        return;
                }
@@ -1005,6 +990,7 @@ public class LibMatrixMult
                //Timing time = new Timing(true);
 
                //pre-processing
+               ret.sparse = mW.sparse;
                ret.allocateDenseOrSparseBlock();
                
                try 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
index 720aed1..68cb43a 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
@@ -6134,26 +6134,25 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
        }
        
        /**
-        * Whether concurrent modification operations are allowed
-        * This method is to be used by methods that attempt to do a task 
concurrently,
-        * like in {@link LibMatrixDatagen#generateRandomMatrix(MatrixBlock, 
RandomMatrixGenerator, long[], Well1024a, long, int)}
+        * Indicates if concurrent modifications of disjoint rows are 
thread-safe.
+        * 
         * @return
         */
        public boolean isThreadSafe() {
-               if (sparse){
-                       if (sparseBlock == null){
-                               // It is assumed that MCSR is the only safe 
sparse block implementation available.
-                               return DEFAULT_SPARSEBLOCK == 
SparseBlock.Type.MCSR;
-                       } 
-                       else {
-                               return sparseBlock.isThreadSafe();      
-                       }
-               } 
-               else {
-                       return true;
-               }
+               return !sparse || (sparseBlock != null) ? 
sparseBlock.isThreadSafe() : 
+                       DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR; //only 
MCSR thread-safe
        }
        
+       /**
+        * Indicates if concurrent modifications of disjoint rows are 
thread-safe.
+        * 
+        * @param sparse
+        * @return
+        */
+       public static boolean isThreadSafe(boolean sparse) {
+               return !sparse || DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR; 
//only MCSR thread-safe 
+       } 
+       
        public void print()
        {
                System.out.println("sparse = "+sparse);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ce84288f/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
index e010083..7fec6b0 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
@@ -27,6 +27,7 @@ import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.instructions.Instruction;
 import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
@@ -39,7 +40,6 @@ import org.apache.sysml.utils.Statistics;
  */
 public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase 
 {
-       
        private static final String TEST_NAME1 = 
"RewriteFuseBinaryOpChainTest1";
        private static final String TEST_NAME2 = 
"RewriteFuseBinaryOpChainTest2";
 
@@ -51,71 +51,59 @@ public class RewriteFuseBinaryOpChainTest extends 
AutomatedTestBase
        private static final double eps = Math.pow(10, -10);
        
        @Override
-       public void setUp() 
-       {
+       public void setUp() {
                TestUtils.clearAssertionInformation();
                addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
                addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
        }
        
        @Test
-       public void testFuseBinaryPlusNoRewrite() 
-       {
+       public void testFuseBinaryPlusNoRewrite() {
                testFuseBinaryChain( TEST_NAME1, false, ExecType.CP );
        }
        
        
        @Test
-       public void testFuseBinaryPlusRewrite() 
-       {
+       public void testFuseBinaryPlusRewrite() {
                testFuseBinaryChain( TEST_NAME1, true, ExecType.CP);
        }
+       
        @Test
-       public void testFuseBinaryMinusNoRewrite() 
-       {
+       public void testFuseBinaryMinusNoRewrite() {
                testFuseBinaryChain( TEST_NAME2, false, ExecType.CP );
        }
        
        @Test
-       public void testFuseBinaryMinusRewrite() 
-       {
+       public void testFuseBinaryMinusRewrite() {
                testFuseBinaryChain( TEST_NAME2, true, ExecType.CP );
        }
        
-       
-       
        @Test
-       public void testSpFuseBinaryPlusNoRewrite() 
-       {
+       public void testSpFuseBinaryPlusNoRewrite() {
                testFuseBinaryChain( TEST_NAME1, false, ExecType.SPARK );
        }
        
-       
        @Test
-       public void testSpFuseBinaryPlusRewrite() 
-       {
+       public void testSpFuseBinaryPlusRewrite() {
                testFuseBinaryChain( TEST_NAME1, true, ExecType.SPARK );
        }
        
-       
        @Test
-       public void testSpFuseBinaryMinusNoRewrite() 
-       {
+       public void testSpFuseBinaryMinusNoRewrite() {
                testFuseBinaryChain( TEST_NAME2, false, ExecType.SPARK  );
        }
        
        @Test
-       public void testSpFuseBinaryMinusRewrite() 
-       {
+       public void testSpFuseBinaryMinusRewrite() {
                testFuseBinaryChain( TEST_NAME2, true, ExecType.SPARK  );
        }
        
        
        /**
         * 
-        * @param condition
-        * @param branchRemoval
-        * @param IPA
+        * @param testname
+        * @param rewrites
+        * @param instType
         */
        private void testFuseBinaryChain( String testname, boolean rewrites, 
ExecType instType )
        {       
@@ -123,21 +111,21 @@ public class RewriteFuseBinaryOpChainTest extends 
AutomatedTestBase
                switch( instType ){
                        case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
                        case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
-                       default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
+                       default: rtplatform = RUNTIME_PLATFORM.SINGLE_NODE; 
break;
                }
                
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                if( rtplatform == RUNTIME_PLATFORM.SPARK )
                        DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               
                boolean rewritesOld = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
                OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
+               
                try
-               {
-                       
+               {       
                        TestConfiguration config = 
getTestConfiguration(testname);
                        loadTestConfiguration(config);
                        
-                       
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + testname + ".dml";
                        programArgs = new String[]{"-explain", 
"-stats","-args", output("S") };
@@ -152,6 +140,13 @@ public class RewriteFuseBinaryOpChainTest extends 
AutomatedTestBase
                        HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS("S");
                        HashMap<CellIndex, Double> rfile  = 
readRMatrixFromFS("S");
                        Assert.assertTrue(TestUtils.compareMatrices(dmlfile, 
rfile, eps, "Stat-DML", "Stat-R"));
+                       
+                       //check for applies rewrites
+                       if( rewrites ) {
+                               String prefix = (instType==ExecType.SPARK) ? 
Instruction.SP_INST_PREFIX  : "";
+                               Assert.assertTrue("Rewrite not 
applied.",Statistics.getCPHeavyHitterOpCodes()
+                                               
.contains(testname.equals(TEST_NAME1) ? prefix+"+*" : prefix+"-*" ));
+                       }
                }
                finally
                {

Reply via email to