Repository: systemml
Updated Branches:
  refs/heads/master 0ae2b4f77 -> ec38b3790


[SYSTEMML-1777] MLContextTestBase class for MLContext testing

Create abstract MLContextTestBase class that contains setup and shutdown
code for MLContext tests. This removes boilerplate code from MLContext
test classes that extend MLContextTestBase.

Closes #580.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ec38b379
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ec38b379
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ec38b379

Branch: refs/heads/master
Commit: ec38b3790f11792d3337c35439954d422d6eb60b
Parents: 0ae2b4f
Author: Deron Eriksson <de...@apache.org>
Authored: Wed Jul 19 11:13:51 2017 -0700
Committer: Deron Eriksson <de...@apache.org>
Committed: Wed Jul 19 11:13:51 2017 -0700

----------------------------------------------------------------------
 .../mlcontext/DataFrameVectorScriptTest.java    |  29 +---
 .../functions/mlcontext/FrameTest.java          |  40 +-----
 .../functions/mlcontext/GNMFTest.java           |  40 +-----
 .../mlcontext/MLContextFrameTest.java           |  41 +-----
 .../mlcontext/MLContextMultipleScriptsTest.java |   4 -
 .../mlcontext/MLContextOutputBlocksizeTest.java |  51 +------
 .../mlcontext/MLContextParforDatasetTest.java   |  52 +------
 .../mlcontext/MLContextScratchCleanupTest.java  |   4 -
 .../integration/mlcontext/MLContextTest.java    | 143 +++----------------
 .../mlcontext/MLContextTestBase.java            |  89 ++++++++++++
 .../test/integration/scripts/nn/NNTest.java     |  46 +-----
 11 files changed, 123 insertions(+), 416 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java
index 65aee8e..55b8371 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java
@@ -38,7 +38,6 @@ import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 import org.apache.sysml.api.mlcontext.FrameFormat;
 import org.apache.sysml.api.mlcontext.FrameMetadata;
-import org.apache.sysml.api.mlcontext.MLContext;
 import org.apache.sysml.api.mlcontext.Matrix;
 import org.apache.sysml.api.mlcontext.Script;
 import org.apache.sysml.conf.ConfigurationManager;
@@ -49,15 +48,13 @@ import 
org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.util.DataConverter;
 import org.apache.sysml.runtime.util.UtilFunctions;
-import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.integration.mlcontext.MLContextTestBase;
 import org.apache.sysml.test.utils.TestUtils;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
 import org.junit.Test;
 
 
-public class DataFrameVectorScriptTest extends AutomatedTestBase
+public class DataFrameVectorScriptTest extends MLContextTestBase
 {
        private final static String TEST_DIR = "functions/mlcontext/";
        private final static String TEST_NAME = "DataFrameConversion";
@@ -75,16 +72,6 @@ public class DataFrameVectorScriptTest extends 
AutomatedTestBase
        private final static double sparsity2 = 0.1;
        private final static double eps=0.0000000001;
 
-       private static SparkSession spark;
-       private static MLContext ml;
-
-       @BeforeClass
-       public static void setUpClass() {
-               spark = createSystemMLSparkSession("DataFrameVectorScriptTest", 
"local");
-               ml = new MLContext(spark);
-               ml.setExplain(true);
-       }
-
        @Override
        public void setUp() {
                addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"A", "B"}));
@@ -343,16 +330,4 @@ public class DataFrameVectorScriptTest extends 
AutomatedTestBase
                JavaRDD<Row> rowRDD = sc.parallelize(list);
                return sparkSession.createDataFrame(rowRDD, dfSchema);
        }
-
-       @AfterClass
-       public static void tearDownClass() {
-               // stop underlying spark context to allow single jvm tests 
(otherwise the
-               // next test that tries to create a SparkContext would fail)
-               spark.stop();
-               spark = null;
-
-               // clear status mlcontext and spark exec context
-               ml.close();
-               ml = null;
-       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java
index c93968c..382f433 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java
@@ -29,10 +29,8 @@ import java.util.List;
 import org.apache.hadoop.io.LongWritable;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.StructType;
 import org.apache.sysml.api.DMLException;
 import org.apache.sysml.api.DMLScript;
@@ -40,8 +38,6 @@ import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.api.mlcontext.FrameFormat;
 import org.apache.sysml.api.mlcontext.FrameMetadata;
 import org.apache.sysml.api.mlcontext.FrameSchema;
-import org.apache.sysml.api.mlcontext.MLContext;
-import org.apache.sysml.api.mlcontext.MLContextUtil;
 import org.apache.sysml.api.mlcontext.MLResults;
 import org.apache.sysml.api.mlcontext.Script;
 import org.apache.sysml.api.mlcontext.ScriptFactory;
@@ -57,17 +53,14 @@ import org.apache.sysml.runtime.matrix.data.InputInfo;
 import org.apache.sysml.runtime.matrix.data.OutputInfo;
 import org.apache.sysml.runtime.util.MapReduceTool;
 import org.apache.sysml.runtime.util.UtilFunctions;
-import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.integration.mlcontext.MLContextTestBase;
 import org.apache.sysml.test.utils.TestUtils;
-import org.junit.After;
-import org.junit.AfterClass;
 import org.junit.Assert;
-import org.junit.BeforeClass;
 import org.junit.Test;
 
 
-public class FrameTest extends AutomatedTestBase 
+public class FrameTest extends MLContextTestBase
 {
        private final static String TEST_DIR = "functions/frame/";
        private final static String TEST_NAME = "FrameGeneral";
@@ -98,17 +91,6 @@ public class FrameTest extends AutomatedTestBase
                schemaMixedLarge = (ValueType[]) 
schemaMixedLargeList.toArray(schemaMixedLarge);
        }
 
-       private static SparkSession spark;
-       private static JavaSparkContext sc;
-       private static MLContext ml;
-
-       @BeforeClass
-       public static void setUpClass() {
-               spark = createSystemMLSparkSession("FrameTest", "local");
-               ml = new MLContext(spark);
-               sc = MLContextUtil.getJavaSparkContext(ml);
-       }
-
        @Override
        public void setUp() {
                addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, 
@@ -373,22 +355,4 @@ public class FrameTest extends AutomatedTestBase
                                                        ", not same as the R 
value " + val2);
                        }
        }
-
-       @After
-       public void tearDown() {
-               super.tearDown();
-       }
-
-       @AfterClass
-       public static void tearDownClass() {
-               // stop underlying spark context to allow single jvm tests 
(otherwise the
-               // next test that tries to create a SparkContext would fail)
-               spark.stop();
-               sc = null;
-               spark = null;
-
-               // clear status mlcontext and spark exec context
-               ml.close();
-               ml = null;
-       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
index 76deec5..44f1f15 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
@@ -28,17 +28,13 @@ import java.util.List;
 
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
 import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
 import org.apache.spark.rdd.RDD;
-import org.apache.spark.sql.SparkSession;
 import org.apache.sysml.api.DMLException;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
-import org.apache.sysml.api.mlcontext.MLContext;
-import org.apache.sysml.api.mlcontext.MLContextUtil;
 import org.apache.sysml.api.mlcontext.MLResults;
 import org.apache.sysml.api.mlcontext.Matrix;
 import org.apache.sysml.api.mlcontext.MatrixFormat;
@@ -55,19 +51,16 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysml.runtime.util.MapReduceTool;
-import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.mlcontext.MLContextTestBase;
 import org.apache.sysml.test.utils.TestUtils;
-import org.junit.After;
-import org.junit.AfterClass;
 import org.junit.Assert;
-import org.junit.BeforeClass;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 import org.junit.runners.Parameterized.Parameters;
 
 @RunWith(value = Parameterized.class)
-public class GNMFTest extends AutomatedTestBase 
+public class GNMFTest extends MLContextTestBase
 {
        private final static String TEST_DIR = "applications/gnmf/";
        private final static String TEST_NAME = "GNMF";
@@ -76,22 +69,11 @@ public class GNMFTest extends AutomatedTestBase
        int numRegisteredInputs;
        int numRegisteredOutputs;
 
-       private static SparkSession spark;
-       private static JavaSparkContext sc;
-       private static MLContext ml;
-
        public GNMFTest(int in, int out) {
                numRegisteredInputs = in;
                numRegisteredOutputs = out;
        }
 
-       @BeforeClass
-       public static void setUpClass() {
-               spark = createSystemMLSparkSession("GNMFTest", "local");
-               ml = new MLContext(spark);
-               sc = MLContextUtil.getJavaSparkContext(ml);
-       }
-
        @Parameters
         public static Collection<Object[]> data() {
           Object[][] data = new Object[][] { { 0, 0 }, { 3, 2 }, { 2, 2 }, { 
2, 1 }, { 2, 0 }, { 3, 0 }};
@@ -256,25 +238,7 @@ public class GNMFTest extends AutomatedTestBase
                        DMLScript.USE_LOCAL_SPARK_CONFIG = oldConfig;
                }
        }
-       
-       @After
-       public void tearDown() {
-               super.tearDown();
-       }
-
-       @AfterClass
-       public static void tearDownClass() {
-               // stop underlying spark context to allow single jvm tests 
(otherwise the
-               // next test that tries to create a SparkContext would fail)
-               spark.stop();
-               sc = null;
-               spark = null;
 
-               // clear status mlcontext and spark exec context
-               ml.close();
-               ml = null;
-       }
-       
        public static class StringToMatrixEntry implements Function<String, 
MatrixEntry> {
 
                private static final long serialVersionUID = 
7456391906436606324L;

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
index bab719e..a7d12a5 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
@@ -28,21 +28,17 @@ import java.util.Arrays;
 import java.util.List;
 
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.rdd.RDD;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 import org.apache.sysml.api.mlcontext.FrameFormat;
 import org.apache.sysml.api.mlcontext.FrameMetadata;
 import org.apache.sysml.api.mlcontext.FrameSchema;
-import org.apache.sysml.api.mlcontext.MLContext;
 import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel;
-import org.apache.sysml.api.mlcontext.MLContextUtil;
 import org.apache.sysml.api.mlcontext.MLResults;
 import org.apache.sysml.api.mlcontext.MatrixFormat;
 import org.apache.sysml.api.mlcontext.MatrixMetadata;
@@ -50,19 +46,14 @@ import org.apache.sysml.api.mlcontext.Script;
 import org.apache.sysml.parser.Expression.ValueType;
 import 
org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
-import org.apache.sysml.test.integration.AutomatedTestBase;
 import 
org.apache.sysml.test.integration.mlcontext.MLContextTest.CommaSeparatedValueStringToDoubleArrayRow;
-import org.junit.After;
-import org.junit.AfterClass;
 import org.junit.Assert;
 import org.junit.BeforeClass;
 import org.junit.Test;
 
 import scala.collection.Iterator;
 
-public class MLContextFrameTest extends AutomatedTestBase {
-       protected final static String TEST_DIR = 
"org/apache/sysml/api/mlcontext";
-       protected final static String TEST_NAME = "MLContextFrame";
+public class MLContextFrameTest extends MLContextTestBase {
 
        public static enum SCRIPT_TYPE {
                DML, PYDML
@@ -72,25 +63,14 @@ public class MLContextFrameTest extends AutomatedTestBase {
                ANY, FILE, JAVA_RDD_STR_CSV, JAVA_RDD_STR_IJV, RDD_STR_CSV, 
RDD_STR_IJV, DATAFRAME
        };
 
-       private static SparkSession spark;
-       private static JavaSparkContext sc;
-       private static MLContext ml;
        private static String CSV_DELIM = ",";
 
        @BeforeClass
        public static void setUpClass() {
-               spark = createSystemMLSparkSession("MLContextFrameTest", 
"local");
-               ml = new MLContext(spark);
-               sc = MLContextUtil.getJavaSparkContext(ml);
+               MLContextTestBase.setUpClass();
                ml.setExplainLevel(ExplainLevel.RECOMPILE_HOPS);
        }
 
-       @Override
-       public void setUp() {
-               addTestConfiguration(TEST_DIR, TEST_NAME);
-               getAndLoadTestConfiguration(TEST_NAME);
-       }
-
        @Test
        public void testFrameJavaRDD_CSV_DML() {
                testFrame(FrameFormat.CSV, SCRIPT_TYPE.DML, 
IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.ANY);
@@ -644,21 +624,4 @@ public class MLContextFrameTest extends AutomatedTestBase {
        // }
        // }
 
-       @After
-       public void tearDown() {
-               super.tearDown();
-       }
-
-       @AfterClass
-       public static void tearDownClass() {
-               // stop underlying spark context to allow single jvm tests 
(otherwise the
-               // next test that tries to create a SparkContext would fail)
-               spark.stop();
-               sc = null;
-               spark = null;
-
-               // clear status mlcontext and spark exec context
-               ml.close();
-               ml = null;
-       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java
 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java
index c418a6f..9b58322 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java
@@ -80,10 +80,6 @@ public class MLContextMultipleScriptsTest extends 
AutomatedTestBase
                runMLContextTestMultipleScript(RUNTIME_PLATFORM.SPARK, true);
        }
 
-       /**
-        * 
-        * @param platform
-        */
        private void runMLContextTestMultipleScript(RUNTIME_PLATFORM platform, 
boolean wRead) 
        {
                RUNTIME_PLATFORM oldplatform = DMLScript.rtplatform;

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextOutputBlocksizeTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextOutputBlocksizeTest.java
 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextOutputBlocksizeTest.java
index fbc413b..af6028c 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextOutputBlocksizeTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextOutputBlocksizeTest.java
@@ -21,10 +21,7 @@ package org.apache.sysml.test.integration.mlcontext;
 
 import static org.apache.sysml.api.mlcontext.ScriptFactory.dml;
 
-import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.sysml.api.mlcontext.MLContext;
 import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel;
 import org.apache.sysml.api.mlcontext.MLResults;
 import org.apache.sysml.api.mlcontext.Matrix;
@@ -36,44 +33,15 @@ import 
org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysml.runtime.util.DataConverter;
-import org.apache.sysml.test.integration.AutomatedTestBase;
-import org.junit.After;
-import org.junit.AfterClass;
 import org.junit.Assert;
-import org.junit.BeforeClass;
 import org.junit.Test;
 
-
-public class MLContextOutputBlocksizeTest extends AutomatedTestBase
+public class MLContextOutputBlocksizeTest extends MLContextTestBase
 {
-       protected final static String TEST_DIR = 
"org/apache/sysml/api/mlcontext";
-       protected final static String TEST_NAME = "MLContext";
-
        private final static int rows = 100;
        private final static int cols = 63;
        private final static double sparsity = 0.7;
 
-       private static SparkConf conf;
-       private static JavaSparkContext sc;
-       private static MLContext ml;
-
-       @BeforeClass
-       public static void setUpClass() {
-               if (conf == null)
-                       conf = SparkExecutionContext.createSystemMLSparkConf()
-                               .setAppName("MLContextTest").setMaster("local");
-               if (sc == null)
-                       sc = new JavaSparkContext(conf);
-               ml = new MLContext(sc);
-       }
-
-       @Override
-       public void setUp() {
-               addTestConfiguration(TEST_DIR, TEST_NAME);
-               getAndLoadTestConfiguration(TEST_NAME);
-       }
-
-
        @Test
        public void testOutputBlocksizeTextcell() {
                runMLContextOutputBlocksizeTest("text");
@@ -131,21 +99,4 @@ public class MLContextOutputBlocksizeTest extends 
AutomatedTestBase
                }
        }
 
-       @After
-       public void tearDown() {
-               super.tearDown();
-       }
-
-       @AfterClass
-       public static void tearDownClass() {
-               // stop spark context to allow single jvm tests (otherwise the
-               // next test that tries to create a SparkContext would fail)
-               sc.stop();
-               sc = null;
-               conf = null;
-
-               // clear status mlcontext and spark exec context
-               ml.close();
-               ml = null;
-       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java
 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java
index 68b1373..0bcecf4 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java
@@ -21,18 +21,15 @@ package org.apache.sysml.test.integration.mlcontext;
 
 import static org.apache.sysml.api.mlcontext.ScriptFactory.dml;
 
-import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.SparkSession;
-import org.apache.sysml.api.mlcontext.MLContext;
+import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel;
 import org.apache.sysml.api.mlcontext.MLResults;
 import org.apache.sysml.api.mlcontext.MatrixFormat;
 import org.apache.sysml.api.mlcontext.MatrixMetadata;
 import org.apache.sysml.api.mlcontext.Script;
-import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
@@ -41,43 +38,16 @@ import 
org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysml.runtime.util.DataConverter;
-import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.utils.TestUtils;
-import org.junit.After;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
 import org.junit.Test;
 
 
-public class MLContextParforDatasetTest extends AutomatedTestBase 
+public class MLContextParforDatasetTest extends MLContextTestBase
 {
-       protected final static String TEST_DIR = 
"org/apache/sysml/api/mlcontext";
-       protected final static String TEST_NAME = "MLContext";
 
        private final static int rows = 100;
        private final static int cols = 1600;
        private final static double sparsity = 0.7;
-       
-       private static SparkConf conf;
-       private static JavaSparkContext sc;
-       private static MLContext ml;
-
-       @BeforeClass
-       public static void setUpClass() {
-               if (conf == null)
-                       conf = SparkExecutionContext.createSystemMLSparkConf()
-                               .setAppName("MLContextTest").setMaster("local");
-               if (sc == null)
-                       sc = new JavaSparkContext(conf);
-               ml = new MLContext(sc);
-       }
-
-       @Override
-       public void setUp() {
-               addTestConfiguration(TEST_DIR, TEST_NAME);
-               getAndLoadTestConfiguration(TEST_NAME);
-       }
-
 
        @Test
        public void testParforDatasetVector() {
@@ -174,22 +144,4 @@ public class MLContextParforDatasetTest extends 
AutomatedTestBase
                        InfrastructureAnalyzer.setLocalMaxMemory(oldmem);       
                }
        }
-
-       @After
-       public void tearDown() {
-               super.tearDown();
-       }
-
-       @AfterClass
-       public static void tearDownClass() {
-               // stop spark context to allow single jvm tests (otherwise the
-               // next test that tries to create a SparkContext would fail)
-               sc.stop();
-               sc = null;
-               conf = null;
-
-               // clear status mlcontext and spark exec context
-               ml.close();
-               ml = null;
-       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java
 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java
index 6391919..e5e575b 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java
@@ -80,10 +80,6 @@ public class MLContextScratchCleanupTest extends 
AutomatedTestBase
                runMLContextTestMultipleScript(RUNTIME_PLATFORM.SPARK, true);
        }
 
-       /**
-        * 
-        * @param platform
-        */
        private void runMLContextTestMultipleScript(RUNTIME_PLATFORM platform, 
boolean wRead) 
        {
                RUNTIME_PLATFORM oldplatform = DMLScript.rtplatform;

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
index 8bb09e2..88d1a28 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
@@ -45,7 +45,6 @@ import java.util.Map;
 
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.ml.linalg.Vector;
 import org.apache.spark.ml.linalg.VectorUDT;
@@ -54,14 +53,11 @@ import org.apache.spark.rdd.RDD;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
-import org.apache.sysml.api.mlcontext.MLContext;
 import org.apache.sysml.api.mlcontext.MLContextConversionUtil;
 import org.apache.sysml.api.mlcontext.MLContextException;
-import org.apache.sysml.api.mlcontext.MLContextUtil;
 import org.apache.sysml.api.mlcontext.MLResults;
 import org.apache.sysml.api.mlcontext.Matrix;
 import org.apache.sysml.api.mlcontext.MatrixFormat;
@@ -73,11 +69,7 @@ import 
org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
-import org.apache.sysml.test.integration.AutomatedTestBase;
-import org.junit.After;
-import org.junit.AfterClass;
 import org.junit.Assert;
-import org.junit.BeforeClass;
 import org.junit.Test;
 
 import scala.Tuple2;
@@ -86,26 +78,7 @@ import scala.collection.Iterator;
 import scala.collection.JavaConversions;
 import scala.collection.Seq;
 
-public class MLContextTest extends AutomatedTestBase {
-       protected final static String TEST_DIR = 
"org/apache/sysml/api/mlcontext";
-       protected final static String TEST_NAME = "MLContext";
-
-       private static SparkSession spark;
-       private static JavaSparkContext sc;
-       private static MLContext ml;
-
-       @BeforeClass
-       public static void setUpClass() {
-               spark = createSystemMLSparkSession("MLContextTest", "local");
-               ml = new MLContext(spark);
-               sc = MLContextUtil.getJavaSparkContext(ml);
-       }
-
-       @Override
-       public void setUp() {
-               addTestConfiguration(TEST_DIR, TEST_NAME);
-               getAndLoadTestConfiguration(TEST_NAME);
-       }
+public class MLContextTest extends MLContextTestBase {
 
        @Test
        public void testCreateDMLScriptBasedOnStringAndExecute() {
@@ -710,9 +683,12 @@ public class MLContextTest extends AutomatedTestBase {
                System.out.println("MLContextTest - DataFrame sum DML, mllib 
vector with ID column");
 
                List<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> list 
= new ArrayList<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>>();
-               list.add(new Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>(1.0, 
org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0)));
-               list.add(new Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>(2.0, 
org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0)));
-               list.add(new Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>(3.0, 
org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0)));
+               list.add(new Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>(1.0,
+                               
org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0)));
+               list.add(new Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>(2.0,
+                               
org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0)));
+               list.add(new Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>(3.0,
+                               
org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0)));
                JavaRDD<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> 
javaRddTuple = sc.parallelize(list);
 
                JavaRDD<Row> javaRddRow = javaRddTuple.map(new 
DoubleMllibVectorRow());
@@ -734,9 +710,12 @@ public class MLContextTest extends AutomatedTestBase {
                System.out.println("MLContextTest - DataFrame sum PYDML, mllib 
vector with ID column");
 
                List<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> list 
= new ArrayList<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>>();
-               list.add(new Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>(1.0, 
org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0)));
-               list.add(new Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>(2.0, 
org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0)));
-               list.add(new Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>(3.0, 
org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0)));
+               list.add(new Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>(1.0,
+                               
org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0)));
+               list.add(new Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>(2.0,
+                               
org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0)));
+               list.add(new Tuple2<Double, 
org.apache.spark.mllib.linalg.Vector>(3.0,
+                               
org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0)));
                JavaRDD<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> 
javaRddTuple = sc.parallelize(list);
 
                JavaRDD<Row> javaRddRow = javaRddTuple.map(new 
DoubleMllibVectorRow());
@@ -2576,7 +2555,8 @@ public class MLContextTest extends AutomatedTestBase {
        @Test
        public void testPrintFormattingMultipleExpressions() {
                System.out.println("MLContextTest - print formatting multiple 
expressions");
-               Script script = dml("a='hello'; b='goodbye'; c=4; d=3; e=3.0; 
f=5.0; g=FALSE; print('%s %d %f %b', (a+b), (c-d), (e*f), !g);");
+               Script script = dml(
+                               "a='hello'; b='goodbye'; c=4; d=3; e=3.0; 
f=5.0; g=FALSE; print('%s %d %f %b', (a+b), (c-d), (e*f), !g);");
                setExpectedStdOut("hellogoodbye 1 15.000000 true");
                ml.execute(script);
        }
@@ -2732,7 +2712,7 @@ public class MLContextTest extends AutomatedTestBase {
        public void testOutputListDML() {
                System.out.println("MLContextTest - output specified as List 
DML");
 
-               List<String> outputs = Arrays.asList("x","y");
+               List<String> outputs = Arrays.asList("x", "y");
                Script script = dml("a=1;x=a+1;y=x+1").out(outputs);
                MLResults results = ml.execute(script);
                Assert.assertEquals(2, results.getLong("x"));
@@ -2743,7 +2723,7 @@ public class MLContextTest extends AutomatedTestBase {
        public void testOutputListPYDML() {
                System.out.println("MLContextTest - output specified as List 
PYDML");
 
-               List<String> outputs = Arrays.asList("x","y");
+               List<String> outputs = Arrays.asList("x", "y");
                Script script = pydml("a=1\nx=a+1\ny=x+1").out(outputs);
                MLResults results = ml.execute(script);
                Assert.assertEquals(2, results.getLong("x"));
@@ -2755,7 +2735,7 @@ public class MLContextTest extends AutomatedTestBase {
        public void testOutputScalaSeqDML() {
                System.out.println("MLContextTest - output specified as Scala 
Seq DML");
 
-               List outputs = Arrays.asList("x","y");
+               List outputs = Arrays.asList("x", "y");
                Seq seq = JavaConversions.asScalaBuffer(outputs).toSeq();
                Script script = dml("a=1;x=a+1;y=x+1").out(seq);
                MLResults results = ml.execute(script);
@@ -2768,7 +2748,7 @@ public class MLContextTest extends AutomatedTestBase {
        public void testOutputScalaSeqPYDML() {
                System.out.println("MLContextTest - output specified as Scala 
Seq PYDML");
 
-               List outputs = Arrays.asList("x","y");
+               List outputs = Arrays.asList("x", "y");
                Seq seq = JavaConversions.asScalaBuffer(outputs).toSeq();
                Script script = pydml("a=1\nx=a+1\ny=x+1").out(seq);
                MLResults results = ml.execute(script);
@@ -2776,89 +2756,4 @@ public class MLContextTest extends AutomatedTestBase {
                Assert.assertEquals(3, results.getLong("y"));
        }
 
-       // NOTE: Uncomment these tests once they work
-
-       // @SuppressWarnings({ "rawtypes", "unchecked" })
-       // @Test
-       // public void testInputTupleSeqWithAndWithoutMetadataDML() {
-       // System.out.println("MLContextTest - Tuple sequence with and without
-       // metadata DML");
-       //
-       // List<String> list1 = new ArrayList<String>();
-       // list1.add("1,2");
-       // list1.add("3,4");
-       // JavaRDD<String> javaRDD1 = sc.parallelize(list1);
-       // RDD<String> rdd1 = JavaRDD.toRDD(javaRDD1);
-       //
-       // List<String> list2 = new ArrayList<String>();
-       // list2.add("5,6");
-       // list2.add("7,8");
-       // JavaRDD<String> javaRDD2 = sc.parallelize(list2);
-       // RDD<String> rdd2 = JavaRDD.toRDD(javaRDD2);
-       //
-       // MatrixMetadata mm1 = new MatrixMetadata(2, 2);
-       //
-       // Tuple3 tuple1 = new Tuple3("m1", rdd1, mm1);
-       // Tuple2 tuple2 = new Tuple2("m2", rdd2);
-       // List tupleList = new ArrayList();
-       // tupleList.add(tuple1);
-       // tupleList.add(tuple2);
-       // Seq seq = JavaConversions.asScalaBuffer(tupleList).toSeq();
-       //
-       // Script script =
-       // dml("print('sums: ' + sum(m1) + ' ' + sum(m2));").in(seq);
-       // setExpectedStdOut("sums: 10.0 26.0");
-       // ml.execute(script);
-       // }
-       //
-       // @SuppressWarnings({ "rawtypes", "unchecked" })
-       // @Test
-       // public void testInputTupleSeqWithAndWithoutMetadataPYDML() {
-       // System.out.println("MLContextTest - Tuple sequence with and without
-       // metadata PYDML");
-       //
-       // List<String> list1 = new ArrayList<String>();
-       // list1.add("1,2");
-       // list1.add("3,4");
-       // JavaRDD<String> javaRDD1 = sc.parallelize(list1);
-       // RDD<String> rdd1 = JavaRDD.toRDD(javaRDD1);
-       //
-       // List<String> list2 = new ArrayList<String>();
-       // list2.add("5,6");
-       // list2.add("7,8");
-       // JavaRDD<String> javaRDD2 = sc.parallelize(list2);
-       // RDD<String> rdd2 = JavaRDD.toRDD(javaRDD2);
-       //
-       // MatrixMetadata mm1 = new MatrixMetadata(2, 2);
-       //
-       // Tuple3 tuple1 = new Tuple3("m1", rdd1, mm1);
-       // Tuple2 tuple2 = new Tuple2("m2", rdd2);
-       // List tupleList = new ArrayList();
-       // tupleList.add(tuple1);
-       // tupleList.add(tuple2);
-       // Seq seq = JavaConversions.asScalaBuffer(tupleList).toSeq();
-       //
-       // Script script =
-       // pydml("print('sums: ' + sum(m1) + ' ' + sum(m2))").in(seq);
-       // setExpectedStdOut("sums: 10.0 26.0");
-       // ml.execute(script);
-       // }
-
-       @After
-       public void tearDown() {
-               super.tearDown();
-       }
-
-       @AfterClass
-       public static void tearDownClass() {
-               // stop underlying spark context to allow single jvm tests 
(otherwise the
-               // next test that tries to create a SparkContext would fail)
-               spark.stop();
-               sc = null;
-               spark = null;
-
-               // clear status mlcontext and spark exec context
-               ml.close();
-               ml = null;
-       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTestBase.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTestBase.java
 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTestBase.java
new file mode 100644
index 0000000..380fb3f
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTestBase.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.mlcontext;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
+import org.apache.sysml.api.mlcontext.MLContext;
+import org.apache.sysml.api.mlcontext.MLContextUtil;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+
+/**
+ * Abstract class that can be used for MLContext tests.
+ * <p>
+ * Note that if using the setUp() method of MLContextTestBase, the test 
directory
+ * and test name can be specified if needed in the subclass.
+ * <p>
+ * 
+ * Example:
+ * 
+ * <pre>
+ * public MLContextTestExample() {
+ *     testDir = this.getClass().getPackage().getName().replace(".", 
File.separator);
+ *     testName = this.getClass().getSimpleName();
+ * }
+ * </pre>
+ *
+ */
+public abstract class MLContextTestBase extends AutomatedTestBase {
+
+       protected static SparkSession spark;
+       protected static JavaSparkContext sc;
+       protected static MLContext ml;
+
+       protected String testDir = null;
+       protected String testName = null;
+
+       @Override
+       public void setUp() {
+               Class<? extends MLContextTestBase> clazz = this.getClass();
+               String dir = (testDir == null) ? 
"org/apache/sysml/api/mlcontext" : testDir;
+               String name = (testName == null) ? clazz.getSimpleName() : 
testName;
+
+               addTestConfiguration(dir, name);
+               getAndLoadTestConfiguration(name);
+       }
+
+       @BeforeClass
+       public static void setUpClass() {
+               spark = createSystemMLSparkSession("SystemML MLContext Test", 
"local");
+               ml = new MLContext(spark);
+               sc = MLContextUtil.getJavaSparkContext(ml);
+       }
+
+       @After
+       public void tearDown() {
+               super.tearDown();
+       }
+
+       @AfterClass
+       public static void tearDownClass() {
+               // stop underlying spark context to allow single jvm tests 
(otherwise
+               // the next test that tries to create a SparkContext would fail)
+               spark.stop();
+               sc = null;
+               spark = null;
+               ml.close();
+               ml = null;
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java 
b/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java
index d86b707..92b9f67 100644
--- a/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java
@@ -19,42 +19,20 @@
 
 package org.apache.sysml.test.integration.scripts.nn;
 
-import org.apache.spark.sql.SparkSession;
-import org.apache.sysml.api.mlcontext.MLContext;
+import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile;
+
 import org.apache.sysml.api.mlcontext.Script;
-import org.apache.sysml.test.integration.AutomatedTestBase;
-import org.junit.After;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
+import org.apache.sysml.test.integration.mlcontext.MLContextTestBase;
 import org.junit.Test;
 
-import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile;
-
 /**
  * Test the SystemML deep learning library, `nn`.
  */
-public class NNTest extends AutomatedTestBase {
+public class NNTest extends MLContextTestBase {
 
-       private static final String TEST_NAME = "NNTest";
-       private static final String TEST_DIR = "scripts/";
        private static final String TEST_SCRIPT = 
"scripts/nn/test/run_tests.dml";
        private static final String ERROR_STRING = "ERROR:";
 
-       private static SparkSession spark;
-       private static MLContext ml;
-
-       @BeforeClass
-       public static void setUpClass() {
-               spark = createSystemMLSparkSession("MLContextTest", "local");
-               ml = new MLContext(spark);
-       }
-
-       @Override
-       public void setUp() {
-               addTestConfiguration(TEST_DIR, TEST_NAME);
-               getAndLoadTestConfiguration(TEST_NAME);
-       }
-
        @Test
        public void testNNLibrary() {
                Script script = dmlFromFile(TEST_SCRIPT);
@@ -62,20 +40,4 @@ public class NNTest extends AutomatedTestBase {
                ml.execute(script);
        }
 
-       @After
-       public void tearDown() {
-               super.tearDown();
-       }
-
-       @AfterClass
-       public static void tearDownClass() {
-               // stop underlying spark context to allow single jvm tests 
(otherwise the
-               // next test that tries to create a SparkContext would fail)
-               spark.stop();
-               spark = null;
-
-               // clear status mlcontext and spark exec context
-               ml.close();
-               ml = null;
-       }
 }

Reply via email to