[SPARK-9580] [SQL] Replace singletons in SQL tests A fundamental limitation of the existing SQL tests is that *there is simply no way to create your own `SparkContext`*. This is a serious limitation because the user may wish to use a different master or config. As a case in point, `BroadcastJoinSuite` is entirely commented out because there is no way to make it pass with the existing infrastructure.
This patch removes the singletons `TestSQLContext` and `TestData`, and instead introduces a `SharedSQLContext` that starts a context per suite. Unfortunately the singletons were so ingrained in the SQL tests that this patch necessarily needed to touch *all* the SQL test files. <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/8111) <!-- Reviewable:end --> Author: Andrew Or <[email protected]> Closes #8111 from andrewor14/sql-tests-refactor. (cherry picked from commit 8187b3ae477e2b2987ae9acc5368d57b1d5653b2) Signed-off-by: Reynold Xin <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9df2a2d7 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9df2a2d7 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9df2a2d7 Branch: refs/heads/branch-1.5 Commit: 9df2a2d76083bea849c106cdd87dd5c489ac262d Parents: b318b11 Author: Andrew Or <[email protected]> Authored: Thu Aug 13 17:42:01 2015 -0700 Committer: Reynold Xin <[email protected]> Committed: Thu Aug 13 17:42:11 2015 -0700 ---------------------------------------------------------------------- project/MimaExcludes.scala | 10 + project/SparkBuild.scala | 16 +- .../catalyst/analysis/AnalysisErrorSuite.scala | 6 +- .../scala/org/apache/spark/sql/SQLContext.scala | 97 +------ .../org/apache/spark/sql/SQLImplicits.scala | 123 ++++++++ .../apache/spark/sql/test/TestSQLContext.scala | 56 ---- .../apache/spark/sql/JavaApplySchemaSuite.java | 10 +- .../apache/spark/sql/JavaDataFrameSuite.java | 39 +-- .../test/org/apache/spark/sql/JavaUDFSuite.java | 10 +- .../spark/sql/sources/JavaSaveLoadSuite.java | 15 +- .../org/apache/spark/sql/CachedTableSuite.scala | 14 +- .../spark/sql/ColumnExpressionSuite.scala | 37 ++- .../spark/sql/DataFrameAggregateSuite.scala | 10 +- .../spark/sql/DataFrameFunctionsSuite.scala | 18 +- .../spark/sql/DataFrameImplicitsSuite.scala | 6 +- .../apache/spark/sql/DataFrameJoinSuite.scala | 10 +- .../spark/sql/DataFrameNaFunctionsSuite.scala | 6 +- .../apache/spark/sql/DataFrameStatSuite.scala | 19 +- .../org/apache/spark/sql/DataFrameSuite.scala | 12 +- .../spark/sql/DataFrameTungstenSuite.scala | 8 +- .../apache/spark/sql/DateFunctionsSuite.scala | 13 +- .../scala/org/apache/spark/sql/JoinSuite.scala | 47 ++- .../apache/spark/sql/JsonFunctionsSuite.scala | 6 +- .../org/apache/spark/sql/ListTablesSuite.scala | 15 +- .../apache/spark/sql/MathExpressionsSuite.scala | 28 +- .../scala/org/apache/spark/sql/QueryTest.scala | 6 - .../scala/org/apache/spark/sql/RowSuite.scala | 7 +- .../org/apache/spark/sql/SQLConfSuite.scala | 19 +- .../org/apache/spark/sql/SQLContextSuite.scala | 13 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 29 +- .../sql/ScalaReflectionRelationSuite.scala | 17 +- .../apache/spark/sql/SerializationSuite.scala | 9 +- .../apache/spark/sql/StringFunctionsSuite.scala | 7 +- .../scala/org/apache/spark/sql/TestData.scala | 197 ------------- .../scala/org/apache/spark/sql/UDFSuite.scala | 39 ++- .../apache/spark/sql/UserDefinedTypeSuite.scala | 9 +- .../columnar/InMemoryColumnarQuerySuite.scala | 14 +- .../columnar/PartitionBatchPruningSuite.scala | 31 +- .../spark/sql/execution/ExchangeSuite.scala | 3 +- .../spark/sql/execution/PlannerSuite.scala | 39 +-- .../execution/RowFormatConvertersSuite.scala | 22 +- .../apache/spark/sql/execution/SortSuite.scala | 3 +- .../spark/sql/execution/SparkPlanTest.scala | 36 ++- .../spark/sql/execution/TungstenSortSuite.scala | 21 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 14 +- .../execution/UnsafeKVExternalSorterSuite.scala | 7 +- .../TungstenAggregationIteratorSuite.scala | 7 +- .../execution/datasources/json/JsonSuite.scala | 42 ++- .../datasources/json/TestJsonData.scala | 40 ++- .../parquet/ParquetAvroCompatibilitySuite.scala | 14 +- .../parquet/ParquetCompatibilityTest.scala | 10 +- .../parquet/ParquetFilterSuite.scala | 8 +- .../datasources/parquet/ParquetIOSuite.scala | 6 +- .../ParquetPartitionDiscoverySuite.scala | 12 +- .../ParquetProtobufCompatibilitySuite.scala | 7 +- .../datasources/parquet/ParquetQuerySuite.scala | 43 ++- .../parquet/ParquetSchemaSuite.scala | 6 +- .../datasources/parquet/ParquetTest.scala | 15 +- .../ParquetThriftCompatibilitySuite.scala | 8 +- .../sql/execution/debug/DebuggingSuite.scala | 6 +- .../execution/joins/HashedRelationSuite.scala | 10 +- .../sql/execution/joins/InnerJoinSuite.scala | 248 +++++++++------- .../sql/execution/joins/OuterJoinSuite.scala | 125 ++++---- .../sql/execution/joins/SemiJoinSuite.scala | 113 ++++---- .../sql/execution/metric/SQLMetricsSuite.scala | 62 ++-- .../sql/execution/ui/SQLListenerSuite.scala | 14 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 9 +- .../apache/spark/sql/jdbc/JDBCWriteSuite.scala | 20 +- .../sql/sources/CreateTableAsSelectSuite.scala | 18 +- .../spark/sql/sources/DDLSourceLoadSuite.scala | 3 +- .../apache/spark/sql/sources/DDLTestSuite.scala | 11 +- .../spark/sql/sources/DataSourceTest.scala | 17 +- .../spark/sql/sources/FilteredScanSuite.scala | 9 +- .../apache/spark/sql/sources/InsertSuite.scala | 30 +- .../sql/sources/PartitionedWriteSuite.scala | 14 +- .../spark/sql/sources/PrunedScanSuite.scala | 11 +- .../spark/sql/sources/SaveLoadSuite.scala | 26 +- .../spark/sql/sources/TableScanSuite.scala | 15 +- .../org/apache/spark/sql/test/SQLTestData.scala | 290 +++++++++++++++++++ .../apache/spark/sql/test/SQLTestUtils.scala | 92 +++++- .../spark/sql/test/SharedSQLContext.scala | 68 +++++ .../apache/spark/sql/test/TestSQLContext.scala | 52 ++++ .../sql/hive/thriftserver/UISeleniumSuite.scala | 2 - .../sql/hive/HiveMetastoreCatalogSuite.scala | 5 +- .../spark/sql/hive/HiveParquetSuite.scala | 11 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 4 +- .../spark/sql/hive/MultiDatabaseSuite.scala | 5 +- .../hive/ParquetHiveCompatibilitySuite.scala | 3 +- .../org/apache/spark/sql/hive/UDFSuite.scala | 4 +- .../hive/execution/AggregationQuerySuite.scala | 9 +- .../sql/hive/execution/HiveExplainSuite.scala | 4 +- .../sql/hive/execution/SQLQuerySuite.scala | 3 +- .../execution/ScriptTransformationSuite.scala | 3 +- .../org/apache/spark/sql/hive/orc/OrcTest.scala | 4 +- .../apache/spark/sql/hive/parquetSuites.scala | 3 +- .../CommitFailureTestRelationSuite.scala | 6 +- .../sql/sources/hadoopFsRelationSuites.scala | 5 +- 97 files changed, 1491 insertions(+), 1234 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/project/MimaExcludes.scala ---------------------------------------------------------------------- diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 784f83c..88745dc 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -179,6 +179,16 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.SparkContext.supportDynamicAllocation") ) ++ Seq( + // SPARK-9580: Remove SQL test singletons + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.LocalSQLContext$SQLSession"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.LocalSQLContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.TestSQLContext"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.test.TestSQLContext$") + ) ++ Seq( // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs ProblemFilters.exclude[IncompatibleResultTypeProblem]( "org.apache.spark.mllib.linalg.VectorUDT.serialize") http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/project/SparkBuild.scala ---------------------------------------------------------------------- diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 74f815f..04e0d49 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -319,6 +319,8 @@ object SQL { lazy val settings = Seq( initialCommands in console := """ + |import org.apache.spark.SparkContext + |import org.apache.spark.sql.SQLContext |import org.apache.spark.sql.catalyst.analysis._ |import org.apache.spark.sql.catalyst.dsl._ |import org.apache.spark.sql.catalyst.errors._ @@ -328,9 +330,14 @@ object SQL { |import org.apache.spark.sql.catalyst.util._ |import org.apache.spark.sql.execution |import org.apache.spark.sql.functions._ - |import org.apache.spark.sql.test.TestSQLContext._ - |import org.apache.spark.sql.types._""".stripMargin, - cleanupCommands in console := "sparkContext.stop()" + |import org.apache.spark.sql.types._ + | + |val sc = new SparkContext("local[*]", "dev-shell") + |val sqlContext = new SQLContext(sc) + |import sqlContext.implicits._ + |import sqlContext._ + """.stripMargin, + cleanupCommands in console := "sc.stop()" ) } @@ -340,8 +347,6 @@ object Hive { javaOptions += "-XX:MaxPermSize=256m", // Specially disable assertions since some Hive tests fail them javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"), - // Multiple queries rely on the TestHive singleton. See comments there for more details. - parallelExecution in Test := false, // Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings // only for this subproject. scalacOptions <<= scalacOptions map { currentOpts: Seq[String] => @@ -349,6 +354,7 @@ object Hive { }, initialCommands in console := """ + |import org.apache.spark.SparkContext |import org.apache.spark.sql.catalyst.analysis._ |import org.apache.spark.sql.catalyst.dsl._ |import org.apache.spark.sql.catalyst.errors._ http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 63b475b..f60d11c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -17,14 +17,10 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.BeforeAndAfter - -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.types._ @@ -42,7 +38,7 @@ case class UnresolvedTestPlan() extends LeafNode { override def output: Seq[Attribute] = Nil } -class AnalysisErrorSuite extends AnalysisTest with BeforeAndAfter { +class AnalysisErrorSuite extends AnalysisTest { import TestRelations._ def errorTest( http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4bf00b3..53de10d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -23,7 +23,6 @@ import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConversions._ import scala.collection.immutable -import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -41,10 +40,9 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** @@ -334,97 +332,10 @@ class SQLContext(@transient val sparkContext: SparkContext) * @since 1.3.0 */ @Experimental - object implicits extends Serializable { - // scalastyle:on - - /** - * Converts $"col name" into an [[Column]]. - * @since 1.3.0 - */ - implicit class StringToColumn(val sc: StringContext) { - def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args: _*)) - } - } - - /** - * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. - * @since 1.3.0 - */ - implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) - - /** - * Creates a DataFrame from an RDD of case classes or tuples. - * @since 1.3.0 - */ - implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { - DataFrameHolder(self.createDataFrame(rdd)) - } - - /** - * Creates a DataFrame from a local Seq of Product. - * @since 1.3.0 - */ - implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = - { - DataFrameHolder(self.createDataFrame(data)) - } - - // Do NOT add more implicit conversions. They are likely to break source compatibility by - // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous - // because of [[DoubleRDDFunctions]]. - - /** - * Creates a single column DataFrame from an RDD[Int]. - * @since 1.3.0 - */ - implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = { - val dataType = IntegerType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.setInt(0, v) - row: InternalRow - } - } - DataFrameHolder( - self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } - - /** - * Creates a single column DataFrame from an RDD[Long]. - * @since 1.3.0 - */ - implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = { - val dataType = LongType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.setLong(0, v) - row: InternalRow - } - } - DataFrameHolder( - self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } - - /** - * Creates a single column DataFrame from an RDD[String]. - * @since 1.3.0 - */ - implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { - val dataType = StringType - val rows = data.mapPartitions { iter => - val row = new SpecificMutableRow(dataType :: Nil) - iter.map { v => - row.update(0, UTF8String.fromString(v)) - row: InternalRow - } - } - DataFrameHolder( - self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) - } + object implicits extends SQLImplicits with Serializable { + protected override def _sqlContext: SQLContext = self } + // scalastyle:on /** * :: Experimental :: http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala new file mode 100644 index 0000000..5f82372 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.types.StructField +import org.apache.spark.unsafe.types.UTF8String + +/** + * A collection of implicit methods for converting common Scala objects into [[DataFrame]]s. + */ +private[sql] abstract class SQLImplicits { + protected def _sqlContext: SQLContext + + /** + * Converts $"col name" into an [[Column]]. + * @since 1.3.0 + */ + implicit class StringToColumn(val sc: StringContext) { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args: _*)) + } + } + + /** + * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. + * @since 1.3.0 + */ + implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) + + /** + * Creates a DataFrame from an RDD of case classes or tuples. + * @since 1.3.0 + */ + implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { + DataFrameHolder(_sqlContext.createDataFrame(rdd)) + } + + /** + * Creates a DataFrame from a local Seq of Product. + * @since 1.3.0 + */ + implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = + { + DataFrameHolder(_sqlContext.createDataFrame(data)) + } + + // Do NOT add more implicit conversions. They are likely to break source compatibility by + // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous + // because of [[DoubleRDDFunctions]]. + + /** + * Creates a single column DataFrame from an RDD[Int]. + * @since 1.3.0 + */ + implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = { + val dataType = IntegerType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.setInt(0, v) + row: InternalRow + } + } + DataFrameHolder( + _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } + + /** + * Creates a single column DataFrame from an RDD[Long]. + * @since 1.3.0 + */ + implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = { + val dataType = LongType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.setLong(0, v) + row: InternalRow + } + } + DataFrameHolder( + _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } + + /** + * Creates a single column DataFrame from an RDD[String]. + * @since 1.3.0 + */ + implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = { + val dataType = StringType + val rows = data.mapPartitions { iter => + val row = new SpecificMutableRow(dataType :: Nil) + iter.map { v => + row.update(0, UTF8String.fromString(v)) + row: InternalRow + } + } + DataFrameHolder( + _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala deleted file mode 100644 index b3a4231..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.test - -import scala.language.implicitConversions - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - -/** A SQLContext that can be used for local testing. */ -class LocalSQLContext - extends SQLContext( - new SparkContext("local[2]", "TestSQLContext", new SparkConf() - .set("spark.sql.testkey", "true") - // SPARK-8910 - .set("spark.ui.enabled", "false"))) { - - override protected[sql] def createSession(): SQLSession = { - new this.SQLSession() - } - - protected[sql] class SQLSession extends super.SQLSession { - protected[sql] override lazy val conf: SQLConf = new SQLConf { - /** Fewer partitions to speed up testing. */ - override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5) - } - } - - /** - * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier way to - * construct [[DataFrame]] directly out of local data without relying on implicits. - */ - protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - DataFrame(this, plan) - } - -} - -object TestSQLContext extends LocalSQLContext - http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index e912eb8..bf693c7 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -27,6 +27,7 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; @@ -34,7 +35,6 @@ import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -48,14 +48,16 @@ public class JavaApplySchemaSuite implements Serializable { @Before public void setUp() { - sqlContext = TestSQLContext$.MODULE$; - javaCtx = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext context = new SparkContext("local[*]", "testing"); + javaCtx = new JavaSparkContext(context); + sqlContext = new SQLContext(context); } @After public void tearDown() { - javaCtx = null; + sqlContext.sparkContext().stop(); sqlContext = null; + javaCtx = null; } public static class Person implements Serializable { http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 7302361..7abdd3d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -17,44 +17,45 @@ package test.org.apache.spark.sql; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; + +import scala.collection.JavaConversions; +import scala.collection.Seq; + import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; +import org.junit.*; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; +import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.test.TestSQLContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.*; -import org.junit.*; - -import scala.collection.JavaConversions; -import scala.collection.Seq; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.Comparator; -import java.util.List; -import java.util.Map; - -import static org.apache.spark.sql.functions.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; - private transient SQLContext context; + private transient TestSQLContext context; @Before public void setUp() { // Trigger static initializer of TestData - TestData$.MODULE$.testData(); - jsc = new JavaSparkContext(TestSQLContext.sparkContext()); - context = TestSQLContext$.MODULE$; + SparkContext sc = new SparkContext("local[*]", "testing"); + jsc = new JavaSparkContext(sc); + context = new TestSQLContext(sc); + context.loadTestData(); } @After public void tearDown() { - jsc = null; + context.sparkContext().stop(); context = null; + jsc = null; } @Test @@ -230,7 +231,7 @@ public class JavaDataFrameSuite { @Test public void testSampleBy() { - DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key")); + DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 79d9273..bb02b58 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -23,12 +23,12 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.api.java.UDF2; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.DataTypes; // The test suite itself is Serializable so that anonymous Function implementations can be @@ -40,12 +40,16 @@ public class JavaUDFSuite implements Serializable { @Before public void setUp() { - sqlContext = TestSQLContext$.MODULE$; - sc = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext _sc = new SparkContext("local[*]", "testing"); + sqlContext = new SQLContext(_sc); + sc = new JavaSparkContext(_sc); } @After public void tearDown() { + sqlContext.sparkContext().stop(); + sqlContext = null; + sc = null; } @SuppressWarnings("unchecked") http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 2706e01..6f9e7f6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -21,13 +21,14 @@ import java.io.File; import java.io.IOException; import java.util.*; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; @@ -52,8 +53,9 @@ public class JavaSaveLoadSuite { @Before public void setUp() throws IOException { - sqlContext = TestSQLContext$.MODULE$; - sc = new JavaSparkContext(sqlContext.sparkContext()); + SparkContext _sc = new SparkContext("local[*]", "testing"); + sqlContext = new SQLContext(_sc); + sc = new JavaSparkContext(_sc); originalDefaultSource = sqlContext.conf().defaultDataSourceName(); path = @@ -71,6 +73,13 @@ public class JavaSaveLoadSuite { df.registerTempTable("jsonTable"); } + @After + public void tearDown() { + sqlContext.sparkContext().stop(); + sqlContext = null; + sc = null; + } + @Test public void saveAndLoad() { Map<String, String> options = new HashMap<String, String>(); http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index a88df91..af7590c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -18,24 +18,20 @@ package org.apache.spark.sql import scala.concurrent.duration._ -import scala.language.{implicitConversions, postfixOps} +import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.storage.{StorageLevel, RDDBlockId} -case class BigData(s: String) +private case class BigData(s: String) -class CachedTableSuite extends QueryTest { - TestData // Load test tables. - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.sql +class CachedTableSuite extends QueryTest with SharedSQLContext { + import testImplicits._ def rddIdOf(tableName: String): Int = { val executedPlan = ctx.table(tableName).queryExecution.executedPlan http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 6a09a3b..ee74e3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -21,16 +21,20 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.execution.{Project, TungstenProject} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.SQLTestUtils -class ColumnExpressionSuite extends QueryTest with SQLTestUtils { - import org.apache.spark.sql.TestData._ +class ColumnExpressionSuite extends QueryTest with SharedSQLContext { + import testImplicits._ - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - - override def sqlContext(): SQLContext = ctx + private lazy val booleanData = { + ctx.createDataFrame(ctx.sparkContext.parallelize( + Row(false, false) :: + Row(false, true) :: + Row(true, false) :: + Row(true, true) :: Nil), + StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) + } test("column names with space") { val df = Seq((1, "a")).toDF("name with space", "name.with.dot") @@ -258,7 +262,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { nullStrings.collect().toSeq.filter(r => r.getString(1) eq null)) checkAnswer( - ctx.sql("select isnull(null), isnull(1)"), + sql("select isnull(null), isnull(1)"), Row(true, false)) } @@ -268,7 +272,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { nullStrings.collect().toSeq.filter(r => r.getString(1) ne null)) checkAnswer( - ctx.sql("select isnotnull(null), isnotnull('a')"), + sql("select isnotnull(null), isnotnull('a')"), Row(false, true)) } @@ -289,7 +293,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) checkAnswer( - ctx.sql("select isnan(15), isnan('invalid')"), + sql("select isnan(15), isnan('invalid')"), Row(false, false)) } @@ -309,7 +313,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { ) testData.registerTempTable("t") checkAnswer( - ctx.sql( + sql( "select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " + " nanvl(b, e), nanvl(e, f) from t"), Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0) @@ -433,13 +437,6 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { } } - val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize( - Row(false, false) :: - Row(false, true) :: - Row(true, false) :: - Row(true, true) :: Nil), - StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) - test("&&") { checkAnswer( booleanData.filter($"a" && true), @@ -523,7 +520,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { ) checkAnswer( - ctx.sql("SELECT upper('aB'), ucase('cDe')"), + sql("SELECT upper('aB'), ucase('cDe')"), Row("AB", "CDE")) } @@ -544,7 +541,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { ) checkAnswer( - ctx.sql("SELECT lower('aB'), lcase('cDe')"), + sql("SELECT lower('aB'), lcase('cDe')"), Row("ab", "cde")) } http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f9cff74..72cf7aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{BinaryType, DecimalType} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.DecimalType -class DataFrameAggregateSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("groupBy") { checkAnswer( http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 03116a3..9d96525 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** * Test suite for functions in [[org.apache.spark.sql.functions]]. */ -class DataFrameFunctionsSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("array with column name") { val df = Seq((0, 1)).toDF("a", "b") @@ -119,11 +117,11 @@ class DataFrameFunctionsSuite extends QueryTest { test("constant functions") { checkAnswer( - ctx.sql("SELECT E()"), + sql("SELECT E()"), Row(scala.math.E) ) checkAnswer( - ctx.sql("SELECT PI()"), + sql("SELECT PI()"), Row(scala.math.Pi) ) } @@ -153,7 +151,7 @@ class DataFrameFunctionsSuite extends QueryTest { test("nvl function") { checkAnswer( - ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), + sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), Row("x", "y", null)) } @@ -222,7 +220,7 @@ class DataFrameFunctionsSuite extends QueryTest { Row(-1) ) checkAnswer( - ctx.sql("SELECT least(a, 2) as l from testData2 order by l"), + sql("SELECT least(a, 2) as l from testData2 order by l"), Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2)) ) } @@ -233,7 +231,7 @@ class DataFrameFunctionsSuite extends QueryTest { Row(3) ) checkAnswer( - ctx.sql("SELECT greatest(a, 2) as g from testData2 order by g"), + sql("SELECT greatest(a, 2) as g from testData2 order by g"), Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3)) ) } http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index fbb3070..e5d7d63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql -class DataFrameImplicitsSuite extends QueryTest { +import org.apache.spark.sql.test.SharedSQLContext - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("RDD of tuples") { checkAnswer( http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index e1c6c70..e2716d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameJoinSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameJoinSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("join - join using") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") @@ -59,7 +57,7 @@ class DataFrameJoinSuite extends QueryTest { checkAnswer( df1.join(df2, $"df1.key" === $"df2.key"), - ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") + sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") .collect().toSeq) } http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index dbe3b44..cdaa14a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql import scala.collection.JavaConversions._ +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameNaFunctionsSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ def createDF(): DataFrame = { Seq[(String, java.lang.Integer, java.lang.Double)]( http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 8f5984e..28bdd6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -19,20 +19,17 @@ package org.apache.spark.sql import java.util.Random -import org.scalatest.Matchers._ - import org.apache.spark.sql.functions.col +import org.apache.spark.sql.test.SharedSQLContext -class DataFrameStatSuite extends QueryTest { - - private val sqlCtx = org.apache.spark.sql.test.TestSQLContext - import sqlCtx.implicits._ +class DataFrameStatSuite extends QueryTest with SharedSQLContext { + import testImplicits._ private def toLetter(i: Int): String = (i + 97).toChar.toString test("sample with replacement") { val n = 100 - val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = true, 0.05, seed = 13), Seq(5, 10, 52, 73).map(Row(_)) @@ -41,7 +38,7 @@ class DataFrameStatSuite extends QueryTest { test("sample without replacement") { val n = 100 - val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = false, 0.05, seed = 13), Seq(16, 23, 88, 100).map(Row(_)) @@ -50,7 +47,7 @@ class DataFrameStatSuite extends QueryTest { test("randomSplit") { val n = 600 - val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") for (seed <- 1 to 5) { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") @@ -167,7 +164,7 @@ class DataFrameStatSuite extends QueryTest { } test("Frequent Items 2") { - val rows = sqlCtx.sparkContext.parallelize(Seq.empty[Int], 4) + val rows = ctx.sparkContext.parallelize(Seq.empty[Int], 4) // this is a regression test, where when merging partitions, we omitted values with higher // counts than those that existed in the map when the map was full. This test should also fail // if anything like SPARK-9614 is observed once again @@ -185,7 +182,7 @@ class DataFrameStatSuite extends QueryTest { } test("sampleBy") { - val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key")) + val df = ctx.range(0, 100).select((col("id") % 3).as("key")) val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) checkAnswer( sampled.groupBy("key").count().orderBy("key"), http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2feec29..10bfa9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -23,18 +23,12 @@ import scala.language.postfixOps import scala.util.Random import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation -import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.functions._ -import org.apache.spark.sql.execution.datasources.json.JSONRelation -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SharedSQLContext} -class DataFrameSuite extends QueryTest with SQLTestUtils { - import org.apache.spark.sql.TestData._ - - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ +class DataFrameSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("analysis error should be eagerly reported") { // Eager analysis. http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index bf8ef9a..77907e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** @@ -27,10 +27,8 @@ import org.apache.spark.sql.types._ * This is here for now so I can make sure Tungsten project is tested without refactoring existing * end-to-end test infra. In the long run this should just go away. */ -class DataFrameTungstenSuite extends QueryTest with SQLTestUtils { - - override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ +class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("test simple types") { withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 17897ca..9080c53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -22,19 +22,18 @@ import java.text.SimpleDateFormat import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.unsafe.types.CalendarInterval -class DateFunctionsSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - - import ctx.implicits._ +class DateFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("function current_date") { val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0)) val d2 = DateTimeUtils.fromJavaDate( - ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) + sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis()) assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) } @@ -44,9 +43,9 @@ class DateFunctionsSuite extends QueryTest { val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) // Execution in one query should return the same value - checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), + checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), Row(true)) - assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( + assert(math.abs(sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( 0).getTime - System.currentTimeMillis()) < 5000) } http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index ae07eaf..f5c5046 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,22 +17,15 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext -class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { - // Ensures tables are loaded. - TestData +class JoinSuite extends QueryTest with SharedSQLContext { + import testImplicits._ - override def sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext - lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - import ctx.logicalPlanToSparkQuery + setupTestData() test("equi-join is hash-join") { val x = testData2.as("x") @@ -43,7 +36,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { } def assertJoin(sqlString: String, c: Class[_]): Any = { - val df = ctx.sql(sqlString) + val df = sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j @@ -126,7 +119,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { test("broadcasted hash join operator selection") { ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") + sql("CACHE TABLE testData") for (sortMergeJoinEnabled <- Seq(true, false)) { withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") { withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") { @@ -141,12 +134,12 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { } } } - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("broadcasted hash outer join operator selection") { ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") + sql("CACHE TABLE testData") withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", @@ -167,7 +160,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { classOf[BroadcastHashOuterJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("multiple-key equi-join is hash-join") { @@ -279,7 +272,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { // Make sure we are choosing left.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.N, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -293,7 +286,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { Row(6, 1) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT r.a, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -339,7 +332,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { // Make sure we are choosing right.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.a, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -348,7 +341,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { Row(null, 6)) checkAnswer( - ctx.sql( + sql( """ |SELECT r.N, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -400,7 +393,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. checkAnswer( - ctx.sql( + sql( """ |SELECT l.a, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -409,7 +402,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { Row(null, 10)) checkAnswer( - ctx.sql( + sql( """ |SELECT r.N, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -424,7 +417,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT l.N, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -439,7 +432,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - ctx.sql( + sql( """ |SELECT r.a, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -450,7 +443,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { test("broadcasted left semi join operator selection") { ctx.cacheManager.clearCache() - ctx.sql("CACHE TABLE testData") + sql("CACHE TABLE testData") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { Seq( @@ -469,11 +462,11 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach { } } - ctx.sql("UNCACHE TABLE testData") + sql("UNCACHE TABLE testData") } test("left semi join") { - val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(df, Row(1, 1) :: Row(1, 2) :: http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 71c26a6..045fea8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql -class JsonFunctionsSuite extends QueryTest { +import org.apache.spark.sql.test.SharedSQLContext - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class JsonFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("function get_json_object") { val df: DataFrame = Seq(("""{"name": "alice", "age": 5}""", "")).toDF("a", "b") http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 2089660..babf883 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} -class ListTablesSuite extends QueryTest with BeforeAndAfter { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { + import testImplicits._ private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") @@ -42,7 +41,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { Row("ListTablesSuiteTable", true)) checkAnswer( - ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), + sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) @@ -55,7 +54,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { Row("ListTablesSuiteTable", true)) checkAnswer( - ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) @@ -67,13 +66,13 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach { + Seq(ctx.tables(), sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) tableDF.registerTempTable("tables") checkAnswer( - ctx.sql( + sql( "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), Row(true, "ListTablesSuiteTable") ) http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 8cf2ef5..30289c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -19,18 +19,16 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions.{log => logarithm} +import org.apache.spark.sql.test.SharedSQLContext private object MathExpressionsTestData { case class DoubleData(a: java.lang.Double, b: java.lang.Double) case class NullDoubles(a: java.lang.Double) } -class MathExpressionsSuite extends QueryTest { - +class MathExpressionsSuite extends QueryTest with SharedSQLContext { import MathExpressionsTestData._ - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ + import testImplicits._ private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF() @@ -149,7 +147,7 @@ class MathExpressionsSuite extends QueryTest { test("toDegrees") { testOneToOneMathFunction(toDegrees, math.toDegrees) checkAnswer( - ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"), + sql("SELECT degrees(0), degrees(1), degrees(1.5)"), Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5))) ) } @@ -157,7 +155,7 @@ class MathExpressionsSuite extends QueryTest { test("toRadians") { testOneToOneMathFunction(toRadians, math.toRadians) checkAnswer( - ctx.sql("SELECT radians(0), radians(1), radians(1.5)"), + sql("SELECT radians(0), radians(1), radians(1.5)"), Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5))) ) } @@ -169,7 +167,7 @@ class MathExpressionsSuite extends QueryTest { test("ceil and ceiling") { testOneToOneMathFunction(ceil, math.ceil) checkAnswer( - ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), + sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), Row(0.0, 1.0, 2.0)) } @@ -214,7 +212,7 @@ class MathExpressionsSuite extends QueryTest { val pi = 3.1415 checkAnswer( - ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + + sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) @@ -233,7 +231,7 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction[Double](signum, math.signum) checkAnswer( - ctx.sql("SELECT sign(10), signum(-11)"), + sql("SELECT sign(10), signum(-11)"), Row(1, -1)) } @@ -241,7 +239,7 @@ class MathExpressionsSuite extends QueryTest { testTwoToOneMathFunction(pow, pow, math.pow) checkAnswer( - ctx.sql("SELECT pow(1, 2), power(2, 1)"), + sql("SELECT pow(1, 2), power(2, 1)"), Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1))) ) } @@ -280,7 +278,7 @@ class MathExpressionsSuite extends QueryTest { test("log / ln") { testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log) checkAnswer( - ctx.sql("SELECT ln(0), ln(1), ln(1.5)"), + sql("SELECT ln(0), ln(1), ln(1.5)"), Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5))) ) } @@ -375,7 +373,7 @@ class MathExpressionsSuite extends QueryTest { df.select(log2("b") + log2("a")), Row(1)) - checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) + checkAnswer(sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) } test("sqrt") { @@ -384,13 +382,13 @@ class MathExpressionsSuite extends QueryTest { df.select(sqrt("a"), sqrt("b")), Row(1.0, 2.0)) - checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) + checkAnswer(sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null)) } test("negative") { checkAnswer( - ctx.sql("SELECT negative(1), negative(0), negative(-1)"), + sql("SELECT negative(1), negative(0), negative(-1)"), Row(-1, 0, 1)) } http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 98ba3c9..4adcefb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -71,12 +71,6 @@ class QueryTest extends PlanTest { checkAnswer(df, expectedAnswer.collect()) } - def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) { - test(sqlString) { - checkAnswer(sqlContext.sql(sqlString), expectedAnswer) - } - } - /** * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. */ http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 8a679c7..795d4e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -20,13 +20,12 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class RowSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class RowSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ test("create row") { val expected = new GenericMutableRow(4) http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 75791e9..7699ada 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql +import org.apache.spark.sql.test.SharedSQLContext -class SQLConfSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SQLConfSuite extends QueryTest with SharedSQLContext { private val testKey = "test.key.0" private val testVal = "test.val.0" @@ -52,21 +51,21 @@ class SQLConfSuite extends QueryTest { test("parse SQL set commands") { ctx.conf.clear() - ctx.sql(s"set $testKey=$testVal") + sql(s"set $testKey=$testVal") assert(ctx.getConf(testKey, testVal + "_") === testVal) assert(ctx.getConf(testKey, testVal + "_") === testVal) - ctx.sql("set some.property=20") + sql("set some.property=20") assert(ctx.getConf("some.property", "0") === "20") - ctx.sql("set some.property = 40") + sql("set some.property = 40") assert(ctx.getConf("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" - ctx.sql(s"set $key=$vs") + sql(s"set $key=$vs") assert(ctx.getConf(key, "0") === vs) - ctx.sql(s"set $key=") + sql(s"set $key=") assert(ctx.getConf(key, "0") === "") ctx.conf.clear() @@ -74,14 +73,14 @@ class SQLConfSuite extends QueryTest { test("deprecated property") { ctx.conf.clear() - ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") assert(ctx.conf.numShufflePartitions === 10) } test("invalid conf value") { ctx.conf.clear() val e = intercept[IllegalArgumentException] { - ctx.sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") + sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") } assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10") } http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index c8d8796..007be12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext -class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SQLContextSuite extends SparkFunSuite with SharedSQLContext { override def afterAll(): Unit = { - SQLContext.setLastInstantiatedContext(ctx) + try { + SQLContext.setLastInstantiatedContext(ctx) + } finally { + super.afterAll() + } } test("getOrCreate instantiates SQLContext") { http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b14ef9b..8c2c328 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,28 +19,23 @@ package org.apache.spark.sql import java.sql.Timestamp -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ class MyDialect extends DefaultParserDialect -class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { - // Make sure the tables are loaded. - TestData +class SQLQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ - val sqlContext = org.apache.spark.sql.test.TestSQLContext - import sqlContext.implicits._ - import sqlContext.sql + setupTestData() test("having clause") { Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav") @@ -60,7 +55,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("show functions") { - checkAnswer(sql("SHOW functions"), FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) + checkAnswer(sql("SHOW functions"), + FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) } test("describe functions") { @@ -178,7 +174,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index") // we except the id is materialized once - val idUDF = udf(() => UUID.randomUUID().toString) + val idUDF = org.apache.spark.sql.functions.udf(() => UUID.randomUUID().toString) val dfWithId = df.withColumn("id", idUDF()) // Make a new DataFrame (actually the same reference to the old one) @@ -712,9 +708,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer( sql( - """ - |SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3 - """.stripMargin), + "SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"), Row(2, 1, 2, 2, 1)) } @@ -1161,7 +1155,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { validateMetadata(sql("SELECT * FROM personWithMeta")) validateMetadata(sql("SELECT id, name FROM personWithMeta")) validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) - validateMetadata(sql("SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) + validateMetadata(sql( + "SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) } test("SPARK-3371 Renaming a function expression with group by gives error") { @@ -1627,7 +1622,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { .toDF("num", "str") df.registerTempTable("1one") - checkAnswer(sqlContext.sql("select count(num) from 1one"), Row(10)) + checkAnswer(sql("select count(num) from 1one"), Row(10)) sqlContext.dropTempTable("1one") } http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index ab6d3dd..295f02f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSQLContext case class ReflectData( stringField: String, @@ -71,17 +72,15 @@ case class ComplexReflectData( mapFieldContainsNull: Map[Int, Option[Long]], dataField: Data) -class ScalaReflectionRelationSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { + import testImplicits._ test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3)) Seq(data).toDF().registerTempTable("reflectData") - assert(ctx.sql("SELECT * FROM reflectData").collect().head === + assert(sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))) @@ -91,7 +90,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { val data = NullReflectData(null, null, null, null, null, null, null) Seq(data).toDF().registerTempTable("reflectNullData") - assert(ctx.sql("SELECT * FROM reflectNullData").collect().head === + assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -99,7 +98,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { val data = OptionalReflectData(None, None, None, None, None, None, None) Seq(data).toDF().registerTempTable("reflectOptionalData") - assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head === + assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } @@ -107,7 +106,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { test("query binary data") { Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary") - val result = ctx.sql("SELECT data FROM reflectBinary") + val result = sql("SELECT data FROM reflectBinary") .collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) } @@ -126,7 +125,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { Nested(None, "abc"))) Seq(data).toDF().registerTempTable("reflectComplexData") - assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head === + assert(sql("SELECT * FROM reflectComplexData").collect().head === Row( Seq(1, 2, 3), Seq(1, 2, null), http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index e55c9e4..45d0ee4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.sql.test.SharedSQLContext -class SerializationSuite extends SparkFunSuite { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext +class SerializationSuite extends SparkFunSuite with SharedSQLContext { test("[SPARK-5235] SQLContext should be serializable") { - val sqlContext = new SQLContext(ctx.sparkContext) - new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext) + val _sqlContext = new SQLContext(sqlContext.sparkContext) + new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext) } } http://git-wip-us.apache.org/repos/asf/spark/blob/9df2a2d7/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index ca298b2..cc95eed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.Decimal -class StringFunctionsSuite extends QueryTest { - - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ +class StringFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("string concat") { val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c") --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
