Repository: systemml Updated Branches: refs/heads/master 0ae2b4f77 -> ec38b3790
[SYSTEMML-1777] MLContextTestBase class for MLContext testing Create abstract MLContextTestBase class that contains setup and shutdown code for MLContext tests. This removes boilerplate code from MLContext test classes that extend MLContextTestBase. Closes #580. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ec38b379 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ec38b379 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ec38b379 Branch: refs/heads/master Commit: ec38b3790f11792d3337c35439954d422d6eb60b Parents: 0ae2b4f Author: Deron Eriksson <de...@apache.org> Authored: Wed Jul 19 11:13:51 2017 -0700 Committer: Deron Eriksson <de...@apache.org> Committed: Wed Jul 19 11:13:51 2017 -0700 ---------------------------------------------------------------------- .../mlcontext/DataFrameVectorScriptTest.java | 29 +--- .../functions/mlcontext/FrameTest.java | 40 +----- .../functions/mlcontext/GNMFTest.java | 40 +----- .../mlcontext/MLContextFrameTest.java | 41 +----- .../mlcontext/MLContextMultipleScriptsTest.java | 4 - .../mlcontext/MLContextOutputBlocksizeTest.java | 51 +------ .../mlcontext/MLContextParforDatasetTest.java | 52 +------ .../mlcontext/MLContextScratchCleanupTest.java | 4 - .../integration/mlcontext/MLContextTest.java | 143 +++---------------- .../mlcontext/MLContextTestBase.java | 89 ++++++++++++ .../test/integration/scripts/nn/NNTest.java | 46 +----- 11 files changed, 123 insertions(+), 416 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java index 65aee8e..55b8371 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java @@ -38,7 +38,6 @@ import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.sysml.api.mlcontext.FrameFormat; import org.apache.sysml.api.mlcontext.FrameMetadata; -import org.apache.sysml.api.mlcontext.MLContext; import org.apache.sysml.api.mlcontext.Matrix; import org.apache.sysml.api.mlcontext.Script; import org.apache.sysml.conf.ConfigurationManager; @@ -49,15 +48,13 @@ import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.util.DataConverter; import org.apache.sysml.runtime.util.UtilFunctions; -import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.integration.mlcontext.MLContextTestBase; import org.apache.sysml.test.utils.TestUtils; -import org.junit.AfterClass; -import org.junit.BeforeClass; import org.junit.Test; -public class DataFrameVectorScriptTest extends AutomatedTestBase +public class DataFrameVectorScriptTest extends MLContextTestBase { private final static String TEST_DIR = "functions/mlcontext/"; private final static String TEST_NAME = "DataFrameConversion"; @@ -75,16 +72,6 @@ public class DataFrameVectorScriptTest extends AutomatedTestBase private final static double sparsity2 = 0.1; private final static double eps=0.0000000001; - private static SparkSession spark; - private static MLContext ml; - - @BeforeClass - public static void setUpClass() { - spark = createSystemMLSparkSession("DataFrameVectorScriptTest", "local"); - ml = new MLContext(spark); - ml.setExplain(true); - } - @Override public void setUp() { addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"A", "B"})); @@ -343,16 +330,4 @@ public class DataFrameVectorScriptTest extends AutomatedTestBase JavaRDD<Row> rowRDD = sc.parallelize(list); return sparkSession.createDataFrame(rowRDD, dfSchema); } - - @AfterClass - public static void tearDownClass() { - // stop underlying spark context to allow single jvm tests (otherwise the - // next test that tries to create a SparkContext would fail) - spark.stop(); - spark = null; - - // clear status mlcontext and spark exec context - ml.close(); - ml = null; - } } http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java index c93968c..382f433 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java @@ -29,10 +29,8 @@ import java.util.List; import org.apache.hadoop.io.LongWritable; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.StructType; import org.apache.sysml.api.DMLException; import org.apache.sysml.api.DMLScript; @@ -40,8 +38,6 @@ import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.api.mlcontext.FrameFormat; import org.apache.sysml.api.mlcontext.FrameMetadata; import org.apache.sysml.api.mlcontext.FrameSchema; -import org.apache.sysml.api.mlcontext.MLContext; -import org.apache.sysml.api.mlcontext.MLContextUtil; import org.apache.sysml.api.mlcontext.MLResults; import org.apache.sysml.api.mlcontext.Script; import org.apache.sysml.api.mlcontext.ScriptFactory; @@ -57,17 +53,14 @@ import org.apache.sysml.runtime.matrix.data.InputInfo; import org.apache.sysml.runtime.matrix.data.OutputInfo; import org.apache.sysml.runtime.util.MapReduceTool; import org.apache.sysml.runtime.util.UtilFunctions; -import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.integration.mlcontext.MLContextTestBase; import org.apache.sysml.test.utils.TestUtils; -import org.junit.After; -import org.junit.AfterClass; import org.junit.Assert; -import org.junit.BeforeClass; import org.junit.Test; -public class FrameTest extends AutomatedTestBase +public class FrameTest extends MLContextTestBase { private final static String TEST_DIR = "functions/frame/"; private final static String TEST_NAME = "FrameGeneral"; @@ -98,17 +91,6 @@ public class FrameTest extends AutomatedTestBase schemaMixedLarge = (ValueType[]) schemaMixedLargeList.toArray(schemaMixedLarge); } - private static SparkSession spark; - private static JavaSparkContext sc; - private static MLContext ml; - - @BeforeClass - public static void setUpClass() { - spark = createSystemMLSparkSession("FrameTest", "local"); - ml = new MLContext(spark); - sc = MLContextUtil.getJavaSparkContext(ml); - } - @Override public void setUp() { addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, @@ -373,22 +355,4 @@ public class FrameTest extends AutomatedTestBase ", not same as the R value " + val2); } } - - @After - public void tearDown() { - super.tearDown(); - } - - @AfterClass - public static void tearDownClass() { - // stop underlying spark context to allow single jvm tests (otherwise the - // next test that tries to create a SparkContext would fail) - spark.stop(); - sc = null; - spark = null; - - // clear status mlcontext and spark exec context - ml.close(); - ml = null; - } } http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java index 76deec5..44f1f15 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java @@ -28,17 +28,13 @@ import java.util.List; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix; import org.apache.spark.mllib.linalg.distributed.MatrixEntry; import org.apache.spark.rdd.RDD; -import org.apache.spark.sql.SparkSession; import org.apache.sysml.api.DMLException; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; -import org.apache.sysml.api.mlcontext.MLContext; -import org.apache.sysml.api.mlcontext.MLContextUtil; import org.apache.sysml.api.mlcontext.MLResults; import org.apache.sysml.api.mlcontext.Matrix; import org.apache.sysml.api.mlcontext.MatrixFormat; @@ -55,19 +51,16 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; import org.apache.sysml.runtime.util.MapReduceTool; -import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.mlcontext.MLContextTestBase; import org.apache.sysml.test.utils.TestUtils; -import org.junit.After; -import org.junit.AfterClass; import org.junit.Assert; -import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameters; @RunWith(value = Parameterized.class) -public class GNMFTest extends AutomatedTestBase +public class GNMFTest extends MLContextTestBase { private final static String TEST_DIR = "applications/gnmf/"; private final static String TEST_NAME = "GNMF"; @@ -76,22 +69,11 @@ public class GNMFTest extends AutomatedTestBase int numRegisteredInputs; int numRegisteredOutputs; - private static SparkSession spark; - private static JavaSparkContext sc; - private static MLContext ml; - public GNMFTest(int in, int out) { numRegisteredInputs = in; numRegisteredOutputs = out; } - @BeforeClass - public static void setUpClass() { - spark = createSystemMLSparkSession("GNMFTest", "local"); - ml = new MLContext(spark); - sc = MLContextUtil.getJavaSparkContext(ml); - } - @Parameters public static Collection<Object[]> data() { Object[][] data = new Object[][] { { 0, 0 }, { 3, 2 }, { 2, 2 }, { 2, 1 }, { 2, 0 }, { 3, 0 }}; @@ -256,25 +238,7 @@ public class GNMFTest extends AutomatedTestBase DMLScript.USE_LOCAL_SPARK_CONFIG = oldConfig; } } - - @After - public void tearDown() { - super.tearDown(); - } - - @AfterClass - public static void tearDownClass() { - // stop underlying spark context to allow single jvm tests (otherwise the - // next test that tries to create a SparkContext would fail) - spark.stop(); - sc = null; - spark = null; - // clear status mlcontext and spark exec context - ml.close(); - ml = null; - } - public static class StringToMatrixEntry implements Function<String, MatrixEntry> { private static final long serialVersionUID = 7456391906436606324L; http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java index bab719e..a7d12a5 100644 --- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java @@ -28,21 +28,17 @@ import java.util.Arrays; import java.util.List; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.rdd.RDD; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.sysml.api.mlcontext.FrameFormat; import org.apache.sysml.api.mlcontext.FrameMetadata; import org.apache.sysml.api.mlcontext.FrameSchema; -import org.apache.sysml.api.mlcontext.MLContext; import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel; -import org.apache.sysml.api.mlcontext.MLContextUtil; import org.apache.sysml.api.mlcontext.MLResults; import org.apache.sysml.api.mlcontext.MatrixFormat; import org.apache.sysml.api.mlcontext.MatrixMetadata; @@ -50,19 +46,14 @@ import org.apache.sysml.api.mlcontext.Script; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils; import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils; -import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.mlcontext.MLContextTest.CommaSeparatedValueStringToDoubleArrayRow; -import org.junit.After; -import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import scala.collection.Iterator; -public class MLContextFrameTest extends AutomatedTestBase { - protected final static String TEST_DIR = "org/apache/sysml/api/mlcontext"; - protected final static String TEST_NAME = "MLContextFrame"; +public class MLContextFrameTest extends MLContextTestBase { public static enum SCRIPT_TYPE { DML, PYDML @@ -72,25 +63,14 @@ public class MLContextFrameTest extends AutomatedTestBase { ANY, FILE, JAVA_RDD_STR_CSV, JAVA_RDD_STR_IJV, RDD_STR_CSV, RDD_STR_IJV, DATAFRAME }; - private static SparkSession spark; - private static JavaSparkContext sc; - private static MLContext ml; private static String CSV_DELIM = ","; @BeforeClass public static void setUpClass() { - spark = createSystemMLSparkSession("MLContextFrameTest", "local"); - ml = new MLContext(spark); - sc = MLContextUtil.getJavaSparkContext(ml); + MLContextTestBase.setUpClass(); ml.setExplainLevel(ExplainLevel.RECOMPILE_HOPS); } - @Override - public void setUp() { - addTestConfiguration(TEST_DIR, TEST_NAME); - getAndLoadTestConfiguration(TEST_NAME); - } - @Test public void testFrameJavaRDD_CSV_DML() { testFrame(FrameFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.ANY); @@ -644,21 +624,4 @@ public class MLContextFrameTest extends AutomatedTestBase { // } // } - @After - public void tearDown() { - super.tearDown(); - } - - @AfterClass - public static void tearDownClass() { - // stop underlying spark context to allow single jvm tests (otherwise the - // next test that tries to create a SparkContext would fail) - spark.stop(); - sc = null; - spark = null; - - // clear status mlcontext and spark exec context - ml.close(); - ml = null; - } } http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java index c418a6f..9b58322 100644 --- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java @@ -80,10 +80,6 @@ public class MLContextMultipleScriptsTest extends AutomatedTestBase runMLContextTestMultipleScript(RUNTIME_PLATFORM.SPARK, true); } - /** - * - * @param platform - */ private void runMLContextTestMultipleScript(RUNTIME_PLATFORM platform, boolean wRead) { RUNTIME_PLATFORM oldplatform = DMLScript.rtplatform; http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextOutputBlocksizeTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextOutputBlocksizeTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextOutputBlocksizeTest.java index fbc413b..af6028c 100644 --- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextOutputBlocksizeTest.java +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextOutputBlocksizeTest.java @@ -21,10 +21,7 @@ package org.apache.sysml.test.integration.mlcontext; import static org.apache.sysml.api.mlcontext.ScriptFactory.dml; -import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.sysml.api.mlcontext.MLContext; import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel; import org.apache.sysml.api.mlcontext.MLResults; import org.apache.sysml.api.mlcontext.Matrix; @@ -36,44 +33,15 @@ import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.util.DataConverter; -import org.apache.sysml.test.integration.AutomatedTestBase; -import org.junit.After; -import org.junit.AfterClass; import org.junit.Assert; -import org.junit.BeforeClass; import org.junit.Test; - -public class MLContextOutputBlocksizeTest extends AutomatedTestBase +public class MLContextOutputBlocksizeTest extends MLContextTestBase { - protected final static String TEST_DIR = "org/apache/sysml/api/mlcontext"; - protected final static String TEST_NAME = "MLContext"; - private final static int rows = 100; private final static int cols = 63; private final static double sparsity = 0.7; - private static SparkConf conf; - private static JavaSparkContext sc; - private static MLContext ml; - - @BeforeClass - public static void setUpClass() { - if (conf == null) - conf = SparkExecutionContext.createSystemMLSparkConf() - .setAppName("MLContextTest").setMaster("local"); - if (sc == null) - sc = new JavaSparkContext(conf); - ml = new MLContext(sc); - } - - @Override - public void setUp() { - addTestConfiguration(TEST_DIR, TEST_NAME); - getAndLoadTestConfiguration(TEST_NAME); - } - - @Test public void testOutputBlocksizeTextcell() { runMLContextOutputBlocksizeTest("text"); @@ -131,21 +99,4 @@ public class MLContextOutputBlocksizeTest extends AutomatedTestBase } } - @After - public void tearDown() { - super.tearDown(); - } - - @AfterClass - public static void tearDownClass() { - // stop spark context to allow single jvm tests (otherwise the - // next test that tries to create a SparkContext would fail) - sc.stop(); - sc = null; - conf = null; - - // clear status mlcontext and spark exec context - ml.close(); - ml = null; - } } http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java index 68b1373..0bcecf4 100644 --- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java @@ -21,18 +21,15 @@ package org.apache.sysml.test.integration.mlcontext; import static org.apache.sysml.api.mlcontext.ScriptFactory.dml; -import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; -import org.apache.sysml.api.mlcontext.MLContext; +import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel; import org.apache.sysml.api.mlcontext.MLResults; import org.apache.sysml.api.mlcontext.MatrixFormat; import org.apache.sysml.api.mlcontext.MatrixMetadata; import org.apache.sysml.api.mlcontext.Script; -import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; @@ -41,43 +38,16 @@ import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.util.DataConverter; -import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.utils.TestUtils; -import org.junit.After; -import org.junit.AfterClass; -import org.junit.BeforeClass; import org.junit.Test; -public class MLContextParforDatasetTest extends AutomatedTestBase +public class MLContextParforDatasetTest extends MLContextTestBase { - protected final static String TEST_DIR = "org/apache/sysml/api/mlcontext"; - protected final static String TEST_NAME = "MLContext"; private final static int rows = 100; private final static int cols = 1600; private final static double sparsity = 0.7; - - private static SparkConf conf; - private static JavaSparkContext sc; - private static MLContext ml; - - @BeforeClass - public static void setUpClass() { - if (conf == null) - conf = SparkExecutionContext.createSystemMLSparkConf() - .setAppName("MLContextTest").setMaster("local"); - if (sc == null) - sc = new JavaSparkContext(conf); - ml = new MLContext(sc); - } - - @Override - public void setUp() { - addTestConfiguration(TEST_DIR, TEST_NAME); - getAndLoadTestConfiguration(TEST_NAME); - } - @Test public void testParforDatasetVector() { @@ -174,22 +144,4 @@ public class MLContextParforDatasetTest extends AutomatedTestBase InfrastructureAnalyzer.setLocalMaxMemory(oldmem); } } - - @After - public void tearDown() { - super.tearDown(); - } - - @AfterClass - public static void tearDownClass() { - // stop spark context to allow single jvm tests (otherwise the - // next test that tries to create a SparkContext would fail) - sc.stop(); - sc = null; - conf = null; - - // clear status mlcontext and spark exec context - ml.close(); - ml = null; - } } http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java index 6391919..e5e575b 100644 --- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java @@ -80,10 +80,6 @@ public class MLContextScratchCleanupTest extends AutomatedTestBase runMLContextTestMultipleScript(RUNTIME_PLATFORM.SPARK, true); } - /** - * - * @param platform - */ private void runMLContextTestMultipleScript(RUNTIME_PLATFORM platform, boolean wRead) { RUNTIME_PLATFORM oldplatform = DMLScript.rtplatform; http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java index 8bb09e2..88d1a28 100644 --- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java @@ -45,7 +45,6 @@ import java.util.Map; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.ml.linalg.VectorUDT; @@ -54,14 +53,11 @@ import org.apache.spark.rdd.RDD; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -import org.apache.sysml.api.mlcontext.MLContext; import org.apache.sysml.api.mlcontext.MLContextConversionUtil; import org.apache.sysml.api.mlcontext.MLContextException; -import org.apache.sysml.api.mlcontext.MLContextUtil; import org.apache.sysml.api.mlcontext.MLResults; import org.apache.sysml.api.mlcontext.Matrix; import org.apache.sysml.api.mlcontext.MatrixFormat; @@ -73,11 +69,7 @@ import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; -import org.apache.sysml.test.integration.AutomatedTestBase; -import org.junit.After; -import org.junit.AfterClass; import org.junit.Assert; -import org.junit.BeforeClass; import org.junit.Test; import scala.Tuple2; @@ -86,26 +78,7 @@ import scala.collection.Iterator; import scala.collection.JavaConversions; import scala.collection.Seq; -public class MLContextTest extends AutomatedTestBase { - protected final static String TEST_DIR = "org/apache/sysml/api/mlcontext"; - protected final static String TEST_NAME = "MLContext"; - - private static SparkSession spark; - private static JavaSparkContext sc; - private static MLContext ml; - - @BeforeClass - public static void setUpClass() { - spark = createSystemMLSparkSession("MLContextTest", "local"); - ml = new MLContext(spark); - sc = MLContextUtil.getJavaSparkContext(ml); - } - - @Override - public void setUp() { - addTestConfiguration(TEST_DIR, TEST_NAME); - getAndLoadTestConfiguration(TEST_NAME); - } +public class MLContextTest extends MLContextTestBase { @Test public void testCreateDMLScriptBasedOnStringAndExecute() { @@ -710,9 +683,12 @@ public class MLContextTest extends AutomatedTestBase { System.out.println("MLContextTest - DataFrame sum DML, mllib vector with ID column"); List<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> list = new ArrayList<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>>(); - list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(1.0, org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0))); - list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(2.0, org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0))); - list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(3.0, org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0))); + list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(1.0, + org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0))); + list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(2.0, + org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0))); + list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(3.0, + org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0))); JavaRDD<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> javaRddTuple = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleMllibVectorRow()); @@ -734,9 +710,12 @@ public class MLContextTest extends AutomatedTestBase { System.out.println("MLContextTest - DataFrame sum PYDML, mllib vector with ID column"); List<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> list = new ArrayList<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>>(); - list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(1.0, org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0))); - list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(2.0, org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0))); - list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(3.0, org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0))); + list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(1.0, + org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0))); + list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(2.0, + org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0))); + list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(3.0, + org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0))); JavaRDD<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> javaRddTuple = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleMllibVectorRow()); @@ -2576,7 +2555,8 @@ public class MLContextTest extends AutomatedTestBase { @Test public void testPrintFormattingMultipleExpressions() { System.out.println("MLContextTest - print formatting multiple expressions"); - Script script = dml("a='hello'; b='goodbye'; c=4; d=3; e=3.0; f=5.0; g=FALSE; print('%s %d %f %b', (a+b), (c-d), (e*f), !g);"); + Script script = dml( + "a='hello'; b='goodbye'; c=4; d=3; e=3.0; f=5.0; g=FALSE; print('%s %d %f %b', (a+b), (c-d), (e*f), !g);"); setExpectedStdOut("hellogoodbye 1 15.000000 true"); ml.execute(script); } @@ -2732,7 +2712,7 @@ public class MLContextTest extends AutomatedTestBase { public void testOutputListDML() { System.out.println("MLContextTest - output specified as List DML"); - List<String> outputs = Arrays.asList("x","y"); + List<String> outputs = Arrays.asList("x", "y"); Script script = dml("a=1;x=a+1;y=x+1").out(outputs); MLResults results = ml.execute(script); Assert.assertEquals(2, results.getLong("x")); @@ -2743,7 +2723,7 @@ public class MLContextTest extends AutomatedTestBase { public void testOutputListPYDML() { System.out.println("MLContextTest - output specified as List PYDML"); - List<String> outputs = Arrays.asList("x","y"); + List<String> outputs = Arrays.asList("x", "y"); Script script = pydml("a=1\nx=a+1\ny=x+1").out(outputs); MLResults results = ml.execute(script); Assert.assertEquals(2, results.getLong("x")); @@ -2755,7 +2735,7 @@ public class MLContextTest extends AutomatedTestBase { public void testOutputScalaSeqDML() { System.out.println("MLContextTest - output specified as Scala Seq DML"); - List outputs = Arrays.asList("x","y"); + List outputs = Arrays.asList("x", "y"); Seq seq = JavaConversions.asScalaBuffer(outputs).toSeq(); Script script = dml("a=1;x=a+1;y=x+1").out(seq); MLResults results = ml.execute(script); @@ -2768,7 +2748,7 @@ public class MLContextTest extends AutomatedTestBase { public void testOutputScalaSeqPYDML() { System.out.println("MLContextTest - output specified as Scala Seq PYDML"); - List outputs = Arrays.asList("x","y"); + List outputs = Arrays.asList("x", "y"); Seq seq = JavaConversions.asScalaBuffer(outputs).toSeq(); Script script = pydml("a=1\nx=a+1\ny=x+1").out(seq); MLResults results = ml.execute(script); @@ -2776,89 +2756,4 @@ public class MLContextTest extends AutomatedTestBase { Assert.assertEquals(3, results.getLong("y")); } - // NOTE: Uncomment these tests once they work - - // @SuppressWarnings({ "rawtypes", "unchecked" }) - // @Test - // public void testInputTupleSeqWithAndWithoutMetadataDML() { - // System.out.println("MLContextTest - Tuple sequence with and without - // metadata DML"); - // - // List<String> list1 = new ArrayList<String>(); - // list1.add("1,2"); - // list1.add("3,4"); - // JavaRDD<String> javaRDD1 = sc.parallelize(list1); - // RDD<String> rdd1 = JavaRDD.toRDD(javaRDD1); - // - // List<String> list2 = new ArrayList<String>(); - // list2.add("5,6"); - // list2.add("7,8"); - // JavaRDD<String> javaRDD2 = sc.parallelize(list2); - // RDD<String> rdd2 = JavaRDD.toRDD(javaRDD2); - // - // MatrixMetadata mm1 = new MatrixMetadata(2, 2); - // - // Tuple3 tuple1 = new Tuple3("m1", rdd1, mm1); - // Tuple2 tuple2 = new Tuple2("m2", rdd2); - // List tupleList = new ArrayList(); - // tupleList.add(tuple1); - // tupleList.add(tuple2); - // Seq seq = JavaConversions.asScalaBuffer(tupleList).toSeq(); - // - // Script script = - // dml("print('sums: ' + sum(m1) + ' ' + sum(m2));").in(seq); - // setExpectedStdOut("sums: 10.0 26.0"); - // ml.execute(script); - // } - // - // @SuppressWarnings({ "rawtypes", "unchecked" }) - // @Test - // public void testInputTupleSeqWithAndWithoutMetadataPYDML() { - // System.out.println("MLContextTest - Tuple sequence with and without - // metadata PYDML"); - // - // List<String> list1 = new ArrayList<String>(); - // list1.add("1,2"); - // list1.add("3,4"); - // JavaRDD<String> javaRDD1 = sc.parallelize(list1); - // RDD<String> rdd1 = JavaRDD.toRDD(javaRDD1); - // - // List<String> list2 = new ArrayList<String>(); - // list2.add("5,6"); - // list2.add("7,8"); - // JavaRDD<String> javaRDD2 = sc.parallelize(list2); - // RDD<String> rdd2 = JavaRDD.toRDD(javaRDD2); - // - // MatrixMetadata mm1 = new MatrixMetadata(2, 2); - // - // Tuple3 tuple1 = new Tuple3("m1", rdd1, mm1); - // Tuple2 tuple2 = new Tuple2("m2", rdd2); - // List tupleList = new ArrayList(); - // tupleList.add(tuple1); - // tupleList.add(tuple2); - // Seq seq = JavaConversions.asScalaBuffer(tupleList).toSeq(); - // - // Script script = - // pydml("print('sums: ' + sum(m1) + ' ' + sum(m2))").in(seq); - // setExpectedStdOut("sums: 10.0 26.0"); - // ml.execute(script); - // } - - @After - public void tearDown() { - super.tearDown(); - } - - @AfterClass - public static void tearDownClass() { - // stop underlying spark context to allow single jvm tests (otherwise the - // next test that tries to create a SparkContext would fail) - spark.stop(); - sc = null; - spark = null; - - // clear status mlcontext and spark exec context - ml.close(); - ml = null; - } } http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTestBase.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTestBase.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTestBase.java new file mode 100644 index 0000000..380fb3f --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTestBase.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.mlcontext; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; +import org.apache.sysml.api.mlcontext.MLContext; +import org.apache.sysml.api.mlcontext.MLContextUtil; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +/** + * Abstract class that can be used for MLContext tests. + * <p> + * Note that if using the setUp() method of MLContextTestBase, the test directory + * and test name can be specified if needed in the subclass. + * <p> + * + * Example: + * + * <pre> + * public MLContextTestExample() { + * testDir = this.getClass().getPackage().getName().replace(".", File.separator); + * testName = this.getClass().getSimpleName(); + * } + * </pre> + * + */ +public abstract class MLContextTestBase extends AutomatedTestBase { + + protected static SparkSession spark; + protected static JavaSparkContext sc; + protected static MLContext ml; + + protected String testDir = null; + protected String testName = null; + + @Override + public void setUp() { + Class<? extends MLContextTestBase> clazz = this.getClass(); + String dir = (testDir == null) ? "org/apache/sysml/api/mlcontext" : testDir; + String name = (testName == null) ? clazz.getSimpleName() : testName; + + addTestConfiguration(dir, name); + getAndLoadTestConfiguration(name); + } + + @BeforeClass + public static void setUpClass() { + spark = createSystemMLSparkSession("SystemML MLContext Test", "local"); + ml = new MLContext(spark); + sc = MLContextUtil.getJavaSparkContext(ml); + } + + @After + public void tearDown() { + super.tearDown(); + } + + @AfterClass + public static void tearDownClass() { + // stop underlying spark context to allow single jvm tests (otherwise + // the next test that tries to create a SparkContext would fail) + spark.stop(); + sc = null; + spark = null; + ml.close(); + ml = null; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java b/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java index d86b707..92b9f67 100644 --- a/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java +++ b/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java @@ -19,42 +19,20 @@ package org.apache.sysml.test.integration.scripts.nn; -import org.apache.spark.sql.SparkSession; -import org.apache.sysml.api.mlcontext.MLContext; +import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile; + import org.apache.sysml.api.mlcontext.Script; -import org.apache.sysml.test.integration.AutomatedTestBase; -import org.junit.After; -import org.junit.AfterClass; -import org.junit.BeforeClass; +import org.apache.sysml.test.integration.mlcontext.MLContextTestBase; import org.junit.Test; -import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile; - /** * Test the SystemML deep learning library, `nn`. */ -public class NNTest extends AutomatedTestBase { +public class NNTest extends MLContextTestBase { - private static final String TEST_NAME = "NNTest"; - private static final String TEST_DIR = "scripts/"; private static final String TEST_SCRIPT = "scripts/nn/test/run_tests.dml"; private static final String ERROR_STRING = "ERROR:"; - private static SparkSession spark; - private static MLContext ml; - - @BeforeClass - public static void setUpClass() { - spark = createSystemMLSparkSession("MLContextTest", "local"); - ml = new MLContext(spark); - } - - @Override - public void setUp() { - addTestConfiguration(TEST_DIR, TEST_NAME); - getAndLoadTestConfiguration(TEST_NAME); - } - @Test public void testNNLibrary() { Script script = dmlFromFile(TEST_SCRIPT); @@ -62,20 +40,4 @@ public class NNTest extends AutomatedTestBase { ml.execute(script); } - @After - public void tearDown() { - super.tearDown(); - } - - @AfterClass - public static void tearDownClass() { - // stop underlying spark context to allow single jvm tests (otherwise the - // next test that tries to create a SparkContext would fail) - spark.stop(); - spark = null; - - // clear status mlcontext and spark exec context - ml.close(); - ml = null; - } }