http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
deleted file mode 100644
index bd9729c..0000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ /dev/null
@@ -1,197 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import org.apache.spark.sql.test.TestSQLContext.implicits._
-import org.apache.spark.sql.test._
-
-
-case class TestData(key: Int, value: String)
-
-object TestData {
-  val testData = TestSQLContext.sparkContext.parallelize(
-    (1 to 100).map(i => TestData(i, i.toString))).toDF()
-  testData.registerTempTable("testData")
-
-  val negativeData = TestSQLContext.sparkContext.parallelize(
-    (1 to 100).map(i => TestData(-i, (-i).toString))).toDF()
-  negativeData.registerTempTable("negativeData")
-
-  case class LargeAndSmallInts(a: Int, b: Int)
-  val largeAndSmallInts =
-    TestSQLContext.sparkContext.parallelize(
-      LargeAndSmallInts(2147483644, 1) ::
-      LargeAndSmallInts(1, 2) ::
-      LargeAndSmallInts(2147483645, 1) ::
-      LargeAndSmallInts(2, 2) ::
-      LargeAndSmallInts(2147483646, 1) ::
-      LargeAndSmallInts(3, 2) :: Nil).toDF()
-  largeAndSmallInts.registerTempTable("largeAndSmallInts")
-
-  case class TestData2(a: Int, b: Int)
-  val testData2 =
-    TestSQLContext.sparkContext.parallelize(
-      TestData2(1, 1) ::
-      TestData2(1, 2) ::
-      TestData2(2, 1) ::
-      TestData2(2, 2) ::
-      TestData2(3, 1) ::
-      TestData2(3, 2) :: Nil, 2).toDF()
-  testData2.registerTempTable("testData2")
-
-  case class DecimalData(a: BigDecimal, b: BigDecimal)
-
-  val decimalData =
-    TestSQLContext.sparkContext.parallelize(
-      DecimalData(1, 1) ::
-      DecimalData(1, 2) ::
-      DecimalData(2, 1) ::
-      DecimalData(2, 2) ::
-      DecimalData(3, 1) ::
-      DecimalData(3, 2) :: Nil).toDF()
-  decimalData.registerTempTable("decimalData")
-
-  case class BinaryData(a: Array[Byte], b: Int)
-  val binaryData =
-    TestSQLContext.sparkContext.parallelize(
-      BinaryData("12".getBytes(), 1) ::
-      BinaryData("22".getBytes(), 5) ::
-      BinaryData("122".getBytes(), 3) ::
-      BinaryData("121".getBytes(), 2) ::
-      BinaryData("123".getBytes(), 4) :: Nil).toDF()
-  binaryData.registerTempTable("binaryData")
-
-  case class TestData3(a: Int, b: Option[Int])
-  val testData3 =
-    TestSQLContext.sparkContext.parallelize(
-      TestData3(1, None) ::
-      TestData3(2, Some(2)) :: Nil).toDF()
-  testData3.registerTempTable("testData3")
-
-  case class UpperCaseData(N: Int, L: String)
-  val upperCaseData =
-    TestSQLContext.sparkContext.parallelize(
-      UpperCaseData(1, "A") ::
-      UpperCaseData(2, "B") ::
-      UpperCaseData(3, "C") ::
-      UpperCaseData(4, "D") ::
-      UpperCaseData(5, "E") ::
-      UpperCaseData(6, "F") :: Nil).toDF()
-  upperCaseData.registerTempTable("upperCaseData")
-
-  case class LowerCaseData(n: Int, l: String)
-  val lowerCaseData =
-    TestSQLContext.sparkContext.parallelize(
-      LowerCaseData(1, "a") ::
-      LowerCaseData(2, "b") ::
-      LowerCaseData(3, "c") ::
-      LowerCaseData(4, "d") :: Nil).toDF()
-  lowerCaseData.registerTempTable("lowerCaseData")
-
-  case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
-  val arrayData =
-    TestSQLContext.sparkContext.parallelize(
-      ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) ::
-      ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
-  arrayData.toDF().registerTempTable("arrayData")
-
-  case class MapData(data: scala.collection.Map[Int, String])
-  val mapData =
-    TestSQLContext.sparkContext.parallelize(
-      MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
-      MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
-      MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
-      MapData(Map(1 -> "a4", 2 -> "b4")) ::
-      MapData(Map(1 -> "a5")) :: Nil)
-  mapData.toDF().registerTempTable("mapData")
-
-  case class StringData(s: String)
-  val repeatedData =
-    TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
-  repeatedData.toDF().registerTempTable("repeatedData")
-
-  val nullableRepeatedData =
-    TestSQLContext.sparkContext.parallelize(
-      List.fill(2)(StringData(null)) ++
-      List.fill(2)(StringData("test")))
-  nullableRepeatedData.toDF().registerTempTable("nullableRepeatedData")
-
-  case class NullInts(a: Integer)
-  val nullInts =
-    TestSQLContext.sparkContext.parallelize(
-      NullInts(1) ::
-      NullInts(2) ::
-      NullInts(3) ::
-      NullInts(null) :: Nil
-    ).toDF()
-  nullInts.registerTempTable("nullInts")
-
-  val allNulls =
-    TestSQLContext.sparkContext.parallelize(
-      NullInts(null) ::
-      NullInts(null) ::
-      NullInts(null) ::
-      NullInts(null) :: Nil).toDF()
-  allNulls.registerTempTable("allNulls")
-
-  case class NullStrings(n: Int, s: String)
-  val nullStrings =
-    TestSQLContext.sparkContext.parallelize(
-      NullStrings(1, "abc") ::
-      NullStrings(2, "ABC") ::
-      NullStrings(3, null) :: Nil).toDF()
-  nullStrings.registerTempTable("nullStrings")
-
-  case class TableName(tableName: String)
-  TestSQLContext
-    .sparkContext
-    .parallelize(TableName("test") :: Nil)
-    .toDF()
-    .registerTempTable("tableName")
-
-  val unparsedStrings =
-    TestSQLContext.sparkContext.parallelize(
-      "1, A1, true, null" ::
-      "2, B2, false, null" ::
-      "3, C3, true, null" ::
-      "4, D4, true, 2147483644" :: Nil)
-
-  case class IntField(i: Int)
-  // An RDD with 4 elements and 8 partitions
-  val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 
4).map(IntField), 8)
-  withEmptyParts.toDF().registerTempTable("withEmptyParts")
-
-  case class Person(id: Int, name: String, age: Int)
-  case class Salary(personId: Int, salary: Double)
-  val person = TestSQLContext.sparkContext.parallelize(
-    Person(0, "mike", 30) ::
-    Person(1, "jim", 20) :: Nil).toDF()
-  person.registerTempTable("person")
-  val salary = TestSQLContext.sparkContext.parallelize(
-    Salary(0, 2000.0) ::
-    Salary(1, 1000.0) :: Nil).toDF()
-  salary.registerTempTable("salary")
-
-  case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: 
Boolean)
-  val complexData =
-    TestSQLContext.sparkContext.parallelize(
-      ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true)
-        :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false)
-        :: Nil).toDF()
-  complexData.registerTempTable("complexData")
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 183dc34..eb275af 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -17,16 +17,13 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.test.SQLTestData._
 
-case class FunctionResult(f1: String, f2: String)
+private case class FunctionResult(f1: String, f2: String)
 
-class UDFSuite extends QueryTest with SQLTestUtils {
-
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
-
-  override def sqlContext(): SQLContext = ctx
+class UDFSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
   test("built-in fixed arity expressions") {
     val df = ctx.emptyDataFrame
@@ -57,7 +54,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
   test("SPARK-8003 spark_partition_id") {
     val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", 
"saying")
     df.registerTempTable("tmp_table")
-    checkAnswer(ctx.sql("select spark_partition_id() from tmp_table").toDF(), 
Row(0))
+    checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), 
Row(0))
     ctx.dropTempTable("tmp_table")
   }
 
@@ -66,9 +63,9 @@ class UDFSuite extends QueryTest with SQLTestUtils {
       val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id")
       data.write.parquet(dir.getCanonicalPath)
       ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table")
-      val answer = ctx.sql("select input_file_name() from 
test_table").head().getString(0)
+      val answer = sql("select input_file_name() from 
test_table").head().getString(0)
       assert(answer.contains(dir.getCanonicalPath))
-      assert(ctx.sql("select input_file_name() from 
test_table").distinct().collect().length >= 2)
+      assert(sql("select input_file_name() from 
test_table").distinct().collect().length >= 2)
       ctx.dropTempTable("test_table")
     }
   }
@@ -91,17 +88,17 @@ class UDFSuite extends QueryTest with SQLTestUtils {
 
   test("Simple UDF") {
     ctx.udf.register("strLenScala", (_: String).length)
-    assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4)
+    assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4)
   }
 
   test("ZeroArgument UDF") {
     ctx.udf.register("random0", () => { Math.random()})
-    assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0)
+    assert(sql("SELECT random0()").head().getDouble(0) >= 0.0)
   }
 
   test("TwoArgument UDF") {
     ctx.udf.register("strLenScala", (_: String).length + (_: Int))
-    assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
+    assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
   }
 
   test("UDF in a WHERE") {
@@ -112,7 +109,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
     df.registerTempTable("integerData")
 
     val result =
-      ctx.sql("SELECT * FROM integerData WHERE oneArgFilter(key)")
+      sql("SELECT * FROM integerData WHERE oneArgFilter(key)")
     assert(result.count() === 20)
   }
 
@@ -124,7 +121,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
     df.registerTempTable("groupData")
 
     val result =
-      ctx.sql(
+      sql(
         """
          | SELECT g, SUM(v) as s
          | FROM groupData
@@ -143,7 +140,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
     df.registerTempTable("groupData")
 
     val result =
-      ctx.sql(
+      sql(
         """
          | SELECT SUM(v)
          | FROM groupData
@@ -163,7 +160,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
     df.registerTempTable("groupData")
 
     val result =
-      ctx.sql(
+      sql(
         """
          | SELECT timesHundred(SUM(v)) as v100
          | FROM groupData
@@ -178,7 +175,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
     ctx.udf.register("returnStruct", (f1: String, f2: String) => 
FunctionResult(f1, f2))
 
     val result =
-      ctx.sql("SELECT returnStruct('test', 'test2') as ret")
+      sql("SELECT returnStruct('test', 'test2') as ret")
         .select($"ret.f1").head().getString(0)
     assert(result === "test")
   }
@@ -186,12 +183,12 @@ class UDFSuite extends QueryTest with SQLTestUtils {
   test("udf that is transformed") {
     ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y))
     // 1 + 1 is constant folded causing a transformation.
-    assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === 
Row(2, 2))
+    assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 
2))
   }
 
   test("type coercion for udf inputs") {
     ctx.udf.register("intExpected", (x: Int) => x)
     // pass a decimal to intExpected.
-    assert(ctx.sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
+    assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 9181222..b6d279a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -24,6 +24,7 @@ import 
com.clearspring.analytics.stream.cardinality.HyperLogLog
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, 
HyperLogLogUDT}
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 import org.apache.spark.util.collection.OpenHashSet
@@ -66,10 +67,8 @@ private[sql] class MyDenseVectorUDT extends 
UserDefinedType[MyDenseVector] {
   private[spark] override def asNullable: MyDenseVectorUDT = this
 }
 
-class UserDefinedTypeSuite extends QueryTest {
-
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
+class UserDefinedTypeSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
   private lazy val pointsRDD = Seq(
     MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
@@ -94,7 +93,7 @@ class UserDefinedTypeSuite extends QueryTest {
     ctx.udf.register("testType", (d: MyDenseVector) => 
d.isInstanceOf[MyDenseVector])
     pointsRDD.registerTempTable("points")
     checkAnswer(
-      ctx.sql("SELECT testType(features) from points"),
+      sql("SELECT testType(features) from points"),
       Seq(Row(true), Row(true)))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index 9bca4e7..952637c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -19,18 +19,16 @@ package org.apache.spark.sql.columnar
 
 import java.sql.{Date, Timestamp}
 
-import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.test.SQLTestData._
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.{QueryTest, Row, TestData}
 import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
 
-class InMemoryColumnarQuerySuite extends QueryTest {
-  // Make sure the tables are loaded.
-  TestData
+class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
-  import ctx.{logicalPlanToSparkQuery, sql}
+  setupTestData()
 
   test("simple columnar query") {
     val plan = ctx.executePlan(testData.logicalPlan).executedPlan

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index 2c08799..ab2644e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -17,20 +17,19 @@
 
 package org.apache.spark.sql.columnar
 
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
-
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql._
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.test.SQLTestData._
 
-class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll 
with BeforeAndAfter {
-
-  private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  import ctx.implicits._
+class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext {
+  import testImplicits._
 
   private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize
   private lazy val originalInMemoryPartitionPruning = 
ctx.conf.inMemoryPartitionPruning
 
   override protected def beforeAll(): Unit = {
+    super.beforeAll()
     // Make a table with 5 partitions, 2 batches per partition, 10 elements 
per batch
     ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10)
 
@@ -44,19 +43,17 @@ class PartitionBatchPruningSuite extends SparkFunSuite with 
BeforeAndAfterAll wi
     ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
     // Enable in-memory table scan accumulators
     ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
-  }
-
-  override protected def afterAll(): Unit = {
-    ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
-    ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, 
originalInMemoryPartitionPruning)
-  }
-
-  before {
     ctx.cacheTable("pruningData")
   }
 
-  after {
-    ctx.uncacheTable("pruningData")
+  override protected def afterAll(): Unit = {
+    try {
+      ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
+      ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, 
originalInMemoryPartitionPruning)
+      ctx.uncacheTable("pruningData")
+    } finally {
+      super.afterAll()
+    }
   }
 
   // Comparisons
@@ -110,7 +107,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with 
BeforeAndAfterAll wi
       expectedQueryResult: => Seq[Int]): Unit = {
 
     test(query) {
-      val df = ctx.sql(query)
+      val df = sql(query)
       val queryExecution = df.queryExecution
 
       assertResult(expectedQueryResult.toArray, s"Wrong query result: 
$queryExecution") {

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index 79e903c..8998f51 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql.execution
 
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
+import org.apache.spark.sql.test.SharedSQLContext
 
-class ExchangeSuite extends SparkPlanTest {
+class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
   test("shuffling UnsafeRows in exchange") {
     val input = (1 to 1000).map(Tuple1.apply)
     checkAnswer(

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 5582caa..937a108 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.{execution, Row, SQLConf}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, 
Literal, SortOrder}
 import org.apache.spark.sql.catalyst.plans._
@@ -27,19 +27,18 @@ import 
org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, 
ShuffledHashJoin}
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
-import org.apache.spark.sql.test.TestSQLContext._
-import org.apache.spark.sql.test.TestSQLContext.implicits._
-import org.apache.spark.sql.test.TestSQLContext.planner._
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.{SQLContext, Row, SQLConf, execution}
 
 
-class PlannerSuite extends SparkFunSuite with SQLTestUtils {
+class PlannerSuite extends SparkFunSuite with SharedSQLContext {
+  import testImplicits._
 
-  override def sqlContext: SQLContext = TestSQLContext
+  setupTestData()
 
   private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
+    val _ctx = ctx
+    import _ctx.planner._
     val plannedOption = 
HashAggregation(query).headOption.orElse(Aggregation(query).headOption)
     val planned =
       plannedOption.getOrElse(
@@ -54,6 +53,8 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
   }
 
   test("unions are collapsed") {
+    val _ctx = ctx
+    import _ctx.planner._
     val query = testData.unionAll(testData).unionAll(testData).logicalPlan
     val planned = BasicOperators(query).head
     val logicalUnions = query collect { case u: logical.Union => u }
@@ -81,14 +82,14 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
 
   test("sizeInBytes estimation of limit operator for broadcast hash join 
optimization") {
     def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = {
-      setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold)
+      ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold)
       val fields = fieldTypes.zipWithIndex.map {
         case (dataType, index) => StructField(s"c${index}", dataType, true)
       } :+ StructField("key", IntegerType, true)
       val schema = StructType(fields)
       val row = Row.fromSeq(Seq.fill(fields.size)(null))
-      val rowRDD = 
org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil)
-      createDataFrame(rowRDD, schema).registerTempTable("testLimit")
+      val rowRDD = ctx.sparkContext.parallelize(row :: Nil)
+      ctx.createDataFrame(rowRDD, schema).registerTempTable("testLimit")
 
       val planned = sql(
         """
@@ -102,10 +103,10 @@ class PlannerSuite extends SparkFunSuite with 
SQLTestUtils {
       assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
       assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
 
-      dropTempTable("testLimit")
+      ctx.dropTempTable("testLimit")
     }
 
-    val origThreshold = conf.autoBroadcastJoinThreshold
+    val origThreshold = ctx.conf.autoBroadcastJoinThreshold
 
     val simpleTypes =
       NullType ::
@@ -137,18 +138,18 @@ class PlannerSuite extends SparkFunSuite with 
SQLTestUtils {
 
     checkPlan(complexTypes, newThreshold = 901617)
 
-    setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
+    ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
   }
 
   test("InMemoryRelation statistics propagation") {
-    val origThreshold = conf.autoBroadcastJoinThreshold
-    setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920)
+    val origThreshold = ctx.conf.autoBroadcastJoinThreshold
+    ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920)
 
     testData.limit(3).registerTempTable("tiny")
     sql("CACHE TABLE tiny")
 
     val a = testData.as("a")
-    val b = table("tiny").as("b")
+    val b = ctx.table("tiny").as("b")
     val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan
 
     val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => 
join }
@@ -157,12 +158,12 @@ class PlannerSuite extends SparkFunSuite with 
SQLTestUtils {
     assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
     assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
 
-    setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
+    ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
   }
 
   test("efficient limit -> project -> sort") {
     val query = testData.sort('key).select('value).limit(2).logicalPlan
-    val planned = planner.TakeOrderedAndProject(query)
+    val planned = ctx.planner.TakeOrderedAndProject(query)
     assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
index dd08e90..ef6ad59 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
@@ -21,11 +21,11 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
Attribute, Literal, IsNull}
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StructType, 
StringType}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StringType}
 import org.apache.spark.unsafe.types.UTF8String
 
-class RowFormatConvertersSuite extends SparkPlanTest {
+class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext {
 
   private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect {
     case c: ConvertToUnsafe => c
@@ -39,20 +39,20 @@ class RowFormatConvertersSuite extends SparkPlanTest {
 
   test("planner should insert unsafe->safe conversions when required") {
     val plan = Limit(10, outputsUnsafe)
-    val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+    val preparedPlan = ctx.prepareForExecution.execute(plan)
     assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe])
   }
 
   test("filter can process unsafe rows") {
     val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe)
-    val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+    val preparedPlan = ctx.prepareForExecution.execute(plan)
     assert(getConverters(preparedPlan).size === 1)
     assert(preparedPlan.outputsUnsafeRows)
   }
 
   test("filter can process safe rows") {
     val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe)
-    val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+    val preparedPlan = ctx.prepareForExecution.execute(plan)
     assert(getConverters(preparedPlan).isEmpty)
     assert(!preparedPlan.outputsUnsafeRows)
   }
@@ -67,33 +67,33 @@ class RowFormatConvertersSuite extends SparkPlanTest {
   test("union requires all of its input rows' formats to agree") {
     val plan = Union(Seq(outputsSafe, outputsUnsafe))
     assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows)
-    val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+    val preparedPlan = ctx.prepareForExecution.execute(plan)
     assert(preparedPlan.outputsUnsafeRows)
   }
 
   test("union can process safe rows") {
     val plan = Union(Seq(outputsSafe, outputsSafe))
-    val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+    val preparedPlan = ctx.prepareForExecution.execute(plan)
     assert(!preparedPlan.outputsUnsafeRows)
   }
 
   test("union can process unsafe rows") {
     val plan = Union(Seq(outputsUnsafe, outputsUnsafe))
-    val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+    val preparedPlan = ctx.prepareForExecution.execute(plan)
     assert(preparedPlan.outputsUnsafeRows)
   }
 
   test("round trip with ConvertToUnsafe and ConvertToSafe") {
     val input = Seq(("hello", 1), ("world", 2))
     checkAnswer(
-      TestSQLContext.createDataFrame(input),
+      ctx.createDataFrame(input),
       plan => ConvertToSafe(ConvertToUnsafe(plan)),
       input.map(Row.fromTuple)
     )
   }
 
   test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") {
-    SparkPlan.currentContext.set(TestSQLContext)
+    SparkPlan.currentContext.set(ctx)
     val schema = ArrayType(StringType)
     val rows = (1 to 100).map { i =>
       InternalRow(new 
GenericArrayData(Array[Any](UTF8String.fromString(i.toString))))

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
index a2c10fd..8fa77b0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql.execution
 
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.test.SharedSQLContext
 
-class SortSuite extends SparkPlanTest {
+class SortSuite extends SparkPlanTest with SharedSQLContext {
 
   // This test was originally added as an example of how to use 
[[SparkPlanTest]];
   // it's not designed to be a comprehensive test of ExternalSort.

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index f46855e..3a87f37 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -17,29 +17,27 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{SQLContext, DataFrame, DataFrameHolder, Row}
-
 import scala.language.implicitConversions
 import scala.reflect.runtime.universe.TypeTag
 import scala.util.control.NonFatal
 
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.util._
+
 /**
  * Base class for writing tests for individual physical operators. For an 
example of how this
  * class's test helper methods can be used, see [[SortSuite]].
  */
-class SparkPlanTest extends SparkFunSuite {
-
-  protected def sqlContext: SQLContext = TestSQLContext
+private[sql] abstract class SparkPlanTest extends SparkFunSuite {
+  protected def _sqlContext: SQLContext
 
   /**
    * Creates a DataFrame from a local Seq of Product.
    */
   implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: 
Seq[A]): DataFrameHolder = {
-    sqlContext.implicits.localSeqToDataFrameHolder(data)
+    _sqlContext.implicits.localSeqToDataFrameHolder(data)
   }
 
   /**
@@ -100,7 +98,7 @@ class SparkPlanTest extends SparkFunSuite {
       planFunction: Seq[SparkPlan] => SparkPlan,
       expectedAnswer: Seq[Row],
       sortAnswers: Boolean = true): Unit = {
-    SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, 
sortAnswers, sqlContext) match {
+    SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, 
sortAnswers, _sqlContext) match {
       case Some(errorMessage) => fail(errorMessage)
       case None =>
     }
@@ -124,7 +122,7 @@ class SparkPlanTest extends SparkFunSuite {
       expectedPlanFunction: SparkPlan => SparkPlan,
       sortAnswers: Boolean = true): Unit = {
     SparkPlanTest.checkAnswer(
-        input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) 
match {
+        input, planFunction, expectedPlanFunction, sortAnswers, _sqlContext) 
match {
       case Some(errorMessage) => fail(errorMessage)
       case None =>
     }
@@ -151,13 +149,13 @@ object SparkPlanTest {
       planFunction: SparkPlan => SparkPlan,
       expectedPlanFunction: SparkPlan => SparkPlan,
       sortAnswers: Boolean,
-      sqlContext: SQLContext): Option[String] = {
+      _sqlContext: SQLContext): Option[String] = {
 
     val outputPlan = planFunction(input.queryExecution.sparkPlan)
     val expectedOutputPlan = 
expectedPlanFunction(input.queryExecution.sparkPlan)
 
     val expectedAnswer: Seq[Row] = try {
-      executePlan(expectedOutputPlan, sqlContext)
+      executePlan(expectedOutputPlan, _sqlContext)
     } catch {
       case NonFatal(e) =>
         val errorMessage =
@@ -172,7 +170,7 @@ object SparkPlanTest {
     }
 
     val actualAnswer: Seq[Row] = try {
-      executePlan(outputPlan, sqlContext)
+      executePlan(outputPlan, _sqlContext)
     } catch {
       case NonFatal(e) =>
         val errorMessage =
@@ -212,12 +210,12 @@ object SparkPlanTest {
       planFunction: Seq[SparkPlan] => SparkPlan,
       expectedAnswer: Seq[Row],
       sortAnswers: Boolean,
-      sqlContext: SQLContext): Option[String] = {
+      _sqlContext: SQLContext): Option[String] = {
 
     val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))
 
     val sparkAnswer: Seq[Row] = try {
-      executePlan(outputPlan, sqlContext)
+      executePlan(outputPlan, _sqlContext)
     } catch {
       case NonFatal(e) =>
         val errorMessage =
@@ -280,10 +278,10 @@ object SparkPlanTest {
     }
   }
 
-  private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): 
Seq[Row] = {
+  private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): 
Seq[Row] = {
     // A very simple resolver to make writing tests easier. In contrast to the 
real resolver
     // this is always case sensitive and does not try to handle scoping or 
complex type resolution.
-    val resolvedPlan = sqlContext.prepareForExecution.execute(
+    val resolvedPlan = _sqlContext.prepareForExecution.execute(
       outputPlan transform {
         case plan: SparkPlan =>
           val inputMap = plan.children.flatMap(_.output).map(a => (a.name, 
a)).toMap

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
index 88bce0e..3158458 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
@@ -19,25 +19,28 @@ package org.apache.spark.sql.execution
 
 import scala.util.Random
 
-import org.scalatest.BeforeAndAfterAll
-
 import org.apache.spark.AccumulatorSuite
 import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf}
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 
 /**
  * A test suite that generates randomized data to test the [[TungstenSort]] 
operator.
  */
-class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
+class TungstenSortSuite extends SparkPlanTest with SharedSQLContext {
 
   override def beforeAll(): Unit = {
-    TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
+    super.beforeAll()
+    ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
   }
 
   override def afterAll(): Unit = {
-    TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, 
SQLConf.CODEGEN_ENABLED.defaultValue.get)
+    try {
+      ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, 
SQLConf.CODEGEN_ENABLED.defaultValue.get)
+    } finally {
+      super.afterAll()
+    }
   }
 
   test("sort followed by limit") {
@@ -61,7 +64,7 @@ class TungstenSortSuite extends SparkPlanTest with 
BeforeAndAfterAll {
   }
 
   test("sorting updates peak execution memory") {
-    val sc = TestSQLContext.sparkContext
+    val sc = ctx.sparkContext
     AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") {
       checkThatPlansAgree(
         (1 to 100).map(v => Tuple1(v)).toDF("a"),
@@ -80,8 +83,8 @@ class TungstenSortSuite extends SparkPlanTest with 
BeforeAndAfterAll {
   ) {
     test(s"sorting on $dataType with nullable=$nullable, 
sortOrder=$sortOrder") {
       val inputData = Seq.fill(1000)(randomDataGenerator())
-      val inputDf = TestSQLContext.createDataFrame(
-        
TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => 
Row(v))),
+      val inputDf = ctx.createDataFrame(
+        ctx.sparkContext.parallelize(Random.shuffle(inputData).map(v => 
Row(v))),
         StructType(StructField("a", dataType, nullable = true) :: Nil)
       )
       assert(TungstenSort.supportsSchema(inputDf.schema))

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index e034730..d1f0b2b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -26,7 +26,7 @@ import org.scalatest.Matchers
 import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
 import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, 
TaskMemoryManager}
 import org.apache.spark.unsafe.types.UTF8String
@@ -36,7 +36,10 @@ import org.apache.spark.unsafe.types.UTF8String
  *
  * Use [[testWithMemoryLeakDetection]] rather than [[test]] to construct test 
cases.
  */
-class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
+class UnsafeFixedWidthAggregationMapSuite
+  extends SparkFunSuite
+  with Matchers
+  with SharedSQLContext {
 
   import UnsafeFixedWidthAggregationMap._
 
@@ -171,9 +174,6 @@ class UnsafeFixedWidthAggregationMapSuite extends 
SparkFunSuite with Matchers {
   }
 
   testWithMemoryLeakDetection("test external sorting") {
-    // Calling this make sure we have block manager and everything else setup.
-    TestSQLContext
-
     // Memory consumption in the beginning of the task.
     val initialMemoryConsumption = 
shuffleMemoryManager.getMemoryConsumptionForThisTask()
 
@@ -233,8 +233,6 @@ class UnsafeFixedWidthAggregationMapSuite extends 
SparkFunSuite with Matchers {
   }
 
   testWithMemoryLeakDetection("test external sorting with an empty map") {
-    // Calling this make sure we have block manager and everything else setup.
-    TestSQLContext
 
     val map = new UnsafeFixedWidthAggregationMap(
       emptyAggregationBuffer,
@@ -282,8 +280,6 @@ class UnsafeFixedWidthAggregationMapSuite extends 
SparkFunSuite with Matchers {
   }
 
   testWithMemoryLeakDetection("test external sorting with empty records") {
-    // Calling this make sure we have block manager and everything else setup.
-    TestSQLContext
 
     // Memory consumption in the beginning of the task.
     val initialMemoryConsumption = 
shuffleMemoryManager.getMemoryConsumptionForThisTask()

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index a9515a0..d3be568 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -23,15 +23,14 @@ import org.apache.spark._
 import org.apache.spark.sql.{RandomDataGenerator, Row}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, 
UnsafeRow, UnsafeProjection}
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, 
TaskMemoryManager}
 
 /**
  * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test 
data.
  */
-class UnsafeKVExternalSorterSuite extends SparkFunSuite {
-
+class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
   private val keyTypes = Seq(IntegerType, FloatType, DoubleType, StringType)
   private val valueTypes = Seq(IntegerType, FloatType, DoubleType, StringType)
 
@@ -109,8 +108,6 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite {
       inputData: Seq[(InternalRow, InternalRow)],
       pageSize: Long,
       spill: Boolean): Unit = {
-    // Calling this make sure we have block manager and everything else setup.
-    TestSQLContext
 
     val taskMemMgr = new TaskMemoryManager(new 
ExecutorMemoryManager(MemoryAllocator.HEAP))
     val shuffleMemMgr = new TestShuffleMemoryManager

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
index ac22c2f..5fdb82b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
@@ -21,15 +21,12 @@ import org.apache.spark._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection
 import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.unsafe.memory.TaskMemoryManager
 
-class TungstenAggregationIteratorSuite extends SparkFunSuite {
+class TungstenAggregationIteratorSuite extends SparkFunSuite with 
SharedSQLContext {
 
   test("memory acquired on construction") {
-    // set up environment
-    val ctx = TestSQLContext
-
     val taskMemoryManager = new 
TaskMemoryManager(SparkEnv.get.executorMemoryManager)
     val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, 
Seq.empty)
     TaskContext.setTaskContext(taskContext)

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 73d5621..1174b27 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -24,22 +24,16 @@ import com.fasterxml.jackson.core.JsonFactory
 import org.apache.spark.rdd.RDD
 import org.scalactic.Tolerance._
 
-import org.apache.spark.sql.{SQLContext, QueryTest, Row, SQLConf}
-import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.{QueryTest, Row, SQLConf}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, 
LogicalRelation}
 import 
org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.test.SQLTestUtils
 import org.apache.spark.util.Utils
 
-class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
-
-  protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-  override def sqlContext: SQLContext = ctx // used by SQLTestUtils
-
-  import ctx.sql
-  import ctx.implicits._
+class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
+  import testImplicits._
 
   test("Type promotion") {
     def checkTypePromotion(expected: Any, actual: Any) {
@@ -596,7 +590,8 @@ class JsonSuite extends QueryTest with SQLTestUtils with 
TestJsonData {
 
     val schema = StructType(StructField("a", LongType, true) :: Nil)
     val logicalRelation =
-      
ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation]
+      ctx.read.schema(schema).json(path)
+        .queryExecution.analyzed.asInstanceOf[LogicalRelation]
     val relationWithSchema = 
logicalRelation.relation.asInstanceOf[JSONRelation]
     assert(relationWithSchema.paths === Array(path))
     assert(relationWithSchema.schema === schema)
@@ -1040,31 +1035,29 @@ class JsonSuite extends QueryTest with SQLTestUtils 
with TestJsonData {
   }
 
   test("JSONRelation equality test") {
-    val context = org.apache.spark.sql.test.TestSQLContext
-
     val relation0 = new JSONRelation(
       Some(empty),
       1.0,
       Some(StructType(StructField("a", IntegerType, true) :: Nil)),
-      None, None)(context)
+      None, None)(ctx)
     val logicalRelation0 = LogicalRelation(relation0)
     val relation1 = new JSONRelation(
       Some(singleRow),
       1.0,
       Some(StructType(StructField("a", IntegerType, true) :: Nil)),
-      None, None)(context)
+      None, None)(ctx)
     val logicalRelation1 = LogicalRelation(relation1)
     val relation2 = new JSONRelation(
       Some(singleRow),
       0.5,
       Some(StructType(StructField("a", IntegerType, true) :: Nil)),
-      None, None)(context)
+      None, None)(ctx)
     val logicalRelation2 = LogicalRelation(relation2)
     val relation3 = new JSONRelation(
       Some(singleRow),
       1.0,
       Some(StructType(StructField("b", IntegerType, true) :: Nil)),
-      None, None)(context)
+      None, None)(ctx)
     val logicalRelation3 = LogicalRelation(relation3)
 
     assert(relation0 !== relation1)
@@ -1089,14 +1082,14 @@ class JsonSuite extends QueryTest with SQLTestUtils 
with TestJsonData {
         .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path)
 
       val d1 = ResolvedDataSource(
-        context,
+        ctx,
         userSpecifiedSchema = None,
         partitionColumns = Array.empty[String],
         provider = classOf[DefaultSource].getCanonicalName,
         options = Map("path" -> path))
 
       val d2 = ResolvedDataSource(
-        context,
+        ctx,
         userSpecifiedSchema = None,
         partitionColumns = Array.empty[String],
         provider = classOf[DefaultSource].getCanonicalName,
@@ -1162,11 +1155,12 @@ class JsonSuite extends QueryTest with SQLTestUtils 
with TestJsonData {
         "abd")
 
         
ctx.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part")
-        checkAnswer(
-          sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and 
col1='abc'"), Row(4))
-        checkAnswer(
-          sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and 
col1='abd'"), Row(5))
-        checkAnswer(sql("SELECT count(a) FROM test_myjson_with_part where d1 = 
1"), Row(9))
+        checkAnswer(sql(
+          "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and 
col1='abc'"), Row(4))
+        checkAnswer(sql(
+          "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and 
col1='abd'"), Row(5))
+        checkAnswer(sql(
+          "SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9))
     })
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
index 6b62c9a..2864181 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
@@ -20,12 +20,11 @@ package org.apache.spark.sql.execution.datasources.json
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.SQLContext
 
-trait TestJsonData {
-
-  protected def ctx: SQLContext
+private[json] trait TestJsonData {
+  protected def _sqlContext: SQLContext
 
   def primitiveFieldAndType: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"string":"this is a simple string.",
           "integer":10,
           "long":21474836470,
@@ -36,7 +35,7 @@ trait TestJsonData {
       }"""  :: Nil)
 
   def primitiveFieldValueTypeConflict: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1,
           "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" ::
       """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null,
@@ -47,14 +46,14 @@ trait TestJsonData {
           "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" 
:: Nil)
 
   def jsonNullStruct: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       
"""{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}"""
 ::
         """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" ::
         """{"nullstr":"","ip":"27.31.100.29","headers":""}""" ::
         """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil)
 
   def complexFieldValueTypeConflict: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"num_struct":11, "str_array":[1, 2, 3],
           "array":[], "struct_array":[], "struct": {}}""" ::
       """{"num_struct":{"field":false}, "str_array":null,
@@ -65,14 +64,14 @@ trait TestJsonData {
           "array":[7], "struct_array":{"field": true}, "struct": {"field": 
"str"}}""" :: Nil)
 
   def arrayElementTypeConflict: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}],
           "array2": [{"field":214748364700}, {"field":1}]}""" ::
       """{"array3": [{"field":"str"}, {"field":1}]}""" ::
       """{"array3": [1, 2, 3]}""" :: Nil)
 
   def missingFields: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"a":true}""" ::
       """{"b":21474836470}""" ::
       """{"c":[33, 44]}""" ::
@@ -80,7 +79,7 @@ trait TestJsonData {
       """{"e":"str"}""" :: Nil)
 
   def complexFieldAndType1: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"struct":{"field1": true, "field2": 92233720368547758070},
           "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", 
"str2"]},
           "arrayOfString":["str1", "str2"],
@@ -96,7 +95,7 @@ trait TestJsonData {
          }"""  :: Nil)
 
   def complexFieldAndType2: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": 
false}, {"field3": null}],
           "complexArrayOfStruct": [
           {
@@ -150,7 +149,7 @@ trait TestJsonData {
       }""" :: Nil)
 
   def mapType1: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"map": {"a": 1}}""" ::
       """{"map": {"b": 2}}""" ::
       """{"map": {"c": 3}}""" ::
@@ -158,7 +157,7 @@ trait TestJsonData {
       """{"map": {"e": null}}""" :: Nil)
 
   def mapType2: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" ::
       """{"map": {"b": {"field2": 2}}}""" ::
       """{"map": {"c": {"field1": [], "field2": 4}}}""" ::
@@ -167,21 +166,21 @@ trait TestJsonData {
       """{"map": {"f": {"field1": null}}}""" :: Nil)
 
   def nullsInArrays: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{"field1":[[null], [[["Test"]]]]}""" ::
       """{"field2":[null, [{"Test":1}]]}""" ::
       """{"field3":[[null], [{"Test":"2"}]]}""" ::
       """{"field4":[[null, [1,2,3]]]}""" :: Nil)
 
   def jsonArray: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """[{"a":"str_a_1"}]""" ::
       """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" ::
       """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
       """[]""" :: Nil)
 
   def corruptRecords: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{""" ::
       """""" ::
       """{"a":1, b:2}""" ::
@@ -190,7 +189,7 @@ trait TestJsonData {
       """]""" :: Nil)
 
   def emptyRecords: RDD[String] =
-    ctx.sparkContext.parallelize(
+    _sqlContext.sparkContext.parallelize(
       """{""" ::
         """""" ::
         """{"a": {}}""" ::
@@ -198,9 +197,8 @@ trait TestJsonData {
         """{"b": [{"c": {}}]}""" ::
         """]""" :: Nil)
 
-  lazy val singleRow: RDD[String] =
-    ctx.sparkContext.parallelize(
-      """{"a":123}""" :: Nil)
 
-  def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]())
+  lazy val singleRow: RDD[String] = 
_sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil)
+
+  def empty: RDD[String] = _sqlContext.sparkContext.parallelize(Seq[String]())
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala
index 866a975..82d40e2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala
@@ -27,18 +27,16 @@ import org.apache.avro.generic.IndexedRecord
 import org.apache.hadoop.fs.Path
 import org.apache.parquet.avro.AvroParquetWriter
 
-import org.apache.spark.sql.execution.datasources.parquet.test.avro.{Nested, 
ParquetAvroCompat, ParquetEnum, Suit}
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.datasources.parquet.test.avro._
+import org.apache.spark.sql.test.SharedSQLContext
 
-class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest {
+class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with 
SharedSQLContext {
   import ParquetCompatibilityTest._
 
-  override val sqlContext: SQLContext = TestSQLContext
-
   private def withWriter[T <: IndexedRecord]
       (path: String, schema: Schema)
-      (f: AvroParquetWriter[T] => Unit) = {
+      (f: AvroParquetWriter[T] => Unit): Unit = {
     val writer = new AvroParquetWriter[T](new Path(path), schema)
     try f(writer) finally writer.close()
   }
@@ -129,7 +127,7 @@ class ParquetAvroCompatibilitySuite extends 
ParquetCompatibilityTest {
   }
 
   test("SPARK-9407 Don't push down predicates involving Parquet ENUM columns") 
{
-    import sqlContext.implicits._
+    import testImplicits._
 
     withTempPath { dir =>
       val path = dir.getCanonicalPath

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
index 0ea64aa..b340672 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
@@ -22,16 +22,18 @@ import scala.collection.JavaConversions._
 import org.apache.hadoop.fs.{Path, PathFilter}
 import org.apache.parquet.hadoop.ParquetFileReader
 import org.apache.parquet.schema.MessageType
-import org.scalatest.BeforeAndAfterAll
 
 import org.apache.spark.sql.QueryTest
 
-abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest 
with BeforeAndAfterAll {
-  def readParquetSchema(path: String): MessageType = {
+/**
+ * Helper class for testing Parquet compatibility.
+ */
+private[sql] abstract class ParquetCompatibilityTest extends QueryTest with 
ParquetTest {
+  protected def readParquetSchema(path: String): MessageType = {
     readParquetSchema(path, { path => !path.getName.startsWith("_") })
   }
 
-  def readParquetSchema(path: String, pathFilter: Path => Boolean): 
MessageType = {
+  protected def readParquetSchema(path: String, pathFilter: Path => Boolean): 
MessageType = {
     val fsPath = new Path(path)
     val fs = fsPath.getFileSystem(configuration)
     val parquetFiles = fs.listStatus(fsPath, new PathFilter {

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index 7dd9680..5b4e568 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -20,12 +20,13 @@ package org.apache.spark.sql.execution.datasources.parquet
 import org.apache.parquet.filter2.predicate.Operators._
 import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators}
 
+import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf}
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning.PhysicalOperation
 import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf}
 
 /**
  * A test suite that tests Parquet filter2 API based filter pushdown 
optimization.
@@ -39,8 +40,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, 
Row, SQLConf}
  * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to 
ensure the inferred
  *    data type is nullable.
  */
-class ParquetFilterSuite extends QueryTest with ParquetTest {
-  lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
+class ParquetFilterSuite extends QueryTest with ParquetTest with 
SharedSQLContext {
 
   private def checkFilterPredicate(
       df: DataFrame,
@@ -301,7 +301,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest 
{
   }
 
   test("SPARK-6554: don't push down predicates which reference partition 
columns") {
-    import sqlContext.implicits._
+    import testImplicits._
 
     withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") {
       withTempPath { dir =>

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index cb16634..d819f3a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -37,6 +37,7 @@ import org.apache.spark.SparkException
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 
 // Write support class for nested groups: ParquetWriter initializes 
GroupWriteSupport
@@ -62,9 +63,8 @@ private[parquet] class TestGroupWriteSupport(schema: 
MessageType) extends WriteS
 /**
  * A test suite that tests basic Parquet I/O.
  */
-class ParquetIOSuite extends QueryTest with ParquetTest {
-  lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
-  import sqlContext.implicits._
+class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
+  import testImplicits._
 
   /**
    * Writes `data` to a Parquet file, reads it back and check file contents.

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
index 73152de..ed8bafb 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
@@ -26,13 +26,13 @@ import scala.collection.mutable.ArrayBuffer
 import com.google.common.io.Files
 import org.apache.hadoop.fs.Path
 
+import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Literal
 import org.apache.spark.sql.execution.datasources.{LogicalRelation, 
PartitionSpec, Partition, PartitioningUtils}
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
-import org.apache.spark.sql._
 import org.apache.spark.unsafe.types.UTF8String
-import PartitioningUtils._
 
 // The data where the partitioning key exists only in the directory structure.
 case class ParquetData(intField: Int, stringField: String)
@@ -40,11 +40,9 @@ case class ParquetData(intField: Int, stringField: String)
 // The data that also includes the partitioning key
 case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: 
String)
 
-class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
-
-  override lazy val sqlContext: SQLContext = 
org.apache.spark.sql.test.TestSQLContext
-  import sqlContext.implicits._
-  import sqlContext.sql
+class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with 
SharedSQLContext {
+  import PartitioningUtils._
+  import testImplicits._
 
   val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__"
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala
index 981334c..b290429 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala
@@ -17,11 +17,10 @@
 
 package org.apache.spark.sql.execution.datasources.parquet
 
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.test.SharedSQLContext
 
-class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest {
-  override def sqlContext: SQLContext = TestSQLContext
+class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with 
SharedSQLContext {
 
   private def readParquetProtobufFile(name: String): DataFrame = {
     val url = Thread.currentThread().getContextClassLoader.getResource(name)

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index 5e6d9c1..e2f2a8c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -21,16 +21,15 @@ import java.io.File
 
 import org.apache.hadoop.fs.Path
 
-import org.apache.spark.sql.types._
 import org.apache.spark.sql.{QueryTest, Row, SQLConf}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
 /**
  * A test suite that tests various Parquet queries.
  */
-class ParquetQuerySuite extends QueryTest with ParquetTest {
-  lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
-  import sqlContext.sql
+class ParquetQuerySuite extends QueryTest with ParquetTest with 
SharedSQLContext {
 
   test("simple select queries") {
     withParquetTable((0 until 10).map(i => (i, i.toString)), "t") {
@@ -41,22 +40,22 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
 
   test("appending") {
     val data = (0 until 10).map(i => (i, i.toString))
-    sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
+    ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
     withParquetTable(data, "t") {
       sql("INSERT INTO TABLE t SELECT * FROM tmp")
-      checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple))
+      checkAnswer(ctx.table("t"), (data ++ data).map(Row.fromTuple))
     }
-    sqlContext.catalog.unregisterTable(Seq("tmp"))
+    ctx.catalog.unregisterTable(Seq("tmp"))
   }
 
   test("overwriting") {
     val data = (0 until 10).map(i => (i, i.toString))
-    sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
+    ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
     withParquetTable(data, "t") {
       sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
-      checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple))
+      checkAnswer(ctx.table("t"), data.map(Row.fromTuple))
     }
-    sqlContext.catalog.unregisterTable(Seq("tmp"))
+    ctx.catalog.unregisterTable(Seq("tmp"))
   }
 
   test("self-join") {
@@ -119,9 +118,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
     val schema = StructType(List(StructField("d", DecimalType(18, 0), false),
       StructField("time", TimestampType, false)).toArray)
     withTempPath { file =>
-      val df = 
sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema)
+      val df = ctx.createDataFrame(ctx.sparkContext.parallelize(data), schema)
       df.write.parquet(file.getCanonicalPath)
-      val df2 = sqlContext.read.parquet(file.getCanonicalPath)
+      val df2 = ctx.read.parquet(file.getCanonicalPath)
       checkAnswer(df2, df.collect().toSeq)
     }
   }
@@ -130,12 +129,12 @@ class ParquetQuerySuite extends QueryTest with 
ParquetTest {
     def testSchemaMerging(expectedColumnNumber: Int): Unit = {
       withTempDir { dir =>
         val basePath = dir.getCanonicalPath
-        sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, 
"foo=1").toString)
-        sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, 
"foo=2").toString)
+        ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, 
"foo=1").toString)
+        ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, 
"foo=2").toString)
         // delete summary files, so if we don't merge part-files, one column 
will not be included.
         Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata"))
         Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata"))
-        assert(sqlContext.read.parquet(basePath).columns.length === 
expectedColumnNumber)
+        assert(ctx.read.parquet(basePath).columns.length === 
expectedColumnNumber)
       }
     }
 
@@ -154,9 +153,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
     def testSchemaMerging(expectedColumnNumber: Int): Unit = {
       withTempDir { dir =>
         val basePath = dir.getCanonicalPath
-        sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, 
"foo=1").toString)
-        sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, 
"foo=2").toString)
-        assert(sqlContext.read.parquet(basePath).columns.length === 
expectedColumnNumber)
+        ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, 
"foo=1").toString)
+        ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, 
"foo=2").toString)
+        assert(ctx.read.parquet(basePath).columns.length === 
expectedColumnNumber)
       }
     }
 
@@ -172,19 +171,19 @@ class ParquetQuerySuite extends QueryTest with 
ParquetTest {
   test("SPARK-8990 DataFrameReader.parquet() should respect user specified 
options") {
     withTempPath { dir =>
       val basePath = dir.getCanonicalPath
-      sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, 
"foo=1").toString)
-      sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, 
"foo=a").toString)
+      ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, 
"foo=1").toString)
+      ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, 
"foo=a").toString)
 
       // Disables the global SQL option for schema merging
       withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") {
         assertResult(2) {
           // Disables schema merging via data source option
-          sqlContext.read.option("mergeSchema", 
"false").parquet(basePath).columns.length
+          ctx.read.option("mergeSchema", 
"false").parquet(basePath).columns.length
         }
 
         assertResult(3) {
           // Enables schema merging via data source option
-          sqlContext.read.option("mergeSchema", 
"true").parquet(basePath).columns.length
+          ctx.read.option("mergeSchema", 
"true").parquet(basePath).columns.length
         }
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
index 971f71e..9dcbc1a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
@@ -22,13 +22,11 @@ import scala.reflect.runtime.universe.TypeTag
 
 import org.apache.parquet.schema.MessageTypeParser
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 
-abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest {
-  val sqlContext = TestSQLContext
+abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext {
 
   /**
    * Checks whether the reflected Parquet message type for product type `T` 
conforms `messageType`.

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
index 3c6e54d..5dbc7d1 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
@@ -22,9 +22,8 @@ import java.io.File
 import scala.reflect.ClassTag
 import scala.reflect.runtime.universe.TypeTag
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.{DataFrame, SaveMode}
+import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext}
 
 /**
  * A helper trait that provides convenient facilities for Parquet testing.
@@ -33,7 +32,9 @@ import org.apache.spark.sql.{DataFrame, SaveMode}
  * convenient to use tuples rather than special case classes when writing test 
cases/suites.
  * Especially, `Tuple1.apply` can be used to easily wrap a single type/value.
  */
-private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
+private[sql] trait ParquetTest extends SQLTestUtils {
+  protected def _sqlContext: SQLContext
+
   /**
    * Writes `data` to a Parquet file, which is then passed to `f` and will be 
deleted after `f`
    * returns.
@@ -42,7 +43,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: 
SparkFunSuite =>
       (data: Seq[T])
       (f: String => Unit): Unit = {
     withTempPath { file =>
-      sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath)
+      _sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath)
       f(file.getCanonicalPath)
     }
   }
@@ -54,7 +55,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: 
SparkFunSuite =>
   protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag]
       (data: Seq[T])
       (f: DataFrame => Unit): Unit = {
-    withParquetFile(data)(path => f(sqlContext.read.parquet(path)))
+    withParquetFile(data)(path => f(_sqlContext.read.parquet(path)))
   }
 
   /**
@@ -66,14 +67,14 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: 
SparkFunSuite =>
       (data: Seq[T], tableName: String)
       (f: => Unit): Unit = {
     withParquetDataFrame(data) { df =>
-      sqlContext.registerDataFrameAsTable(df, tableName)
+      _sqlContext.registerDataFrameAsTable(df, tableName)
       withTempTable(tableName)(f)
     }
   }
 
   protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
       data: Seq[T], path: File): Unit = {
-    
sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
+    
_sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
   }
 
   protected def makeParquetFile[T <: Product: ClassTag: TypeTag](

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
index 92b1d82..b789c5a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
@@ -17,14 +17,12 @@
 
 package org.apache.spark.sql.execution.datasources.parquet
 
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.test.SharedSQLContext
 
-class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest {
+class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with 
SharedSQLContext {
   import ParquetCompatibilityTest._
 
-  override val sqlContext: SQLContext = TestSQLContext
-
   private val parquetFilePath =
     
Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
index 239deb7..2218947 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
@@ -18,10 +18,10 @@
 package org.apache.spark.sql.execution.debug
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.SharedSQLContext
+
+class DebuggingSuite extends SparkFunSuite with SharedSQLContext {
 
-class DebuggingSuite extends SparkFunSuite {
   test("DataFrame.debug()") {
     testData.debug()
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/8187b3ae/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index d33a967..4c9187a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -23,12 +23,12 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
 import org.apache.spark.util.collection.CompactBuffer
 
 
-class HashedRelationSuite extends SparkFunSuite {
+class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
 
   // Key is simply the record itself
   private val keyProjection = new Projection {
@@ -37,7 +37,7 @@ class HashedRelationSuite extends SparkFunSuite {
 
   test("GeneralHashedRelation") {
     val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), 
InternalRow(2))
-    val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, 
"data")
+    val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
     val hashed = HashedRelation(data.iterator, numDataRows, keyProjection)
     assert(hashed.isInstanceOf[GeneralHashedRelation])
 
@@ -53,7 +53,7 @@ class HashedRelationSuite extends SparkFunSuite {
 
   test("UniqueKeyHashedRelation") {
     val data = Array(InternalRow(0), InternalRow(1), InternalRow(2))
-    val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, 
"data")
+    val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
     val hashed = HashedRelation(data.iterator, numDataRows, keyProjection)
     assert(hashed.isInstanceOf[UniqueKeyHashedRelation])
 
@@ -73,7 +73,7 @@ class HashedRelationSuite extends SparkFunSuite {
   test("UnsafeHashedRelation") {
     val schema = StructType(StructField("a", IntegerType, true) :: Nil)
     val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), 
InternalRow(2))
-    val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, 
"data")
+    val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
     val toUnsafe = UnsafeProjection.create(schema)
     val unsafeData = data.map(toUnsafe(_).copy()).toArray
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to