[SYSTEMML-1347] Accept SparkSession in Java/Scala MLContext API Add MLContext constructor for SparkSession. In MLContext, store SparkSession reference instead of SparkContext. Remove unused monitoring parameter in MLContext. Simplifications in MLContextUtil and MLContextConversionUtil. Method for creating SparkSession in AutomatedTestBase. Update tests for SparkSession. Add MLContext SparkSession constructor to MLContext guide.
Closes #405. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/9c19b477 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/9c19b477 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/9c19b477 Branch: refs/heads/master Commit: 9c19b4771caa96af4e959dda363d41e32818fb56 Parents: 9820f4c Author: Deron Eriksson <de...@us.ibm.com> Authored: Fri Apr 7 10:35:55 2017 -0700 Committer: Deron Eriksson <de...@us.ibm.com> Committed: Fri Apr 7 10:35:55 2017 -0700 ---------------------------------------------------------------------- docs/spark-mlcontext-programming-guide.md | 14 +-- .../apache/sysml/api/mlcontext/MLContext.java | 54 +++++------ .../api/mlcontext/MLContextConversionUtil.java | 92 +++++++++--------- .../sysml/api/mlcontext/MLContextUtil.java | 45 +++++++-- .../context/SparkExecutionContext.java | 16 ++-- .../test/integration/AutomatedTestBase.java | 26 ++++++ .../DataFrameMatrixConversionTest.java | 43 +++++---- .../DataFrameRowFrameConversionTest.java | 49 +++++----- .../DataFrameVectorFrameConversionTest.java | 42 +++++---- .../mlcontext/DataFrameVectorScriptTest.java | 62 ++++++------- .../functions/mlcontext/FrameTest.java | 25 ++--- .../functions/mlcontext/GNMFTest.java | 21 ++--- .../mlcontext/MLContextFrameTest.java | 43 +++------ .../mlcontext/MLContextMultipleScriptsTest.java | 16 ++-- .../mlcontext/MLContextScratchCleanupTest.java | 16 ++-- .../integration/mlcontext/MLContextTest.java | 98 +++++++------------- 16 files changed, 327 insertions(+), 335 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/docs/spark-mlcontext-programming-guide.md ---------------------------------------------------------------------- diff --git a/docs/spark-mlcontext-programming-guide.md b/docs/spark-mlcontext-programming-guide.md index c28eaf5..3b7bfc8 100644 --- a/docs/spark-mlcontext-programming-guide.md +++ b/docs/spark-mlcontext-programming-guide.md @@ -47,10 +47,10 @@ spark-shell --executor-memory 4G --driver-memory 4G --jars SystemML.jar ## Create MLContext -All primary classes that a user interacts with are located in the `org.apache.sysml.api.mlcontext package`. -For convenience, we can additionally add a static import of ScriptFactory to shorten the syntax for creating Script objects. -An `MLContext` object can be created by passing its constructor a reference to the `SparkContext`. If successful, you -should see a "`Welcome to Apache SystemML!`" message. +All primary classes that a user interacts with are located in the `org.apache.sysml.api.mlcontext` package. +For convenience, we can additionally add a static import of `ScriptFactory` to shorten the syntax for creating `Script` objects. +An `MLContext` object can be created by passing its constructor a reference to the `SparkSession` (`spark`) or `SparkContext` (`sc`). +If successful, you should see a "`Welcome to Apache SystemML!`" message. <div class="codetabs"> @@ -58,7 +58,7 @@ should see a "`Welcome to Apache SystemML!`" message. {% highlight scala %} import org.apache.sysml.api.mlcontext._ import org.apache.sysml.api.mlcontext.ScriptFactory._ -val ml = new MLContext(sc) +val ml = new MLContext(spark) {% endhighlight %} </div> @@ -70,7 +70,7 @@ import org.apache.sysml.api.mlcontext._ scala> import org.apache.sysml.api.mlcontext.ScriptFactory._ import org.apache.sysml.api.mlcontext.ScriptFactory._ -scala> val ml = new MLContext(sc) +scala> val ml = new MLContext(spark) Welcome to Apache SystemML! @@ -1753,7 +1753,7 @@ Archiver-Version: Plexus Archiver Artifact-Id: systemml Build-Jdk: 1.8.0_60 Build-Time: 2017-02-03 22:32:43 UTC -Built-By: deroneriksson +Built-By: sparkuser Created-By: Apache Maven 3.3.9 Group-Id: org.apache.systemml Main-Class: org.apache.sysml.api.DMLScript http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java index cb98083..41df7fd 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java @@ -29,6 +29,7 @@ import java.util.Set; import org.apache.log4j.Logger; import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.api.MLContextProxy; @@ -62,9 +63,9 @@ public class MLContext { public static Logger log = Logger.getLogger(MLContext.class); /** - * SparkContext object. + * SparkSession object. */ - private SparkContext sc = null; + private SparkSession spark = null; /** * Reference to the currently executing script. @@ -164,6 +165,16 @@ public class MLContext { } /** + * Create an MLContext based on a SparkSession for interaction with SystemML + * on Spark. + * + * @param spark SparkSession + */ + public MLContext(SparkSession spark) { + initMLContext(spark); + } + + /** * Create an MLContext based on a SparkContext for interaction with SystemML * on Spark. * @@ -171,7 +182,7 @@ public class MLContext { * SparkContext */ public MLContext(SparkContext sparkContext) { - this(sparkContext, false); + initMLContext(SparkSession.builder().sparkContext(sparkContext).getOrCreate()); } /** @@ -182,38 +193,21 @@ public class MLContext { * JavaSparkContext */ public MLContext(JavaSparkContext javaSparkContext) { - this(javaSparkContext.sc(), false); - } - - /** - * Create an MLContext based on a SparkContext for interaction with SystemML - * on Spark, optionally monitor performance. - * - * @param sc - * SparkContext object. - * @param monitorPerformance - * {@code true} if performance should be monitored, {@code false} - * otherwise - */ - public MLContext(SparkContext sc, boolean monitorPerformance) { - initMLContext(sc, monitorPerformance); + initMLContext(SparkSession.builder().sparkContext(javaSparkContext.sc()).getOrCreate()); } /** * Initialize MLContext. Verify Spark version supported, set default * execution mode, set MLContextProxy, set default config, set compiler - * config, and configure monitoring if needed. + * config. * * @param sc * SparkContext object. - * @param monitorPerformance - * {@code true} if performance should be monitored, {@code false} - * otherwise */ - private void initMLContext(SparkContext sc, boolean monitorPerformance) { + private void initMLContext(SparkSession spark) { try { - MLContextUtil.verifySparkVersionSupported(sc); + MLContextUtil.verifySparkVersionSupported(spark); } catch (MLContextException e) { if (info() != null) { log.warn("Apache Spark " + this.info().minimumRecommendedSparkVersion() + " or above is recommended for SystemML " + this.info().version()); @@ -231,7 +225,7 @@ public class MLContext { System.out.println(MLContextUtil.welcomeMessage()); } - this.sc = sc; + this.spark = spark; // by default, run in hybrid Spark mode for optimal performance DMLScript.rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; @@ -329,12 +323,12 @@ public class MLContext { } /** - * Obtain the SparkContext associated with this MLContext. + * Obtain the SparkSession associated with this MLContext. * - * @return the SparkContext associated with this MLContext. + * @return the SparkSession associated with this MLContext. */ - public SparkContext getSparkContext() { - return sc; + public SparkSession getSparkSession() { + return spark; } /** @@ -641,7 +635,7 @@ public class MLContext { scripts.clear(); scriptHistoryStrings.clear(); resetConfig(); - sc = null; + spark = null; } /** http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java index c496325..dc20108 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java @@ -40,7 +40,6 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -import org.apache.sysml.api.MLContextProxy; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.parser.Expression.ValueType; @@ -137,11 +136,7 @@ public class MLContextConversionUtil { try { InputStream is = url.openStream(); List<String> lines = IOUtils.readLines(is); - MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContextForAPI(); - SparkContext sparkContext = activeMLContext.getSparkContext(); - @SuppressWarnings("resource") - JavaSparkContext javaSparkContext = new JavaSparkContext(sparkContext); - JavaRDD<String> javaRDD = javaSparkContext.parallelize(lines); + JavaRDD<String> javaRDD = jsc().parallelize(lines); if ((matrixMetadata == null) || (matrixMetadata.getMatrixFormat() == MatrixFormat.CSV)) { return javaRDDStringCSVToMatrixObject(variableName, javaRDD, matrixMetadata); } else if (matrixMetadata.getMatrixFormat() == MatrixFormat.IJV) { @@ -370,8 +365,6 @@ public class MLContextConversionUtil { frameMetadata = new FrameMetadata(); determineFrameFormatIfNeeded(dataFrame, frameMetadata); boolean containsID = isDataFrameWithIDColumn(frameMetadata); - JavaSparkContext javaSparkContext = MLContextUtil - .getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI()); MatrixCharacteristics mc = frameMetadata.asMatrixCharacteristics(); if( mc == null ) mc = new MatrixCharacteristics(); @@ -380,7 +373,7 @@ public class MLContextConversionUtil { //TODO extend frame schema by column names (right now dropped) Pair<String[], ValueType[]> ret = new Pair<String[], ValueType[]>(); JavaPairRDD<Long, FrameBlock> binaryBlock = FrameRDDConverterUtils - .dataFrameToBinaryBlock(javaSparkContext, dataFrame, mc, containsID, ret); + .dataFrameToBinaryBlock(jsc(), dataFrame, mc, containsID, ret); frameMetadata.setFrameSchema(new FrameSchema(Arrays.asList(ret.getValue()))); frameMetadata.setMatrixCharacteristics(mc); //required due to meta data copy @@ -426,13 +419,10 @@ public class MLContextConversionUtil { matrixMetadata.asMatrixCharacteristics() : new MatrixCharacteristics(); boolean containsID = isDataFrameWithIDColumn(matrixMetadata); boolean isVector = isVectorBasedDataFrame(matrixMetadata); - - //get spark context - JavaSparkContext sc = MLContextUtil.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI()); //convert data frame to binary block matrix JavaPairRDD<MatrixIndexes,MatrixBlock> out = RDDConverterUtils - .dataFrameToBinaryBlock(sc, dataFrame, mc, containsID, isVector); + .dataFrameToBinaryBlock(jsc(), dataFrame, mc, containsID, isVector); //update determined matrix characteristics if( matrixMetadata != null ) @@ -639,14 +629,12 @@ public class MLContextConversionUtil { frameMetadata.asMatrixCharacteristics() : new MatrixCharacteristics(); JavaPairRDD<LongWritable, Text> javaPairRDDText = javaPairRDD.mapToPair(new CopyTextInputFunction()); - JavaSparkContext jsc = MLContextUtil.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI()); - FrameObject frameObject = new FrameObject(OptimizerUtils.getUniqueTempFileName(), new MatrixFormatMetaData(mc, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo), frameMetadata.getFrameSchema().getSchema().toArray(new ValueType[0])); JavaPairRDD<Long, FrameBlock> rdd; try { - rdd = FrameRDDConverterUtils.csvToBinaryBlock(jsc, javaPairRDDText, mc, + rdd = FrameRDDConverterUtils.csvToBinaryBlock(jsc(), javaPairRDDText, mc, frameObject.getSchema(), false, ",", false, -1); } catch (DMLRuntimeException e) { e.printStackTrace(); @@ -701,8 +689,6 @@ public class MLContextConversionUtil { JavaPairRDD<LongWritable, Text> javaPairRDDText = javaPairRDD.mapToPair(new CopyTextInputFunction()); - JavaSparkContext jsc = MLContextUtil.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI()); - FrameObject frameObject = new FrameObject(OptimizerUtils.getUniqueTempFileName(), new MatrixFormatMetaData(mc, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo), frameMetadata.getFrameSchema().getSchema().toArray(new ValueType[0])); @@ -711,7 +697,7 @@ public class MLContextConversionUtil { ValueType[] lschema = null; if (lschema == null) lschema = UtilFunctions.nCopies((int) mc.getCols(), ValueType.STRING); - rdd = FrameRDDConverterUtils.textCellToBinaryBlock(jsc, javaPairRDDText, mc, lschema); + rdd = FrameRDDConverterUtils.textCellToBinaryBlock(jsc(), javaPairRDDText, mc, lschema); } catch (DMLRuntimeException e) { e.printStackTrace(); return null; @@ -859,11 +845,7 @@ public class MLContextConversionUtil { public static JavaRDD<String> matrixObjectToJavaRDDStringCSV(MatrixObject matrixObject) { List<String> list = matrixObjectToListStringCSV(matrixObject); - MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContextForAPI(); - SparkContext sc = activeMLContext.getSparkContext(); - @SuppressWarnings("resource") - JavaSparkContext jsc = new JavaSparkContext(sc); - return jsc.parallelize(list); + return jsc().parallelize(list); } /** @@ -877,8 +859,7 @@ public class MLContextConversionUtil { public static JavaRDD<String> frameObjectToJavaRDDStringCSV(FrameObject frameObject, String delimiter) { List<String> list = frameObjectToListStringCSV(frameObject, delimiter); - JavaSparkContext jsc = MLContextUtil.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI()); - return jsc.parallelize(list); + return jsc().parallelize(list); } /** @@ -892,11 +873,7 @@ public class MLContextConversionUtil { public static JavaRDD<String> matrixObjectToJavaRDDStringIJV(MatrixObject matrixObject) { List<String> list = matrixObjectToListStringIJV(matrixObject); - MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContextForAPI(); - SparkContext sc = activeMLContext.getSparkContext(); - @SuppressWarnings("resource") - JavaSparkContext jsc = new JavaSparkContext(sc); - return jsc.parallelize(list); + return jsc().parallelize(list); } /** @@ -909,8 +886,7 @@ public class MLContextConversionUtil { public static JavaRDD<String> frameObjectToJavaRDDStringIJV(FrameObject frameObject) { List<String> list = frameObjectToListStringIJV(frameObject); - JavaSparkContext jsc = MLContextUtil.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI()); - return jsc.parallelize(list); + return jsc().parallelize(list); } /** @@ -934,10 +910,8 @@ public class MLContextConversionUtil { List<String> list = matrixObjectToListStringIJV(matrixObject); - MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContextForAPI(); - SparkContext sc = activeMLContext.getSparkContext(); ClassTag<String> tag = scala.reflect.ClassTag$.MODULE$.apply(String.class); - return sc.parallelize(JavaConversions.asScalaBuffer(list), sc.defaultParallelism(), tag); + return sc().parallelize(JavaConversions.asScalaBuffer(list), sc().defaultParallelism(), tag); } /** @@ -961,9 +935,8 @@ public class MLContextConversionUtil { List<String> list = frameObjectToListStringIJV(frameObject); - SparkContext sc = MLContextUtil.getSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI()); ClassTag<String> tag = scala.reflect.ClassTag$.MODULE$.apply(String.class); - return sc.parallelize(JavaConversions.asScalaBuffer(list), sc.defaultParallelism(), tag); + return sc().parallelize(JavaConversions.asScalaBuffer(list), sc().defaultParallelism(), tag); } /** @@ -987,10 +960,8 @@ public class MLContextConversionUtil { List<String> list = matrixObjectToListStringCSV(matrixObject); - MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContextForAPI(); - SparkContext sc = activeMLContext.getSparkContext(); ClassTag<String> tag = scala.reflect.ClassTag$.MODULE$.apply(String.class); - return sc.parallelize(JavaConversions.asScalaBuffer(list), sc.defaultParallelism(), tag); + return sc().parallelize(JavaConversions.asScalaBuffer(list), sc().defaultParallelism(), tag); } /** @@ -1015,9 +986,8 @@ public class MLContextConversionUtil { List<String> list = frameObjectToListStringCSV(frameObject, delimiter); - SparkContext sc = MLContextUtil.getSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI()); ClassTag<String> tag = scala.reflect.ClassTag$.MODULE$.apply(String.class); - return sc.parallelize(JavaConversions.asScalaBuffer(list), sc.defaultParallelism(), tag); + return sc().parallelize(JavaConversions.asScalaBuffer(list), sc().defaultParallelism(), tag); } /** @@ -1247,10 +1217,7 @@ public class MLContextConversionUtil { .getRDDHandleForMatrixObject(matrixObject, InputInfo.BinaryBlockInputInfo); MatrixCharacteristics mc = matrixObject.getMatrixCharacteristics(); - SparkContext sc = ((MLContext) MLContextProxy.getActiveMLContextForAPI()).getSparkContext(); - SparkSession sparkSession = SparkSession.builder().sparkContext(sc).getOrCreate(); - - return RDDConverterUtils.binaryBlockToDataFrame(sparkSession, binaryBlockMatrix, mc, isVectorDF); + return RDDConverterUtils.binaryBlockToDataFrame(spark(), binaryBlockMatrix, mc, isVectorDF); } catch (DMLRuntimeException e) { throw new MLContextException("DMLRuntimeException while converting matrix object to DataFrame", e); @@ -1274,9 +1241,7 @@ public class MLContextConversionUtil { .getRDDHandleForFrameObject(frameObject, InputInfo.BinaryBlockInputInfo); MatrixCharacteristics mc = frameObject.getMatrixCharacteristics(); - JavaSparkContext jsc = MLContextUtil.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI()); - SparkSession sparkSession = SparkSession.builder().sparkContext(jsc.sc()).getOrCreate(); - return FrameRDDConverterUtils.binaryBlockToDataFrame(sparkSession, binaryBlockFrame, mc, frameObject.getSchema()); + return FrameRDDConverterUtils.binaryBlockToDataFrame(spark(), binaryBlockFrame, mc, frameObject.getSchema()); } catch (DMLRuntimeException e) { throw new MLContextException("DMLRuntimeException while converting frame object to DataFrame", e); @@ -1348,4 +1313,31 @@ public class MLContextConversionUtil { throw new MLContextException("DMLRuntimeException while converting frame object to 2D string array", e); } } + + /** + * Obtain JavaSparkContext from MLContextProxy. + * + * @return the Java Spark Context + */ + public static JavaSparkContext jsc() { + return MLContextUtil.getJavaSparkContextFromProxy(); + } + + /** + * Obtain SparkContext from MLContextProxy. + * + * @return the Spark Context + */ + public static SparkContext sc() { + return MLContextUtil.getSparkContextFromProxy(); + } + + /** + * Obtain SparkSession from MLContextProxy. + * + * @return the Spark Session + */ + public static SparkSession spark() { + return MLContextUtil.getSparkSessionFromProxy(); + } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java index 4cd95d4..c4314bf 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java @@ -47,10 +47,12 @@ import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.rdd.RDD; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +import org.apache.sysml.api.MLContextProxy; import org.apache.sysml.conf.CompilerConfig; import org.apache.sysml.conf.CompilerConfig.ConfigType; import org.apache.sysml.conf.ConfigurationManager; @@ -170,12 +172,12 @@ public final class MLContextUtil { * Check that the Spark version is supported. If it isn't supported, throw * an MLContextException. * - * @param sc - * SparkContext + * @param spark + * SparkSession * @throws MLContextException * thrown if Spark version isn't supported */ - public static void verifySparkVersionSupported(SparkContext sc) { + public static void verifySparkVersionSupported(SparkSession spark) { String minimumRecommendedSparkVersion = null; try { // If this is being called using the SystemML jar file, @@ -192,7 +194,7 @@ public final class MLContextUtil { throw new MLContextException("Minimum recommended Spark version could not be determined from SystemML jar file manifest or pom.xml"); } } - String sparkVersion = sc.version(); + String sparkVersion = spark.version(); if (!MLContextUtil.isSparkVersionSupported(sparkVersion, minimumRecommendedSparkVersion)) { throw new MLContextException( "Spark " + sparkVersion + " or greater is recommended for this version of SystemML."); @@ -1027,7 +1029,7 @@ public final class MLContextUtil { * @return the Spark Context */ public static SparkContext getSparkContext(MLContext mlContext) { - return mlContext.getSparkContext(); + return mlContext.getSparkSession().sparkContext(); } /** @@ -1038,7 +1040,38 @@ public final class MLContextUtil { * @return the Java Spark Context */ public static JavaSparkContext getJavaSparkContext(MLContext mlContext) { - return new JavaSparkContext(mlContext.getSparkContext()); + return new JavaSparkContext(mlContext.getSparkSession().sparkContext()); + } + + /** + * Obtain the Spark Context from the MLContextProxy + * + * @return the Spark Context + */ + public static SparkContext getSparkContextFromProxy() { + MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContextForAPI(); + SparkContext sc = getSparkContext(activeMLContext); + return sc; + } + + /** + * Obtain the Java Spark Context from the MLContextProxy + * + * @return the Java Spark Context + */ + public static JavaSparkContext getJavaSparkContextFromProxy() { + MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContextForAPI(); + JavaSparkContext jsc = getJavaSparkContext(activeMLContext); + return jsc; + } + + /** + * Obtain the Spark Session from the MLContextProxy + * + * @return the Spark Session + */ + public static SparkSession getSparkSessionFromProxy() { + return ((MLContext) MLContextProxy.getActiveMLContextForAPI()).getSparkSession(); } /** http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java index c2e3dd0..92946ff 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java @@ -39,6 +39,7 @@ import org.apache.spark.storage.StorageLevel; import org.apache.spark.util.LongAccumulator; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.MLContextProxy; +import org.apache.sysml.api.mlcontext.MLContextUtil; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.lops.Checkpoint; @@ -85,10 +86,11 @@ public class SparkExecutionContext extends ExecutionContext private static final boolean LDEBUG = false; //local debug flag //internal configurations - private static final boolean LAZY_SPARKCTX_CREATION = true; - private static final boolean ASYNCHRONOUS_VAR_DESTROY = true; - private static final boolean FAIR_SCHEDULER_MODE = true; - + private static boolean LAZY_SPARKCTX_CREATION = true; + private static boolean ASYNCHRONOUS_VAR_DESTROY = true; + + public static boolean FAIR_SCHEDULER_MODE = true; + //executor memory and relative fractions as obtained from the spark configuration private static SparkClusterConfig _sconf = null; @@ -198,7 +200,7 @@ public class SparkExecutionContext extends ExecutionContext _spctx = new JavaSparkContext(mlCtx.getSparkContext()); } else if (mlCtxObj instanceof org.apache.sysml.api.mlcontext.MLContext) { org.apache.sysml.api.mlcontext.MLContext mlCtx = (org.apache.sysml.api.mlcontext.MLContext) mlCtxObj; - _spctx = new JavaSparkContext(mlCtx.getSparkContext()); + _spctx = MLContextUtil.getJavaSparkContext(mlCtx); } } else @@ -267,12 +269,12 @@ public class SparkExecutionContext extends ExecutionContext return conf; } - + /** * Spark instructions should call this for all matrix inputs except broadcast * variables. * - * @param varname varible name + * @param varname variable name * @return JavaPairRDD of MatrixIndexes-MatrixBlocks * @throws DMLRuntimeException if DMLRuntimeException occurs */ http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java index 318a1c8..f3ede65 100644 --- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java @@ -35,6 +35,8 @@ import java.util.HashMap; import org.apache.sysml.lops.Lop; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.SparkSession.Builder; import org.apache.wink.json4j.JSONObject; import org.junit.After; import org.junit.Assert; @@ -1788,4 +1790,28 @@ public abstract class AutomatedTestBase return true; return false; } + + /** + * Create a SystemML-preferred Spark Session. + * + * @param appName the application name + * @param master the master value (ie, "local", etc) + * @return Spark Session + */ + public static SparkSession createSystemMLSparkSession(String appName, String master) { + Builder builder = SparkSession.builder(); + if (appName != null) { + builder.appName(appName); + } + if (master != null) { + builder.master(master); + } + builder.config("spark.driver.maxResultSize", "0"); + if (SparkExecutionContext.FAIR_SCHEDULER_MODE) { + builder.config("spark.scheduler.mode", "FAIR"); + } + builder.config("spark.locality.wait", "5s"); + SparkSession spark = builder.getOrCreate(); + return spark; + } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameMatrixConversionTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameMatrixConversionTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameMatrixConversionTest.java index bf5d33d..a6b6811 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameMatrixConversionTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameMatrixConversionTest.java @@ -27,7 +27,6 @@ import org.apache.spark.sql.SparkSession; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.conf.ConfigurationManager; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; @@ -37,6 +36,8 @@ import org.apache.sysml.runtime.util.DataConverter; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; import org.apache.sysml.test.utils.TestUtils; +import org.junit.AfterClass; +import org.junit.BeforeClass; import org.junit.Test; @@ -55,7 +56,15 @@ public class DataFrameMatrixConversionTest extends AutomatedTestBase private final static double sparsity2 = 0.1; private final static double eps=0.0000000001; - + private static SparkSession spark; + private static JavaSparkContext sc; + + @BeforeClass + public static void setUpClass() { + spark = createSystemMLSparkSession("DataFrameMatrixConversionTest", "local"); + sc = new JavaSparkContext(spark.sparkContext()); + } + @Override public void setUp() { addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"A", "B"})); @@ -160,20 +169,11 @@ public class DataFrameMatrixConversionTest extends AutomatedTestBase public void testVectorConversionWideSparseUnknown() { testDataFrameConversion(true, cols3, false, true); } - - /** - * - * @param vector - * @param singleColBlock - * @param dense - * @param unknownDims - */ + private void testDataFrameConversion(boolean vector, int cols, boolean dense, boolean unknownDims) { boolean oldConfig = DMLScript.USE_LOCAL_SPARK_CONFIG; RUNTIME_PLATFORM oldPlatform = DMLScript.rtplatform; - SparkExecutionContext sec = null; - try { DMLScript.USE_LOCAL_SPARK_CONFIG = true; @@ -187,17 +187,12 @@ public class DataFrameMatrixConversionTest extends AutomatedTestBase int blksz = ConfigurationManager.getBlocksize(); MatrixCharacteristics mc1 = new MatrixCharacteristics(rows, cols, blksz, blksz, mbA.getNonZeros()); MatrixCharacteristics mc2 = unknownDims ? new MatrixCharacteristics() : new MatrixCharacteristics(mc1); - - //setup spark context - sec = (SparkExecutionContext) ExecutionContextFactory.createContext(); - JavaSparkContext sc = sec.getSparkContext(); - SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); - + //get binary block input rdd JavaPairRDD<MatrixIndexes,MatrixBlock> in = SparkExecutionContext.toMatrixJavaPairRDD(sc, mbA, blksz, blksz); //matrix - dataframe - matrix conversion - Dataset<Row> df = RDDConverterUtils.binaryBlockToDataFrame(sparkSession, in, mc1, vector); + Dataset<Row> df = RDDConverterUtils.binaryBlockToDataFrame(spark, in, mc1, vector); df = ( rows==rows3 ) ? df.repartition(rows) : df; JavaPairRDD<MatrixIndexes,MatrixBlock> out = RDDConverterUtils.dataFrameToBinaryBlock(sc, df, mc2, true, vector); @@ -212,9 +207,17 @@ public class DataFrameMatrixConversionTest extends AutomatedTestBase throw new RuntimeException(ex); } finally { - sec.close(); DMLScript.USE_LOCAL_SPARK_CONFIG = oldConfig; DMLScript.rtplatform = oldPlatform; } } + + @AfterClass + public static void tearDownClass() { + // stop underlying spark context to allow single jvm tests (otherwise the + // next test that tries to create a SparkContext would fail) + spark.stop(); + sc = null; + spark = null; + } } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameRowFrameConversionTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameRowFrameConversionTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameRowFrameConversionTest.java index 09628e5..452c1e1 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameRowFrameConversionTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameRowFrameConversionTest.java @@ -28,7 +28,6 @@ import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.parser.Expression.ValueType; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; @@ -39,6 +38,8 @@ import org.apache.sysml.runtime.util.UtilFunctions; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; import org.apache.sysml.test.utils.TestUtils; +import org.junit.AfterClass; +import org.junit.BeforeClass; import org.junit.Test; @@ -55,7 +56,20 @@ public class DataFrameRowFrameConversionTest extends AutomatedTestBase private final static double sparsity2 = 0.1; private final static double eps=0.0000000001; - + private static SparkSession spark; + private static JavaSparkContext sc; + + @BeforeClass + public static void setUpClass() { + spark = SparkSession.builder() + .appName("DataFrameRowFrameConversionTest") + .master("local") + .config("spark.memory.offHeap.enabled", "false") + .config("spark.sql.codegen.wholeStage", "false") + .getOrCreate(); + sc = new JavaSparkContext(spark.sparkContext()); + } + @Override public void setUp() { addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"A", "B"})); @@ -182,20 +196,11 @@ public class DataFrameRowFrameConversionTest extends AutomatedTestBase public void testRowLongConversionMultiSparseUnknown() { testDataFrameConversion(ValueType.INT, false, false, true); } - - /** - * - * @param vector - * @param singleColBlock - * @param dense - * @param unknownDims - */ + private void testDataFrameConversion(ValueType vt, boolean singleColBlock, boolean dense, boolean unknownDims) { boolean oldConfig = DMLScript.USE_LOCAL_SPARK_CONFIG; RUNTIME_PLATFORM oldPlatform = DMLScript.rtplatform; - SparkExecutionContext sec = null; - try { DMLScript.USE_LOCAL_SPARK_CONFIG = true; @@ -212,20 +217,12 @@ public class DataFrameRowFrameConversionTest extends AutomatedTestBase MatrixCharacteristics mc1 = new MatrixCharacteristics(rows1, cols, blksz, blksz, mbA.getNonZeros()); MatrixCharacteristics mc2 = unknownDims ? new MatrixCharacteristics() : new MatrixCharacteristics(mc1); ValueType[] schema = UtilFunctions.nCopies(cols, vt); - - //setup spark context - sec = (SparkExecutionContext) ExecutionContextFactory.createContext(); - JavaSparkContext sc = sec.getSparkContext(); - SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); - - sc.getConf().set("spark.memory.offHeap.enabled", "false"); - sparkSession.conf().set("spark.sql.codegen.wholeStage", "false"); //get binary block input rdd JavaPairRDD<Long,FrameBlock> in = SparkExecutionContext.toFrameJavaPairRDD(sc, fbA); //frame - dataframe - frame conversion - Dataset<Row> df = FrameRDDConverterUtils.binaryBlockToDataFrame(sparkSession, in, mc1, schema); + Dataset<Row> df = FrameRDDConverterUtils.binaryBlockToDataFrame(spark, in, mc1, schema); JavaPairRDD<Long,FrameBlock> out = FrameRDDConverterUtils.dataFrameToBinaryBlock(sc, df, mc2, true); //get output frame block @@ -240,9 +237,17 @@ public class DataFrameRowFrameConversionTest extends AutomatedTestBase throw new RuntimeException(ex); } finally { - sec.close(); DMLScript.USE_LOCAL_SPARK_CONFIG = oldConfig; DMLScript.rtplatform = oldPlatform; } } + + @AfterClass + public static void tearDownClass() { + // stop underlying spark context to allow single jvm tests (otherwise the + // next test that tries to create a SparkContext would fail) + spark.stop(); + sc = null; + spark = null; + } } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorFrameConversionTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorFrameConversionTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorFrameConversionTest.java index 4a73376..e68eee9 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorFrameConversionTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorFrameConversionTest.java @@ -40,7 +40,6 @@ import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils; import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils; @@ -52,6 +51,8 @@ import org.apache.sysml.runtime.util.UtilFunctions; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; import org.apache.sysml.test.utils.TestUtils; +import org.junit.AfterClass; +import org.junit.BeforeClass; import org.junit.Test; @@ -73,6 +74,15 @@ public class DataFrameVectorFrameConversionTest extends AutomatedTestBase private final static double sparsity2 = 0.1; private final static double eps=0.0000000001; + private static SparkSession spark; + private static JavaSparkContext sc; + + @BeforeClass + public static void setUpClass() { + spark = createSystemMLSparkSession("DataFrameVectorFrameConversionTest", "local"); + sc = new JavaSparkContext(spark.sparkContext()); + } + @Override public void setUp() { addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"A", "B"})); @@ -237,20 +247,11 @@ public class DataFrameVectorFrameConversionTest extends AutomatedTestBase public void testVectorMixed2ConversionSparse() { testDataFrameConversion(schemaMixed2, false, true, false); } - - /** - * - * @param vector - * @param singleColBlock - * @param dense - * @param unknownDims - */ + private void testDataFrameConversion(ValueType[] schema, boolean containsID, boolean dense, boolean unknownDims) { boolean oldConfig = DMLScript.USE_LOCAL_SPARK_CONFIG; RUNTIME_PLATFORM oldPlatform = DMLScript.rtplatform; - SparkExecutionContext sec = null; - try { DMLScript.USE_LOCAL_SPARK_CONFIG = true; @@ -264,14 +265,9 @@ public class DataFrameVectorFrameConversionTest extends AutomatedTestBase int blksz = ConfigurationManager.getBlocksize(); MatrixCharacteristics mc1 = new MatrixCharacteristics(rows1, cols, blksz, blksz, mbA.getNonZeros()); MatrixCharacteristics mc2 = unknownDims ? new MatrixCharacteristics() : new MatrixCharacteristics(mc1); - - //setup spark context - sec = (SparkExecutionContext) ExecutionContextFactory.createContext(); - JavaSparkContext sc = sec.getSparkContext(); - SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); - + //create input data frame - Dataset<Row> df = createDataFrame(sparkSession, mbA, containsID, schema); + Dataset<Row> df = createDataFrame(spark, mbA, containsID, schema); //dataframe - frame conversion JavaPairRDD<Long,FrameBlock> out = FrameRDDConverterUtils.dataFrameToBinaryBlock(sc, df, mc2, containsID); @@ -289,7 +285,6 @@ public class DataFrameVectorFrameConversionTest extends AutomatedTestBase throw new RuntimeException(ex); } finally { - sec.close(); DMLScript.USE_LOCAL_SPARK_CONFIG = oldConfig; DMLScript.rtplatform = oldPlatform; } @@ -346,4 +341,13 @@ public class DataFrameVectorFrameConversionTest extends AutomatedTestBase JavaRDD<Row> rowRDD = sc.parallelize(list); return sparkSession.createDataFrame(rowRDD, dfSchema); } + + @AfterClass + public static void tearDownClass() { + // stop underlying spark context to allow single jvm tests (otherwise the + // next test that tries to create a SparkContext would fail) + spark.stop(); + sc = null; + spark = null; + } } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java index 92677b8..0f3d3b2 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java @@ -24,7 +24,6 @@ import static org.apache.sysml.api.mlcontext.ScriptFactory.dml; import java.util.ArrayList; import java.util.List; -import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.linalg.DenseVector; @@ -45,7 +44,6 @@ import org.apache.sysml.api.mlcontext.Script; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; @@ -54,6 +52,8 @@ import org.apache.sysml.runtime.util.UtilFunctions; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; import org.apache.sysml.test.utils.TestUtils; +import org.junit.AfterClass; +import org.junit.BeforeClass; import org.junit.Test; @@ -75,6 +75,16 @@ public class DataFrameVectorScriptTest extends AutomatedTestBase private final static double sparsity2 = 0.1; private final static double eps=0.0000000001; + private static SparkSession spark; + private static MLContext ml; + + @BeforeClass + public static void setUpClass() { + spark = createSystemMLSparkSession("DataFrameVectorScriptTest", "local"); + ml = new MLContext(spark); + ml.setExplain(true); + } + @Override public void setUp() { addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"A", "B"})); @@ -239,21 +249,10 @@ public class DataFrameVectorScriptTest extends AutomatedTestBase public void testVectorMixed2ConversionSparse() { testDataFrameScriptInput(schemaMixed2, false, true, false); } - - /** - * - * @param schema - * @param containsID - * @param dense - * @param unknownDims - */ + private void testDataFrameScriptInput(ValueType[] schema, boolean containsID, boolean dense, boolean unknownDims) { //TODO fix inconsistency ml context vs jmlc register Xf - - JavaSparkContext sc = null; - MLContext ml = null; - try { //generate input data and setup metadata @@ -264,25 +263,15 @@ public class DataFrameVectorScriptTest extends AutomatedTestBase int blksz = ConfigurationManager.getBlocksize(); MatrixCharacteristics mc1 = new MatrixCharacteristics(rows1, cols, blksz, blksz, mbA.getNonZeros()); MatrixCharacteristics mc2 = unknownDims ? new MatrixCharacteristics() : new MatrixCharacteristics(mc1); - - //setup spark context - SparkConf conf = SparkExecutionContext.createSystemMLSparkConf() - .setAppName("MLContextFrameTest").setMaster("local"); - sc = new JavaSparkContext(conf); - SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); - + //create input data frame - Dataset<Row> df = createDataFrame(sparkSession, mbA, containsID, schema); + Dataset<Row> df = createDataFrame(spark, mbA, containsID, schema); // Create full frame metadata, and empty frame metadata FrameMetadata meta = new FrameMetadata(containsID ? FrameFormat.DF_WITH_INDEX : FrameFormat.DF, mc2.getRows(), mc2.getCols()); FrameMetadata metaEmpty = new FrameMetadata(); - //create mlcontext - ml = new MLContext(sc); - ml.setExplain(true); - //run scripts and obtain result Script script1 = dml( "Xm = as.matrix(Xf);") @@ -305,15 +294,6 @@ public class DataFrameVectorScriptTest extends AutomatedTestBase ex.printStackTrace(); throw new RuntimeException(ex); } - finally { - // stop spark context to allow single jvm tests (otherwise the - // next test that tries to create a SparkContext would fail) - if( sc != null ) - sc.stop(); - // clear status mlcontext and spark exec context - if( ml != null ) - ml.close(); - } } @SuppressWarnings("resource") @@ -367,4 +347,16 @@ public class DataFrameVectorScriptTest extends AutomatedTestBase JavaRDD<Row> rowRDD = sc.parallelize(list); return sparkSession.createDataFrame(rowRDD, dfSchema); } + + @AfterClass + public static void tearDownClass() { + // stop underlying spark context to allow single jvm tests (otherwise the + // next test that tries to create a SparkContext would fail) + spark.stop(); + spark = null; + + // clear status mlcontext and spark exec context + ml.close(); + ml = null; + } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java index d485c48..c93968c 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java @@ -27,7 +27,6 @@ import java.util.HashMap; import java.util.List; import org.apache.hadoop.io.LongWritable; -import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -42,6 +41,7 @@ import org.apache.sysml.api.mlcontext.FrameFormat; import org.apache.sysml.api.mlcontext.FrameMetadata; import org.apache.sysml.api.mlcontext.FrameSchema; import org.apache.sysml.api.mlcontext.MLContext; +import org.apache.sysml.api.mlcontext.MLContextUtil; import org.apache.sysml.api.mlcontext.MLResults; import org.apache.sysml.api.mlcontext.Script; import org.apache.sysml.api.mlcontext.ScriptFactory; @@ -49,7 +49,6 @@ import org.apache.sysml.parser.DataExpression; import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.parser.ParseException; import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils; import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils.LongFrameToLongWritableFrameFunction; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; @@ -99,18 +98,15 @@ public class FrameTest extends AutomatedTestBase schemaMixedLarge = (ValueType[]) schemaMixedLargeList.toArray(schemaMixedLarge); } - private static SparkConf conf; + private static SparkSession spark; private static JavaSparkContext sc; private static MLContext ml; @BeforeClass public static void setUpClass() { - if (conf == null) - conf = SparkExecutionContext.createSystemMLSparkConf() - .setAppName("FrameTest").setMaster("local"); - if (sc == null) - sc = new JavaSparkContext(conf); - ml = new MLContext(sc); + spark = createSystemMLSparkSession("FrameTest", "local"); + ml = new MLContext(spark); + sc = MLContextUtil.getJavaSparkContext(ml); } @Override @@ -237,16 +233,15 @@ public class FrameTest extends AutomatedTestBase if(bFromDataFrame) { //Create DataFrame for input A - SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); StructType dfSchemaA = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schema, false); JavaRDD<Row> rowRDDA = FrameRDDConverterUtils.csvToRowRDD(sc, input("A"), DataExpression.DEFAULT_DELIM_DELIMITER, schema); - dfA = sparkSession.createDataFrame(rowRDDA, dfSchemaA); + dfA = spark.createDataFrame(rowRDDA, dfSchemaA); //Create DataFrame for input B StructType dfSchemaB = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schemaB, false); JavaRDD<Row> rowRDDB = FrameRDDConverterUtils.csvToRowRDD(sc, input("B"), DataExpression.DEFAULT_DELIM_DELIMITER, schemaB); - dfB = sparkSession.createDataFrame(rowRDDB, dfSchemaB); + dfB = spark.createDataFrame(rowRDDB, dfSchemaB); } try @@ -386,11 +381,11 @@ public class FrameTest extends AutomatedTestBase @AfterClass public static void tearDownClass() { - // stop spark context to allow single jvm tests (otherwise the + // stop underlying spark context to allow single jvm tests (otherwise the // next test that tries to create a SparkContext would fail) - sc.stop(); + spark.stop(); sc = null; - conf = null; + spark = null; // clear status mlcontext and spark exec context ml.close(); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java index eeeb925..f9f5fbd 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java @@ -26,7 +26,6 @@ import java.util.Collection; import java.util.HashMap; import java.util.List; -import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -34,10 +33,12 @@ import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix; import org.apache.spark.mllib.linalg.distributed.MatrixEntry; import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.SparkSession; import org.apache.sysml.api.DMLException; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.api.mlcontext.MLContext; +import org.apache.sysml.api.mlcontext.MLContextUtil; import org.apache.sysml.api.mlcontext.MLResults; import org.apache.sysml.api.mlcontext.Matrix; import org.apache.sysml.api.mlcontext.MatrixFormat; @@ -47,7 +48,6 @@ import org.apache.sysml.api.mlcontext.ScriptFactory; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.parser.ParseException; import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils; import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; @@ -76,7 +76,7 @@ public class GNMFTest extends AutomatedTestBase int numRegisteredInputs; int numRegisteredOutputs; - private static SparkConf conf; + private static SparkSession spark; private static JavaSparkContext sc; private static MLContext ml; @@ -87,12 +87,9 @@ public class GNMFTest extends AutomatedTestBase @BeforeClass public static void setUpClass() { - if (conf == null) - conf = SparkExecutionContext.createSystemMLSparkConf() - .setAppName("GNMFTest").setMaster("local"); - if (sc == null) - sc = new JavaSparkContext(conf); - ml = new MLContext(sc); + spark = createSystemMLSparkSession("GNMFTest", "local"); + ml = new MLContext(spark); + sc = MLContextUtil.getJavaSparkContext(ml); } @Parameters @@ -267,11 +264,11 @@ public class GNMFTest extends AutomatedTestBase @AfterClass public static void tearDownClass() { - // stop spark context to allow single jvm tests (otherwise the + // stop underlying spark context to allow single jvm tests (otherwise the // next test that tries to create a SparkContext would fail) - sc.stop(); + spark.stop(); sc = null; - conf = null; + spark = null; // clear status mlcontext and spark exec context ml.close(); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java index 6dd74d3..bab719e 100644 --- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java @@ -27,7 +27,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.rdd.RDD; @@ -43,12 +42,12 @@ import org.apache.sysml.api.mlcontext.FrameMetadata; import org.apache.sysml.api.mlcontext.FrameSchema; import org.apache.sysml.api.mlcontext.MLContext; import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel; +import org.apache.sysml.api.mlcontext.MLContextUtil; import org.apache.sysml.api.mlcontext.MLResults; import org.apache.sysml.api.mlcontext.MatrixFormat; import org.apache.sysml.api.mlcontext.MatrixMetadata; import org.apache.sysml.api.mlcontext.Script; import org.apache.sysml.parser.Expression.ValueType; -import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils; import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils; import org.apache.sysml.test.integration.AutomatedTestBase; @@ -73,19 +72,16 @@ public class MLContextFrameTest extends AutomatedTestBase { ANY, FILE, JAVA_RDD_STR_CSV, JAVA_RDD_STR_IJV, RDD_STR_CSV, RDD_STR_IJV, DATAFRAME }; - private static SparkConf conf; + private static SparkSession spark; private static JavaSparkContext sc; private static MLContext ml; private static String CSV_DELIM = ","; @BeforeClass public static void setUpClass() { - if (conf == null) - conf = SparkExecutionContext.createSystemMLSparkConf() - .setAppName("MLContextFrameTest").setMaster("local"); - if (sc == null) - sc = new JavaSparkContext(conf); - ml = new MLContext(sc); + spark = createSystemMLSparkSession("MLContextFrameTest", "local"); + ml = new MLContext(spark); + sc = MLContextUtil.getJavaSparkContext(ml); ml.setExplainLevel(ExplainLevel.RECOMPILE_HOPS); } @@ -238,11 +234,10 @@ public class MLContextFrameTest extends AutomatedTestBase { JavaRDD<Row> javaRddRowB = FrameRDDConverterUtils.csvToRowRDD(sc, javaRDDB, CSV_DELIM, schemaB); // Create DataFrame - SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); StructType dfSchemaA = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schemaA, false); - Dataset<Row> dataFrameA = sparkSession.createDataFrame(javaRddRowA, dfSchemaA); + Dataset<Row> dataFrameA = spark.createDataFrame(javaRddRowA, dfSchemaA); StructType dfSchemaB = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schemaB, false); - Dataset<Row> dataFrameB = sparkSession.createDataFrame(javaRddRowB, dfSchemaB); + Dataset<Row> dataFrameB = spark.createDataFrame(javaRddRowB, dfSchemaB); if (script_type == SCRIPT_TYPE.DML) script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", dataFrameA, fmA).in("B", dataFrameB, fmB).out("A") .out("C"); @@ -492,18 +487,16 @@ public class MLContextFrameTest extends AutomatedTestBase { JavaRDD<Row> javaRddRowA = FrameRDDConverterUtils.csvToRowRDD(sc, javaRddStringA, CSV_DELIM, schema); JavaRDD<Row> javaRddRowB = javaRddStringB.map(new CommaSeparatedValueStringToDoubleArrayRow()); - SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); - List<StructField> fieldsA = new ArrayList<StructField>(); fieldsA.add(DataTypes.createStructField("1", DataTypes.StringType, true)); fieldsA.add(DataTypes.createStructField("2", DataTypes.DoubleType, true)); StructType schemaA = DataTypes.createStructType(fieldsA); - Dataset<Row> dataFrameA = sparkSession.createDataFrame(javaRddRowA, schemaA); + Dataset<Row> dataFrameA = spark.createDataFrame(javaRddRowA, schemaA); List<StructField> fieldsB = new ArrayList<StructField>(); fieldsB.add(DataTypes.createStructField("1", DataTypes.DoubleType, true)); StructType schemaB = DataTypes.createStructType(fieldsB); - Dataset<Row> dataFrameB = sparkSession.createDataFrame(javaRddRowB, schemaB); + Dataset<Row> dataFrameB = spark.createDataFrame(javaRddRowB, schemaB); String dmlString = "[tA, tAM] = transformencode (target = A, spec = \"{ids: true ,recode: [ 1, 2 ]}\");\n" + "C = tA %*% B;\n" + "M = s * C;"; @@ -529,14 +522,12 @@ public class MLContextFrameTest extends AutomatedTestBase { JavaRDD<Row> javaRddRowA = sc. parallelize( Arrays.asList(rowsA)); - SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); - List<StructField> fieldsA = new ArrayList<StructField>(); fieldsA.add(DataTypes.createStructField("myID", DataTypes.StringType, true)); fieldsA.add(DataTypes.createStructField("FeatureName", DataTypes.StringType, true)); fieldsA.add(DataTypes.createStructField("FeatureValue", DataTypes.IntegerType, true)); StructType schemaA = DataTypes.createStructType(fieldsA); - Dataset<Row> dataFrameA = sparkSession.createDataFrame(javaRddRowA, schemaA); + Dataset<Row> dataFrameA = spark.createDataFrame(javaRddRowA, schemaA); String dmlString = "[tA, tAM] = transformencode (target = A, spec = \"{ids: false ,recode: [ myID, FeatureName ]}\");"; @@ -571,14 +562,12 @@ public class MLContextFrameTest extends AutomatedTestBase { JavaRDD<Row> javaRddRowA = sc. parallelize( Arrays.asList(rowsA)); - SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); - List<StructField> fieldsA = new ArrayList<StructField>(); fieldsA.add(DataTypes.createStructField("featureName", DataTypes.StringType, true)); fieldsA.add(DataTypes.createStructField("featureValue", DataTypes.IntegerType, true)); fieldsA.add(DataTypes.createStructField("id", DataTypes.StringType, true)); StructType schemaA = DataTypes.createStructType(fieldsA); - Dataset<Row> dataFrameA = sparkSession.createDataFrame(javaRddRowA, schemaA); + Dataset<Row> dataFrameA = spark.createDataFrame(javaRddRowA, schemaA); String dmlString = "[tA, tAM] = transformencode (target = A, spec = \"{ids: false ,recode: [ featureName, id ]}\");"; @@ -621,15 +610,13 @@ public class MLContextFrameTest extends AutomatedTestBase { // JavaRDD<Row> javaRddRowA = javaRddStringA.map(new // CommaSeparatedValueStringToRow()); // - // SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); - // // List<StructField> fieldsA = new ArrayList<StructField>(); // fieldsA.add(DataTypes.createStructField("1", DataTypes.StringType, // true)); // fieldsA.add(DataTypes.createStructField("2", DataTypes.StringType, // true)); // StructType schemaA = DataTypes.createStructType(fieldsA); - // DataFrame dataFrameA = sparkSession.createDataFrame(javaRddRowA, schemaA); + // DataFrame dataFrameA = spark.createDataFrame(javaRddRowA, schemaA); // // String dmlString = "[tA, tAM] = transformencode (target = A, spec = // \"{ids: true ,recode: [ 1, 2 ]}\");\n"; @@ -664,11 +651,11 @@ public class MLContextFrameTest extends AutomatedTestBase { @AfterClass public static void tearDownClass() { - // stop spark context to allow single jvm tests (otherwise the + // stop underlying spark context to allow single jvm tests (otherwise the // next test that tries to create a SparkContext would fail) - sc.stop(); + spark.stop(); sc = null; - conf = null; + spark = null; // clear status mlcontext and spark exec context ml.close(); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java index de46c2a..c418a6f 100644 --- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java @@ -23,14 +23,12 @@ import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile; import java.io.File; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.api.mlcontext.MLContext; import org.apache.sysml.api.mlcontext.Matrix; import org.apache.sysml.api.mlcontext.Script; -import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.utils.TestUtils; import org.junit.After; @@ -92,12 +90,10 @@ public class MLContextMultipleScriptsTest extends AutomatedTestBase DMLScript.rtplatform = platform; //create mlcontext - SparkConf conf = SparkExecutionContext.createSystemMLSparkConf() - .setAppName("MLContextFrameTest").setMaster("local"); - JavaSparkContext sc = new JavaSparkContext(conf); - MLContext ml = new MLContext(sc); + SparkSession spark = createSystemMLSparkSession("MLContextMultipleScriptsTest", "local"); + MLContext ml = new MLContext(spark); ml.setExplain(true); - + String dml1 = baseDirectory + File.separator + "MultiScript1.dml"; String dml2 = baseDirectory + File.separator + (wRead?"MultiScript2b.dml":"MultiScript2.dml"); String dml3 = baseDirectory + File.separator + (wRead?"MultiScript3b.dml":"MultiScript3.dml"); @@ -119,9 +115,9 @@ public class MLContextMultipleScriptsTest extends AutomatedTestBase finally { DMLScript.rtplatform = oldplatform; - // stop spark context to allow single jvm tests (otherwise the + // stop underlying spark context to allow single jvm tests (otherwise the // next test that tries to create a SparkContext would fail) - sc.stop(); + spark.stop(); // clear status mlcontext and spark exec context ml.close(); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java index c9a3dbc..6391919 100644 --- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java @@ -23,14 +23,12 @@ import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile; import java.io.File; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.api.mlcontext.MLContext; import org.apache.sysml.api.mlcontext.Matrix; import org.apache.sysml.api.mlcontext.Script; -import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.utils.TestUtils; import org.junit.After; @@ -92,12 +90,10 @@ public class MLContextScratchCleanupTest extends AutomatedTestBase DMLScript.rtplatform = platform; //create mlcontext - SparkConf conf = SparkExecutionContext.createSystemMLSparkConf() - .setAppName("MLContextFrameTest").setMaster("local"); - JavaSparkContext sc = new JavaSparkContext(conf); - MLContext ml = new MLContext(sc); + SparkSession spark = createSystemMLSparkSession("MLContextScratchCleanupTest", "local"); + MLContext ml = new MLContext(spark); ml.setExplain(true); - + String dml1 = baseDirectory + File.separator + "ScratchCleanup1.dml"; String dml2 = baseDirectory + File.separator + (wRead?"ScratchCleanup2b.dml":"ScratchCleanup2.dml"); @@ -120,9 +116,9 @@ public class MLContextScratchCleanupTest extends AutomatedTestBase finally { DMLScript.rtplatform = oldplatform; - // stop spark context to allow single jvm tests (otherwise the + // stop underlying spark context to allow single jvm tests (otherwise the // next test that tries to create a SparkContext would fail) - sc.stop(); + spark.stop(); // clear status mlcontext and spark exec context ml.close(); }