This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e804f5360144 [SPARK-46179][SQL] Pull out code into reusable functions 
in SQLQueryTestSuite
e804f5360144 is described below

commit e804f53601444df71c6df8bd6237cc350bfec076
Author: Andy Lam <[email protected]>
AuthorDate: Thu Dec 21 11:00:29 2023 +0800

    [SPARK-46179][SQL] Pull out code into reusable functions in 
SQLQueryTestSuite
    
    ### What changes were proposed in this pull request?
    
    ### Why are the changes needed?
    
    As a prelude to https://github.com/apache/spark/pull/44084, in this PR, I 
refactored SQLQueryTestSuite by pulling out code into functions for reuse in 
subclasses.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Simple refactor, no testing.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #44405 from andylam-db/crossdbms-pre.
    
    Authored-by: Andy Lam <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../org/apache/spark/sql/SQLQueryTestHelper.scala  |  30 ++--
 .../org/apache/spark/sql/SQLQueryTestSuite.scala   | 190 ++++++++++++++++-----
 .../thriftserver/ThriftServerQueryTestSuite.scala  |   2 +-
 3 files changed, 163 insertions(+), 59 deletions(-)

diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala
index d8956961440d..c08569150e2a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala
@@ -98,20 +98,24 @@ trait SQLQueryTestHelper extends Logging {
     }
   }
 
+  /**
+   * Uses the Spark logical plan to determine whether the plan is semantically 
sorted. This is
+   * important to make non-sorted queries test cases more deterministic.
+   */
+  protected def isSemanticallySorted(plan: LogicalPlan): Boolean = plan match {
+    case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => 
false
+    case _: DescribeCommandBase
+         | _: DescribeColumnCommand
+         | _: DescribeRelation
+         | _: DescribeColumn => true
+    case PhysicalOperation(_, _, Sort(_, true, _)) => true
+    case _ => plan.children.iterator.exists(isSemanticallySorted)
+  }
+
   /** Executes a query and returns the result as (schema of the output, 
normalized output). */
   protected def getNormalizedQueryExecutionResult(
       session: SparkSession, sql: String): (String, Seq[String]) = {
     // Returns true if the plan is supposed to be sorted.
-    def isSorted(plan: LogicalPlan): Boolean = plan match {
-      case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => 
false
-      case _: DescribeCommandBase
-          | _: DescribeColumnCommand
-          | _: DescribeRelation
-          | _: DescribeColumn => true
-      case PhysicalOperation(_, _, Sort(_, true, _)) => true
-      case _ => plan.children.iterator.exists(isSorted)
-    }
-
     val df = session.sql(sql)
     val schema = df.schema.catalogString
     // Get answer, but also get rid of the #1234 expression ids that show up 
in explain plans
@@ -120,7 +124,11 @@ trait SQLQueryTestHelper extends Logging {
     }
 
     // If the output is not pre-sorted, sort it.
-    if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, 
answer.sorted)
+    if (isSemanticallySorted(df.queryExecution.analyzed)) {
+      (schema, answer)
+    } else {
+      (schema, answer.sorted)
+    }
   }
 
   /**
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
index 032964766792..9a78b7f52b74 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
@@ -349,20 +349,16 @@ class SQLQueryTestSuite extends QueryTest with 
SharedSparkSession with SQLHelper
     }
   }
 
-  /** Run a test case. */
-  protected def runSqlTestCase(testCase: TestCase, listTestCases: 
Seq[TestCase]): Unit = {
-    def splitWithSemicolon(seq: Seq[String]) = {
-      seq.mkString("\n").split("(?<=[^\\\\]);")
-    }
-
-    def splitCommentsAndCodes(input: String) = input.split("\n").partition { 
line =>
+  protected def splitCommentsAndCodes(input: String) =
+    input.split("\n").partition { line =>
       val newLine = line.trim
       newLine.startsWith("--") && !newLine.startsWith("--QUERY-DELIMITER")
     }
 
-    val input = fileToString(new File(testCase.inputFile))
-
-    val (comments, code) = splitCommentsAndCodes(input)
+  protected def getQueries(code: Array[String], comments: Array[String]) = {
+    def splitWithSemicolon(seq: Seq[String]) = {
+      seq.mkString("\n").split("(?<=[^\\\\]);")
+    }
 
     // If `--IMPORT` found, load code from another test case file, then insert 
them
     // into the head in this test.
@@ -406,52 +402,74 @@ class SQLQueryTestSuite extends QueryTest with 
SharedSparkSession with SQLHelper
     }
 
     // List of SQL queries to run
-    val queries = tempQueries.map(_.trim).filter(_ != "").toSeq
+    tempQueries.map(_.trim).filter(_ != "")
       // Fix misplacement when comment is at the end of the query.
       
.map(_.split("\n").filterNot(_.startsWith("--")).mkString("\n")).map(_.trim).filter(_
 != "")
+  }
 
+  protected def getSparkSettings(comments: Array[String]): Array[(String, 
String)] = {
     val settingLines = comments.filter(_.startsWith("--SET 
")).map(_.substring(6))
-    val settings = settingLines.flatMap(_.split(",").map { kv =>
+    settingLines.flatMap(_.split(",").map { kv =>
       val (conf, value) = kv.span(_ != '=')
       conf.trim -> value.substring(1).trim
     })
+  }
 
-    if (regenerateGoldenFiles) {
-      runQueries(queries, testCase, settings.toImmutableArraySeq)
-    } else {
-      // A config dimension has multiple config sets, and a config set has 
multiple configs.
-      // - config dim:     Seq[Seq[(String, String)]]
-      //   - config set:   Seq[(String, String)]
-      //     - config:     (String, String))
-      // We need to do cartesian product for all the config dimensions, to get 
a list of
-      // config sets, and run the query once for each config set.
-      val configDimLines = 
comments.filter(_.startsWith("--CONFIG_DIM")).map(_.substring(12))
-      val configDims = configDimLines.groupBy(_.takeWhile(_ != ' ')).transform 
{ (_, lines) =>
-        lines.map(_.dropWhile(_ != ' ').substring(1)).map(_.split(",").map { 
kv =>
-          val (conf, value) = kv.span(_ != '=')
-          conf.trim -> value.substring(1).trim
-        }.toSeq).toSeq
-      }
+  protected def getSparkConfigDimensions(comments: Array[String]): 
Seq[Seq[(String, String)]] = {
+    // A config dimension has multiple config sets, and a config set has 
multiple configs.
+    // - config dim:     Seq[Seq[(String, String)]]
+    //   - config set:   Seq[(String, String)]
+    //     - config:     (String, String))
+    // We need to do cartesian product for all the config dimensions, to get a 
list of
+    // config sets, and run the query once for each config set.
+    val configDimLines = 
comments.filter(_.startsWith("--CONFIG_DIM")).map(_.substring(12))
+    val configDims = configDimLines.groupBy(_.takeWhile(_ != ' 
')).view.mapValues { lines =>
+      lines.map(_.dropWhile(_ != ' ').substring(1)).map(_.split(",").map { kv 
=>
+        val (conf, value) = kv.span(_ != '=')
+        conf.trim -> value.substring(1).trim
+      }.toSeq).toSeq
+    }
 
-      val configSets = configDims.values.foldLeft(Seq(Seq[(String, 
String)]())) { (res, dim) =>
-        dim.flatMap { configSet => res.map(_ ++ configSet) }
-      }
+    configDims.values.foldLeft(Seq(Seq[(String, String)]())) { (res, dim) =>
+      dim.flatMap { configSet => res.map(_ ++ configSet) }
+    }
+  }
 
-      configSets.foreach { configSet =>
-        try {
-          runQueries(queries, testCase, (settings ++ 
configSet).toImmutableArraySeq)
-        } catch {
-          case e: Throwable =>
-            val configs = configSet.map {
-              case (k, v) => s"$k=$v"
-            }
-            logError(s"Error using configs: ${configs.mkString(",")}")
-            throw e
-        }
+  protected def runQueriesWithSparkConfigDimensions(
+      queries: Seq[String],
+      testCase: TestCase,
+      sparkConfigSet: Array[(String, String)],
+      sparkConfigDims: Seq[Seq[(String, String)]]): Unit = {
+    sparkConfigDims.foreach { configDim =>
+      try {
+        runQueries(queries, testCase, (sparkConfigSet ++ 
configDim).toImmutableArraySeq)
+      } catch {
+        case e: Throwable =>
+          val configs = configDim.map {
+            case (k, v) => s"$k=$v"
+          }
+          logError(s"Error using configs: ${configs.mkString(",")}")
+          throw e
       }
     }
   }
 
+  /** Run a test case. */
+  protected def runSqlTestCase(testCase: TestCase, listTestCases: 
Seq[TestCase]): Unit = {
+    val input = fileToString(new File(testCase.inputFile))
+    val (comments, code) = splitCommentsAndCodes(input)
+    val queries = getQueries(code, comments)
+    val settings = getSparkSettings(comments)
+
+    if (regenerateGoldenFiles) {
+      runQueries(queries, testCase, settings.toImmutableArraySeq)
+    } else {
+      val configSets = getSparkConfigDimensions(comments)
+      runQueriesWithSparkConfigDimensions(
+        queries, testCase, settings, configSets)
+    }
+  }
+
   def hasNoDuplicateColumns(schema: String): Boolean = {
     val columnAndTypes = schema.replaceFirst("^struct<", 
"").stripSuffix(">").split(",")
     columnAndTypes.size == columnAndTypes.distinct.length
@@ -498,7 +516,7 @@ class SQLQueryTestSuite extends QueryTest with 
SharedSparkSession with SQLHelper
   protected def runQueries(
       queries: Seq[String],
       testCase: TestCase,
-      configSet: Seq[(String, String)]): Unit = {
+      sparkConfigSet: Seq[(String, String)]): Unit = {
     // Create a local SparkSession to have stronger isolation between 
different test cases.
     // This does not isolate catalog changes.
     val localSparkSession = spark.newSession()
@@ -529,9 +547,9 @@ class SQLQueryTestSuite extends QueryTest with 
SharedSparkSession with SQLHelper
         localSparkSession.conf.set(SQLConf.ANSI_ENABLED.key, false)
     }
 
-    if (configSet.nonEmpty) {
+    if (sparkConfigSet.nonEmpty) {
       // Execute the list of set operation in order to add the desired configs
-      val setOperations = configSet.map { case (key, value) => s"set 
$key=$value" }
+      val setOperations = sparkConfigSet.map { case (key, value) => s"set 
$key=$value" }
       logInfo(s"Setting configs: ${setOperations.mkString(", ")}")
       setOperations.foreach(localSparkSession.sql)
     }
@@ -612,9 +630,18 @@ class SQLQueryTestSuite extends QueryTest with 
SharedSparkSession with SQLHelper
     }
   }
 
+  /**
+   * Returns the desired file path for results, given the input file. This is 
implemented as a
+   * function because differente Suites extending this class may want their 
results files with
+   * different names or in different locations.
+   */
+  protected def resultFileForInputFile(file: File): String = {
+    file.getAbsolutePath.replace(inputFilePath, goldenFilePath) + ".out"
+  }
+
   protected lazy val listTestCases: Seq[TestCase] = {
     listFilesRecursively(new File(inputFilePath)).flatMap { file =>
-      var resultFile = file.getAbsolutePath.replace(inputFilePath, 
goldenFilePath) + ".out"
+      var resultFile = resultFileForInputFile(file)
       var analyzerResultFile =
         file.getAbsolutePath.replace(inputFilePath, analyzerGoldenFilePath) + 
".out"
       // JDK-4511638 changes 'toString' result of Float/Double
@@ -625,7 +652,6 @@ class SQLQueryTestSuite extends QueryTest with 
SharedSparkSession with SQLHelper
       }
       val absPath = file.getAbsolutePath
       val testCaseName = 
absPath.stripPrefix(inputFilePath).stripPrefix(File.separator)
-      val analyzerTestCaseName = s"${testCaseName}_analyzer_test"
 
       // Create test cases of test types that depend on the input filename.
       val newTestCases: Seq[TestCase] = if (file.getAbsolutePath.startsWith(
@@ -934,4 +960,74 @@ class SQLQueryTestSuite extends QueryTest with 
SharedSparkSession with SQLHelper
     }
     override def numSegments: Int = 2
   }
+
+  test("test splitCommentsAndCodes") {
+    {
+      // Correctly split comments and codes
+      val input =
+        """-- Comment 1
+          |SELECT * FROM table1;
+          |-- Comment 2
+          |SELECT * FROM table2;
+          |""".stripMargin
+
+      val (comments, codes) = splitCommentsAndCodes(input)
+      assert(comments.toSet == Set("-- Comment 1", "-- Comment 2"))
+      assert(codes.toSet == Set("SELECT * FROM table1;", "SELECT * FROM 
table2;"))
+    }
+
+    {
+      // Handle input with no comments
+      val input = "SELECT * FROM table;"
+      val (comments, codes) = splitCommentsAndCodes(input)
+      assert(comments.isEmpty)
+      assert(codes.toSet == Set("SELECT * FROM table;"))
+    }
+
+    {
+      // Handle input with no codes
+      val input =
+        """-- Comment 1
+          |-- Comment 2
+          |""".stripMargin
+
+      val (comments, codes) = splitCommentsAndCodes(input)
+      assert(comments.toSet == Set("-- Comment 1", "-- Comment 2"))
+      assert(codes.isEmpty)
+    }
+  }
+
+  test("Test logic for determining whether a query is semantically sorted") {
+    withTable("t1", "t2") {
+      spark.sql("CREATE TABLE t1(a int, b int) USING parquet")
+      spark.sql("CREATE TABLE t2(a int, b int) USING parquet")
+
+      val unsortedSelectQuery = "select * from t1"
+      val sortedSelectQuery = "select * from t1 order by a, b"
+
+      val unsortedJoinQuery = "select * from t1 join t2 on t1.a = t2.a"
+      val sortedJoinQuery = "select * from t1 join t2 on t1.a = t2.a order by 
t1.a"
+
+      val unsortedAggQuery = "select a, max(b) from t1 group by a"
+      val sortedAggQuery = "select a, max(b) from t1 group by a order by a"
+
+      val unsortedDistinctQuery = "select distinct a from t1"
+      val sortedDistinctQuery = "select distinct a from t1 order by a"
+
+      val unsortedWindowQuery = "SELECT a, b, SUM(b) OVER (ORDER BY a) AS 
cumulative_sum FROM t1;"
+      val sortedWindowQuery = "SELECT a, b, SUM(b) OVER (ORDER BY a) AS 
cumulative_sum FROM " +
+        "t1 ORDER BY a, b;"
+
+      assert(!isSemanticallySorted(spark.sql(unsortedSelectQuery).logicalPlan))
+      assert(!isSemanticallySorted(spark.sql(unsortedJoinQuery).logicalPlan))
+      
assert(!isSemanticallySorted(spark.sql(unsortedDistinctQuery).logicalPlan))
+      assert(!isSemanticallySorted(spark.sql(unsortedWindowQuery).logicalPlan))
+      assert(!isSemanticallySorted(spark.sql(unsortedAggQuery).logicalPlan))
+      assert(isSemanticallySorted(spark.sql(sortedSelectQuery).logicalPlan))
+      assert(isSemanticallySorted(spark.sql(sortedJoinQuery).logicalPlan))
+      assert(isSemanticallySorted(spark.sql(sortedAggQuery).logicalPlan))
+      assert(isSemanticallySorted(spark.sql(sortedWindowQuery).logicalPlan))
+      assert(isSemanticallySorted(spark.sql(sortedDistinctQuery).logicalPlan))
+    }
+  }
 }
diff --git 
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala
 
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala
index ee336b13b0b2..8d90d47e1bf5 100644
--- 
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala
+++ 
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala
@@ -246,7 +246,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite 
with SharedThriftServ
 
   override lazy val listTestCases: Seq[TestCase] = {
     listFilesRecursively(new File(inputFilePath)).flatMap { file =>
-      var resultFile = file.getAbsolutePath.replace(inputFilePath, 
goldenFilePath) + ".out"
+      var resultFile = resultFileForInputFile(file)
       // JDK-4511638 changes 'toString' result of Float/Double
       // JDK-8282081 changes DataTimeFormatter 'F' symbol
       if (Utils.isJavaVersionAtLeast21 && (new File(resultFile + 
".java21")).exists()) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to