Repository: incubator-systemml Updated Branches: refs/heads/master 201238fd3 -> 80ab57bda
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/80ab57bd/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java index e15f8dd..abea5be 100644 --- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java @@ -53,7 +53,7 @@ import org.apache.spark.rdd.RDD; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -513,13 +513,13 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<String> javaRddString = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_DOUBLES); @@ -539,13 +539,13 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<String> javaRddString = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_DOUBLES); @@ -565,14 +565,14 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<String> javaRddString = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_INDEX); @@ -592,14 +592,14 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<String> javaRddString = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_INDEX); @@ -619,14 +619,14 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<String> javaRddString = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_INDEX); @@ -646,14 +646,14 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<String> javaRddString = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_DOUBLES_WITH_INDEX); @@ -673,12 +673,12 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<Tuple2<Double, Vector>> javaRddTuple = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C1", new VectorUDT(), true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_INDEX); @@ -698,12 +698,12 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<Tuple2<Double, Vector>> javaRddTuple = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C1", new VectorUDT(), true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR_WITH_INDEX); @@ -723,11 +723,11 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<Vector> javaRddVector = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddVector.map(new VectorRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("C1", new VectorUDT(), true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR); @@ -747,11 +747,11 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<Vector> javaRddVector = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddVector.map(new VectorRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("C1", new VectorUDT(), true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); MatrixMetadata mm = new MatrixMetadata(MatrixFormat.DF_VECTOR); @@ -1559,13 +1559,13 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<String> javaRddString = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); BinaryBlockMatrix binaryBlockMatrix = new BinaryBlockMatrix(dataFrame); Script script = dml("avg = avg(M);").in("M", binaryBlockMatrix).out("avg"); @@ -1584,13 +1584,13 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<String> javaRddString = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("C1", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("C2", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("C3", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); BinaryBlockMatrix binaryBlockMatrix = new BinaryBlockMatrix(dataFrame); Script script = pydml("avg = avg(M)").in("M", binaryBlockMatrix).out("avg"); @@ -1853,13 +1853,13 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<String> javaRddString = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); MatrixMetadata mm = new MatrixMetadata(3, 3, 9); @@ -1879,13 +1879,13 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<String> javaRddString = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); MatrixMetadata mm = new MatrixMetadata(3, 3, 9); @@ -2069,13 +2069,13 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<String> javaRddString = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame); setExpectedStdOut("sum: 27.0"); @@ -2093,13 +2093,13 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<String> javaRddString = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); Script script = pydml("print('sum: ' + sum(M))").in("M", dataFrame); setExpectedStdOut("sum: 27.0"); @@ -2117,14 +2117,14 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<String> javaRddString = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame); setExpectedStdOut("sum: 27.0"); @@ -2142,14 +2142,14 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<String> javaRddString = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddString.map(new CommaSeparatedValueStringToDoubleArrayRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C1", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C2", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C3", DataTypes.DoubleType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); Script script = pydml("print('sum: ' + sum(M))").in("M", dataFrame); setExpectedStdOut("sum: 27.0"); @@ -2167,12 +2167,12 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<Tuple2<Double, Vector>> javaRddTuple = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C1", new VectorUDT(), true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame); setExpectedStdOut("sum: 45.0"); @@ -2190,12 +2190,12 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<Tuple2<Double, Vector>> javaRddTuple = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleVectorRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField(RDDConverterUtils.DF_ID_COLUMN, DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("C1", new VectorUDT(), true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); Script script = dml("print('sum: ' + sum(M))").in("M", dataFrame); setExpectedStdOut("sum: 45.0"); @@ -2213,11 +2213,11 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<Vector> javaRddVector = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddVector.map(new VectorRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("C1", new VectorUDT(), true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); Script script = dml("print('sum: ' + sum(M));").in("M", dataFrame); setExpectedStdOut("sum: 45.0"); @@ -2235,11 +2235,11 @@ public class MLContextTest extends AutomatedTestBase { JavaRDD<Vector> javaRddVector = sc.parallelize(list); JavaRDD<Row> javaRddRow = javaRddVector.map(new VectorRow()); - SQLContext sqlContext = new SQLContext(sc); + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("C1", new VectorUDT(), true)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> dataFrame = sqlContext.createDataFrame(javaRddRow, schema); + Dataset<Row> dataFrame = sparkSession.createDataFrame(javaRddRow, schema); Script script = dml("print('sum: ' + sum(M))").in("M", dataFrame); setExpectedStdOut("sum: 45.0"); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/80ab57bd/src/test/scala/org/apache/sysml/api/ml/LogisticRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/src/test/scala/org/apache/sysml/api/ml/LogisticRegressionSuite.scala b/src/test/scala/org/apache/sysml/api/ml/LogisticRegressionSuite.scala index 555d0a2..689bf82 100644 --- a/src/test/scala/org/apache/sysml/api/ml/LogisticRegressionSuite.scala +++ b/src/test/scala/org/apache/sysml/api/ml/LogisticRegressionSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} import org.apache.spark.ml.linalg.Vector +import org.apache.spark.sql._ import scala.reflect.runtime.universe._ case class LabeledDocument[T:TypeTag](id: Long, text: String, label: Double) @@ -40,9 +41,9 @@ class LogisticRegressionSuite extends FunSuite with WrapperSparkContext with Mat test("run logistic regression with default") { //Make sure system ml home set when run wrapper - val newsqlContext = new org.apache.spark.sql.SQLContext(sc); + val newSparkSession = SparkSession.builder().master("local").appName("TestLocal").getOrCreate(); - import newsqlContext.implicits._ + import newSparkSession.implicits._ val training = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 3.0)), LabeledPoint(1.0, Vectors.dense(1.0, 0.4, 2.1)), @@ -62,8 +63,8 @@ class LogisticRegressionSuite extends FunSuite with WrapperSparkContext with Mat test("test logistic regression with mlpipeline"){ //Make sure system ml home set when run wrapper - val newsqlContext = new org.apache.spark.sql.SQLContext(sc); - import newsqlContext.implicits._ + val newSparkSession = SparkSession.builder().master("local").appName("TestLocal").getOrCreate(); + import newSparkSession.implicits._ val training = sc.parallelize(Seq( LabeledDocument(0L, "a b c d e spark", 1.0), LabeledDocument(1L, "b d", 2.0), http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/80ab57bd/src/test/scala/org/apache/sysml/api/ml/WrapperSparkContext.scala ---------------------------------------------------------------------- diff --git a/src/test/scala/org/apache/sysml/api/ml/WrapperSparkContext.scala b/src/test/scala/org/apache/sysml/api/ml/WrapperSparkContext.scala index 205a1a9..0bd6f27 100644 --- a/src/test/scala/org/apache/sysml/api/ml/WrapperSparkContext.scala +++ b/src/test/scala/org/apache/sysml/api/ml/WrapperSparkContext.scala @@ -21,27 +21,21 @@ package org.apache.sysml.api.ml import org.scalatest.{ BeforeAndAfterAll, Suite } import org.apache.spark.{ SparkConf, SparkContext } -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession trait WrapperSparkContext extends BeforeAndAfterAll { self: Suite => @transient var sc: SparkContext = _ - @transient var sqlContext: SQLContext = _ + @transient var sparkSession: SparkSession = _ override def beforeAll() { super.beforeAll() - val conf = new SparkConf() - .setMaster("local[2]") - .setAppName("MLlibUnitTest") - sc = new SparkContext(conf) - //SQLContext.clearActive() - sqlContext = new SQLContext(sc) - //SQLContext.setActive(sqlContext) + sparkSession = SparkSession.builder().master("local[2]").appName("MLlibUnitTest").getOrCreate(); + sc = sparkSession.sparkContext; } override def afterAll() { try { - sqlContext = null - //SQLContext.clearActive() + sparkSession = null if (sc != null) { sc.stop() }
