[SPARK-15296][MLLIB] Refactor All Java Tests that use SparkSession

## What changes were proposed in this pull request?
Refactor All Java Tests that use SparkSession, to extend SharedSparkSesion

## How was this patch tested?
Existing Tests

Author: Sandeep Singh <sand...@techaddict.me>

Closes #13101 from techaddict/SPARK-15296.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/01cf649c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/01cf649c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/01cf649c

Branch: refs/heads/master
Commit: 01cf649c4f96f64fb4bd09e0e1811cabcc5ead2e
Parents: 16ba71a
Author: Sandeep Singh <sand...@techaddict.me>
Authored: Thu May 19 20:38:44 2016 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Thu May 19 20:38:44 2016 -0700

----------------------------------------------------------------------
 .../examples/ml/JavaGaussianMixtureExample.java |  2 +-
 .../org/apache/spark/SharedSparkSession.java    | 48 ++++++++++++++++++++
 .../org/apache/spark/ml/JavaPipelineSuite.java  | 27 +++--------
 .../JavaDecisionTreeClassifierSuite.java        | 27 +----------
 .../classification/JavaGBTClassifierSuite.java  | 28 +-----------
 .../JavaLogisticRegressionSuite.java            | 28 +++---------
 ...JavaMultilayerPerceptronClassifierSuite.java | 23 +---------
 .../ml/classification/JavaNaiveBayesSuite.java  | 23 +---------
 .../ml/classification/JavaOneVsRestSuite.java   | 30 +++---------
 .../JavaRandomForestClassifierSuite.java        | 28 +-----------
 .../spark/ml/clustering/JavaKMeansSuite.java    | 27 +++--------
 .../spark/ml/feature/JavaBucketizerSuite.java   | 21 +--------
 .../apache/spark/ml/feature/JavaDCTSuite.java   | 21 +--------
 .../spark/ml/feature/JavaHashingTFSuite.java    | 21 +--------
 .../spark/ml/feature/JavaNormalizerSuite.java   | 24 +---------
 .../apache/spark/ml/feature/JavaPCASuite.java   | 26 ++---------
 .../feature/JavaPolynomialExpansionSuite.java   | 24 +---------
 .../ml/feature/JavaStandardScalerSuite.java     | 24 +---------
 .../ml/feature/JavaStopWordsRemoverSuite.java   | 22 +--------
 .../ml/feature/JavaStringIndexerSuite.java      | 26 ++---------
 .../spark/ml/feature/JavaTokenizerSuite.java    | 24 +---------
 .../ml/feature/JavaVectorAssemblerSuite.java    | 26 ++---------
 .../ml/feature/JavaVectorIndexerSuite.java      | 25 +---------
 .../spark/ml/feature/JavaVectorSlicerSuite.java | 21 +--------
 .../spark/ml/feature/JavaWord2VecSuite.java     | 21 +--------
 .../apache/spark/ml/param/JavaParamsSuite.java  | 23 ----------
 .../JavaDecisionTreeRegressorSuite.java         | 26 +----------
 .../ml/regression/JavaGBTRegressorSuite.java    | 26 +----------
 .../regression/JavaLinearRegressionSuite.java   | 28 +++---------
 .../JavaRandomForestRegressorSuite.java         | 26 +----------
 .../source/libsvm/JavaLibSVMRelationSuite.java  | 20 +++-----
 .../ml/tuning/JavaCrossValidatorSuite.java      | 33 ++++----------
 .../ml/util/JavaDefaultReadWriteSuite.java      | 31 +++----------
 .../JavaLogisticRegressionSuite.java            | 25 +---------
 .../classification/JavaNaiveBayesSuite.java     | 25 +---------
 .../mllib/classification/JavaSVMSuite.java      | 25 +---------
 .../JavaStreamingLogisticRegressionSuite.java   |  3 +-
 .../clustering/JavaBisectingKMeansSuite.java    | 26 +----------
 .../clustering/JavaGaussianMixtureSuite.java    | 25 +---------
 .../spark/mllib/clustering/JavaKMeansSuite.java | 25 +---------
 .../spark/mllib/clustering/JavaLDASuite.java    | 29 +++---------
 .../clustering/JavaStreamingKMeansSuite.java    |  3 +-
 .../evaluation/JavaRankingMetricsSuite.java     | 28 +++---------
 .../spark/mllib/feature/JavaTfIdfSuite.java     | 25 +---------
 .../spark/mllib/feature/JavaWord2VecSuite.java  | 25 +---------
 .../mllib/fpm/JavaAssociationRulesSuite.java    | 25 +---------
 .../spark/mllib/fpm/JavaFPGrowthSuite.java      | 25 +---------
 .../spark/mllib/fpm/JavaPrefixSpanSuite.java    | 24 +---------
 .../spark/mllib/linalg/JavaMatricesSuite.java   |  3 +-
 .../spark/mllib/linalg/JavaVectorsSuite.java    |  3 +-
 .../spark/mllib/random/JavaRandomRDDsSuite.java | 24 +---------
 .../mllib/recommendation/JavaALSSuite.java      | 25 +---------
 .../regression/JavaIsotonicRegressionSuite.java | 25 +---------
 .../spark/mllib/regression/JavaLassoSuite.java  | 25 +---------
 .../regression/JavaLinearRegressionSuite.java   | 25 +---------
 .../regression/JavaRidgeRegressionSuite.java    | 25 +---------
 .../JavaStreamingLinearRegressionSuite.java     |  3 +-
 .../spark/mllib/stat/JavaStatisticsSuite.java   |  3 +-
 .../spark/mllib/tree/JavaDecisionTreeSuite.java | 26 +----------
 59 files changed, 207 insertions(+), 1148 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java
----------------------------------------------------------------------
diff --git 
a/examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java
 
b/examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java
index 79b9909..526bed9 100644
--- 
a/examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java
+++ 
b/examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java
@@ -37,7 +37,7 @@ public class JavaGaussianMixtureExample {
 
   public static void main(String[] args) {
 
-    // Creates a SparkSession 
+    // Creates a SparkSession
     SparkSession spark = SparkSession
             .builder()
             .appName("JavaGaussianMixtureExample")

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/SharedSparkSession.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/SharedSparkSession.java 
b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java
new file mode 100644
index 0000000..4377987
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+import java.io.IOException;
+import java.io.Serializable;
+
+import org.junit.After;
+import org.junit.Before;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
+
+public abstract class SharedSparkSession implements Serializable {
+
+  protected transient SparkSession spark;
+  protected transient JavaSparkContext jsc;
+
+  @Before
+  public void setUp() throws IOException {
+    spark = SparkSession.builder()
+      .master("local[2]")
+      .appName(getClass().getSimpleName())
+      .getOrCreate();
+    jsc = new JavaSparkContext(spark.sparkContext());
+  }
+
+  @After
+  public void tearDown() {
+    spark.stop();
+    spark = null;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index a81a36d..9b20900 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -17,47 +17,34 @@
 
 package org.apache.spark.ml;
 
-import org.junit.After;
-import org.junit.Before;
+import java.io.IOException;
+
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.ml.classification.LogisticRegression;
 import static 
org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
 import org.apache.spark.ml.feature.LabeledPoint;
 import org.apache.spark.ml.feature.StandardScaler;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
 
 /**
  * Test Pipeline construction and fitting in Java.
  */
-public class JavaPipelineSuite {
+public class JavaPipelineSuite extends SharedSparkSession {
 
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
   private transient Dataset<Row> dataset;
 
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaPipelineSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
+  @Override
+  public void setUp() throws IOException {
+    super.setUp();
     JavaRDD<LabeledPoint> points =
       jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
     dataset = spark.createDataFrame(points, LabeledPoint.class);
   }
 
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
-
   @Test
   public void pipeline() {
     StandardScaler scaler = new StandardScaler()

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
index c76a194..5aba4e8 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
@@ -17,42 +17,19 @@
 
 package org.apache.spark.ml.classification;
 
-import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
 
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.classification.LogisticRegressionSuite;
 import org.apache.spark.ml.feature.LabeledPoint;
 import org.apache.spark.ml.tree.impl.TreeTests;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
 
-public class JavaDecisionTreeClassifierSuite implements Serializable {
-
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaDecisionTreeClassifierSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaDecisionTreeClassifierSuite extends SharedSparkSession {
 
   @Test
   public void runDT() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
index 4648926..74bb46b 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
@@ -17,43 +17,19 @@
 
 package org.apache.spark.ml.classification;
 
-import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
 
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.classification.LogisticRegressionSuite;
 import org.apache.spark.ml.feature.LabeledPoint;
 import org.apache.spark.ml.tree.impl.TreeTests;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
 
-
-public class JavaGBTClassifierSuite implements Serializable {
-
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaGBTClassifierSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaGBTClassifierSuite extends SharedSparkSession {
 
   @Test
   public void runDT() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index b8da04c..0041021 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -17,52 +17,36 @@
 
 package org.apache.spark.ml.classification;
 
-import java.io.Serializable;
+import java.io.IOException;
 import java.util.List;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import static 
org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
 import org.apache.spark.ml.feature.LabeledPoint;
 import org.apache.spark.ml.linalg.Vector;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
 
-public class JavaLogisticRegressionSuite implements Serializable {
+public class JavaLogisticRegressionSuite extends SharedSparkSession {
 
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
   private transient Dataset<Row> dataset;
 
   private transient JavaRDD<LabeledPoint> datasetRDD;
   private double eps = 1e-5;
 
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaLogisticRegressionSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-
+  @Override
+  public void setUp() throws IOException {
+    super.setUp();
     List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
     datasetRDD = jsc.parallelize(points, 2);
     dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
     dataset.createOrReplaceTempView("dataset");
   }
 
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
-
   @Test
   public void logisticRegressionDefaultParams() {
     LogisticRegression lr = new LogisticRegression();

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
index 48edbc8..6d0604d 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
@@ -17,38 +17,19 @@
 
 package org.apache.spark.ml.classification;
 
-import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.ml.feature.LabeledPoint;
 import org.apache.spark.ml.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
 
-public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
-
-  private transient SparkSession spark;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaLogisticRegressionSuite")
-      .getOrCreate();
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaMultilayerPerceptronClassifierSuite extends 
SharedSparkSession {
 
   @Test
   public void testMLPC() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
index 7879098..c2a9e7b 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -17,43 +17,24 @@
 
 package org.apache.spark.ml.classification;
 
-import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
 
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 import static org.junit.Assert.assertEquals;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.ml.linalg.VectorUDT;
 import org.apache.spark.ml.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.Metadata;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 
-public class JavaNaiveBayesSuite implements Serializable {
-
-  private transient SparkSession spark;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaLogisticRegressionSuite")
-      .getOrCreate();
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaNaiveBayesSuite extends SharedSparkSession {
 
   public void validatePrediction(Dataset<Row> predictionAndLabels) {
     for (Row r : predictionAndLabels.collectAsList()) {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
index 58bc5a4..6194167 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
@@ -17,39 +17,29 @@
 
 package org.apache.spark.ml.classification;
 
-import java.io.Serializable;
+import java.io.IOException;
 import java.util.List;
 
 import scala.collection.JavaConverters;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import static 
org.apache.spark.ml.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
 import org.apache.spark.ml.feature.LabeledPoint;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
+import static 
org.apache.spark.ml.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
 
-public class JavaOneVsRestSuite implements Serializable {
+public class JavaOneVsRestSuite extends SharedSparkSession {
 
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
   private transient Dataset<Row> dataset;
   private transient JavaRDD<LabeledPoint> datasetRDD;
 
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaLOneVsRestSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-
+  @Override
+  public void setUp() throws IOException {
+    super.setUp();
     int nPoints = 3;
 
     // The following coefficients and xMean/xVariance are computed from iris 
dataset with
@@ -68,12 +58,6 @@ public class JavaOneVsRestSuite implements Serializable {
     dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
   }
 
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
-
   @Test
   public void oneVsRestDefaultParams() {
     OneVsRest ova = new OneVsRest();

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
index 1ed20b1..dd98513 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -17,45 +17,21 @@
 
 package org.apache.spark.ml.classification;
 
-import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.classification.LogisticRegressionSuite;
 import org.apache.spark.ml.feature.LabeledPoint;
 import org.apache.spark.ml.linalg.Vector;
 import org.apache.spark.ml.tree.impl.TreeTests;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
 
-
-public class JavaRandomForestClassifierSuite implements Serializable {
-
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaRandomForestClassifierSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaRandomForestClassifierSuite extends SharedSparkSession {
 
   @Test
   public void runDT() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
index 9d07170..1be6f96 100644
--- a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
@@ -17,43 +17,30 @@
 
 package org.apache.spark.ml.clustering;
 
-import java.io.Serializable;
+import java.io.IOException;
 import java.util.Arrays;
 import java.util.List;
 
+import org.junit.Test;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.ml.linalg.Vector;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
 
-public class JavaKMeansSuite implements Serializable {
+public class JavaKMeansSuite extends SharedSparkSession {
 
   private transient int k = 5;
   private transient Dataset<Row> dataset;
-  private transient SparkSession spark;
 
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaKMeansSuite")
-      .getOrCreate();
+  @Override
+  public void setUp() throws IOException {
+    super.setUp();
     dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k);
   }
 
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
-
   @Test
   public void fitAndTransform() {
     KMeans kmeans = new KMeans().setK(k).setSeed(1);

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
index a96b43d..8763938 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
@@ -20,36 +20,19 @@ package org.apache.spark.ml.feature;
 import java.util.Arrays;
 import java.util.List;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.Metadata;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 
-public class JavaBucketizerSuite {
-  private transient SparkSession spark;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaBucketizerSuite")
-      .getOrCreate();
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaBucketizerSuite extends SharedSparkSession {
 
   @Test
   public void bucketizerTest() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
index 9d8c09b..b7956b6 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
@@ -22,38 +22,21 @@ import java.util.List;
 
 import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.ml.linalg.Vector;
 import org.apache.spark.ml.linalg.VectorUDT;
 import org.apache.spark.ml.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.Metadata;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 
-public class JavaDCTSuite {
-  private transient SparkSession spark;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaDCTSuite")
-      .getOrCreate();
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaDCTSuite extends SharedSparkSession {
 
   @Test
   public void javaCompatibilityTest() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
index 3c37441..57696d0 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
@@ -20,38 +20,21 @@ package org.apache.spark.ml.feature;
 import java.util.Arrays;
 import java.util.List;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.ml.linalg.Vector;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.Metadata;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 
 
-public class JavaHashingTFSuite {
-  private transient SparkSession spark;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaHashingTFSuite")
-      .getOrCreate();
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaHashingTFSuite extends SharedSparkSession {
 
   @Test
   public void hashingTF() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
index b3e213a..6f877b5 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
@@ -19,35 +19,15 @@ package org.apache.spark.ml.feature;
 
 import java.util.Arrays;
 
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.ml.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
 
-public class JavaNormalizerSuite {
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaNormalizerSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaNormalizerSuite extends SharedSparkSession {
 
   @Test
   public void normalizer() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
index a4bce22..ac479c0 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
@@ -23,13 +23,11 @@ import java.util.List;
 
 import scala.Tuple2;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.ml.linalg.Vector;
 import org.apache.spark.ml.linalg.Vectors;
@@ -37,26 +35,8 @@ import org.apache.spark.mllib.linalg.Matrix;
 import org.apache.spark.mllib.linalg.distributed.RowMatrix;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-
-public class JavaPCASuite implements Serializable {
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaPCASuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
 
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaPCASuite extends SharedSparkSession {
 
   public static class VectorPair implements Serializable {
     private Vector features = Vectors.dense(0.0);
@@ -95,7 +75,7 @@ public class JavaPCASuite implements Serializable {
               }
             }
     ).rdd());
-    
+
     Matrix pc = mat.computePrincipalComponents(3);
 
     mat.multiply(pc).rows().toJavaRDD();

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
index a28f73f..df5d34f 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
@@ -20,41 +20,21 @@ package org.apache.spark.ml.feature;
 import java.util.Arrays;
 import java.util.List;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.ml.linalg.Vector;
 import org.apache.spark.ml.linalg.VectorUDT;
 import org.apache.spark.ml.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.Metadata;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 
-public class JavaPolynomialExpansionSuite {
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaPolynomialExpansionSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    jsc.stop();
-    jsc = null;
-  }
+public class JavaPolynomialExpansionSuite extends SharedSparkSession {
 
   @Test
   public void polynomialExpansionTest() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
index 8415fdb..dbc0b1d 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
@@ -20,34 +20,14 @@ package org.apache.spark.ml.feature;
 import java.util.Arrays;
 import java.util.List;
 
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.ml.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
 
-public class JavaStandardScalerSuite {
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaStandardScalerSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaStandardScalerSuite extends SharedSparkSession {
 
   @Test
   public void standardScaler() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
index 2b156f3..6480b57 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
@@ -20,37 +20,19 @@ package org.apache.spark.ml.feature;
 import java.util.Arrays;
 import java.util.List;
 
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.Metadata;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 
 
-public class JavaStopWordsRemoverSuite {
-
-  private transient SparkSession spark;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaStopWordsRemoverSuite")
-      .getOrCreate();
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaStopWordsRemoverSuite extends SharedSparkSession {
 
   @Test
   public void javaCompatibilityTest() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
index 52c0bde..c1928a2 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
@@ -20,37 +20,19 @@ package org.apache.spark.ml.feature;
 import java.util.Arrays;
 import java.util.List;
 
-import org.junit.After;
+import static org.apache.spark.sql.types.DataTypes.*;
+
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.SparkConf;
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
-import static org.apache.spark.sql.types.DataTypes.*;
-
-public class JavaStringIndexerSuite {
-  private transient SparkSession spark;
 
-  @Before
-  public void setUp() {
-    SparkConf sparkConf = new SparkConf();
-    sparkConf.setMaster("local");
-    sparkConf.setAppName("JavaStringIndexerSuite");
-
-    spark = SparkSession.builder().config(sparkConf).getOrCreate();
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaStringIndexerSuite extends SharedSparkSession {
 
   @Test
   public void testStringIndexer() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
index 0bac283..27550a3 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
@@ -20,35 +20,15 @@ package org.apache.spark.ml.feature;
 import java.util.Arrays;
 import java.util.List;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
 
-public class JavaTokenizerSuite {
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaTokenizerSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaTokenizerSuite extends SharedSparkSession {
 
   @Test
   public void regexTokenizer() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
index fedaa77..583652b 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
@@ -19,40 +19,22 @@ package org.apache.spark.ml.feature;
 
 import java.util.Arrays;
 
-import org.junit.After;
+import static org.apache.spark.sql.types.DataTypes.*;
+
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.SparkConf;
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.ml.linalg.Vector;
 import org.apache.spark.ml.linalg.VectorUDT;
 import org.apache.spark.ml.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
-import static org.apache.spark.sql.types.DataTypes.*;
-
-public class JavaVectorAssemblerSuite {
-  private transient SparkSession spark;
 
-  @Before
-  public void setUp() {
-    SparkConf sparkConf = new SparkConf();
-    sparkConf.setMaster("local");
-    sparkConf.setAppName("JavaVectorAssemblerSuite");
-
-    spark = SparkSession.builder().config(sparkConf).getOrCreate();
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaVectorAssemblerSuite extends SharedSparkSession {
 
   @Test
   public void testVectorAssembler() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
index a8dd446..ca8fae3 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
@@ -17,42 +17,21 @@
 
 package org.apache.spark.ml.feature;
 
-import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData;
 import org.apache.spark.ml.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
 
 
-public class JavaVectorIndexerSuite implements Serializable {
-  private transient SparkSession spark;
-  private JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaVectorIndexerSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaVectorIndexerSuite extends SharedSparkSession {
 
   @Test
   public void vectorIndexerAPI() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
index a565c77..3dc2e1f 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
@@ -20,11 +20,10 @@ package org.apache.spark.ml.feature;
 import java.util.Arrays;
 import java.util.List;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.ml.attribute.Attribute;
 import org.apache.spark.ml.attribute.AttributeGroup;
 import org.apache.spark.ml.attribute.NumericAttribute;
@@ -33,26 +32,10 @@ import org.apache.spark.ml.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.StructType;
 
 
-public class JavaVectorSlicerSuite {
-  private transient SparkSession spark;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaVectorSlicerSuite")
-      .getOrCreate();
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaVectorSlicerSuite extends SharedSparkSession {
 
   @Test
   public void vectorSlice() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
index bef7eb0..d0a849f 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
@@ -19,34 +19,17 @@ package org.apache.spark.ml.feature;
 
 import java.util.Arrays;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.ml.linalg.Vector;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.*;
 
-public class JavaWord2VecSuite {
-  private transient SparkSession spark;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaWord2VecSuite")
-      .getOrCreate();
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaWord2VecSuite extends SharedSparkSession {
 
   @Test
   public void testJavaWord2Vec() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
index a5b5dd4..1077e10 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
@@ -19,37 +19,14 @@ package org.apache.spark.ml.param;
 
 import java.util.Arrays;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.SparkSession;
-
 /**
  * Test Param and related classes in Java
  */
 public class JavaParamsSuite {
 
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaParamsSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
-
   @Test
   public void testParams() {
     JavaTestParams testParams = new JavaTestParams();

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
index 4ea3f22..1da85ed 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -17,43 +17,21 @@
 
 package org.apache.spark.ml.regression;
 
-import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
 
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.ml.classification.LogisticRegressionSuite;
 import org.apache.spark.ml.feature.LabeledPoint;
 import org.apache.spark.ml.tree.impl.TreeTests;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
 
 
-public class JavaDecisionTreeRegressorSuite implements Serializable {
-
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaDecisionTreeRegressorSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaDecisionTreeRegressorSuite extends SharedSparkSession {
 
   @Test
   public void runDT() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
index 3b5edf1..7fd9b1f 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
@@ -17,43 +17,21 @@
 
 package org.apache.spark.ml.regression;
 
-import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
 
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.ml.classification.LogisticRegressionSuite;
 import org.apache.spark.ml.feature.LabeledPoint;
 import org.apache.spark.ml.tree.impl.TreeTests;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
 
 
-public class JavaGBTRegressorSuite implements Serializable {
-
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaGBTRegressorSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaGBTRegressorSuite extends SharedSparkSession {
 
   @Test
   public void runDT() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index 126aa62..6cdcdda 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -17,48 +17,32 @@
 
 package org.apache.spark.ml.regression;
 
-import java.io.Serializable;
+import java.io.IOException;
 import java.util.List;
 
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 import static org.junit.Assert.assertEquals;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import static 
org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
 import org.apache.spark.ml.feature.LabeledPoint;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
 
-public class JavaLinearRegressionSuite implements Serializable {
-
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
+public class JavaLinearRegressionSuite extends SharedSparkSession {
   private transient Dataset<Row> dataset;
   private transient JavaRDD<LabeledPoint> datasetRDD;
 
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaLinearRegressionSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
+  @Override
+  public void setUp() throws IOException {
+    super.setUp();
     List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
     datasetRDD = jsc.parallelize(points, 2);
     dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
     dataset.createOrReplaceTempView("dataset");
   }
 
-  @After
-  public void tearDown() {
-    jsc.stop();
-    jsc = null;
-  }
-
   @Test
   public void linearRegressionDefaultParams() {
     LinearRegression lr = new LinearRegression();

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index d601e7c..4ba13e2 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -17,45 +17,23 @@
 
 package org.apache.spark.ml.regression;
 
-import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.ml.classification.LogisticRegressionSuite;
 import org.apache.spark.ml.feature.LabeledPoint;
 import org.apache.spark.ml.linalg.Vector;
 import org.apache.spark.ml.tree.impl.TreeTests;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
 
 
-public class JavaRandomForestRegressorSuite implements Serializable {
-
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaRandomForestRegressorSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaRandomForestRegressorSuite extends SharedSparkSession {
 
   @Test
   public void runDT() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
 
b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
index 022dcf9..fa39f45 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
@@ -23,35 +23,28 @@ import java.nio.charset.StandardCharsets;
 
 import com.google.common.io.Files;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.ml.linalg.DenseVector;
 import org.apache.spark.ml.linalg.Vectors;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.util.Utils;
 
 
 /**
  * Test LibSVMRelation in Java.
  */
-public class JavaLibSVMRelationSuite {
-  private transient SparkSession spark;
+public class JavaLibSVMRelationSuite extends SharedSparkSession {
 
   private File tempDir;
   private String path;
 
-  @Before
+  @Override
   public void setUp() throws IOException {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaLibSVMRelationSuite")
-      .getOrCreate();
-
+    super.setUp();
     tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), 
"datasource");
     File file = new File(tempDir, "part-00000");
     String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0";
@@ -59,10 +52,9 @@ public class JavaLibSVMRelationSuite {
     path = tempDir.toURI().toString();
   }
 
-  @After
+  @Override
   public void tearDown() {
-    spark.stop();
-    spark = null;
+    super.tearDown();
     Utils.deleteRecursively(tempDir);
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
index b874ccd..692d5ad 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -17,48 +17,33 @@
 
 package org.apache.spark.ml.tuning;
 
-import java.io.Serializable;
+import java.io.IOException;
 import java.util.List;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.ml.classification.LogisticRegression;
-import static 
org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
-import org.apache.spark.ml.feature.LabeledPoint;
 import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
+import org.apache.spark.ml.feature.LabeledPoint;
 import org.apache.spark.ml.param.ParamMap;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
+import static 
org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
 
-public class JavaCrossValidatorSuite implements Serializable {
 
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-  private transient Dataset<Row> dataset;
+public class JavaCrossValidatorSuite extends SharedSparkSession {
 
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaCrossValidatorSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
+  private transient Dataset<Row> dataset;
 
+  @Override
+  public void setUp() throws IOException {
+    super.setUp();
     List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
     dataset = spark.createDataFrame(jsc.parallelize(points, 2), 
LabeledPoint.class);
   }
 
-  @After
-  public void tearDown() {
-    jsc.stop();
-    jsc = null;
-  }
-
   @Test
   public void crossValidationWithLogisticRegression() {
     LogisticRegression lr = new LogisticRegression();

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
index 7151e27..da623d1 100644
--- 
a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
@@ -20,42 +20,25 @@ package org.apache.spark.ml.util;
 import java.io.File;
 import java.io.IOException;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.SparkSession;
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.util.Utils;
 
-public class JavaDefaultReadWriteSuite {
-
-  JavaSparkContext jsc = null;
-  SparkSession spark = null;
+public class JavaDefaultReadWriteSuite extends SharedSparkSession {
   File tempDir = null;
 
-  @Before
-  public void setUp() {
-    SQLContext.clearActive();
-    spark = SparkSession.builder()
-      .master("local[2]")
-      .appName("JavaDefaultReadWriteSuite")
-      .getOrCreate();
-    SQLContext.setActive(spark.wrapped());
-
+  @Override
+  public void setUp() throws IOException {
+    super.setUp();
     tempDir = Utils.createTempDir(
       System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite");
   }
 
-  @After
+  @Override
   public void tearDown() {
-    SQLContext.clearActive();
-    if (spark != null) {
-      spark.stop();
-      spark = null;
-    }
+    super.tearDown();
     Utils.deleteRecursively(tempDir);
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
 
b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
index 2f10d14..c04e2e6 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java
@@ -17,37 +17,16 @@
 
 package org.apache.spark.mllib.classification;
 
-import java.io.Serializable;
 import java.util.List;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.SparkSession;
 
-public class JavaLogisticRegressionSuite implements Serializable {
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaLogisticRegressionSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaLogisticRegressionSuite extends SharedSparkSession {
 
   int validatePrediction(List<LabeledPoint> validationData, 
LogisticRegressionModel model) {
     int numAccurate = 0;

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
 
b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
index 5e212e2..6ded42e 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -17,42 +17,21 @@
 
 package org.apache.spark.mllib.classification;
 
-import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.linalg.Vectors;
 import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.SparkSession;
 
 
-public class JavaNaiveBayesSuite implements Serializable {
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaNaiveBayesSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaNaiveBayesSuite extends SharedSparkSession {
 
   private static final List<LabeledPoint> POINTS = Arrays.asList(
     new LabeledPoint(0, Vectors.dense(1.0, 0.0, 0.0)),

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
index 2a090c0..0f54e68 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
@@ -17,37 +17,16 @@
 
 package org.apache.spark.mllib.classification;
 
-import java.io.Serializable;
 import java.util.List;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.SparkSession;
 
-public class JavaSVMSuite implements Serializable {
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaSVMSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaSVMSuite extends SharedSparkSession {
 
   int validatePrediction(List<LabeledPoint> validationData, SVMModel model) {
     int numAccurate = 0;

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
 
b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
index 62c6d9b..8c6bced 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
@@ -17,7 +17,6 @@
 
 package org.apache.spark.mllib.classification;
 
-import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
 
@@ -37,7 +36,7 @@ import org.apache.spark.streaming.api.java.JavaPairDStream;
 import org.apache.spark.streaming.api.java.JavaStreamingContext;
 import static org.apache.spark.streaming.JavaTestUtils.*;
 
-public class JavaStreamingLogisticRegressionSuite implements Serializable {
+public class JavaStreamingLogisticRegressionSuite {
 
   protected transient JavaStreamingContext ssc;
 

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
 
b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
index 7f29b05..3d62b27 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
@@ -17,39 +17,17 @@
 
 package org.apache.spark.mllib.clustering;
 
-import java.io.Serializable;
-
 import com.google.common.collect.Lists;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.SparkSession;
-
-public class JavaBisectingKMeansSuite implements Serializable {
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
 
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaBisectingKMeansSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaBisectingKMeansSuite extends SharedSparkSession {
 
   @Test
   public void twoDimensionalData() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
 
b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
index 20edd08..bf76719 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java
@@ -17,40 +17,19 @@
 
 package org.apache.spark.mllib.clustering;
 
-import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
 
 import static org.junit.Assert.assertEquals;
 
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.SparkSession;
 
-public class JavaGaussianMixtureSuite implements Serializable {
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaGaussianMixture")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaGaussianMixtureSuite extends SharedSparkSession {
 
   @Test
   public void runGaussianMixture() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
index 4e5b87f..270e636 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
@@ -17,40 +17,19 @@
 
 package org.apache.spark.mllib.clustering;
 
-import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
 
 import static org.junit.Assert.assertEquals;
 
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.SparkSession;
 
-public class JavaKMeansSuite implements Serializable {
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaKMeans")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaKMeansSuite extends SharedSparkSession {
 
   @Test
   public void runKMeansUsingStaticMethods() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
index f16585a..08d6713 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -17,39 +17,28 @@
 
 package org.apache.spark.mllib.clustering;
 
-import java.io.Serializable;
+import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 
 import scala.Tuple2;
 import scala.Tuple3;
 
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 import static org.junit.Assert.*;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.mllib.linalg.Matrix;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.SparkSession;
-
-public class JavaLDASuite implements Serializable {
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaLDASuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
 
+public class JavaLDASuite extends SharedSparkSession {
+  @Override
+  public void setUp() throws IOException {
+    super.setUp();
     ArrayList<Tuple2<Long, Vector>> tinyCorpus = new ArrayList<>();
     for (int i = 0; i < LDASuite.tinyCorpus().length; i++) {
       tinyCorpus.add(new Tuple2<>((Long) LDASuite.tinyCorpus()[i]._1(),
@@ -59,12 +48,6 @@ public class JavaLDASuite implements Serializable {
     corpus = JavaPairRDD.fromJavaRDD(tmpCorpus);
   }
 
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
-
   @Test
   public void localLDAModel() {
     Matrix topics = LDASuite.tinyTopics();

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
 
b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
index d1d618f..d41fc0e 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
@@ -17,7 +17,6 @@
 
 package org.apache.spark.mllib.clustering;
 
-import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
 
@@ -36,7 +35,7 @@ import org.apache.spark.streaming.api.java.JavaPairDStream;
 import org.apache.spark.streaming.api.java.JavaStreamingContext;
 import static org.apache.spark.streaming.JavaTestUtils.*;
 
-public class JavaStreamingKMeansSuite implements Serializable {
+public class JavaStreamingKMeansSuite {
 
   protected transient JavaStreamingContext ssc;
 

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
 
b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
index 6a096d6..e9d7e4f 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
@@ -17,35 +17,25 @@
 
 package org.apache.spark.mllib.evaluation;
 
-import java.io.Serializable;
+import java.io.IOException;
 import java.util.Arrays;
 import java.util.List;
 
 import scala.Tuple2;
 import scala.Tuple2$;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.SparkSession;
 
-public class JavaRankingMetricsSuite implements Serializable {
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
+public class JavaRankingMetricsSuite extends SharedSparkSession {
   private transient JavaRDD<Tuple2<List<Integer>, List<Integer>>> 
predictionAndLabels;
 
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaPCASuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-
+  @Override
+  public void setUp() throws IOException {
+    super.setUp();
     predictionAndLabels = jsc.parallelize(Arrays.asList(
       Tuple2$.MODULE$.apply(
         Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 
4, 5)),
@@ -55,12 +45,6 @@ public class JavaRankingMetricsSuite implements Serializable 
{
         Arrays.asList(1, 2, 3, 4, 5), Arrays.<Integer>asList())), 2);
   }
 
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
-
   @Test
   public void rankingMetrics() {
     @SuppressWarnings("unchecked")

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
index de50fb8..05128ea 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
@@ -17,38 +17,17 @@
 
 package org.apache.spark.mllib.feature;
 
-import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.sql.SparkSession;
 
-public class JavaTfIdfSuite implements Serializable {
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaPCASuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaTfIdfSuite extends SharedSparkSession {
 
   @Test
   public void tfIdf() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
index 64885cc..3e3abdd 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java
@@ -17,7 +17,6 @@
 
 package org.apache.spark.mllib.feature;
 
-import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
 
@@ -25,33 +24,13 @@ import com.google.common.base.Strings;
 
 import scala.Tuple2;
 
-import org.junit.After;
 import org.junit.Assert;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.SparkSession;
 
-public class JavaWord2VecSuite implements Serializable {
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaPCASuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaWord2VecSuite extends SharedSparkSession {
 
   @Test
   @SuppressWarnings("unchecked")

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
index fdc19a5..3451e07 100644
--- 
a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
@@ -16,36 +16,15 @@
  */
 package org.apache.spark.mllib.fpm;
 
-import java.io.Serializable;
 import java.util.Arrays;
 
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
-import org.apache.spark.sql.SparkSession;
 
-public class JavaAssociationRulesSuite implements Serializable {
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaAssociationRulesSuite")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaAssociationRulesSuite extends SharedSparkSession {
 
   @Test
   public void runAssociationRules() {

http://git-wip-us.apache.org/repos/asf/spark/blob/01cf649c/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
index f235251..46e9dd8 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
@@ -18,39 +18,18 @@
 package org.apache.spark.mllib.fpm;
 
 import java.io.File;
-import java.io.Serializable;
 import java.util.Arrays;
 import java.util.List;
 
 import static org.junit.Assert.assertEquals;
 
-import org.junit.After;
-import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.SharedSparkSession;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.util.Utils;
 
-public class JavaFPGrowthSuite implements Serializable {
-  private transient SparkSession spark;
-  private transient JavaSparkContext jsc;
-
-  @Before
-  public void setUp() {
-    spark = SparkSession.builder()
-      .master("local")
-      .appName("JavaFPGrowth")
-      .getOrCreate();
-    jsc = new JavaSparkContext(spark.sparkContext());
-  }
-
-  @After
-  public void tearDown() {
-    spark.stop();
-    spark = null;
-  }
+public class JavaFPGrowthSuite extends SharedSparkSession {
 
   @Test
   public void runFPGrowth() {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to