Repository: spark Updated Branches: refs/heads/master d492cc5a2 -> 29b1f6b09
[SPARK-21256][SQL] Add withSQLConf to Catalyst Test ### What changes were proposed in this pull request? SQLConf is moved to Catalyst. We are adding more and more test cases for verifying the conf-specific behaviors. It is nice to add a helper function to simplify the test cases. ### How was this patch tested? N/A Author: gatorsmile <[email protected]> Closes #18469 from gatorsmile/withSQLConf. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/29b1f6b0 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/29b1f6b0 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/29b1f6b0 Branch: refs/heads/master Commit: 29b1f6b09f98e216af71e893a9da0c4717c80679 Parents: d492cc5 Author: gatorsmile <[email protected]> Authored: Tue Jul 4 08:54:07 2017 -0700 Committer: gatorsmile <[email protected]> Committed: Tue Jul 4 08:54:07 2017 -0700 ---------------------------------------------------------------------- .../InferFiltersFromConstraintsSuite.scala | 5 +-- .../optimizer/OuterJoinEliminationSuite.scala | 6 +--- .../catalyst/optimizer/PruneFiltersSuite.scala | 6 +--- .../plans/ConstraintPropagationSuite.scala | 24 +++++++-------- .../spark/sql/catalyst/plans/PlanTest.scala | 32 +++++++++++++++++++- .../AggregateEstimationSuite.scala | 9 ++---- .../BasicStatsEstimationSuite.scala | 12 +++----- .../spark/sql/SparkSessionBuilderSuite.scala | 3 ++ .../apache/spark/sql/test/SQLTestUtils.scala | 30 +++++------------- 9 files changed, 64 insertions(+), 63 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/29b1f6b0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index cdc9f25..d2dd469 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -206,13 +206,10 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } test("No inferred filter when constraint propagation is disabled") { - try { - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, originalQuery) - } finally { - SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) } } } http://git-wip-us.apache.org/repos/asf/spark/blob/29b1f6b0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index 623ff3d..893c111 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -234,9 +234,7 @@ class OuterJoinEliminationSuite extends PlanTest { } test("no outer join elimination if constraint propagation is disabled") { - try { - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) - + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { val x = testRelation.subquery('x) val y = testRelation1.subquery('y) @@ -251,8 +249,6 @@ class OuterJoinEliminationSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, originalQuery.analyze) - } finally { - SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) } } } http://git-wip-us.apache.org/repos/asf/spark/blob/29b1f6b0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index 706634c..6d1a05f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED class PruneFiltersSuite extends PlanTest { @@ -149,8 +148,7 @@ class PruneFiltersSuite extends PlanTest { ("tr1.a".attr > 10 || "tr1.c".attr < 10) && 'd.attr < 100) - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) - try { + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { val optimized = Optimize.execute(queryWithUselessFilter.analyze) // When constraint propagation is disabled, the useless filter won't be pruned. // It gets pushed down. Because the rule `CombineFilters` runs only once, there are redundant @@ -160,8 +158,6 @@ class PruneFiltersSuite extends PlanTest { .join(tr2.where('d.attr < 100).where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)).analyze comparePlans(optimized, correctAnswer) - } finally { - SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) } } } http://git-wip-us.apache.org/repos/asf/spark/blob/29b1f6b0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index a3948d9..a37e06d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType} -class ConstraintPropagationSuite extends SparkFunSuite { +class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { private def resolveColumn(tr: LocalRelation, columnName: String): Expression = resolveColumn(tr.analyze, columnName) @@ -400,26 +400,26 @@ class ConstraintPropagationSuite extends SparkFunSuite { } test("enable/disable constraint propagation") { - try { - val tr = LocalRelation('a.int, 'b.string, 'c.int) - val filterRelation = tr.where('a.attr > 10) + val tr = LocalRelation('a.int, 'b.string, 'c.int) + val filterRelation = tr.where('a.attr > 10) - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true) + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "true") { assert(filterRelation.analyze.constraints.nonEmpty) + } - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { assert(filterRelation.analyze.constraints.isEmpty) + } - val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) - .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3) + val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) + .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3) - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true) + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "true") { assert(aliasedRelation.analyze.constraints.nonEmpty) + } - SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false) + withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { assert(aliasedRelation.analyze.constraints.isEmpty) - } finally { - SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED) } } } http://git-wip-us.apache.org/repos/asf/spark/blob/29b1f6b0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 6883d23..e9679d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -28,8 +29,9 @@ import org.apache.spark.sql.internal.SQLConf /** * Provides helper methods for comparing plans. */ -abstract class PlanTest extends SparkFunSuite with PredicateHelper { +trait PlanTest extends SparkFunSuite with PredicateHelper { + // TODO(gatorsmile): remove this from PlanTest and all the analyzer/optimizer rules protected val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true) /** @@ -142,4 +144,32 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { plan1 == plan2 } } + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL + * configurations. + */ + protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val conf = SQLConf.get + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (conf.contains(key)) { + Some(conf.getConfString(key)) + } else { + None + } + } + (keys, values).zipped.foreach { (k, v) => + if (SQLConf.staticConfKeys.contains(k)) { + throw new AnalysisException(s"Cannot modify the value of a static config: $k") + } + conf.setConfString(k, v) + } + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.setConfString(key, value) + case (key, None) => conf.unsetConf(key) + } + } + } } http://git-wip-us.apache.org/repos/asf/spark/blob/29b1f6b0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index 30ddf03..23f95a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.catalyst.statsEstimation import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.internal.SQLConf -class AggregateEstimationSuite extends StatsEstimationTestBase { +class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest { /** Columns for testing */ private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( @@ -100,9 +101,7 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { size = Some(4 * (8 + 4)), attributeStats = AttributeMap(Seq("key12").map(nameToColInfo))) - val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED) - try { - SQLConf.get.setConf(SQLConf.CBO_ENABLED, false) + withSQLConf(SQLConf.CBO_ENABLED.key -> "false") { val noGroupAgg = Aggregate(groupingExpressions = Nil, aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child) assert(noGroupAgg.stats == @@ -114,8 +113,6 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { assert(hasGroupAgg.stats == // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4))) - } finally { - SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue) } } http://git-wip-us.apache.org/repos/asf/spark/blob/29b1f6b0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 31a8cbd..5fd21a0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.catalyst.statsEstimation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType -class BasicStatsEstimationSuite extends StatsEstimationTestBase { +class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { val attribute = attr("key") val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4) @@ -82,18 +83,15 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { plan: LogicalPlan, expectedStatsCboOn: Statistics, expectedStatsCboOff: Statistics): Unit = { - val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED) - try { + withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { // Invalidate statistics plan.invalidateStatsCache() - SQLConf.get.setConf(SQLConf.CBO_ENABLED, true) assert(plan.stats == expectedStatsCboOn) + } + withSQLConf(SQLConf.CBO_ENABLED.key -> "false") { plan.invalidateStatsCache() - SQLConf.get.setConf(SQLConf.CBO_ENABLED, false) assert(plan.stats == expectedStatsCboOff) - } finally { - SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue) } } http://git-wip-us.apache.org/repos/asf/spark/blob/29b1f6b0/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index 4f6d5f7..cdac682 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.internal.SQLConf /** * Test cases for the builder pattern of [[SparkSession]]. @@ -67,6 +68,8 @@ class SparkSessionBuilderSuite extends SparkFunSuite { assert(activeSession != defaultSession) assert(session == activeSession) assert(session.conf.get("spark-config2") == "a") + assert(session.sessionState.conf == SQLConf.get) + assert(SQLConf.get.getConfString("spark-config2") == "a") SparkSession.clearActiveSession() assert(SparkSession.builder().getOrCreate() == defaultSession) http://git-wip-us.apache.org/repos/asf/spark/blob/29b1f6b0/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index d74a7cc..92ee7d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -35,9 +35,11 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.FilterExec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{UninterruptibleThread, Utils} /** @@ -53,7 +55,8 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} private[sql] trait SQLTestUtils extends SparkFunSuite with Eventually with BeforeAndAfterAll - with SQLTestData { self => + with SQLTestData + with PlanTest { self => protected def sparkContext = spark.sparkContext @@ -89,28 +92,9 @@ private[sql] trait SQLTestUtils } } - /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL - * configurations. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map { key => - if (spark.conf.contains(key)) { - Some(spark.conf.get(key)) - } else { - None - } - } - (keys, values).zipped.foreach(spark.conf.set) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) - } - } + protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + SparkSession.setActiveSession(spark) + super.withSQLConf(pairs: _*)(f) } /** --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
