http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index 9f81751..00c59f0 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -30,25 +30,26 @@ import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite
-  .generateLogisticInputAsList;
-
+import org.apache.spark.sql.SparkSession;
+import static 
org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
 
 public class JavaLinearRegressionSuite implements Serializable {
 
+  private transient SparkSession spark;
   private transient JavaSparkContext jsc;
-  private transient SQLContext jsql;
   private transient Dataset<Row> dataset;
   private transient JavaRDD<LabeledPoint> datasetRDD;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
-    jsql = new SQLContext(jsc);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaLinearRegressionSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
     List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
     datasetRDD = jsc.parallelize(points, 2);
-    dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
+    dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
     dataset.registerTempTable("dataset");
   }
 
@@ -65,7 +66,7 @@ public class JavaLinearRegressionSuite implements 
Serializable {
     assertEquals("auto", lr.getSolver());
     LinearRegressionModel model = lr.fit(dataset);
     model.transform(dataset).registerTempTable("prediction");
-    Dataset<Row> predictions = jsql.sql("SELECT label, prediction FROM 
prediction");
+    Dataset<Row> predictions = spark.sql("SELECT label, prediction FROM 
prediction");
     predictions.collect();
     // Check defaults
     assertEquals("features", model.getFeaturesCol());
@@ -76,8 +77,8 @@ public class JavaLinearRegressionSuite implements 
Serializable {
   public void linearRegressionWithSetters() {
     // Set params, train, and check as many params as we can.
     LinearRegression lr = new LinearRegression()
-        .setMaxIter(10)
-        .setRegParam(1.0).setSolver("l-bfgs");
+      .setMaxIter(10)
+      .setRegParam(1.0).setSolver("l-bfgs");
     LinearRegressionModel model = lr.fit(dataset);
     LinearRegression parent = (LinearRegression) model.parent();
     assertEquals(10, parent.getMaxIter());
@@ -85,7 +86,7 @@ public class JavaLinearRegressionSuite implements 
Serializable {
 
     // Call fit() with new params, and check as many params as we can.
     LinearRegressionModel model2 =
-        lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), 
lr.predictionCol().w("thePred"));
+      lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), 
lr.predictionCol().w("thePred"));
     LinearRegression parent2 = (LinearRegression) model2.parent();
     assertEquals(5, parent2.getMaxIter());
     assertEquals(0.1, parent2.getRegParam(), 0.0);

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index 38b895f..fdb41ff 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -28,27 +28,33 @@ import org.junit.Test;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.classification.LogisticRegressionSuite;
 import org.apache.spark.ml.tree.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
 
 
 public class JavaRandomForestRegressorSuite implements Serializable {
 
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaRandomForestRegressorSuite");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaRandomForestRegressorSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -57,7 +63,7 @@ public class JavaRandomForestRegressorSuite implements 
Serializable {
     double A = 2.0;
     double B = -1.5;
 
-    JavaRDD<LabeledPoint> data = sc.parallelize(
+    JavaRDD<LabeledPoint> data = jsc.parallelize(
       LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 
2).cache();
     Map<Integer, Integer> categoricalFeatures = new HashMap<>();
     Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
0);
@@ -75,22 +81,22 @@ public class JavaRandomForestRegressorSuite implements 
Serializable {
       .setSeed(1234)
       .setNumTrees(3)
       .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
-    for (String impurity: RandomForestRegressor.supportedImpurities()) {
+    for (String impurity : RandomForestRegressor.supportedImpurities()) {
       rf.setImpurity(impurity);
     }
-    for (String featureSubsetStrategy: 
RandomForestRegressor.supportedFeatureSubsetStrategies()) {
+    for (String featureSubsetStrategy : 
RandomForestRegressor.supportedFeatureSubsetStrategies()) {
       rf.setFeatureSubsetStrategy(featureSubsetStrategy);
     }
     String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
-    for (String strategy: realStrategies) {
+    for (String strategy : realStrategies) {
       rf.setFeatureSubsetStrategy(strategy);
     }
     String[] integerStrategies = {"1", "10", "100", "1000", "10000"};
-    for (String strategy: integerStrategies) {
+    for (String strategy : integerStrategies) {
       rf.setFeatureSubsetStrategy(strategy);
     }
     String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", 
"0"};
-    for (String strategy: invalidStrategies) {
+    for (String strategy : invalidStrategies) {
       try {
         rf.setFeatureSubsetStrategy(strategy);
         Assert.fail("Expected exception to be thrown for invalid strategies");

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
index 1c18b2b..058f2dd 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
@@ -28,12 +28,11 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.DenseVector;
 import org.apache.spark.mllib.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 import org.apache.spark.util.Utils;
 
 
@@ -41,16 +40,17 @@ import org.apache.spark.util.Utils;
  * Test LibSVMRelation in Java.
  */
 public class JavaLibSVMRelationSuite {
-  private transient JavaSparkContext jsc;
-  private transient SQLContext sqlContext;
+  private transient SparkSession spark;
 
   private File tempDir;
   private String path;
 
   @Before
   public void setUp() throws IOException {
-    jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite");
-    sqlContext = new SQLContext(jsc);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaLibSVMRelationSuite")
+      .getOrCreate();
 
     tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), 
"datasource");
     File file = new File(tempDir, "part-00000");
@@ -61,14 +61,14 @@ public class JavaLibSVMRelationSuite {
 
   @After
   public void tearDown() {
-    jsc.stop();
-    jsc = null;
+    spark.stop();
+    spark = null;
     Utils.deleteRecursively(tempDir);
   }
 
   @Test
   public void verifyLibSVMDF() {
-    Dataset<Row> dataset = 
sqlContext.read().format("libsvm").option("vectorType", "dense")
+    Dataset<Row> dataset = spark.read().format("libsvm").option("vectorType", 
"dense")
       .load(path);
     Assert.assertEquals("label", dataset.columns()[0]);
     Assert.assertEquals("features", dataset.columns()[1]);

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
index 24b0097..8b4d034 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -32,21 +32,25 @@ import org.apache.spark.ml.param.ParamMap;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 import static 
org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
 
 public class JavaCrossValidatorSuite implements Serializable {
 
+  private transient SparkSession spark;
   private transient JavaSparkContext jsc;
-  private transient SQLContext jsql;
   private transient Dataset<Row> dataset;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite");
-    jsql = new SQLContext(jsc);
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaCrossValidatorSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
+
     List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
-    dataset = jsql.createDataFrame(jsc.parallelize(points, 2), 
LabeledPoint.class);
+    dataset = spark.createDataFrame(jsc.parallelize(points, 2), 
LabeledPoint.class);
   }
 
   @After
@@ -59,8 +63,8 @@ public class JavaCrossValidatorSuite implements Serializable {
   public void crossValidationWithLogisticRegression() {
     LogisticRegression lr = new LogisticRegression();
     ParamMap[] lrParamMaps = new ParamGridBuilder()
-      .addGrid(lr.regParam(), new double[] {0.001, 1000.0})
-      .addGrid(lr.maxIter(), new int[] {0, 10})
+      .addGrid(lr.regParam(), new double[]{0.001, 1000.0})
+      .addGrid(lr.maxIter(), new int[]{0, 10})
       .build();
     BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator();
     CrossValidator cv = new CrossValidator()

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala 
b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
index 9283015..878bc66 100644
--- a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
+++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
@@ -37,4 +37,5 @@ object IdentifiableSuite {
   class Test(override val uid: String) extends Identifiable {
     def this() = this(Identifiable.randomUID("test"))
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
index 01ff1ea..7151e27 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
@@ -27,31 +27,34 @@ import org.junit.Test;
 
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.SparkSession;
 import org.apache.spark.util.Utils;
 
 public class JavaDefaultReadWriteSuite {
 
   JavaSparkContext jsc = null;
-  SQLContext sqlContext = null;
+  SparkSession spark = null;
   File tempDir = null;
 
   @Before
   public void setUp() {
-    jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite");
     SQLContext.clearActive();
-    sqlContext = new SQLContext(jsc);
-    SQLContext.setActive(sqlContext);
+    spark = SparkSession.builder()
+      .master("local[2]")
+      .appName("JavaDefaultReadWriteSuite")
+      .getOrCreate();
+    SQLContext.setActive(spark.wrapped());
+
     tempDir = Utils.createTempDir(
       System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite");
   }
 
   @After
   public void tearDown() {
-    sqlContext = null;
     SQLContext.clearActive();
-    if (jsc != null) {
-      jsc.stop();
-      jsc = null;
+    if (spark != null) {
+      spark.stop();
+      spark = null;
     }
     Utils.deleteRecursively(tempDir);
   }
@@ -70,7 +73,7 @@ public class JavaDefaultReadWriteSuite {
     } catch (IOException e) {
       // expected
     }
-    instance.write().context(sqlContext).overwrite().save(outputPath);
+    instance.write().context(spark.wrapped()).overwrite().save(outputPath);
     MyParams newInstance = MyParams.load(outputPath);
     Assert.assertEquals("UID should match.", instance.uid(), 
newInstance.uid());
     Assert.assertEquals("Params should be preserved.",

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
 
b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
index 862221d..2f10d14 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
@@ -27,26 +27,31 @@ import org.junit.Test;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
-
 import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.SparkSession;
 
 public class JavaLogisticRegressionSuite implements Serializable {
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaLogisticRegressionSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   int validatePrediction(List<LabeledPoint> validationData, 
LogisticRegressionModel model) {
     int numAccurate = 0;
-    for (LabeledPoint point: validationData) {
+    for (LabeledPoint point : validationData) {
       Double prediction = model.predict(point.features());
       if (prediction == point.label()) {
         numAccurate++;
@@ -61,16 +66,16 @@ public class JavaLogisticRegressionSuite implements 
Serializable {
     double A = 2.0;
     double B = -1.5;
 
-    JavaRDD<LabeledPoint> testRDD = sc.parallelize(
-        LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 
42), 2).cache();
+    JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
+      LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 
2).cache();
     List<LabeledPoint> validationData =
-        LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
+      LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
 
     LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD();
     lrImpl.setIntercept(true);
     lrImpl.optimizer().setStepSize(1.0)
-                      .setRegParam(1.0)
-                      .setNumIterations(100);
+      .setRegParam(1.0)
+      .setNumIterations(100);
     LogisticRegressionModel model = lrImpl.run(testRDD.rdd());
 
     int numAccurate = validatePrediction(validationData, model);
@@ -83,13 +88,13 @@ public class JavaLogisticRegressionSuite implements 
Serializable {
     double A = 0.0;
     double B = -2.5;
 
-    JavaRDD<LabeledPoint> testRDD = sc.parallelize(
-        LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 
42), 2).cache();
+    JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
+      LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 
2).cache();
     List<LabeledPoint> validationData =
-        LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
+      LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
 
     LogisticRegressionModel model = LogisticRegressionWithSGD.train(
-        testRDD.rdd(), 100, 1.0, 1.0);
+      testRDD.rdd(), 100, 1.0, 1.0);
 
     int numAccurate = validatePrediction(validationData, model);
     Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
 
b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
index 3771c0e..5e212e2 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -32,20 +32,26 @@ import org.apache.spark.api.java.function.Function;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.linalg.Vectors;
 import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.SparkSession;
 
 
 public class JavaNaiveBayesSuite implements Serializable {
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaNaiveBayesSuite");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaNaiveBayesSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   private static final List<LabeledPoint> POINTS = Arrays.asList(
@@ -59,7 +65,7 @@ public class JavaNaiveBayesSuite implements Serializable {
 
   private int validatePrediction(List<LabeledPoint> points, NaiveBayesModel 
model) {
     int correct = 0;
-    for (LabeledPoint p: points) {
+    for (LabeledPoint p : points) {
       if (model.predict(p.features()) == p.label()) {
         correct += 1;
       }
@@ -69,7 +75,7 @@ public class JavaNaiveBayesSuite implements Serializable {
 
   @Test
   public void runUsingConstructor() {
-    JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
+    JavaRDD<LabeledPoint> testRDD = jsc.parallelize(POINTS, 2).cache();
 
     NaiveBayes nb = new NaiveBayes().setLambda(1.0);
     NaiveBayesModel model = nb.run(testRDD.rdd());
@@ -80,7 +86,7 @@ public class JavaNaiveBayesSuite implements Serializable {
 
   @Test
   public void runUsingStaticMethods() {
-    JavaRDD<LabeledPoint> testRDD = sc.parallelize(POINTS, 2).cache();
+    JavaRDD<LabeledPoint> testRDD = jsc.parallelize(POINTS, 2).cache();
 
     NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd());
     int numAccurate1 = validatePrediction(POINTS, model1);
@@ -93,13 +99,14 @@ public class JavaNaiveBayesSuite implements Serializable {
 
   @Test
   public void testPredictJavaRDD() {
-    JavaRDD<LabeledPoint> examples = sc.parallelize(POINTS, 2).cache();
+    JavaRDD<LabeledPoint> examples = jsc.parallelize(POINTS, 2).cache();
     NaiveBayesModel model = NaiveBayes.train(examples.rdd());
     JavaRDD<Vector> vectors = examples.map(new Function<LabeledPoint, 
Vector>() {
       @Override
       public Vector call(LabeledPoint v) throws Exception {
         return v.features();
-      }});
+      }
+    });
     JavaRDD<Double> predictions = model.predict(vectors);
     // Should be able to get the first prediction.
     predictions.first();

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
index 31b9f3e..2a090c0 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
@@ -28,24 +28,30 @@ import org.junit.Test;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.SparkSession;
 
 public class JavaSVMSuite implements Serializable {
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaSVMSuite");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaSVMSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   int validatePrediction(List<LabeledPoint> validationData, SVMModel model) {
     int numAccurate = 0;
-    for (LabeledPoint point: validationData) {
+    for (LabeledPoint point : validationData) {
       Double prediction = model.predict(point.features());
       if (prediction == point.label()) {
         numAccurate++;
@@ -60,16 +66,16 @@ public class JavaSVMSuite implements Serializable {
     double A = 2.0;
     double[] weights = {-1.5, 1.0};
 
-    JavaRDD<LabeledPoint> testRDD = 
sc.parallelize(SVMSuite.generateSVMInputAsList(A,
-        weights, nPoints, 42), 2).cache();
+    JavaRDD<LabeledPoint> testRDD = 
jsc.parallelize(SVMSuite.generateSVMInputAsList(A,
+      weights, nPoints, 42), 2).cache();
     List<LabeledPoint> validationData =
-        SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
+      SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
 
     SVMWithSGD svmSGDImpl = new SVMWithSGD();
     svmSGDImpl.setIntercept(true);
     svmSGDImpl.optimizer().setStepSize(1.0)
-                          .setRegParam(1.0)
-                          .setNumIterations(100);
+      .setRegParam(1.0)
+      .setNumIterations(100);
     SVMModel model = svmSGDImpl.run(testRDD.rdd());
 
     int numAccurate = validatePrediction(validationData, model);
@@ -82,10 +88,10 @@ public class JavaSVMSuite implements Serializable {
     double A = 0.0;
     double[] weights = {-1.5, 1.0};
 
-    JavaRDD<LabeledPoint> testRDD = 
sc.parallelize(SVMSuite.generateSVMInputAsList(A,
-        weights, nPoints, 42), 2).cache();
+    JavaRDD<LabeledPoint> testRDD = 
jsc.parallelize(SVMSuite.generateSVMInputAsList(A,
+      weights, nPoints, 42), 2).cache();
     List<LabeledPoint> validationData =
-        SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
+      SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
 
     SVMModel model = SVMWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0, 1.0);
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
 
b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
index a714620..7f29b05 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering;
 import java.io.Serializable;
 
 import com.google.common.collect.Lists;
+
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
@@ -29,27 +30,33 @@ import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.SparkSession;
 
 public class JavaBisectingKMeansSuite implements Serializable {
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", this.getClass().getSimpleName());
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaBisectingKMeansSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
   public void twoDimensionalData() {
-    JavaRDD<Vector> points = sc.parallelize(Lists.newArrayList(
+    JavaRDD<Vector> points = jsc.parallelize(Lists.newArrayList(
       Vectors.dense(4, -1),
       Vectors.dense(4, 1),
-      Vectors.sparse(2, new int[] {0}, new double[] {1.0})
+      Vectors.sparse(2, new int[]{0}, new double[]{1.0})
     ), 2);
 
     BisectingKMeans bkm = new BisectingKMeans()
@@ -58,15 +65,15 @@ public class JavaBisectingKMeansSuite implements 
Serializable {
       .setSeed(1L);
     BisectingKMeansModel model = bkm.run(points);
     Assert.assertEquals(3, model.k());
-    Assert.assertArrayEquals(new double[] {3.0, 0.0}, 
model.root().center().toArray(), 1e-12);
-    for (ClusteringTreeNode child: model.root().children()) {
+    Assert.assertArrayEquals(new double[]{3.0, 0.0}, 
model.root().center().toArray(), 1e-12);
+    for (ClusteringTreeNode child : model.root().children()) {
       double[] center = child.center().toArray();
       if (center[0] > 2) {
         Assert.assertEquals(2, child.size());
-        Assert.assertArrayEquals(new double[] {4.0, 0.0}, center, 1e-12);
+        Assert.assertArrayEquals(new double[]{4.0, 0.0}, center, 1e-12);
       } else {
         Assert.assertEquals(1, child.size());
-        Assert.assertArrayEquals(new double[] {1.0, 0.0}, center, 1e-12);
+        Assert.assertArrayEquals(new double[]{1.0, 0.0}, center, 1e-12);
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
 
b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
index 123f78d..20edd08 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
@@ -21,29 +21,35 @@ import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
 
+import static org.junit.Assert.assertEquals;
+
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
-import static org.junit.Assert.assertEquals;
-
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.SparkSession;
 
 public class JavaGaussianMixtureSuite implements Serializable {
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaGaussianMixture");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaGaussianMixture")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -54,7 +60,7 @@ public class JavaGaussianMixtureSuite implements Serializable 
{
       Vectors.dense(1.0, 4.0, 6.0)
     );
 
-    JavaRDD<Vector> data = sc.parallelize(points, 2);
+    JavaRDD<Vector> data = jsc.parallelize(points, 2);
     GaussianMixtureModel model = new 
GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234)
       .run(data);
     assertEquals(model.gaussians().length, 2);

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
index ad06676..4e5b87f 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
@@ -21,28 +21,35 @@ import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
 
+import static org.junit.Assert.assertEquals;
+
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
-import static org.junit.Assert.*;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.SparkSession;
 
 public class JavaKMeansSuite implements Serializable {
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaKMeans");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaKMeans")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -55,7 +62,7 @@ public class JavaKMeansSuite implements Serializable {
 
     Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
 
-    JavaRDD<Vector> data = sc.parallelize(points, 2);
+    JavaRDD<Vector> data = jsc.parallelize(points, 2);
     KMeansModel model = KMeans.train(data.rdd(), 1, 1, 1, 
KMeans.K_MEANS_PARALLEL());
     assertEquals(1, model.clusterCenters().length);
     assertEquals(expectedCenter, model.clusterCenters()[0]);
@@ -74,7 +81,7 @@ public class JavaKMeansSuite implements Serializable {
 
     Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
 
-    JavaRDD<Vector> data = sc.parallelize(points, 2);
+    JavaRDD<Vector> data = jsc.parallelize(points, 2);
     KMeansModel model = new 
KMeans().setK(1).setMaxIterations(5).run(data.rdd());
     assertEquals(1, model.clusterCenters().length);
     assertEquals(expectedCenter, model.clusterCenters()[0]);
@@ -94,7 +101,7 @@ public class JavaKMeansSuite implements Serializable {
       Vectors.dense(1.0, 3.0, 0.0),
       Vectors.dense(1.0, 4.0, 6.0)
     );
-    JavaRDD<Vector> data = sc.parallelize(points, 2);
+    JavaRDD<Vector> data = jsc.parallelize(points, 2);
     KMeansModel model = new 
KMeans().setK(1).setMaxIterations(5).run(data.rdd());
     JavaRDD<Integer> predictions = model.predict(data);
     // Should be able to get the first prediction.

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
index db19b30..f16585a 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -27,37 +27,42 @@ import scala.Tuple3;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
-import static org.junit.Assert.assertArrayEquals;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.*;
 
-import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
 import org.apache.spark.mllib.linalg.Matrix;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.SparkSession;
 
 public class JavaLDASuite implements Serializable {
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaLDA");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaLDASuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
+
     ArrayList<Tuple2<Long, Vector>> tinyCorpus = new ArrayList<>();
     for (int i = 0; i < LDASuite.tinyCorpus().length; i++) {
-      tinyCorpus.add(new Tuple2<>((Long)LDASuite.tinyCorpus()[i]._1(),
-          LDASuite.tinyCorpus()[i]._2()));
+      tinyCorpus.add(new Tuple2<>((Long) LDASuite.tinyCorpus()[i]._1(),
+        LDASuite.tinyCorpus()[i]._2()));
     }
-    JavaRDD<Tuple2<Long, Vector>> tmpCorpus = sc.parallelize(tinyCorpus, 2);
+    JavaRDD<Tuple2<Long, Vector>> tmpCorpus = jsc.parallelize(tinyCorpus, 2);
     corpus = JavaPairRDD.fromJavaRDD(tmpCorpus);
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -95,7 +100,7 @@ public class JavaLDASuite implements Serializable {
       .setMaxIterations(5)
       .setSeed(12345);
 
-    DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus);
+    DistributedLDAModel model = (DistributedLDAModel) lda.run(corpus);
 
     // Check: basic parameters
     LocalLDAModel localModel = model.toLocal();
@@ -124,7 +129,7 @@ public class JavaLDASuite implements Serializable {
         public Boolean call(Tuple2<Long, Vector> tuple2) {
           return Vectors.norm(tuple2._2(), 1.0) != 0.0;
         }
-    });
+      });
     assertEquals(topicDistributions.count(), nonEmptyCorpus.count());
 
     // Check: javaTopTopicsPerDocuments
@@ -179,7 +184,7 @@ public class JavaLDASuite implements Serializable {
 
   @Test
   public void localLdaMethods() {
-    JavaRDD<Tuple2<Long, Vector>> docs = sc.parallelize(toyData, 2);
+    JavaRDD<Tuple2<Long, Vector>> docs = jsc.parallelize(toyData, 2);
     JavaPairRDD<Long, Vector> pairedDocs = JavaPairRDD.fromJavaRDD(docs);
 
     // check: topicDistributions
@@ -191,7 +196,7 @@ public class JavaLDASuite implements Serializable {
     // check: logLikelihood.
     ArrayList<Tuple2<Long, Vector>> docsSingleWord = new ArrayList<>();
     docsSingleWord.add(new Tuple2<>(0L, Vectors.dense(1.0, 0.0, 0.0)));
-    JavaPairRDD<Long, Vector> single = 
JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord));
+    JavaPairRDD<Long, Vector> single = 
JavaPairRDD.fromJavaRDD(jsc.parallelize(docsSingleWord));
     double logLikelihood = toyModel.logLikelihood(single);
   }
 
@@ -199,7 +204,7 @@ public class JavaLDASuite implements Serializable {
   private static int tinyVocabSize = LDASuite.tinyVocabSize();
   private static Matrix tinyTopics = LDASuite.tinyTopics();
   private static Tuple2<int[], double[]>[] tinyTopicDescription =
-      LDASuite.tinyTopicDescription();
+    LDASuite.tinyTopicDescription();
   private JavaPairRDD<Long, Vector> corpus;
   private LocalLDAModel toyModel = LDASuite.toyModel();
   private ArrayList<Tuple2<Long, Vector>> toyData = LDASuite.javaToyData();

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
 
b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
index 62edbd3..d1d618f 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
@@ -27,8 +27,6 @@ import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
-import static org.apache.spark.streaming.JavaTestUtils.*;
-
 import org.apache.spark.SparkConf;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.linalg.Vectors;
@@ -36,6 +34,7 @@ import org.apache.spark.streaming.Duration;
 import org.apache.spark.streaming.api.java.JavaDStream;
 import org.apache.spark.streaming.api.java.JavaPairDStream;
 import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import static org.apache.spark.streaming.JavaTestUtils.*;
 
 public class JavaStreamingKMeansSuite implements Serializable {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
 
b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
index fa4d334..6a096d6 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
@@ -31,27 +31,34 @@ import org.junit.Test;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
 
 public class JavaRankingMetricsSuite implements Serializable {
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
   private transient JavaRDD<Tuple2<List<Integer>, List<Integer>>> 
predictionAndLabels;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaRankingMetricsSuite");
-    predictionAndLabels = sc.parallelize(Arrays.asList(
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaPCASuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
+
+    predictionAndLabels = jsc.parallelize(Arrays.asList(
       Tuple2$.MODULE$.apply(
         Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 
4, 5)),
       Tuple2$.MODULE$.apply(
-          Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 
3)),
+        Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)),
       Tuple2$.MODULE$.apply(
-          Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2);
+        Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2);
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
index 8a320af..de50fb8 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
@@ -29,19 +29,25 @@ import org.junit.Test;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.sql.SparkSession;
 
 public class JavaTfIdfSuite implements Serializable {
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaTfIdfSuite");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaPCASuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -49,7 +55,7 @@ public class JavaTfIdfSuite implements Serializable {
     // The tests are to check Java compatibility.
     HashingTF tf = new HashingTF();
     @SuppressWarnings("unchecked")
-    JavaRDD<List<String>> documents = sc.parallelize(Arrays.asList(
+    JavaRDD<List<String>> documents = jsc.parallelize(Arrays.asList(
       Arrays.asList("this is a sentence".split(" ")),
       Arrays.asList("this is another sentence".split(" ")),
       Arrays.asList("this is still a sentence".split(" "))), 2);
@@ -59,7 +65,7 @@ public class JavaTfIdfSuite implements Serializable {
     JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
     List<Vector> localTfIdfs = tfIdfs.collect();
     int indexOfThis = tf.indexOf("this");
-    for (Vector v: localTfIdfs) {
+    for (Vector v : localTfIdfs) {
       Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
     }
   }
@@ -69,7 +75,7 @@ public class JavaTfIdfSuite implements Serializable {
     // The tests are to check Java compatibility.
     HashingTF tf = new HashingTF();
     @SuppressWarnings("unchecked")
-    JavaRDD<List<String>> documents = sc.parallelize(Arrays.asList(
+    JavaRDD<List<String>> documents = jsc.parallelize(Arrays.asList(
       Arrays.asList("this is a sentence".split(" ")),
       Arrays.asList("this is another sentence".split(" ")),
       Arrays.asList("this is still a sentence".split(" "))), 2);
@@ -79,7 +85,7 @@ public class JavaTfIdfSuite implements Serializable {
     JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
     List<Vector> localTfIdfs = tfIdfs.collect();
     int indexOfThis = tf.indexOf("this");
-    for (Vector v: localTfIdfs) {
+    for (Vector v : localTfIdfs) {
       Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
index e13ed07..64885cc 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
@@ -21,9 +21,10 @@ import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
 
+import com.google.common.base.Strings;
+
 import scala.Tuple2;
 
-import com.google.common.base.Strings;
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
@@ -31,19 +32,25 @@ import org.junit.Test;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
 
 public class JavaWord2VecSuite implements Serializable {
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaWord2VecSuite");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaPCASuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -53,7 +60,7 @@ public class JavaWord2VecSuite implements Serializable {
     String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10);
     List<String> words = Arrays.asList(sentence.split(" "));
     List<List<String>> localDoc = Arrays.asList(words, words);
-    JavaRDD<List<String>> doc = sc.parallelize(localDoc);
+    JavaRDD<List<String>> doc = jsc.parallelize(localDoc);
     Word2Vec word2vec = new Word2Vec()
       .setVectorSize(10)
       .setSeed(42L);

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
index 2bef7a8..fdc19a5 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
@@ -26,32 +26,37 @@ import org.junit.Test;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
+import org.apache.spark.sql.SparkSession;
 
 public class JavaAssociationRulesSuite implements Serializable {
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaFPGrowth");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaAssociationRulesSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
   public void runAssociationRules() {
 
     @SuppressWarnings("unchecked")
-    JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = 
sc.parallelize(Arrays.asList(
-      new FreqItemset<String>(new String[] {"a"}, 15L),
-      new FreqItemset<String>(new String[] {"b"}, 35L),
-      new FreqItemset<String>(new String[] {"a", "b"}, 12L)
+    JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = 
jsc.parallelize(Arrays.asList(
+      new FreqItemset<String>(new String[]{"a"}, 15L),
+      new FreqItemset<String>(new String[]{"b"}, 35L),
+      new FreqItemset<String>(new String[]{"a", "b"}, 12L)
     ));
 
     JavaRDD<AssociationRules.Rule<String>> results = (new 
AssociationRules()).run(freqItemsets);
   }
 }
-

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
index 916fff1..f235251 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
@@ -22,34 +22,41 @@ import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
 
+import static org.junit.Assert.assertEquals;
+
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
-import static org.junit.Assert.*;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
 import org.apache.spark.util.Utils;
 
 public class JavaFPGrowthSuite implements Serializable {
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaFPGrowth");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaFPGrowth")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
   public void runFPGrowth() {
 
     @SuppressWarnings("unchecked")
-    JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList(
+    JavaRDD<List<String>> rdd = jsc.parallelize(Arrays.asList(
       Arrays.asList("r z h k p".split(" ")),
       Arrays.asList("z y x w v u t s".split(" ")),
       Arrays.asList("s x o n r".split(" ")),
@@ -65,7 +72,7 @@ public class JavaFPGrowthSuite implements Serializable {
     List<FPGrowth.FreqItemset<String>> freqItemsets = 
model.freqItemsets().toJavaRDD().collect();
     assertEquals(18, freqItemsets.size());
 
-    for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
+    for (FPGrowth.FreqItemset<String> itemset : freqItemsets) {
       // Test return types.
       List<String> items = itemset.javaItems();
       long freq = itemset.freq();
@@ -76,7 +83,7 @@ public class JavaFPGrowthSuite implements Serializable {
   public void runFPGrowthSaveLoad() {
 
     @SuppressWarnings("unchecked")
-    JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList(
+    JavaRDD<List<String>> rdd = jsc.parallelize(Arrays.asList(
       Arrays.asList("r z h k p".split(" ")),
       Arrays.asList("z y x w v u t s".split(" ")),
       Arrays.asList("s x o n r".split(" ")),
@@ -94,15 +101,15 @@ public class JavaFPGrowthSuite implements Serializable {
     String outputPath = tempDir.getPath();
 
     try {
-      model.save(sc.sc(), outputPath);
+      model.save(spark.sparkContext(), outputPath);
       @SuppressWarnings("unchecked")
       FPGrowthModel<String> newModel =
-          (FPGrowthModel<String>) FPGrowthModel.load(sc.sc(), outputPath);
+        (FPGrowthModel<String>) FPGrowthModel.load(spark.sparkContext(), 
outputPath);
       List<FPGrowth.FreqItemset<String>> freqItemsets = 
newModel.freqItemsets().toJavaRDD()
         .collect();
       assertEquals(18, freqItemsets.size());
 
-      for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
+      for (FPGrowth.FreqItemset<String> itemset : freqItemsets) {
         // Test return types.
         List<String> items = itemset.javaItems();
         long freq = itemset.freq();

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
index 8a67793..bf7f1fc 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
@@ -29,25 +29,31 @@ import org.junit.Test;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence;
+import org.apache.spark.sql.SparkSession;
 import org.apache.spark.util.Utils;
 
 public class JavaPrefixSpanSuite {
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaPrefixSpan");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaPrefixSpan")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
   public void runPrefixSpan() {
-    JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList(
+    JavaRDD<List<List<Integer>>> sequences = jsc.parallelize(Arrays.asList(
       Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
       Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 
2)),
       Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
@@ -61,7 +67,7 @@ public class JavaPrefixSpanSuite {
     List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
     Assert.assertEquals(5, localFreqSeqs.size());
     // Check that each frequent sequence could be materialized.
-    for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) {
+    for (PrefixSpan.FreqSequence<Integer> freqSeq : localFreqSeqs) {
       List<List<Integer>> seq = freqSeq.javaSequence();
       long freq = freqSeq.freq();
     }
@@ -69,7 +75,7 @@ public class JavaPrefixSpanSuite {
 
   @Test
   public void runPrefixSpanSaveLoad() {
-    JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList(
+    JavaRDD<List<List<Integer>>> sequences = jsc.parallelize(Arrays.asList(
       Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
       Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 
2)),
       Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
@@ -85,13 +91,13 @@ public class JavaPrefixSpanSuite {
     String outputPath = tempDir.getPath();
 
     try {
-      model.save(sc.sc(), outputPath);
-      PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath);
+      model.save(spark.sparkContext(), outputPath);
+      PrefixSpanModel newModel = PrefixSpanModel.load(spark.sparkContext(), 
outputPath);
       JavaRDD<FreqSequence<Integer>> freqSeqs = 
newModel.freqSequences().toJavaRDD();
       List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
       Assert.assertEquals(5, localFreqSeqs.size());
       // Check that each frequent sequence could be materialized.
-      for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) {
+      for (PrefixSpan.FreqSequence<Integer> freqSeq : localFreqSeqs) {
         List<List<Integer>> seq = freqSeq.javaSequence();
         long freq = freqSeq.freq();
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
index 8beea10..92fc578 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java
@@ -17,147 +17,149 @@
 
 package org.apache.spark.mllib.linalg;
 
-import static org.junit.Assert.*;
-import org.junit.Test;
-
 import java.io.Serializable;
 import java.util.Random;
 
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+import org.junit.Test;
+
 public class JavaMatricesSuite implements Serializable {
 
-    @Test
-    public void randMatrixConstruction() {
-        Random rng = new Random(24);
-        Matrix r = Matrices.rand(3, 4, rng);
-        rng.setSeed(24);
-        DenseMatrix dr = DenseMatrix.rand(3, 4, rng);
-        assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
-
-        rng.setSeed(24);
-        Matrix rn = Matrices.randn(3, 4, rng);
-        rng.setSeed(24);
-        DenseMatrix drn = DenseMatrix.randn(3, 4, rng);
-        assertArrayEquals(rn.toArray(), drn.toArray(), 0.0);
-
-        rng.setSeed(24);
-        Matrix s = Matrices.sprand(3, 4, 0.5, rng);
-        rng.setSeed(24);
-        SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng);
-        assertArrayEquals(s.toArray(), sr.toArray(), 0.0);
-
-        rng.setSeed(24);
-        Matrix sn = Matrices.sprandn(3, 4, 0.5, rng);
-        rng.setSeed(24);
-        SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng);
-        assertArrayEquals(sn.toArray(), srn.toArray(), 0.0);
-    }
-
-    @Test
-    public void identityMatrixConstruction() {
-        Matrix r = Matrices.eye(2);
-        DenseMatrix dr = DenseMatrix.eye(2);
-        SparseMatrix sr = SparseMatrix.speye(2);
-        assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
-        assertArrayEquals(sr.toArray(), dr.toArray(), 0.0);
-        assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0);
-    }
-
-    @Test
-    public void diagonalMatrixConstruction() {
-        Vector v = Vectors.dense(1.0, 0.0, 2.0);
-        Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0});
-
-        Matrix m = Matrices.diag(v);
-        Matrix sm = Matrices.diag(sv);
-        DenseMatrix d = DenseMatrix.diag(v);
-        DenseMatrix sd = DenseMatrix.diag(sv);
-        SparseMatrix s = SparseMatrix.spdiag(v);
-        SparseMatrix ss = SparseMatrix.spdiag(sv);
-
-        assertArrayEquals(m.toArray(), sm.toArray(), 0.0);
-        assertArrayEquals(d.toArray(), sm.toArray(), 0.0);
-        assertArrayEquals(d.toArray(), sd.toArray(), 0.0);
-        assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
-        assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
-        assertArrayEquals(s.values(), ss.values(), 0.0);
-        assertEquals(2, s.values().length);
-        assertEquals(2, ss.values().length);
-        assertEquals(4, s.colPtrs().length);
-        assertEquals(4, ss.colPtrs().length);
-    }
-
-    @Test
-    public void zerosMatrixConstruction() {
-        Matrix z = Matrices.zeros(2, 2);
-        Matrix one = Matrices.ones(2, 2);
-        DenseMatrix dz = DenseMatrix.zeros(2, 2);
-        DenseMatrix done = DenseMatrix.ones(2, 2);
-
-        assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
-        assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
-        assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 
0.0);
-        assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 
0.0);
-    }
-
-    @Test
-    public void sparseDenseConversion() {
-        int m = 3;
-        int n = 2;
-        double[] values = new double[]{1.0, 2.0, 4.0, 5.0};
-        double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0};
-        int[] colPtrs = new int[]{0, 2, 4};
-        int[] rowIndices = new int[]{0, 1, 1, 2};
-
-        SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, 
values);
-        DenseMatrix deMat1 = new DenseMatrix(m, n, allValues);
-
-        SparseMatrix spMat2 = deMat1.toSparse();
-        DenseMatrix deMat2 = spMat1.toDense();
-
-        assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0);
-        assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0);
-    }
-
-    @Test
-    public void concatenateMatrices() {
-        int m = 3;
-        int n = 2;
-
-        Random rng = new Random(42);
-        SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng);
-        rng.setSeed(42);
-        DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng);
-        Matrix deMat2 = Matrices.eye(3);
-        Matrix spMat2 = Matrices.speye(3);
-        Matrix deMat3 = Matrices.eye(2);
-        Matrix spMat3 = Matrices.speye(2);
-
-        Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2});
-        Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2});
-        Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
-        Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});
-
-        assertEquals(3, deHorz1.numRows());
-        assertEquals(3, deHorz2.numRows());
-        assertEquals(3, deHorz3.numRows());
-        assertEquals(3, spHorz.numRows());
-        assertEquals(5, deHorz1.numCols());
-        assertEquals(5, deHorz2.numCols());
-        assertEquals(5, deHorz3.numCols());
-        assertEquals(5, spHorz.numCols());
-
-        Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
-        Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
-        Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
-        Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});
-
-        assertEquals(5, deVert1.numRows());
-        assertEquals(5, deVert2.numRows());
-        assertEquals(5, deVert3.numRows());
-        assertEquals(5, spVert.numRows());
-        assertEquals(2, deVert1.numCols());
-        assertEquals(2, deVert2.numCols());
-        assertEquals(2, deVert3.numCols());
-        assertEquals(2, spVert.numCols());
-    }
+  @Test
+  public void randMatrixConstruction() {
+    Random rng = new Random(24);
+    Matrix r = Matrices.rand(3, 4, rng);
+    rng.setSeed(24);
+    DenseMatrix dr = DenseMatrix.rand(3, 4, rng);
+    assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
+
+    rng.setSeed(24);
+    Matrix rn = Matrices.randn(3, 4, rng);
+    rng.setSeed(24);
+    DenseMatrix drn = DenseMatrix.randn(3, 4, rng);
+    assertArrayEquals(rn.toArray(), drn.toArray(), 0.0);
+
+    rng.setSeed(24);
+    Matrix s = Matrices.sprand(3, 4, 0.5, rng);
+    rng.setSeed(24);
+    SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng);
+    assertArrayEquals(s.toArray(), sr.toArray(), 0.0);
+
+    rng.setSeed(24);
+    Matrix sn = Matrices.sprandn(3, 4, 0.5, rng);
+    rng.setSeed(24);
+    SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng);
+    assertArrayEquals(sn.toArray(), srn.toArray(), 0.0);
+  }
+
+  @Test
+  public void identityMatrixConstruction() {
+    Matrix r = Matrices.eye(2);
+    DenseMatrix dr = DenseMatrix.eye(2);
+    SparseMatrix sr = SparseMatrix.speye(2);
+    assertArrayEquals(r.toArray(), dr.toArray(), 0.0);
+    assertArrayEquals(sr.toArray(), dr.toArray(), 0.0);
+    assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0);
+  }
+
+  @Test
+  public void diagonalMatrixConstruction() {
+    Vector v = Vectors.dense(1.0, 0.0, 2.0);
+    Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0});
+
+    Matrix m = Matrices.diag(v);
+    Matrix sm = Matrices.diag(sv);
+    DenseMatrix d = DenseMatrix.diag(v);
+    DenseMatrix sd = DenseMatrix.diag(sv);
+    SparseMatrix s = SparseMatrix.spdiag(v);
+    SparseMatrix ss = SparseMatrix.spdiag(sv);
+
+    assertArrayEquals(m.toArray(), sm.toArray(), 0.0);
+    assertArrayEquals(d.toArray(), sm.toArray(), 0.0);
+    assertArrayEquals(d.toArray(), sd.toArray(), 0.0);
+    assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
+    assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
+    assertArrayEquals(s.values(), ss.values(), 0.0);
+    assertEquals(2, s.values().length);
+    assertEquals(2, ss.values().length);
+    assertEquals(4, s.colPtrs().length);
+    assertEquals(4, ss.colPtrs().length);
+  }
+
+  @Test
+  public void zerosMatrixConstruction() {
+    Matrix z = Matrices.zeros(2, 2);
+    Matrix one = Matrices.ones(2, 2);
+    DenseMatrix dz = DenseMatrix.zeros(2, 2);
+    DenseMatrix done = DenseMatrix.ones(2, 2);
+
+    assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
+    assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0);
+    assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
+    assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0);
+  }
+
+  @Test
+  public void sparseDenseConversion() {
+    int m = 3;
+    int n = 2;
+    double[] values = new double[]{1.0, 2.0, 4.0, 5.0};
+    double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0};
+    int[] colPtrs = new int[]{0, 2, 4};
+    int[] rowIndices = new int[]{0, 1, 1, 2};
+
+    SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values);
+    DenseMatrix deMat1 = new DenseMatrix(m, n, allValues);
+
+    SparseMatrix spMat2 = deMat1.toSparse();
+    DenseMatrix deMat2 = spMat1.toDense();
+
+    assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0);
+    assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0);
+  }
+
+  @Test
+  public void concatenateMatrices() {
+    int m = 3;
+    int n = 2;
+
+    Random rng = new Random(42);
+    SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng);
+    rng.setSeed(42);
+    DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng);
+    Matrix deMat2 = Matrices.eye(3);
+    Matrix spMat2 = Matrices.speye(3);
+    Matrix deMat3 = Matrices.eye(2);
+    Matrix spMat3 = Matrices.speye(2);
+
+    Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2});
+    Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2});
+    Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
+    Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});
+
+    assertEquals(3, deHorz1.numRows());
+    assertEquals(3, deHorz2.numRows());
+    assertEquals(3, deHorz3.numRows());
+    assertEquals(3, spHorz.numRows());
+    assertEquals(5, deHorz1.numCols());
+    assertEquals(5, deHorz2.numCols());
+    assertEquals(5, deHorz3.numCols());
+    assertEquals(5, spHorz.numCols());
+
+    Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
+    Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
+    Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
+    Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});
+
+    assertEquals(5, deVert1.numRows());
+    assertEquals(5, deVert2.numRows());
+    assertEquals(5, deVert3.numRows());
+    assertEquals(5, spVert.numRows());
+    assertEquals(2, deVert1.numCols());
+    assertEquals(2, deVert2.numCols());
+    assertEquals(2, deVert3.numCols());
+    assertEquals(2, spVert.numCols());
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
index 4ba8e54..817b962 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
@@ -20,10 +20,11 @@ package org.apache.spark.mllib.linalg;
 import java.io.Serializable;
 import java.util.Arrays;
 
+import static org.junit.Assert.assertArrayEquals;
+
 import scala.Tuple2;
 
 import org.junit.Test;
-import static org.junit.Assert.*;
 
 public class JavaVectorsSuite implements Serializable {
 
@@ -37,8 +38,8 @@ public class JavaVectorsSuite implements Serializable {
   public void sparseArrayConstruction() {
     @SuppressWarnings("unchecked")
     Vector v = Vectors.sparse(3, Arrays.asList(
-        new Tuple2<>(0, 2.0),
-        new Tuple2<>(2, 3.0)));
+      new Tuple2<>(0, 2.0),
+      new Tuple2<>(2, 3.0)));
     assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0);
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
index be58691..b449108 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java
@@ -20,29 +20,35 @@ package org.apache.spark.mllib.random;
 import java.io.Serializable;
 import java.util.Arrays;
 
-import org.apache.spark.api.java.JavaRDD;
-import org.junit.Assert;
 import org.junit.After;
+import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
 import org.apache.spark.api.java.JavaDoubleRDD;
+import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.sql.SparkSession;
 import static org.apache.spark.mllib.random.RandomRDDs.*;
 
 public class JavaRandomRDDsSuite {
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaRandomRDDsSuite");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaRandomRDDsSuite")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   @Test
@@ -50,10 +56,10 @@ public class JavaRandomRDDsSuite {
     long m = 1000L;
     int p = 2;
     long seed = 1L;
-    JavaDoubleRDD rdd1 = uniformJavaRDD(sc, m);
-    JavaDoubleRDD rdd2 = uniformJavaRDD(sc, m, p);
-    JavaDoubleRDD rdd3 = uniformJavaRDD(sc, m, p, seed);
-    for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+    JavaDoubleRDD rdd1 = uniformJavaRDD(jsc, m);
+    JavaDoubleRDD rdd2 = uniformJavaRDD(jsc, m, p);
+    JavaDoubleRDD rdd3 = uniformJavaRDD(jsc, m, p, seed);
+    for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
       Assert.assertEquals(m, rdd.count());
     }
   }
@@ -63,10 +69,10 @@ public class JavaRandomRDDsSuite {
     long m = 1000L;
     int p = 2;
     long seed = 1L;
-    JavaDoubleRDD rdd1 = normalJavaRDD(sc, m);
-    JavaDoubleRDD rdd2 = normalJavaRDD(sc, m, p);
-    JavaDoubleRDD rdd3 = normalJavaRDD(sc, m, p, seed);
-    for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+    JavaDoubleRDD rdd1 = normalJavaRDD(jsc, m);
+    JavaDoubleRDD rdd2 = normalJavaRDD(jsc, m, p);
+    JavaDoubleRDD rdd3 = normalJavaRDD(jsc, m, p, seed);
+    for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
       Assert.assertEquals(m, rdd.count());
     }
   }
@@ -78,10 +84,10 @@ public class JavaRandomRDDsSuite {
     long m = 1000L;
     int p = 2;
     long seed = 1L;
-    JavaDoubleRDD rdd1 = logNormalJavaRDD(sc, mean, std, m);
-    JavaDoubleRDD rdd2 = logNormalJavaRDD(sc, mean, std, m, p);
-    JavaDoubleRDD rdd3 = logNormalJavaRDD(sc, mean, std, m, p, seed);
-    for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+    JavaDoubleRDD rdd1 = logNormalJavaRDD(jsc, mean, std, m);
+    JavaDoubleRDD rdd2 = logNormalJavaRDD(jsc, mean, std, m, p);
+    JavaDoubleRDD rdd3 = logNormalJavaRDD(jsc, mean, std, m, p, seed);
+    for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
       Assert.assertEquals(m, rdd.count());
     }
   }
@@ -92,10 +98,10 @@ public class JavaRandomRDDsSuite {
     long m = 1000L;
     int p = 2;
     long seed = 1L;
-    JavaDoubleRDD rdd1 = poissonJavaRDD(sc, mean, m);
-    JavaDoubleRDD rdd2 = poissonJavaRDD(sc, mean, m, p);
-    JavaDoubleRDD rdd3 = poissonJavaRDD(sc, mean, m, p, seed);
-    for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+    JavaDoubleRDD rdd1 = poissonJavaRDD(jsc, mean, m);
+    JavaDoubleRDD rdd2 = poissonJavaRDD(jsc, mean, m, p);
+    JavaDoubleRDD rdd3 = poissonJavaRDD(jsc, mean, m, p, seed);
+    for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
       Assert.assertEquals(m, rdd.count());
     }
   }
@@ -106,10 +112,10 @@ public class JavaRandomRDDsSuite {
     long m = 1000L;
     int p = 2;
     long seed = 1L;
-    JavaDoubleRDD rdd1 = exponentialJavaRDD(sc, mean, m);
-    JavaDoubleRDD rdd2 = exponentialJavaRDD(sc, mean, m, p);
-    JavaDoubleRDD rdd3 = exponentialJavaRDD(sc, mean, m, p, seed);
-    for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+    JavaDoubleRDD rdd1 = exponentialJavaRDD(jsc, mean, m);
+    JavaDoubleRDD rdd2 = exponentialJavaRDD(jsc, mean, m, p);
+    JavaDoubleRDD rdd3 = exponentialJavaRDD(jsc, mean, m, p, seed);
+    for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
       Assert.assertEquals(m, rdd.count());
     }
   }
@@ -117,14 +123,14 @@ public class JavaRandomRDDsSuite {
   @Test
   public void testGammaRDD() {
     double shape = 1.0;
-    double scale = 2.0;
+    double jscale = 2.0;
     long m = 1000L;
     int p = 2;
     long seed = 1L;
-    JavaDoubleRDD rdd1 = gammaJavaRDD(sc, shape, scale, m);
-    JavaDoubleRDD rdd2 = gammaJavaRDD(sc, shape, scale, m, p);
-    JavaDoubleRDD rdd3 = gammaJavaRDD(sc, shape, scale, m, p, seed);
-    for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+    JavaDoubleRDD rdd1 = gammaJavaRDD(jsc, shape, jscale, m);
+    JavaDoubleRDD rdd2 = gammaJavaRDD(jsc, shape, jscale, m, p);
+    JavaDoubleRDD rdd3 = gammaJavaRDD(jsc, shape, jscale, m, p, seed);
+    for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
       Assert.assertEquals(m, rdd.count());
     }
   }
@@ -137,10 +143,10 @@ public class JavaRandomRDDsSuite {
     int n = 10;
     int p = 2;
     long seed = 1L;
-    JavaRDD<Vector> rdd1 = uniformJavaVectorRDD(sc, m, n);
-    JavaRDD<Vector> rdd2 = uniformJavaVectorRDD(sc, m, n, p);
-    JavaRDD<Vector> rdd3 = uniformJavaVectorRDD(sc, m, n, p, seed);
-    for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+    JavaRDD<Vector> rdd1 = uniformJavaVectorRDD(jsc, m, n);
+    JavaRDD<Vector> rdd2 = uniformJavaVectorRDD(jsc, m, n, p);
+    JavaRDD<Vector> rdd3 = uniformJavaVectorRDD(jsc, m, n, p, seed);
+    for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
       Assert.assertEquals(m, rdd.count());
       Assert.assertEquals(n, rdd.first().size());
     }
@@ -153,10 +159,10 @@ public class JavaRandomRDDsSuite {
     int n = 10;
     int p = 2;
     long seed = 1L;
-    JavaRDD<Vector> rdd1 = normalJavaVectorRDD(sc, m, n);
-    JavaRDD<Vector> rdd2 = normalJavaVectorRDD(sc, m, n, p);
-    JavaRDD<Vector> rdd3 = normalJavaVectorRDD(sc, m, n, p, seed);
-    for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+    JavaRDD<Vector> rdd1 = normalJavaVectorRDD(jsc, m, n);
+    JavaRDD<Vector> rdd2 = normalJavaVectorRDD(jsc, m, n, p);
+    JavaRDD<Vector> rdd3 = normalJavaVectorRDD(jsc, m, n, p, seed);
+    for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
       Assert.assertEquals(m, rdd.count());
       Assert.assertEquals(n, rdd.first().size());
     }
@@ -171,10 +177,10 @@ public class JavaRandomRDDsSuite {
     int n = 10;
     int p = 2;
     long seed = 1L;
-    JavaRDD<Vector> rdd1 = logNormalJavaVectorRDD(sc, mean, std, m, n);
-    JavaRDD<Vector> rdd2 = logNormalJavaVectorRDD(sc, mean, std, m, n, p);
-    JavaRDD<Vector> rdd3 = logNormalJavaVectorRDD(sc, mean, std, m, n, p, 
seed);
-    for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+    JavaRDD<Vector> rdd1 = logNormalJavaVectorRDD(jsc, mean, std, m, n);
+    JavaRDD<Vector> rdd2 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p);
+    JavaRDD<Vector> rdd3 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p, 
seed);
+    for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
       Assert.assertEquals(m, rdd.count());
       Assert.assertEquals(n, rdd.first().size());
     }
@@ -188,10 +194,10 @@ public class JavaRandomRDDsSuite {
     int n = 10;
     int p = 2;
     long seed = 1L;
-    JavaRDD<Vector> rdd1 = poissonJavaVectorRDD(sc, mean, m, n);
-    JavaRDD<Vector> rdd2 = poissonJavaVectorRDD(sc, mean, m, n, p);
-    JavaRDD<Vector> rdd3 = poissonJavaVectorRDD(sc, mean, m, n, p, seed);
-    for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+    JavaRDD<Vector> rdd1 = poissonJavaVectorRDD(jsc, mean, m, n);
+    JavaRDD<Vector> rdd2 = poissonJavaVectorRDD(jsc, mean, m, n, p);
+    JavaRDD<Vector> rdd3 = poissonJavaVectorRDD(jsc, mean, m, n, p, seed);
+    for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
       Assert.assertEquals(m, rdd.count());
       Assert.assertEquals(n, rdd.first().size());
     }
@@ -205,10 +211,10 @@ public class JavaRandomRDDsSuite {
     int n = 10;
     int p = 2;
     long seed = 1L;
-    JavaRDD<Vector> rdd1 = exponentialJavaVectorRDD(sc, mean, m, n);
-    JavaRDD<Vector> rdd2 = exponentialJavaVectorRDD(sc, mean, m, n, p);
-    JavaRDD<Vector> rdd3 = exponentialJavaVectorRDD(sc, mean, m, n, p, seed);
-    for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+    JavaRDD<Vector> rdd1 = exponentialJavaVectorRDD(jsc, mean, m, n);
+    JavaRDD<Vector> rdd2 = exponentialJavaVectorRDD(jsc, mean, m, n, p);
+    JavaRDD<Vector> rdd3 = exponentialJavaVectorRDD(jsc, mean, m, n, p, seed);
+    for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
       Assert.assertEquals(m, rdd.count());
       Assert.assertEquals(n, rdd.first().size());
     }
@@ -218,15 +224,15 @@ public class JavaRandomRDDsSuite {
   @SuppressWarnings("unchecked")
   public void testGammaVectorRDD() {
     double shape = 1.0;
-    double scale = 2.0;
+    double jscale = 2.0;
     long m = 100L;
     int n = 10;
     int p = 2;
     long seed = 1L;
-    JavaRDD<Vector> rdd1 = gammaJavaVectorRDD(sc, shape, scale, m, n);
-    JavaRDD<Vector> rdd2 = gammaJavaVectorRDD(sc, shape, scale, m, n, p);
-    JavaRDD<Vector> rdd3 = gammaJavaVectorRDD(sc, shape, scale, m, n, p, seed);
-    for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+    JavaRDD<Vector> rdd1 = gammaJavaVectorRDD(jsc, shape, jscale, m, n);
+    JavaRDD<Vector> rdd2 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p);
+    JavaRDD<Vector> rdd3 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p, 
seed);
+    for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
       Assert.assertEquals(m, rdd.count());
       Assert.assertEquals(n, rdd.first().size());
     }
@@ -238,10 +244,10 @@ public class JavaRandomRDDsSuite {
     long seed = 1L;
     int numPartitions = 0;
     StringGenerator gen = new StringGenerator();
-    JavaRDD<String> rdd1 = randomJavaRDD(sc, gen, size);
-    JavaRDD<String> rdd2 = randomJavaRDD(sc, gen, size, numPartitions);
-    JavaRDD<String> rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed);
-    for (JavaRDD<String> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+    JavaRDD<String> rdd1 = randomJavaRDD(jsc, gen, size);
+    JavaRDD<String> rdd2 = randomJavaRDD(jsc, gen, size, numPartitions);
+    JavaRDD<String> rdd3 = randomJavaRDD(jsc, gen, size, numPartitions, seed);
+    for (JavaRDD<String> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
       Assert.assertEquals(size, rdd.count());
       Assert.assertEquals(2, rdd.first().length());
     }
@@ -255,10 +261,10 @@ public class JavaRandomRDDsSuite {
     int n = 10;
     int p = 2;
     long seed = 1L;
-    JavaRDD<Vector> rdd1 = randomJavaVectorRDD(sc, generator, m, n);
-    JavaRDD<Vector> rdd2 = randomJavaVectorRDD(sc, generator, m, n, p);
-    JavaRDD<Vector> rdd3 = randomJavaVectorRDD(sc, generator, m, n, p, seed);
-    for (JavaRDD<Vector> rdd: Arrays.asList(rdd1, rdd2, rdd3)) {
+    JavaRDD<Vector> rdd1 = randomJavaVectorRDD(jsc, generator, m, n);
+    JavaRDD<Vector> rdd2 = randomJavaVectorRDD(jsc, generator, m, n, p);
+    JavaRDD<Vector> rdd3 = randomJavaVectorRDD(jsc, generator, m, n, p, seed);
+    for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
       Assert.assertEquals(m, rdd.count());
       Assert.assertEquals(n, rdd.first().size());
     }
@@ -271,10 +277,12 @@ class StringGenerator implements 
RandomDataGenerator<String>, Serializable {
   public String nextValue() {
     return "42";
   }
+
   @Override
   public StringGenerator copy() {
     return new StringGenerator();
   }
+
   @Override
   public void setSeed(long seed) {
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
index d0bf7f5..aa78405 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
@@ -32,40 +32,46 @@ import org.junit.Test;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
 
 public class JavaALSSuite implements Serializable {
-  private transient JavaSparkContext sc;
+  private transient SparkSession spark;
+  private transient JavaSparkContext jsc;
 
   @Before
   public void setUp() {
-    sc = new JavaSparkContext("local", "JavaALS");
+    spark = SparkSession.builder()
+      .master("local")
+      .appName("JavaALS")
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
   }
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
+    spark.stop();
+    spark = null;
   }
 
   private void validatePrediction(
-      MatrixFactorizationModel model,
-      int users,
-      int products,
-      double[] trueRatings,
-      double matchThreshold,
-      boolean implicitPrefs,
-      double[] truePrefs) {
+    MatrixFactorizationModel model,
+    int users,
+    int products,
+    double[] trueRatings,
+    double matchThreshold,
+    boolean implicitPrefs,
+    double[] truePrefs) {
     List<Tuple2<Integer, Integer>> localUsersProducts = new ArrayList<>(users 
* products);
-    for (int u=0; u < users; ++u) {
-      for (int p=0; p < products; ++p) {
+    for (int u = 0; u < users; ++u) {
+      for (int p = 0; p < products; ++p) {
         localUsersProducts.add(new Tuple2<>(u, p));
       }
     }
-    JavaPairRDD<Integer, Integer> usersProducts = 
sc.parallelizePairs(localUsersProducts);
+    JavaPairRDD<Integer, Integer> usersProducts = 
jsc.parallelizePairs(localUsersProducts);
     List<Rating> predictedRatings = model.predict(usersProducts).collect();
     Assert.assertEquals(users * products, predictedRatings.size());
     if (!implicitPrefs) {
-      for (Rating r: predictedRatings) {
+      for (Rating r : predictedRatings) {
         double prediction = r.rating();
         double correct = trueRatings[r.product() * users + r.user()];
         Assert.assertTrue(String.format("Prediction=%2.4f not below match 
threshold of %2.2f",
@@ -76,7 +82,7 @@ public class JavaALSSuite implements Serializable {
       // (ref Mahout's implicit ALS tests)
       double sqErr = 0.0;
       double denom = 0.0;
-      for (Rating r: predictedRatings) {
+      for (Rating r : predictedRatings) {
         double prediction = r.rating();
         double truePref = truePrefs[r.product() * users + r.user()];
         double confidence = 1.0 +
@@ -98,9 +104,9 @@ public class JavaALSSuite implements Serializable {
     int users = 50;
     int products = 100;
     Tuple3<List<Rating>, double[], double[]> testData =
-        ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, 
false);
+      ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, 
false);
 
-    JavaRDD<Rating> data = sc.parallelize(testData._1());
+    JavaRDD<Rating> data = jsc.parallelize(testData._1());
     MatrixFactorizationModel model = ALS.train(data.rdd(), features, 
iterations);
     validatePrediction(model, users, products, testData._2(), 0.3, false, 
testData._3());
   }
@@ -112,9 +118,9 @@ public class JavaALSSuite implements Serializable {
     int users = 100;
     int products = 200;
     Tuple3<List<Rating>, double[], double[]> testData =
-        ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, 
false);
+      ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, 
false);
 
-    JavaRDD<Rating> data = sc.parallelize(testData._1());
+    JavaRDD<Rating> data = jsc.parallelize(testData._1());
 
     MatrixFactorizationModel model = new ALS().setRank(features)
       .setIterations(iterations)
@@ -129,9 +135,9 @@ public class JavaALSSuite implements Serializable {
     int users = 80;
     int products = 160;
     Tuple3<List<Rating>, double[], double[]> testData =
-        ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, 
false);
+      ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, 
false);
 
-    JavaRDD<Rating> data = sc.parallelize(testData._1());
+    JavaRDD<Rating> data = jsc.parallelize(testData._1());
     MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, 
iterations);
     validatePrediction(model, users, products, testData._2(), 0.4, true, 
testData._3());
   }
@@ -143,9 +149,9 @@ public class JavaALSSuite implements Serializable {
     int users = 100;
     int products = 200;
     Tuple3<List<Rating>, double[], double[]> testData =
-        ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, 
false);
+      ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, 
false);
 
-    JavaRDD<Rating> data = sc.parallelize(testData._1());
+    JavaRDD<Rating> data = jsc.parallelize(testData._1());
 
     MatrixFactorizationModel model = new ALS().setRank(features)
       .setIterations(iterations)
@@ -161,9 +167,9 @@ public class JavaALSSuite implements Serializable {
     int users = 80;
     int products = 160;
     Tuple3<List<Rating>, double[], double[]> testData =
-        ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, 
true);
+      ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, 
true);
 
-    JavaRDD<Rating> data = sc.parallelize(testData._1());
+    JavaRDD<Rating> data = jsc.parallelize(testData._1());
     MatrixFactorizationModel model = new ALS().setRank(features)
       .setIterations(iterations)
       .setImplicitPrefs(true)
@@ -179,8 +185,8 @@ public class JavaALSSuite implements Serializable {
     int users = 200;
     int products = 50;
     List<Rating> testData = ALSSuite.generateRatingsAsJava(
-        users, products, features, 0.7, true, false)._1();
-    JavaRDD<Rating> data = sc.parallelize(testData);
+      users, products, features, 0.7, true, false)._1();
+    JavaRDD<Rating> data = jsc.parallelize(testData);
     MatrixFactorizationModel model = new ALS().setRank(features)
       .setIterations(iterations)
       .setImplicitPrefs(true)
@@ -193,7 +199,7 @@ public class JavaALSSuite implements Serializable {
   private static void validateRecommendations(Rating[] recommendations, int 
howMany) {
     Assert.assertEquals(howMany, recommendations.length);
     for (int i = 1; i < recommendations.length; i++) {
-      Assert.assertTrue(recommendations[i-1].rating() >= 
recommendations[i].rating());
+      Assert.assertTrue(recommendations[i - 1].rating() >= 
recommendations[i].rating());
     }
     Assert.assertTrue(recommendations[0].rating() > 0.7);
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to