This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e804f5360144 [SPARK-46179][SQL] Pull out code into reusable functions
in SQLQueryTestSuite
e804f5360144 is described below
commit e804f53601444df71c6df8bd6237cc350bfec076
Author: Andy Lam <[email protected]>
AuthorDate: Thu Dec 21 11:00:29 2023 +0800
[SPARK-46179][SQL] Pull out code into reusable functions in
SQLQueryTestSuite
### What changes were proposed in this pull request?
### Why are the changes needed?
As a prelude to https://github.com/apache/spark/pull/44084, in this PR, I
refactored SQLQueryTestSuite by pulling out code into functions for reuse in
subclasses.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Simple refactor, no testing.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #44405 from andylam-db/crossdbms-pre.
Authored-by: Andy Lam <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../org/apache/spark/sql/SQLQueryTestHelper.scala | 30 ++--
.../org/apache/spark/sql/SQLQueryTestSuite.scala | 190 ++++++++++++++++-----
.../thriftserver/ThriftServerQueryTestSuite.scala | 2 +-
3 files changed, 163 insertions(+), 59 deletions(-)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala
index d8956961440d..c08569150e2a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala
@@ -98,20 +98,24 @@ trait SQLQueryTestHelper extends Logging {
}
}
+ /**
+ * Uses the Spark logical plan to determine whether the plan is semantically
sorted. This is
+ * important to make non-sorted queries test cases more deterministic.
+ */
+ protected def isSemanticallySorted(plan: LogicalPlan): Boolean = plan match {
+ case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct =>
false
+ case _: DescribeCommandBase
+ | _: DescribeColumnCommand
+ | _: DescribeRelation
+ | _: DescribeColumn => true
+ case PhysicalOperation(_, _, Sort(_, true, _)) => true
+ case _ => plan.children.iterator.exists(isSemanticallySorted)
+ }
+
/** Executes a query and returns the result as (schema of the output,
normalized output). */
protected def getNormalizedQueryExecutionResult(
session: SparkSession, sql: String): (String, Seq[String]) = {
// Returns true if the plan is supposed to be sorted.
- def isSorted(plan: LogicalPlan): Boolean = plan match {
- case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct =>
false
- case _: DescribeCommandBase
- | _: DescribeColumnCommand
- | _: DescribeRelation
- | _: DescribeColumn => true
- case PhysicalOperation(_, _, Sort(_, true, _)) => true
- case _ => plan.children.iterator.exists(isSorted)
- }
-
val df = session.sql(sql)
val schema = df.schema.catalogString
// Get answer, but also get rid of the #1234 expression ids that show up
in explain plans
@@ -120,7 +124,11 @@ trait SQLQueryTestHelper extends Logging {
}
// If the output is not pre-sorted, sort it.
- if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema,
answer.sorted)
+ if (isSemanticallySorted(df.queryExecution.analyzed)) {
+ (schema, answer)
+ } else {
+ (schema, answer.sorted)
+ }
}
/**
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
index 032964766792..9a78b7f52b74 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
@@ -349,20 +349,16 @@ class SQLQueryTestSuite extends QueryTest with
SharedSparkSession with SQLHelper
}
}
- /** Run a test case. */
- protected def runSqlTestCase(testCase: TestCase, listTestCases:
Seq[TestCase]): Unit = {
- def splitWithSemicolon(seq: Seq[String]) = {
- seq.mkString("\n").split("(?<=[^\\\\]);")
- }
-
- def splitCommentsAndCodes(input: String) = input.split("\n").partition {
line =>
+ protected def splitCommentsAndCodes(input: String) =
+ input.split("\n").partition { line =>
val newLine = line.trim
newLine.startsWith("--") && !newLine.startsWith("--QUERY-DELIMITER")
}
- val input = fileToString(new File(testCase.inputFile))
-
- val (comments, code) = splitCommentsAndCodes(input)
+ protected def getQueries(code: Array[String], comments: Array[String]) = {
+ def splitWithSemicolon(seq: Seq[String]) = {
+ seq.mkString("\n").split("(?<=[^\\\\]);")
+ }
// If `--IMPORT` found, load code from another test case file, then insert
them
// into the head in this test.
@@ -406,52 +402,74 @@ class SQLQueryTestSuite extends QueryTest with
SharedSparkSession with SQLHelper
}
// List of SQL queries to run
- val queries = tempQueries.map(_.trim).filter(_ != "").toSeq
+ tempQueries.map(_.trim).filter(_ != "")
// Fix misplacement when comment is at the end of the query.
.map(_.split("\n").filterNot(_.startsWith("--")).mkString("\n")).map(_.trim).filter(_
!= "")
+ }
+ protected def getSparkSettings(comments: Array[String]): Array[(String,
String)] = {
val settingLines = comments.filter(_.startsWith("--SET
")).map(_.substring(6))
- val settings = settingLines.flatMap(_.split(",").map { kv =>
+ settingLines.flatMap(_.split(",").map { kv =>
val (conf, value) = kv.span(_ != '=')
conf.trim -> value.substring(1).trim
})
+ }
- if (regenerateGoldenFiles) {
- runQueries(queries, testCase, settings.toImmutableArraySeq)
- } else {
- // A config dimension has multiple config sets, and a config set has
multiple configs.
- // - config dim: Seq[Seq[(String, String)]]
- // - config set: Seq[(String, String)]
- // - config: (String, String))
- // We need to do cartesian product for all the config dimensions, to get
a list of
- // config sets, and run the query once for each config set.
- val configDimLines =
comments.filter(_.startsWith("--CONFIG_DIM")).map(_.substring(12))
- val configDims = configDimLines.groupBy(_.takeWhile(_ != ' ')).transform
{ (_, lines) =>
- lines.map(_.dropWhile(_ != ' ').substring(1)).map(_.split(",").map {
kv =>
- val (conf, value) = kv.span(_ != '=')
- conf.trim -> value.substring(1).trim
- }.toSeq).toSeq
- }
+ protected def getSparkConfigDimensions(comments: Array[String]):
Seq[Seq[(String, String)]] = {
+ // A config dimension has multiple config sets, and a config set has
multiple configs.
+ // - config dim: Seq[Seq[(String, String)]]
+ // - config set: Seq[(String, String)]
+ // - config: (String, String))
+ // We need to do cartesian product for all the config dimensions, to get a
list of
+ // config sets, and run the query once for each config set.
+ val configDimLines =
comments.filter(_.startsWith("--CONFIG_DIM")).map(_.substring(12))
+ val configDims = configDimLines.groupBy(_.takeWhile(_ != '
')).view.mapValues { lines =>
+ lines.map(_.dropWhile(_ != ' ').substring(1)).map(_.split(",").map { kv
=>
+ val (conf, value) = kv.span(_ != '=')
+ conf.trim -> value.substring(1).trim
+ }.toSeq).toSeq
+ }
- val configSets = configDims.values.foldLeft(Seq(Seq[(String,
String)]())) { (res, dim) =>
- dim.flatMap { configSet => res.map(_ ++ configSet) }
- }
+ configDims.values.foldLeft(Seq(Seq[(String, String)]())) { (res, dim) =>
+ dim.flatMap { configSet => res.map(_ ++ configSet) }
+ }
+ }
- configSets.foreach { configSet =>
- try {
- runQueries(queries, testCase, (settings ++
configSet).toImmutableArraySeq)
- } catch {
- case e: Throwable =>
- val configs = configSet.map {
- case (k, v) => s"$k=$v"
- }
- logError(s"Error using configs: ${configs.mkString(",")}")
- throw e
- }
+ protected def runQueriesWithSparkConfigDimensions(
+ queries: Seq[String],
+ testCase: TestCase,
+ sparkConfigSet: Array[(String, String)],
+ sparkConfigDims: Seq[Seq[(String, String)]]): Unit = {
+ sparkConfigDims.foreach { configDim =>
+ try {
+ runQueries(queries, testCase, (sparkConfigSet ++
configDim).toImmutableArraySeq)
+ } catch {
+ case e: Throwable =>
+ val configs = configDim.map {
+ case (k, v) => s"$k=$v"
+ }
+ logError(s"Error using configs: ${configs.mkString(",")}")
+ throw e
}
}
}
+ /** Run a test case. */
+ protected def runSqlTestCase(testCase: TestCase, listTestCases:
Seq[TestCase]): Unit = {
+ val input = fileToString(new File(testCase.inputFile))
+ val (comments, code) = splitCommentsAndCodes(input)
+ val queries = getQueries(code, comments)
+ val settings = getSparkSettings(comments)
+
+ if (regenerateGoldenFiles) {
+ runQueries(queries, testCase, settings.toImmutableArraySeq)
+ } else {
+ val configSets = getSparkConfigDimensions(comments)
+ runQueriesWithSparkConfigDimensions(
+ queries, testCase, settings, configSets)
+ }
+ }
+
def hasNoDuplicateColumns(schema: String): Boolean = {
val columnAndTypes = schema.replaceFirst("^struct<",
"").stripSuffix(">").split(",")
columnAndTypes.size == columnAndTypes.distinct.length
@@ -498,7 +516,7 @@ class SQLQueryTestSuite extends QueryTest with
SharedSparkSession with SQLHelper
protected def runQueries(
queries: Seq[String],
testCase: TestCase,
- configSet: Seq[(String, String)]): Unit = {
+ sparkConfigSet: Seq[(String, String)]): Unit = {
// Create a local SparkSession to have stronger isolation between
different test cases.
// This does not isolate catalog changes.
val localSparkSession = spark.newSession()
@@ -529,9 +547,9 @@ class SQLQueryTestSuite extends QueryTest with
SharedSparkSession with SQLHelper
localSparkSession.conf.set(SQLConf.ANSI_ENABLED.key, false)
}
- if (configSet.nonEmpty) {
+ if (sparkConfigSet.nonEmpty) {
// Execute the list of set operation in order to add the desired configs
- val setOperations = configSet.map { case (key, value) => s"set
$key=$value" }
+ val setOperations = sparkConfigSet.map { case (key, value) => s"set
$key=$value" }
logInfo(s"Setting configs: ${setOperations.mkString(", ")}")
setOperations.foreach(localSparkSession.sql)
}
@@ -612,9 +630,18 @@ class SQLQueryTestSuite extends QueryTest with
SharedSparkSession with SQLHelper
}
}
+ /**
+ * Returns the desired file path for results, given the input file. This is
implemented as a
+ * function because differente Suites extending this class may want their
results files with
+ * different names or in different locations.
+ */
+ protected def resultFileForInputFile(file: File): String = {
+ file.getAbsolutePath.replace(inputFilePath, goldenFilePath) + ".out"
+ }
+
protected lazy val listTestCases: Seq[TestCase] = {
listFilesRecursively(new File(inputFilePath)).flatMap { file =>
- var resultFile = file.getAbsolutePath.replace(inputFilePath,
goldenFilePath) + ".out"
+ var resultFile = resultFileForInputFile(file)
var analyzerResultFile =
file.getAbsolutePath.replace(inputFilePath, analyzerGoldenFilePath) +
".out"
// JDK-4511638 changes 'toString' result of Float/Double
@@ -625,7 +652,6 @@ class SQLQueryTestSuite extends QueryTest with
SharedSparkSession with SQLHelper
}
val absPath = file.getAbsolutePath
val testCaseName =
absPath.stripPrefix(inputFilePath).stripPrefix(File.separator)
- val analyzerTestCaseName = s"${testCaseName}_analyzer_test"
// Create test cases of test types that depend on the input filename.
val newTestCases: Seq[TestCase] = if (file.getAbsolutePath.startsWith(
@@ -934,4 +960,74 @@ class SQLQueryTestSuite extends QueryTest with
SharedSparkSession with SQLHelper
}
override def numSegments: Int = 2
}
+
+ test("test splitCommentsAndCodes") {
+ {
+ // Correctly split comments and codes
+ val input =
+ """-- Comment 1
+ |SELECT * FROM table1;
+ |-- Comment 2
+ |SELECT * FROM table2;
+ |""".stripMargin
+
+ val (comments, codes) = splitCommentsAndCodes(input)
+ assert(comments.toSet == Set("-- Comment 1", "-- Comment 2"))
+ assert(codes.toSet == Set("SELECT * FROM table1;", "SELECT * FROM
table2;"))
+ }
+
+ {
+ // Handle input with no comments
+ val input = "SELECT * FROM table;"
+ val (comments, codes) = splitCommentsAndCodes(input)
+ assert(comments.isEmpty)
+ assert(codes.toSet == Set("SELECT * FROM table;"))
+ }
+
+ {
+ // Handle input with no codes
+ val input =
+ """-- Comment 1
+ |-- Comment 2
+ |""".stripMargin
+
+ val (comments, codes) = splitCommentsAndCodes(input)
+ assert(comments.toSet == Set("-- Comment 1", "-- Comment 2"))
+ assert(codes.isEmpty)
+ }
+ }
+
+ test("Test logic for determining whether a query is semantically sorted") {
+ withTable("t1", "t2") {
+ spark.sql("CREATE TABLE t1(a int, b int) USING parquet")
+ spark.sql("CREATE TABLE t2(a int, b int) USING parquet")
+
+ val unsortedSelectQuery = "select * from t1"
+ val sortedSelectQuery = "select * from t1 order by a, b"
+
+ val unsortedJoinQuery = "select * from t1 join t2 on t1.a = t2.a"
+ val sortedJoinQuery = "select * from t1 join t2 on t1.a = t2.a order by
t1.a"
+
+ val unsortedAggQuery = "select a, max(b) from t1 group by a"
+ val sortedAggQuery = "select a, max(b) from t1 group by a order by a"
+
+ val unsortedDistinctQuery = "select distinct a from t1"
+ val sortedDistinctQuery = "select distinct a from t1 order by a"
+
+ val unsortedWindowQuery = "SELECT a, b, SUM(b) OVER (ORDER BY a) AS
cumulative_sum FROM t1;"
+ val sortedWindowQuery = "SELECT a, b, SUM(b) OVER (ORDER BY a) AS
cumulative_sum FROM " +
+ "t1 ORDER BY a, b;"
+
+ assert(!isSemanticallySorted(spark.sql(unsortedSelectQuery).logicalPlan))
+ assert(!isSemanticallySorted(spark.sql(unsortedJoinQuery).logicalPlan))
+
assert(!isSemanticallySorted(spark.sql(unsortedDistinctQuery).logicalPlan))
+ assert(!isSemanticallySorted(spark.sql(unsortedWindowQuery).logicalPlan))
+ assert(!isSemanticallySorted(spark.sql(unsortedAggQuery).logicalPlan))
+ assert(isSemanticallySorted(spark.sql(sortedSelectQuery).logicalPlan))
+ assert(isSemanticallySorted(spark.sql(sortedJoinQuery).logicalPlan))
+ assert(isSemanticallySorted(spark.sql(sortedAggQuery).logicalPlan))
+ assert(isSemanticallySorted(spark.sql(sortedWindowQuery).logicalPlan))
+ assert(isSemanticallySorted(spark.sql(sortedDistinctQuery).logicalPlan))
+ }
+ }
}
diff --git
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala
index ee336b13b0b2..8d90d47e1bf5 100644
---
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala
+++
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala
@@ -246,7 +246,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite
with SharedThriftServ
override lazy val listTestCases: Seq[TestCase] = {
listFilesRecursively(new File(inputFilePath)).flatMap { file =>
- var resultFile = file.getAbsolutePath.replace(inputFilePath,
goldenFilePath) + ".out"
+ var resultFile = resultFileForInputFile(file)
// JDK-4511638 changes 'toString' result of Float/Double
// JDK-8282081 changes DataTimeFormatter 'F' symbol
if (Utils.isJavaVersionAtLeast21 && (new File(resultFile +
".java21")).exists()) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]