[SPARK-15037][SQL][MLLIB] Use SparkSession instead of SQLContext in Scala/Java 
TestSuites

## What changes were proposed in this pull request?
Use SparkSession instead of SQLContext in Scala/Java TestSuites
as this PR already very big working Python TestSuites in a diff PR.

## How was this patch tested?
Existing tests

Author: Sandeep Singh <[email protected]>

Closes #12907 from techaddict/SPARK-15037.

(cherry picked from commit ed0b4070fb50054b1ecf66ff6c32458a4967dfd3)
Signed-off-by: Andrew Or <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5bf74b44
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5bf74b44
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5bf74b44

Branch: refs/heads/branch-2.0
Commit: 5bf74b44d9efcb8b0f0c3e7d129bc5ba31419551
Parents: 19a9c23
Author: Sandeep Singh <[email protected]>
Authored: Tue May 10 11:17:47 2016 -0700
Committer: Andrew Or <[email protected]>
Committed: Tue May 10 11:17:58 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/JavaPipelineSuite.java  |  27 +-
 .../spark/ml/attribute/JavaAttributeSuite.java  |   2 +-
 .../JavaDecisionTreeClassifierSuite.java        |  23 +-
 .../classification/JavaGBTClassifierSuite.java  |  18 +-
 .../JavaLogisticRegressionSuite.java            |  49 ++--
 ...JavaMultilayerPerceptronClassifierSuite.java |  36 +--
 .../ml/classification/JavaNaiveBayesSuite.java  |  18 +-
 .../ml/classification/JavaOneVsRestSuite.java   |  90 +++---
 .../JavaRandomForestClassifierSuite.java        |  26 +-
 .../spark/ml/clustering/JavaKMeansSuite.java    |  26 +-
 .../spark/ml/feature/JavaBucketizerSuite.java   |  20 +-
 .../apache/spark/ml/feature/JavaDCTSuite.java   |  21 +-
 .../spark/ml/feature/JavaHashingTFSuite.java    |  18 +-
 .../spark/ml/feature/JavaNormalizerSuite.java   |  19 +-
 .../apache/spark/ml/feature/JavaPCASuite.java   |  21 +-
 .../feature/JavaPolynomialExpansionSuite.java   |  19 +-
 .../ml/feature/JavaStandardScalerSuite.java     |  17 +-
 .../ml/feature/JavaStopWordsRemoverSuite.java   |  22 +-
 .../ml/feature/JavaStringIndexerSuite.java      |  26 +-
 .../spark/ml/feature/JavaTokenizerSuite.java    |  21 +-
 .../ml/feature/JavaVectorAssemblerSuite.java    |  31 ++-
 .../ml/feature/JavaVectorIndexerSuite.java      |  18 +-
 .../spark/ml/feature/JavaVectorSlicerSuite.java |  18 +-
 .../spark/ml/feature/JavaWord2VecSuite.java     |  22 +-
 .../apache/spark/ml/param/JavaParamsSuite.java  |  14 +-
 .../apache/spark/ml/param/JavaTestParams.java   |  38 ++-
 .../JavaDecisionTreeRegressorSuite.java         |  18 +-
 .../ml/regression/JavaGBTRegressorSuite.java    |  18 +-
 .../regression/JavaLinearRegressionSuite.java   |  25 +-
 .../JavaRandomForestRegressorSuite.java         |  28 +-
 .../source/libsvm/JavaLibSVMRelationSuite.java  |  18 +-
 .../ml/tuning/JavaCrossValidatorSuite.java      |  18 +-
 .../spark/ml/util/IdentifiableSuite.scala       |   1 +
 .../ml/util/JavaDefaultReadWriteSuite.java      |  21 +-
 .../JavaLogisticRegressionSuite.java            |  35 ++-
 .../classification/JavaNaiveBayesSuite.java     |  25 +-
 .../mllib/classification/JavaSVMSuite.java      |  32 ++-
 .../clustering/JavaBisectingKMeansSuite.java    |  27 +-
 .../clustering/JavaGaussianMixtureSuite.java    |  20 +-
 .../spark/mllib/clustering/JavaKMeansSuite.java |  23 +-
 .../spark/mllib/clustering/JavaLDASuite.java    |  37 +--
 .../clustering/JavaStreamingKMeansSuite.java    |   3 +-
 .../evaluation/JavaRankingMetricsSuite.java     |  21 +-
 .../spark/mllib/feature/JavaTfIdfSuite.java     |  22 +-
 .../spark/mllib/feature/JavaWord2VecSuite.java  |  19 +-
 .../mllib/fpm/JavaAssociationRulesSuite.java    |  23 +-
 .../spark/mllib/fpm/JavaFPGrowthSuite.java      |  29 +-
 .../spark/mllib/fpm/JavaPrefixSpanSuite.java    |  26 +-
 .../spark/mllib/linalg/JavaMatricesSuite.java   | 278 ++++++++++---------
 .../spark/mllib/linalg/JavaVectorsSuite.java    |   7 +-
 .../spark/mllib/random/JavaRandomRDDsSuite.java | 136 ++++-----
 .../mllib/recommendation/JavaALSSuite.java      |  64 +++--
 .../regression/JavaIsotonicRegressionSuite.java |  22 +-
 .../spark/mllib/regression/JavaLassoSuite.java  |  32 ++-
 .../regression/JavaLinearRegressionSuite.java   |  42 +--
 .../regression/JavaRidgeRegressionSuite.java    |  22 +-
 .../spark/mllib/stat/JavaStatisticsSuite.java   |  32 ++-
 .../spark/mllib/tree/JavaDecisionTreeSuite.java |  24 +-
 .../org/apache/spark/ml/PipelineSuite.scala     |   2 +-
 .../ml/classification/ClassifierSuite.scala     |   4 +-
 .../DecisionTreeClassifierSuite.scala           |   4 +-
 .../ml/classification/GBTClassifierSuite.scala  |   6 +-
 .../LogisticRegressionSuite.scala               |  12 +-
 .../MultilayerPerceptronClassifierSuite.scala   |   8 +-
 .../ml/classification/NaiveBayesSuite.scala     |  12 +-
 .../ml/classification/OneVsRestSuite.scala      |   4 +-
 .../RandomForestClassifierSuite.scala           |   4 +-
 .../ml/clustering/BisectingKMeansSuite.scala    |   2 +-
 .../ml/clustering/GaussianMixtureSuite.scala    |   2 +-
 .../spark/ml/clustering/KMeansSuite.scala       |  10 +-
 .../apache/spark/ml/clustering/LDASuite.scala   |  16 +-
 .../BinaryClassificationEvaluatorSuite.scala    |   8 +-
 ...MulticlassClassificationEvaluatorSuite.scala |   2 +-
 .../evaluation/RegressionEvaluatorSuite.scala   |   4 +-
 .../spark/ml/feature/BinarizerSuite.scala       |   8 +-
 .../spark/ml/feature/BucketizerSuite.scala      |   8 +-
 .../spark/ml/feature/ChiSqSelectorSuite.scala   |   9 +-
 .../spark/ml/feature/CountVectorizerSuite.scala |  14 +-
 .../org/apache/spark/ml/feature/DCTSuite.scala  |   2 +-
 .../spark/ml/feature/HashingTFSuite.scala       |   4 +-
 .../org/apache/spark/ml/feature/IDFSuite.scala  |   4 +-
 .../spark/ml/feature/InteractionSuite.scala     |  12 +-
 .../spark/ml/feature/MaxAbsScalerSuite.scala    |   2 +-
 .../spark/ml/feature/MinMaxScalerSuite.scala    |   4 +-
 .../apache/spark/ml/feature/NGramSuite.scala    |   8 +-
 .../spark/ml/feature/NormalizerSuite.scala      |   2 +-
 .../spark/ml/feature/OneHotEncoderSuite.scala   |   6 +-
 .../org/apache/spark/ml/feature/PCASuite.scala  |   2 +-
 .../ml/feature/PolynomialExpansionSuite.scala   |   6 +-
 .../apache/spark/ml/feature/RFormulaSuite.scala |  44 +--
 .../spark/ml/feature/SQLTransformerSuite.scala  |   6 +-
 .../spark/ml/feature/StandardScalerSuite.scala  |   8 +-
 .../ml/feature/StopWordsRemoverSuite.scala      |  16 +-
 .../spark/ml/feature/StringIndexerSuite.scala   |  18 +-
 .../spark/ml/feature/TokenizerSuite.scala       |   8 +-
 .../spark/ml/feature/VectorAssemblerSuite.scala |   6 +-
 .../spark/ml/feature/VectorIndexerSuite.scala   |  12 +-
 .../spark/ml/feature/VectorSlicerSuite.scala    |   2 +-
 .../apache/spark/ml/feature/Word2VecSuite.scala |  16 +-
 .../spark/ml/recommendation/ALSSuite.scala      |  21 +-
 .../regression/AFTSurvivalRegressionSuite.scala |   8 +-
 .../regression/DecisionTreeRegressorSuite.scala |   2 +-
 .../spark/ml/regression/GBTRegressorSuite.scala |   6 +-
 .../GeneralizedLinearRegressionSuite.scala      |  32 +--
 .../ml/regression/IsotonicRegressionSuite.scala |   8 +-
 .../ml/regression/LinearRegressionSuite.scala   |  18 +-
 .../regression/RandomForestRegressorSuite.scala |   2 +-
 .../ml/source/libsvm/LibSVMRelationSuite.scala  |  14 +-
 .../tree/impl/GradientBoostedTreesSuite.scala   |   4 +-
 .../apache/spark/ml/tree/impl/TreeTests.scala   |  10 +-
 .../spark/ml/tuning/CrossValidatorSuite.scala   |   4 +-
 .../ml/tuning/TrainValidationSplitSuite.scala   |   4 +-
 .../apache/spark/ml/util/MLTestingUtils.scala   |  28 +-
 .../mllib/util/MLlibTestSparkContext.scala      |  24 +-
 .../apache/spark/sql/JavaApplySchemaSuite.java  |  42 +--
 .../apache/spark/sql/JavaDataFrameSuite.java    |  70 +++--
 .../org/apache/spark/sql/JavaDatasetSuite.java  |  89 +++---
 .../test/org/apache/spark/sql/JavaUDFSuite.java |  27 +-
 .../sources/JavaDatasetAggregatorSuiteBase.java |  20 +-
 .../spark/sql/sources/JavaSaveLoadSuite.java    |  33 ++-
 .../org/apache/spark/sql/CachedTableSuite.scala | 216 +++++++-------
 .../spark/sql/ColumnExpressionSuite.scala       |  12 +-
 .../spark/sql/DataFrameAggregateSuite.scala     |  12 +-
 .../apache/spark/sql/DataFrameJoinSuite.scala   |   2 +-
 .../apache/spark/sql/DataFramePivotSuite.scala  |   6 +-
 .../apache/spark/sql/DataFrameStatSuite.scala   |   8 +-
 .../org/apache/spark/sql/DataFrameSuite.scala   |  82 +++---
 .../spark/sql/DataFrameTimeWindowingSuite.scala |   8 +-
 .../spark/sql/DataFrameTungstenSuite.scala      |   4 +-
 .../org/apache/spark/sql/DatasetBenchmark.scala |  38 +--
 .../apache/spark/sql/DatasetCacheSuite.scala    |  13 +-
 .../org/apache/spark/sql/DatasetSuite.scala     |  26 +-
 .../apache/spark/sql/ExtraStrategiesSuite.scala |   4 +-
 .../scala/org/apache/spark/sql/JoinSuite.scala  |  22 +-
 .../org/apache/spark/sql/ListTablesSuite.scala  |  20 +-
 .../apache/spark/sql/LocalSparkSession.scala    |  68 +++++
 .../scala/org/apache/spark/sql/QueryTest.scala  |   8 +-
 .../org/apache/spark/sql/SQLQuerySuite.scala    |  72 ++---
 .../apache/spark/sql/SerializationSuite.scala   |   4 +-
 .../scala/org/apache/spark/sql/StreamTest.scala |   2 +-
 .../apache/spark/sql/StringFunctionsSuite.scala |   2 +-
 .../scala/org/apache/spark/sql/UDFSuite.scala   |  52 ++--
 .../apache/spark/sql/UserDefinedTypeSuite.scala |  10 +-
 .../execution/ExchangeCoordinatorSuite.scala    |  45 +--
 .../spark/sql/execution/ExchangeSuite.scala     |   4 +-
 .../spark/sql/execution/PlannerSuite.scala      |  36 +--
 .../spark/sql/execution/SQLExecutionSuite.scala |  19 +-
 .../spark/sql/execution/SparkPlanTest.scala     |  27 +-
 .../sql/execution/WholeStageCodegenSuite.scala  |  20 +-
 .../columnar/InMemoryColumnarQuerySuite.scala   |  39 +--
 .../columnar/PartitionBatchPruningSuite.scala   |  19 +-
 .../spark/sql/execution/command/DDLSuite.scala  |  39 +--
 .../datasources/FileSourceStrategySuite.scala   |   6 +-
 .../datasources/HadoopFsRelationSuite.scala     |   4 +-
 .../execution/datasources/csv/CSVSuite.scala    |  80 +++---
 .../json/JsonParsingOptionsSuite.scala          |  48 ++--
 .../execution/datasources/json/JsonSuite.scala  | 124 ++++-----
 .../datasources/json/TestJsonData.scala         |  46 +--
 .../parquet/ParquetAvroCompatibilitySuite.scala |  14 +-
 .../parquet/ParquetCompatibilityTest.scala      |   2 +-
 .../parquet/ParquetFilterSuite.scala            |  34 +--
 .../datasources/parquet/ParquetIOSuite.scala    |  49 ++--
 .../parquet/ParquetInteroperabilitySuite.scala  |   2 +-
 .../ParquetPartitionDiscoverySuite.scala        |  34 +--
 .../datasources/parquet/ParquetQuerySuite.scala | 102 +++----
 .../parquet/ParquetReadBenchmark.scala          |  76 ++---
 .../parquet/ParquetSchemaSuite.scala            |  14 +-
 .../datasources/parquet/ParquetTest.scala       |  10 +-
 .../ParquetThriftCompatibilitySuite.scala       |   4 +-
 .../datasources/parquet/TPCDSBenchmark.scala    |  21 +-
 .../execution/datasources/text/TextSuite.scala  |  20 +-
 .../sql/execution/debug/DebuggingSuite.scala    |   2 +-
 .../execution/joins/BroadcastJoinSuite.scala    |  18 +-
 .../sql/execution/joins/InnerJoinSuite.scala    |  10 +-
 .../sql/execution/joins/OuterJoinSuite.scala    |   8 +-
 .../sql/execution/metric/SQLMetricsSuite.scala  |  30 +-
 .../streaming/FileStreamSinkLogSuite.scala      |   4 +-
 .../streaming/HDFSMetadataLogSuite.scala        |  33 +--
 .../streaming/state/StateStoreRDDSuite.scala    |  58 ++--
 .../sql/execution/ui/SQLListenerSuite.scala     |  36 +--
 .../spark/sql/internal/CatalogSuite.scala       |  69 +++--
 .../spark/sql/internal/SQLConfSuite.scala       |  74 ++---
 .../org/apache/spark/sql/jdbc/JDBCSuite.scala   |  34 +--
 .../apache/spark/sql/jdbc/JDBCWriteSuite.scala  |  46 +--
 .../spark/sql/sources/DataSourceTest.scala      |   2 +-
 .../sql/sources/PartitionedWriteSuite.scala     |  10 +-
 .../streaming/ContinuousQueryManagerSuite.scala |  50 ++--
 .../streaming/DataFrameReaderWriterSuite.scala  |  82 +++---
 .../sql/streaming/FileStreamSinkSuite.scala     |  18 +-
 .../sql/streaming/FileStreamSourceSuite.scala   |   6 +-
 .../spark/sql/streaming/FileStressSuite.scala   |   4 +-
 .../spark/sql/streaming/MemorySinkSuite.scala   |   4 +-
 .../spark/sql/streaming/StreamSuite.scala       |  10 +-
 .../org/apache/spark/sql/test/SQLTestData.scala |  56 ++--
 .../apache/spark/sql/test/SQLTestUtils.scala    |  38 +--
 .../spark/sql/test/SharedSQLContext.scala       |  29 +-
 .../apache/spark/sql/test/TestSQLContext.scala  |  48 ++--
 .../sql/util/ContinuousQueryListenerSuite.scala |  20 +-
 .../spark/sql/util/DataFrameCallbackSuite.scala |  18 +-
 .../spark/sql/hive/test/TestHiveSingleton.scala |   4 +-
 .../sql/catalyst/ExpressionToSQLSuite.scala     |   4 +-
 .../sql/catalyst/LogicalPlanToSQLSuite.scala    |  12 +-
 .../spark/sql/catalyst/SQLBuilderTest.scala     |   2 +-
 .../spark/sql/hive/ErrorPositionSuite.scala     |   8 +-
 .../spark/sql/hive/HiveSparkSubmitSuite.scala   |  16 +-
 .../sql/hive/InsertIntoHiveTableSuite.scala     |   6 +-
 .../sql/hive/MetastoreDataSourcesSuite.scala    |  18 +-
 .../spark/sql/hive/MultiDatabaseSuite.scala     |  88 +++---
 .../hive/ParquetHiveCompatibilitySuite.scala    |  12 +-
 .../hive/execution/AggregationQuerySuite.scala  | 110 ++++----
 .../spark/sql/hive/execution/HiveDDLSuite.scala |  18 +-
 .../spark/sql/hive/execution/HiveUDFSuite.scala |   2 +-
 .../sql/hive/execution/SQLQuerySuite.scala      |  32 +--
 .../spark/sql/hive/execution/SQLViewSuite.scala |  14 +-
 .../hive/execution/SQLWindowFunctionSuite.scala |   2 +-
 .../sql/hive/orc/OrcHadoopFsRelationSuite.scala |   8 +-
 .../spark/sql/hive/orc/OrcQuerySuite.scala      |  38 +--
 .../org/apache/spark/sql/hive/orc/OrcTest.scala |   4 +-
 .../apache/spark/sql/hive/parquetSuites.scala   |   2 +-
 .../spark/sql/sources/BucketedWriteSuite.scala  |   2 +-
 .../CommitFailureTestRelationSuite.scala        |   6 +-
 .../sql/sources/HadoopFsRelationTest.scala      |  66 ++---
 .../sources/ParquetHadoopFsRelationSuite.scala  |  20 +-
 .../SimpleTextHadoopFsRelationSuite.scala       |   4 +-
 224 files changed, 2934 insertions(+), 2611 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index 60a4a1d..e0c4363 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -17,18 +17,18 @@
 
 package org.apache.spark.ml;
 
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.ml.classification.LogisticRegression;
 import org.apache.spark.ml.feature.StandardScaler;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
 import static 
org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
 
 /**
@@ -36,23 +36,26 @@ import static 
org.apache.spark.mllib.classification.LogisticRegressionSuite.gene
  */
 public class JavaPipelineSuite {
 
+  private transient SparkSession spark;
   private transient JavaSparkContext jsc;
-  private transient SQLContext jsql;
   private transient Dataset<Row> dataset;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaPipelineSuite");
-    jsql = new SQLContext(jsc);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaPipelineSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
     JavaRDD<LabeledPoint> points =
       jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
-    dataset = jsql.createDataFrame(points, LabeledPoint.class);
+    dataset = spark.createDataFrame(points, LabeledPoint.class);
   }
 
   @After
   public void tearDown() {
-    jsc.stop();
-    jsc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -63,10 +66,10 @@ public class JavaPipelineSuite {
     LogisticRegression lr = new LogisticRegression()
       .setFeaturesCol("scaledFeatures");
     Pipeline pipeline = new Pipeline()
-      .setStages(new PipelineStage[] {scaler, lr});
+      .setStages(new PipelineStage[]{scaler, lr});
     PipelineModel model = pipeline.fit(dataset);
     model.transform(dataset).registerTempTable("prediction");
-    Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction 
FROM prediction");
+    Dataset<Row> predictions = spark.sql("SELECT label, probability, 
prediction FROM prediction");
     predictions.collectAsList();
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java
index b74bbed..15cde0d 100644
--- a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java
@@ -17,8 +17,8 @@
 
 package org.apache.spark.ml.attribute;
 
-import org.junit.Test;
 import org.junit.Assert;
+import org.junit.Test;
 
 public class JavaAttributeSuite {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
index 1f23682..8b89991 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
@@ -21,8 +21,6 @@ import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
 
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -32,21 +30,28 @@ import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.ml.tree.impl.TreeTests;
 import org.apache.spark.mllib.classification.LogisticRegressionSuite;
 import org.apache.spark.mllib.regression.LabeledPoint;
-
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
 
 public class JavaDecisionTreeClassifierSuite implements Serializable {
 
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaDecisionTreeClassifierSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -55,7 +60,7 @@ public class JavaDecisionTreeClassifierSuite implements 
Serializable {
     double A = 2.0;
     double B = -1.5;
 
-    JavaRDD<LabeledPoint> data = sc.parallelize(
+    JavaRDD<LabeledPoint> data = jsc.parallelize(
       LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 
2).cache();
     Map<Integer, Integer> categoricalFeatures = new HashMap<>();
     Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
2);
@@ -70,7 +75,7 @@ public class JavaDecisionTreeClassifierSuite implements 
Serializable {
       .setCacheNodeIds(false)
       .setCheckpointInterval(10)
       .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
-    for (String impurity: DecisionTreeClassifier.supportedImpurities()) {
+    for (String impurity : DecisionTreeClassifier.supportedImpurities()) {
       dt.setImpurity(impurity);
     }
     DecisionTreeClassificationModel model = dt.fit(dataFrame);

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
index 7484105..682371e 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
@@ -32,21 +32,27 @@ import 
org.apache.spark.mllib.classification.LogisticRegressionSuite;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
 
 
 public class JavaGBTClassifierSuite implements Serializable {
 
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaGBTClassifierSuite");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaGBTClassifierSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -55,7 +61,7 @@ public class JavaGBTClassifierSuite implements Serializable {
     double A = 2.0;
     double B = -1.5;
 
-    JavaRDD<LabeledPoint> data = sc.parallelize(
+    JavaRDD<LabeledPoint> data = jsc.parallelize(
       LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 
2).cache();
     Map<Integer, Integer> categoricalFeatures = new HashMap<>();
     Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
2);
@@ -74,7 +80,7 @@ public class JavaGBTClassifierSuite implements Serializable {
       .setMaxIter(3)
       .setStepSize(0.1)
       .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
-    for (String lossType: GBTClassifier.supportedLossTypes()) {
+    for (String lossType : GBTClassifier.supportedLossTypes()) {
       rf.setLossType(lossType);
     }
     GBTClassificationModel model = rf.fit(dataFrame);

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index e160a5a..e3ff683 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -27,18 +27,17 @@ import org.junit.Test;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
-import static 
org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
-
+import org.apache.spark.sql.SparkSession;
+import static 
org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
 
 public class JavaLogisticRegressionSuite implements Serializable {
 
+  private transient SparkSession spark;
   private transient JavaSparkContext jsc;
-  private transient SQLContext jsql;
   private transient Dataset<Row> dataset;
 
   private transient JavaRDD<LabeledPoint> datasetRDD;
@@ -46,18 +45,22 @@ public class JavaLogisticRegressionSuite implements 
Serializable {
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
-    jsql = new SQLContext(jsc);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaLogisticRegressionSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
+
     List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
     datasetRDD = jsc.parallelize(points, 2);
-    dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
+    dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
     dataset.registerTempTable("dataset");
   }
 
   @After
   public void tearDown() {
-    jsc.stop();
-    jsc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -66,7 +69,7 @@ public class JavaLogisticRegressionSuite implements 
Serializable {
     Assert.assertEquals(lr.getLabelCol(), "label");
     LogisticRegressionModel model = lr.fit(dataset);
     model.transform(dataset).registerTempTable("prediction");
-    Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction 
FROM prediction");
+    Dataset<Row> predictions = spark.sql("SELECT label, probability, 
prediction FROM prediction");
     predictions.collectAsList();
     // Check defaults
     Assert.assertEquals(0.5, model.getThreshold(), eps);
@@ -95,23 +98,23 @@ public class JavaLogisticRegressionSuite implements 
Serializable {
     // Modify model params, and check that the params worked.
     model.setThreshold(1.0);
     model.transform(dataset).registerTempTable("predAllZero");
-    Dataset<Row> predAllZero = jsql.sql("SELECT prediction, myProbability FROM 
predAllZero");
-    for (Row r: predAllZero.collectAsList()) {
+    Dataset<Row> predAllZero = spark.sql("SELECT prediction, myProbability 
FROM predAllZero");
+    for (Row r : predAllZero.collectAsList()) {
       Assert.assertEquals(0.0, r.getDouble(0), eps);
     }
     // Call transform with params, and check that the params worked.
     model.transform(dataset, model.threshold().w(0.0), 
model.probabilityCol().w("myProb"))
       .registerTempTable("predNotAllZero");
-    Dataset<Row> predNotAllZero = jsql.sql("SELECT prediction, myProb FROM 
predNotAllZero");
+    Dataset<Row> predNotAllZero = spark.sql("SELECT prediction, myProb FROM 
predNotAllZero");
     boolean foundNonZero = false;
-    for (Row r: predNotAllZero.collectAsList()) {
+    for (Row r : predNotAllZero.collectAsList()) {
       if (r.getDouble(0) != 0.0) foundNonZero = true;
     }
     Assert.assertTrue(foundNonZero);
 
     // Call fit() with new params, and check as many params as we can.
     LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), 
lr.regParam().w(0.1),
-        lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
+      lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
     LogisticRegression parent2 = (LogisticRegression) model2.parent();
     Assert.assertEquals(5, parent2.getMaxIter());
     Assert.assertEquals(0.1, parent2.getRegParam(), eps);
@@ -128,10 +131,10 @@ public class JavaLogisticRegressionSuite implements 
Serializable {
     Assert.assertEquals(2, model.numClasses());
 
     model.transform(dataset).registerTempTable("transformed");
-    Dataset<Row> trans1 = jsql.sql("SELECT rawPrediction, probability FROM 
transformed");
-    for (Row row: trans1.collectAsList()) {
-      Vector raw = (Vector)row.get(0);
-      Vector prob = (Vector)row.get(1);
+    Dataset<Row> trans1 = spark.sql("SELECT rawPrediction, probability FROM 
transformed");
+    for (Row row : trans1.collectAsList()) {
+      Vector raw = (Vector) row.get(0);
+      Vector prob = (Vector) row.get(1);
       Assert.assertEquals(raw.size(), 2);
       Assert.assertEquals(prob.size(), 2);
       double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
@@ -139,11 +142,11 @@ public class JavaLogisticRegressionSuite implements 
Serializable {
       Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), 
eps);
     }
 
-    Dataset<Row> trans2 = jsql.sql("SELECT prediction, probability FROM 
transformed");
-    for (Row row: trans2.collectAsList()) {
+    Dataset<Row> trans2 = spark.sql("SELECT prediction, probability FROM 
transformed");
+    for (Row row : trans2.collectAsList()) {
       double pred = row.getDouble(0);
-      Vector prob = (Vector)row.get(1);
-      double probOfPred = prob.apply((int)pred);
+      Vector prob = (Vector) row.get(1);
+      double probOfPred = prob.apply((int) pred);
       for (int i = 0; i < prob.size(); ++i) {
         Assert.assertTrue(probOfPred >= prob.apply(i));
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
index bc955f3..b0624ce 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
@@ -26,49 +26,49 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.Vectors;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 
 public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
 
-  private transient JavaSparkContext jsc;
-  private transient SQLContext sqlContext;
+  private transient SparkSession spark;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
-    sqlContext = new SQLContext(jsc);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaLogisticRegressionSuite")
+      .getOrCreate();
   }
 
   @After
   public void tearDown() {
-    jsc.stop();
-    jsc = null;
-    sqlContext = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
   public void testMLPC() {
-    Dataset<Row> dataFrame = sqlContext.createDataFrame(
-      jsc.parallelize(Arrays.asList(
-        new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
-        new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
-        new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
-        new LabeledPoint(0.0, Vectors.dense(1.0, 1.0)))),
-      LabeledPoint.class);
+    List<LabeledPoint> data = Arrays.asList(
+      new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+      new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
+      new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
+      new LabeledPoint(0.0, Vectors.dense(1.0, 1.0))
+    );
+    Dataset<Row> dataFrame = spark.createDataFrame(data, LabeledPoint.class);
+
     MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier()
-      .setLayers(new int[] {2, 5, 2})
+      .setLayers(new int[]{2, 5, 2})
       .setBlockSize(1)
       .setSeed(123L)
       .setMaxIter(100);
     MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame);
     Dataset<Row> result = model.transform(dataFrame);
     List<Row> predictionAndLabels = result.select("prediction", 
"label").collectAsList();
-    for (Row r: predictionAndLabels) {
+    for (Row r : predictionAndLabels) {
       Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1));
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
index 45101f2..3fc3648 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -26,13 +26,12 @@ import org.junit.Before;
 import org.junit.Test;
 import static org.junit.Assert.assertEquals;
 
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.VectorUDT;
 import org.apache.spark.mllib.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.Metadata;
 import org.apache.spark.sql.types.StructField;
@@ -40,19 +39,20 @@ import org.apache.spark.sql.types.StructType;
 
 public class JavaNaiveBayesSuite implements Serializable {
 
-  private transient JavaSparkContext jsc;
-  private transient SQLContext jsql;
+  private transient SparkSession spark;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
-    jsql = new SQLContext(jsc);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaLogisticRegressionSuite")
+      .getOrCreate();
   }
 
   @After
   public void tearDown() {
-    jsc.stop();
-    jsc = null;
+    spark.stop();
+    spark = null;
   }
 
   public void validatePrediction(Dataset<Row> predictionAndLabels) {
@@ -88,7 +88,7 @@ public class JavaNaiveBayesSuite implements Serializable {
       new StructField("features", new VectorUDT(), false, Metadata.empty())
     });
 
-    Dataset<Row> dataset = jsql.createDataFrame(data, schema);
+    Dataset<Row> dataset = spark.createDataFrame(data, schema);
     NaiveBayes nb = new 
NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
     NaiveBayesModel model = nb.fit(dataset);
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
index 00f4476..486fbbd 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
@@ -20,7 +20,6 @@ package org.apache.spark.ml.classification;
 import java.io.Serializable;
 import java.util.List;
 
-import org.apache.spark.sql.Row;
 import scala.collection.JavaConverters;
 
 import org.junit.After;
@@ -30,56 +29,61 @@ import org.junit.Test;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
-import static 
org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import static 
org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
 
 public class JavaOneVsRestSuite implements Serializable {
 
-    private transient JavaSparkContext jsc;
-    private transient SQLContext jsql;
-    private transient Dataset<Row> dataset;
-    private transient JavaRDD<LabeledPoint> datasetRDD;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
+  private transient Dataset<Row> dataset;
+  private transient JavaRDD<LabeledPoint> datasetRDD;
+
+  @Before
+  public void setUp() {
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaLOneVsRestSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
 
-    @Before
-    public void setUp() {
-        jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite");
-        jsql = new SQLContext(jsc);
-        int nPoints = 3;
+    int nPoints = 3;
 
-        // The following coefficients and xMean/xVariance are computed from 
iris dataset with
-        // lambda=0.2.
-        // As a result, we are drawing samples from probability distribution 
of an actual model.
-        double[] coefficients = {
-                -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
-                -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 };
+    // The following coefficients and xMean/xVariance are computed from iris 
dataset with
+    // lambda=0.2.
+    // As a result, we are drawing samples from probability distribution of an 
actual model.
+    double[] coefficients = {
+      -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
+      -0.16624, -0.84355, -0.048509, -0.301789, 4.170682};
 
-        double[] xMean = {5.843, 3.057, 3.758, 1.199};
-        double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
-        List<LabeledPoint> points = JavaConverters.seqAsJavaListConverter(
-            generateMultinomialLogisticInput(coefficients, xMean, xVariance, 
true, nPoints, 42)
-        ).asJava();
-        datasetRDD = jsc.parallelize(points, 2);
-        dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
-    }
+    double[] xMean = {5.843, 3.057, 3.758, 1.199};
+    double[] xVariance = {0.6856, 0.1899, 3.116, 0.581};
+    List<LabeledPoint> points = JavaConverters.seqAsJavaListConverter(
+      generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, 
nPoints, 42)
+    ).asJava();
+    datasetRDD = jsc.parallelize(points, 2);
+    dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
+  }
 
-    @After
-    public void tearDown() {
-        jsc.stop();
-        jsc = null;
-    }
+  @After
+  public void tearDown() {
+    spark.stop();
+    spark = null;
+  }
 
-    @Test
-    public void oneVsRestDefaultParams() {
-        OneVsRest ova = new OneVsRest();
-        ova.setClassifier(new LogisticRegression());
-        Assert.assertEquals(ova.getLabelCol() , "label");
-        Assert.assertEquals(ova.getPredictionCol() , "prediction");
-        OneVsRestModel ovaModel = ova.fit(dataset);
-        Dataset<Row> predictions = ovaModel.transform(dataset).select("label", 
"prediction");
-        predictions.collectAsList();
-        Assert.assertEquals(ovaModel.getLabelCol(), "label");
-        Assert.assertEquals(ovaModel.getPredictionCol() , "prediction");
-    }
+  @Test
+  public void oneVsRestDefaultParams() {
+    OneVsRest ova = new OneVsRest();
+    ova.setClassifier(new LogisticRegression());
+    Assert.assertEquals(ova.getLabelCol(), "label");
+    Assert.assertEquals(ova.getPredictionCol(), "prediction");
+    OneVsRestModel ovaModel = ova.fit(dataset);
+    Dataset<Row> predictions = ovaModel.transform(dataset).select("label", 
"prediction");
+    predictions.collectAsList();
+    Assert.assertEquals(ovaModel.getLabelCol(), "label");
+    Assert.assertEquals(ovaModel.getPredictionCol(), "prediction");
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
index 4f40fd6..e385566 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -34,21 +34,27 @@ import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
 
 
 public class JavaRandomForestClassifierSuite implements Serializable {
 
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaRandomForestClassifierSuite");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaRandomForestClassifierSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -57,7 +63,7 @@ public class JavaRandomForestClassifierSuite implements 
Serializable {
     double A = 2.0;
     double B = -1.5;
 
-    JavaRDD<LabeledPoint> data = sc.parallelize(
+    JavaRDD<LabeledPoint> data = jsc.parallelize(
       LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 
2).cache();
     Map<Integer, Integer> categoricalFeatures = new HashMap<>();
     Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
2);
@@ -75,22 +81,22 @@ public class JavaRandomForestClassifierSuite implements 
Serializable {
       .setSeed(1234)
       .setNumTrees(3)
       .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
-    for (String impurity: RandomForestClassifier.supportedImpurities()) {
+    for (String impurity : RandomForestClassifier.supportedImpurities()) {
       rf.setImpurity(impurity);
     }
-    for (String featureSubsetStrategy: 
RandomForestClassifier.supportedFeatureSubsetStrategies()) {
+    for (String featureSubsetStrategy : 
RandomForestClassifier.supportedFeatureSubsetStrategies()) {
       rf.setFeatureSubsetStrategy(featureSubsetStrategy);
     }
     String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
-    for (String strategy: realStrategies) {
+    for (String strategy : realStrategies) {
       rf.setFeatureSubsetStrategy(strategy);
     }
     String[] integerStrategies = {"1", "10", "100", "1000", "10000"};
-    for (String strategy: integerStrategies) {
+    for (String strategy : integerStrategies) {
       rf.setFeatureSubsetStrategy(strategy);
     }
     String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", 
"0"};
-    for (String strategy: invalidStrategies) {
+    for (String strategy : invalidStrategies) {
       try {
         rf.setFeatureSubsetStrategy(strategy);
         Assert.fail("Expected exception to be thrown for invalid strategies");

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
index a3fcdb5..3ab09ac 100644
--- a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
@@ -21,37 +21,37 @@ import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
 
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
 
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 
 public class JavaKMeansSuite implements Serializable {
 
   private transient int k = 5;
-  private transient JavaSparkContext sc;
   private transient Dataset<Row> dataset;
-  private transient SQLContext sql;
+  private transient SparkSession spark;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaKMeansSuite");
-    sql = new SQLContext(sc);
-
-    dataset = KMeansSuite.generateKMeansData(sql, 50, 3, k);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaKMeansSuite")
+      .getOrCreate();
+    dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k);
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -65,7 +65,7 @@ public class JavaKMeansSuite implements Serializable {
     Dataset<Row> transformed = model.transform(dataset);
     List<String> columns = Arrays.asList(transformed.columns());
     List<String> expectedColumns = Arrays.asList("features", "prediction");
-    for (String column: expectedColumns) {
+    for (String column : expectedColumns) {
       assertTrue(columns.contains(column));
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
index 77e3a48..a96b43d 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
@@ -25,40 +25,40 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.Metadata;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 
 public class JavaBucketizerSuite {
-  private transient JavaSparkContext jsc;
-  private transient SQLContext jsql;
+  private transient SparkSession spark;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaBucketizerSuite");
-    jsql = new SQLContext(jsc);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaBucketizerSuite")
+      .getOrCreate();
   }
 
   @After
   public void tearDown() {
-    jsc.stop();
-    jsc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
   public void bucketizerTest() {
     double[] splits = {-0.5, 0.0, 0.5};
 
-    StructType schema = new StructType(new StructField[] {
+    StructType schema = new StructType(new StructField[]{
       new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
     });
-    Dataset<Row> dataset = jsql.createDataFrame(
+    Dataset<Row> dataset = spark.createDataFrame(
       Arrays.asList(
         RowFactory.create(-0.5),
         RowFactory.create(-0.3),

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
index ed1ad4c..06482d8 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
@@ -21,43 +21,44 @@ import java.util.Arrays;
 import java.util.List;
 
 import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D;
+
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.linalg.VectorUDT;
 import org.apache.spark.mllib.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.Metadata;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 
 public class JavaDCTSuite {
-  private transient JavaSparkContext jsc;
-  private transient SQLContext jsql;
+  private transient SparkSession spark;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaDCTSuite");
-    jsql = new SQLContext(jsc);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaDCTSuite")
+      .getOrCreate();
   }
 
   @After
   public void tearDown() {
-    jsc.stop();
-    jsc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
   public void javaCompatibilityTest() {
-    double[] input = new double[] {1D, 2D, 3D, 4D};
-    Dataset<Row> dataset = jsql.createDataFrame(
+    double[] input = new double[]{1D, 2D, 3D, 4D};
+    Dataset<Row> dataset = spark.createDataFrame(
       Arrays.asList(RowFactory.create(Vectors.dense(input))),
       new StructType(new StructField[]{
         new StructField("vec", (new VectorUDT()), false, Metadata.empty())

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
index 6e2cc7e..0e21d4a 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
@@ -25,12 +25,11 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.Metadata;
 import org.apache.spark.sql.types.StructField;
@@ -38,19 +37,20 @@ import org.apache.spark.sql.types.StructType;
 
 
 public class JavaHashingTFSuite {
-  private transient JavaSparkContext jsc;
-  private transient SQLContext jsql;
+  private transient SparkSession spark;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaHashingTFSuite");
-    jsql = new SQLContext(jsc);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaHashingTFSuite")
+      .getOrCreate();
   }
 
   @After
   public void tearDown() {
-    jsc.stop();
-    jsc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -65,7 +65,7 @@ public class JavaHashingTFSuite {
       new StructField("sentence", DataTypes.StringType, false, 
Metadata.empty())
     });
 
-    Dataset<Row> sentenceData = jsql.createDataFrame(data, schema);
+    Dataset<Row> sentenceData = spark.createDataFrame(data, schema);
     Tokenizer tokenizer = new Tokenizer()
       .setInputCol("sentence")
       .setOutputCol("words");

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
index 5bbd963..04b2897 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
@@ -23,27 +23,30 @@ import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 
 public class JavaNormalizerSuite {
+  private transient SparkSession spark;
   private transient JavaSparkContext jsc;
-  private transient SQLContext jsql;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaNormalizerSuite");
-    jsql = new SQLContext(jsc);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaNormalizerSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    jsc.stop();
-    jsc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -54,7 +57,7 @@ public class JavaNormalizerSuite {
       new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
       new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
     ));
-    Dataset<Row> dataFrame = jsql.createDataFrame(points, 
VectorIndexerSuite.FeatureData.class);
+    Dataset<Row> dataFrame = spark.createDataFrame(points, 
VectorIndexerSuite.FeatureData.class);
     Normalizer normalizer = new Normalizer()
       .setInputCol("features")
       .setOutputCol("normFeatures");

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
index 1389d17..32f6b43 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
@@ -28,31 +28,34 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.linalg.distributed.RowMatrix;
+import org.apache.spark.api.java.function.Function;
 import org.apache.spark.mllib.linalg.Matrix;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.mllib.linalg.distributed.RowMatrix;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 
 public class JavaPCASuite implements Serializable {
+  private transient SparkSession spark;
   private transient JavaSparkContext jsc;
-  private transient SQLContext sqlContext;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaPCASuite");
-    sqlContext = new SQLContext(jsc);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaPCASuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    jsc.stop();
-    jsc = null;
+    spark.stop();
+    spark = null;
   }
 
   public static class VectorPair implements Serializable {
@@ -100,7 +103,7 @@ public class JavaPCASuite implements Serializable {
       }
     );
 
-    Dataset<Row> df = sqlContext.createDataFrame(featuresExpected, 
VectorPair.class);
+    Dataset<Row> df = spark.createDataFrame(featuresExpected, 
VectorPair.class);
     PCAModel pca = new PCA()
       .setInputCol("features")
       .setOutputCol("pca_features")

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
index 6a8bb64..8f72607 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
@@ -32,19 +32,22 @@ import org.apache.spark.mllib.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.Metadata;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 
 public class JavaPolynomialExpansionSuite {
+  private transient SparkSession spark;
   private transient JavaSparkContext jsc;
-  private transient SQLContext jsql;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaPolynomialExpansionSuite");
-    jsql = new SQLContext(jsc);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaPolynomialExpansionSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
@@ -72,20 +75,20 @@ public class JavaPolynomialExpansionSuite {
       )
     );
 
-    StructType schema = new StructType(new StructField[] {
+    StructType schema = new StructType(new StructField[]{
       new StructField("features", new VectorUDT(), false, Metadata.empty()),
       new StructField("expected", new VectorUDT(), false, Metadata.empty())
     });
 
-    Dataset<Row> dataset = jsql.createDataFrame(data, schema);
+    Dataset<Row> dataset = spark.createDataFrame(data, schema);
 
     List<Row> pairs = polyExpansion.transform(dataset)
       .select("polyFeatures", "expected")
       .collectAsList();
 
     for (Row r : pairs) {
-      double[] polyFeatures = ((Vector)r.get(0)).toArray();
-      double[] expected = ((Vector)r.get(1)).toArray();
+      double[] polyFeatures = ((Vector) r.get(0)).toArray();
+      double[] expected = ((Vector) r.get(1)).toArray();
       Assert.assertArrayEquals(polyFeatures, expected, 1e-1);
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
index 3f6fc33..c7397bd 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
@@ -28,22 +28,25 @@ import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 
 public class JavaStandardScalerSuite {
+  private transient SparkSession spark;
   private transient JavaSparkContext jsc;
-  private transient SQLContext jsql;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaStandardScalerSuite");
-    jsql = new SQLContext(jsc);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaStandardScalerSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    jsc.stop();
-    jsc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -54,7 +57,7 @@ public class JavaStandardScalerSuite {
       new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
       new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
     );
-    Dataset<Row> dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2),
+    Dataset<Row> dataFrame = spark.createDataFrame(jsc.parallelize(points, 2),
       VectorIndexerSuite.FeatureData.class);
     StandardScaler scaler = new StandardScaler()
       .setInputCol("features")

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
index bdcbde5..2b156f3 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
@@ -24,11 +24,10 @@ import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.Metadata;
 import org.apache.spark.sql.types.StructField;
@@ -37,19 +36,20 @@ import org.apache.spark.sql.types.StructType;
 
 public class JavaStopWordsRemoverSuite {
 
-  private transient JavaSparkContext jsc;
-  private transient SQLContext jsql;
+  private transient SparkSession spark;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaStopWordsRemoverSuite");
-    jsql = new SQLContext(jsc);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaStopWordsRemoverSuite")
+      .getOrCreate();
   }
 
   @After
   public void tearDown() {
-    jsc.stop();
-    jsc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -62,11 +62,11 @@ public class JavaStopWordsRemoverSuite {
       RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")),
       RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb"))
     );
-    StructType schema = new StructType(new StructField[] {
+    StructType schema = new StructType(new StructField[]{
       new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), 
false,
-                      Metadata.empty())
+        Metadata.empty())
     });
-    Dataset<Row> dataset = jsql.createDataFrame(data, schema);
+    Dataset<Row> dataset = spark.createDataFrame(data, schema);
 
     remover.transform(dataset).collect();
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
index 431779c..52c0bde 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
@@ -25,40 +25,42 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.SparkConf;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 import static org.apache.spark.sql.types.DataTypes.*;
 
 public class JavaStringIndexerSuite {
-  private transient JavaSparkContext jsc;
-  private transient SQLContext sqlContext;
+  private transient SparkSession spark;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaStringIndexerSuite");
-    sqlContext = new SQLContext(jsc);
+    SparkConf sparkConf = new SparkConf();
+    sparkConf.setMaster("local");
+    sparkConf.setAppName("JavaStringIndexerSuite");
+
+    spark = SparkSession.builder().config(sparkConf).getOrCreate();
   }
 
   @After
   public void tearDown() {
-    jsc.stop();
-    sqlContext = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
   public void testStringIndexer() {
-    StructType schema = createStructType(new StructField[] {
+    StructType schema = createStructType(new StructField[]{
       createStructField("id", IntegerType, false),
       createStructField("label", StringType, false)
     });
     List<Row> data = Arrays.asList(
       cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c"));
-    Dataset<Row> dataset = sqlContext.createDataFrame(data, schema);
+    Dataset<Row> dataset = spark.createDataFrame(data, schema);
 
     StringIndexer indexer = new StringIndexer()
       .setInputCol("label")
@@ -70,7 +72,9 @@ public class JavaStringIndexerSuite {
       output.orderBy("id").select("id", "labelIndex").collectAsList());
   }
 
-  /** An alias for RowFactory.create. */
+  /**
+   * An alias for RowFactory.create.
+   */
   private Row cr(Object... values) {
     return RowFactory.create(values);
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
index 83d16cb..0bac283 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
@@ -29,22 +29,25 @@ import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 
 public class JavaTokenizerSuite {
+  private transient SparkSession spark;
   private transient JavaSparkContext jsc;
-  private transient SQLContext jsql;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaTokenizerSuite");
-    jsql = new SQLContext(jsc);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaTokenizerSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    jsc.stop();
-    jsc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -59,10 +62,10 @@ public class JavaTokenizerSuite {
 
 
     JavaRDD<TokenizerTestData> rdd = jsc.parallelize(Arrays.asList(
-      new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}),
-      new TokenizerTestData("Te,st.  punct", new String[] {"Te,st.", "punct"})
+      new TokenizerTestData("Test of tok.", new String[]{"Test", "tok."}),
+      new TokenizerTestData("Te,st.  punct", new String[]{"Te,st.", "punct"})
     ));
-    Dataset<Row> dataset = jsql.createDataFrame(rdd, TokenizerTestData.class);
+    Dataset<Row> dataset = spark.createDataFrame(rdd, TokenizerTestData.class);
 
     List<Row> pairs = myRegExTokenizer.transform(dataset)
       .select("tokens", "wantedTokens")

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
index e45e198..8774cd0 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
@@ -24,36 +24,39 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.SparkConf;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.linalg.VectorUDT;
 import org.apache.spark.mllib.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
 import static org.apache.spark.sql.types.DataTypes.*;
 
 public class JavaVectorAssemblerSuite {
-  private transient JavaSparkContext jsc;
-  private transient SQLContext sqlContext;
+  private transient SparkSession spark;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite");
-    sqlContext = new SQLContext(jsc);
+    SparkConf sparkConf = new SparkConf();
+    sparkConf.setMaster("local");
+    sparkConf.setAppName("JavaVectorAssemblerSuite");
+
+    spark = SparkSession.builder().config(sparkConf).getOrCreate();
   }
 
   @After
   public void tearDown() {
-    jsc.stop();
-    jsc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
   public void testVectorAssembler() {
-    StructType schema = createStructType(new StructField[] {
+    StructType schema = createStructType(new StructField[]{
       createStructField("id", IntegerType, false),
       createStructField("x", DoubleType, false),
       createStructField("y", new VectorUDT(), false),
@@ -63,14 +66,14 @@ public class JavaVectorAssemblerSuite {
     });
     Row row = RowFactory.create(
       0, 0.0, Vectors.dense(1.0, 2.0), "a",
-      Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L);
-    Dataset<Row> dataset = sqlContext.createDataFrame(Arrays.asList(row), 
schema);
+      Vectors.sparse(2, new int[]{1}, new double[]{3.0}), 10L);
+    Dataset<Row> dataset = spark.createDataFrame(Arrays.asList(row), schema);
     VectorAssembler assembler = new VectorAssembler()
-      .setInputCols(new String[] {"x", "y", "z", "n"})
+      .setInputCols(new String[]{"x", "y", "z", "n"})
       .setOutputCol("features");
     Dataset<Row> output = assembler.transform(dataset);
     Assert.assertEquals(
-      Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 
10.0}),
+      Vectors.sparse(6, new int[]{1, 2, 4, 5}, new double[]{1.0, 2.0, 3.0, 
10.0}),
       output.select("features").first().<Vector>getAs(0));
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
index fec6cac..c386c9a 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
@@ -32,21 +32,26 @@ import 
org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData;
 import org.apache.spark.mllib.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 
 
 public class JavaVectorIndexerSuite implements Serializable {
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaVectorIndexerSuite");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaVectorIndexerSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -57,8 +62,7 @@ public class JavaVectorIndexerSuite implements Serializable {
       new FeatureData(Vectors.dense(1.0, 3.0)),
       new FeatureData(Vectors.dense(1.0, 4.0))
     );
-    SQLContext sqlContext = new SQLContext(sc);
-    Dataset<Row> data = sqlContext.createDataFrame(sc.parallelize(points, 2), 
FeatureData.class);
+    Dataset<Row> data = spark.createDataFrame(jsc.parallelize(points, 2), 
FeatureData.class);
     VectorIndexer indexer = new VectorIndexer()
       .setInputCol("features")
       .setOutputCol("indexed")

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
index e2da111..59ad3c2 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
@@ -25,7 +25,6 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.ml.attribute.Attribute;
 import org.apache.spark.ml.attribute.AttributeGroup;
 import org.apache.spark.ml.attribute.NumericAttribute;
@@ -34,24 +33,25 @@ import org.apache.spark.mllib.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.StructType;
 
 
 public class JavaVectorSlicerSuite {
-  private transient JavaSparkContext jsc;
-  private transient SQLContext jsql;
+  private transient SparkSession spark;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaVectorSlicerSuite");
-    jsql = new SQLContext(jsc);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaVectorSlicerSuite")
+      .getOrCreate();
   }
 
   @After
   public void tearDown() {
-    jsc.stop();
-    jsc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -69,7 +69,7 @@ public class JavaVectorSlicerSuite {
     );
 
     Dataset<Row> dataset =
-        jsql.createDataFrame(data, (new 
StructType()).add(group.toStructField()));
+      spark.createDataFrame(data, (new 
StructType()).add(group.toStructField()));
 
     VectorSlicer vectorSlicer = new VectorSlicer()
       .setInputCol("userFeatures").setOutputCol("features");

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
index 7517b70..392aabc 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
@@ -24,28 +24,28 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.*;
 
 public class JavaWord2VecSuite {
-  private transient JavaSparkContext jsc;
-  private transient SQLContext sqlContext;
+  private transient SparkSession spark;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaWord2VecSuite");
-    sqlContext = new SQLContext(jsc);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaWord2VecSuite")
+      .getOrCreate();
   }
 
   @After
   public void tearDown() {
-    jsc.stop();
-    jsc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -53,7 +53,7 @@ public class JavaWord2VecSuite {
     StructType schema = new StructType(new StructField[]{
       new StructField("text", new ArrayType(DataTypes.StringType, true), 
false, Metadata.empty())
     });
-    Dataset<Row> documentDF = sqlContext.createDataFrame(
+    Dataset<Row> documentDF = spark.createDataFrame(
       Arrays.asList(
         RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
         RowFactory.create(Arrays.asList("I wish Java could use case 
classes".split(" "))),
@@ -68,8 +68,8 @@ public class JavaWord2VecSuite {
     Word2VecModel model = word2Vec.fit(documentDF);
     Dataset<Row> result = model.transform(documentDF);
 
-    for (Row r: result.select("result").collectAsList()) {
-      double[] polyFeatures = ((Vector)r.get(0)).toArray();
+    for (Row r : result.select("result").collectAsList()) {
+      double[] polyFeatures = ((Vector) r.get(0)).toArray();
       Assert.assertEquals(polyFeatures.length, 3);
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
index fa777f3..a5b5dd4 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
@@ -25,23 +25,29 @@ import org.junit.Before;
 import org.junit.Test;
 
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
 
 /**
  * Test Param and related classes in Java
  */
 public class JavaParamsSuite {
 
+  private transient SparkSession spark;
   private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaParamsSuite");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaParamsSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    jsc.stop();
-    jsc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -51,7 +57,7 @@ public class JavaParamsSuite {
     testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
     Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
     Assert.assertEquals(testParams.getMyStringParam(), "a");
-    Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] 
{1.0, 2.0}, 0.0);
+    Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new 
double[]{1.0, 2.0}, 0.0);
   }
 
   @Test

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java 
b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
index 06f7fbb..1ad5f7a 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
@@ -45,9 +45,14 @@ public class JavaTestParams extends JavaParams {
   }
 
   private IntParam myIntParam_;
-  public IntParam myIntParam() { return myIntParam_; }
 
-  public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); }
+  public IntParam myIntParam() {
+    return myIntParam_;
+  }
+
+  public int getMyIntParam() {
+    return (Integer) getOrDefault(myIntParam_);
+  }
 
   public JavaTestParams setMyIntParam(int value) {
     set(myIntParam_, value);
@@ -55,9 +60,14 @@ public class JavaTestParams extends JavaParams {
   }
 
   private DoubleParam myDoubleParam_;
-  public DoubleParam myDoubleParam() { return myDoubleParam_; }
 
-  public double getMyDoubleParam() { return 
(Double)getOrDefault(myDoubleParam_); }
+  public DoubleParam myDoubleParam() {
+    return myDoubleParam_;
+  }
+
+  public double getMyDoubleParam() {
+    return (Double) getOrDefault(myDoubleParam_);
+  }
 
   public JavaTestParams setMyDoubleParam(double value) {
     set(myDoubleParam_, value);
@@ -65,9 +75,14 @@ public class JavaTestParams extends JavaParams {
   }
 
   private Param<String> myStringParam_;
-  public Param<String> myStringParam() { return myStringParam_; }
 
-  public String getMyStringParam() { return getOrDefault(myStringParam_); }
+  public Param<String> myStringParam() {
+    return myStringParam_;
+  }
+
+  public String getMyStringParam() {
+    return getOrDefault(myStringParam_);
+  }
 
   public JavaTestParams setMyStringParam(String value) {
     set(myStringParam_, value);
@@ -75,9 +90,14 @@ public class JavaTestParams extends JavaParams {
   }
 
   private DoubleArrayParam myDoubleArrayParam_;
-  public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; }
 
-  public double[] getMyDoubleArrayParam() { return 
getOrDefault(myDoubleArrayParam_); }
+  public DoubleArrayParam myDoubleArrayParam() {
+    return myDoubleArrayParam_;
+  }
+
+  public double[] getMyDoubleArrayParam() {
+    return getOrDefault(myDoubleArrayParam_);
+  }
 
   public JavaTestParams setMyDoubleArrayParam(double[] value) {
     set(myDoubleArrayParam_, value);
@@ -96,7 +116,7 @@ public class JavaTestParams extends JavaParams {
 
     setDefault(myIntParam(), 1);
     setDefault(myDoubleParam(), 0.5);
-    setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0});
+    setDefault(myDoubleArrayParam(), new double[]{1.0, 2.0});
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
index fa3b28e..bbd59a0 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -32,21 +32,27 @@ import 
org.apache.spark.mllib.classification.LogisticRegressionSuite;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
 
 
 public class JavaDecisionTreeRegressorSuite implements Serializable {
 
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaDecisionTreeRegressorSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -55,7 +61,7 @@ public class JavaDecisionTreeRegressorSuite implements 
Serializable {
     double A = 2.0;
     double B = -1.5;
 
-    JavaRDD<LabeledPoint> data = sc.parallelize(
+    JavaRDD<LabeledPoint> data = jsc.parallelize(
       LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 
2).cache();
     Map<Integer, Integer> categoricalFeatures = new HashMap<>();
     Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
0);
@@ -70,7 +76,7 @@ public class JavaDecisionTreeRegressorSuite implements 
Serializable {
       .setCacheNodeIds(false)
       .setCheckpointInterval(10)
       .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
-    for (String impurity: DecisionTreeRegressor.supportedImpurities()) {
+    for (String impurity : DecisionTreeRegressor.supportedImpurities()) {
       dt.setImpurity(impurity);
     }
     DecisionTreeRegressionModel model = dt.fit(dataFrame);

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
index 8413ea0..5370b58 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
@@ -32,21 +32,27 @@ import 
org.apache.spark.mllib.classification.LogisticRegressionSuite;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
 
 
 public class JavaGBTRegressorSuite implements Serializable {
 
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaGBTRegressorSuite");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaGBTRegressorSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -55,7 +61,7 @@ public class JavaGBTRegressorSuite implements Serializable {
     double A = 2.0;
     double B = -1.5;
 
-    JavaRDD<LabeledPoint> data = sc.parallelize(
+    JavaRDD<LabeledPoint> data = jsc.parallelize(
       LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 
2).cache();
     Map<Integer, Integer> categoricalFeatures = new HashMap<>();
     Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
0);
@@ -73,7 +79,7 @@ public class JavaGBTRegressorSuite implements Serializable {
       .setMaxIter(3)
       .setStepSize(0.1)
       .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
-    for (String lossType: GBTRegressor.supportedLossTypes()) {
+    for (String lossType : GBTRegressor.supportedLossTypes()) {
       rf.setLossType(lossType);
     }
     GBTRegressionModel model = rf.fit(dataFrame);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to