This is an automated email from the ASF dual-hosted git repository.
marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 24a1dcd4c9 [GLUTEN-8738] Update GlutenSQLQueryTestSuite to match with
the original file (#8837)
24a1dcd4c9 is described below
commit 24a1dcd4c954a87edc05d8b4678687ca3c39b89e
Author: Rong Ma <[email protected]>
AuthorDate: Sun Mar 2 08:31:26 2025 +0000
[GLUTEN-8738] Update GlutenSQLQueryTestSuite to match with the original
file (#8837)
---
.../GlutenClickHouseTPCHBucketSuite.scala | 4 +-
.../execution/GlutenClickHouseTPCHSuite.scala | 14 +-
.../execution/GlutenFunctionValidateSuite.scala | 4 +-
.../gluten/execution/VeloxMetricsSuite.scala | 4 +-
.../apache/gluten/execution/VeloxTPCHSuite.scala | 6 +-
.../sql/{TestUtils.scala => GlutenTestUtils.scala} | 6 +-
.../ClickHouseSQLQueryTestSettings.scala | 5 +-
.../utils/velox/VeloxSQLQueryTestSettings.scala | 2 +-
.../apache/spark/sql/GlutenSQLQueryTestSuite.scala | 746 +++++++++++++--------
9 files changed, 493 insertions(+), 298 deletions(-)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala
index 5eba3977fa..9b32ecf609 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala
@@ -17,7 +17,7 @@
package org.apache.gluten.execution
import org.apache.spark.SparkConf
-import org.apache.spark.sql.{DataFrame, Row, TestUtils}
+import org.apache.spark.sql.{DataFrame, GlutenTestUtils, Row}
import org.apache.spark.sql.execution.InputIteratorTransformer
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.SortAggregateExec
@@ -590,7 +590,7 @@ class GlutenClickHouseTPCHBucketSuite
case o => o
})
}
- TestUtils.compareAnswers(sortedRes, exceptedResult)
+ GlutenTestUtils.compareAnswers(sortedRes, exceptedResult)
}
val SQL =
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala
index 65a01dea30..10cd1b8aac 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala
@@ -17,7 +17,7 @@
package org.apache.gluten.execution
import org.apache.spark.SparkConf
-import org.apache.spark.sql.{Row, TestUtils}
+import org.apache.spark.sql.{GlutenTestUtils, Row}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.types.{DecimalType, StructType}
@@ -342,7 +342,7 @@ class GlutenClickHouseTPCHSuite extends
GlutenClickHouseTPCHAbstractSuite {
assert(result.size == 7)
val expected =
Seq(Row(465.0), Row(67.0), Row(160.0), Row(371.0), Row(732.0),
Row(138.0), Row(785.0))
- TestUtils.compareAnswers(result, expected)
+ GlutenTestUtils.compareAnswers(result, expected)
}
test("test 'order by' two keys") {
@@ -358,7 +358,7 @@ class GlutenClickHouseTPCHSuite extends
GlutenClickHouseTPCHAbstractSuite {
val result = df.take(3)
val expected =
Seq(Row(0, "ALGERIA", 0), Row(1, "ARGENTINA", 1), Row(2, "BRAZIL",
1))
- TestUtils.compareAnswers(result, expected)
+ GlutenTestUtils.compareAnswers(result, expected)
}
}
@@ -373,7 +373,7 @@ class GlutenClickHouseTPCHSuite extends
GlutenClickHouseTPCHAbstractSuite {
assert(sortExec.size == 1)
val result = df.collect()
val expectedResult = Seq(Row(0), Row(1), Row(2), Row(3), Row(4))
- TestUtils.compareAnswers(result, expectedResult)
+ GlutenTestUtils.compareAnswers(result, expectedResult)
}
}
@@ -416,7 +416,7 @@ class GlutenClickHouseTPCHSuite extends
GlutenClickHouseTPCHAbstractSuite {
new java.math.BigDecimal("123456789.223456789012345678901234567"),
Seq(new java.math.BigDecimal("123456789.123456789012345678901234567"))
))
- TestUtils.compareAnswers(result, expectedResult)
+ GlutenTestUtils.compareAnswers(result, expectedResult)
}
test("test decimal128") {
@@ -434,8 +434,8 @@ class GlutenClickHouseTPCHSuite extends
GlutenClickHouseTPCHAbstractSuite {
.add("b1", DecimalType(38, 27)))
val df2 = spark.createDataFrame(data, schema)
- TestUtils.compareAnswers(df2.select("b").collect(), Seq(Row(struct)))
- TestUtils.compareAnswers(
+ GlutenTestUtils.compareAnswers(df2.select("b").collect(), Seq(Row(struct)))
+ GlutenTestUtils.compareAnswers(
df2.select("a").collect(),
Seq(Row(new
java.math.BigDecimal("123456789.123456789012345678901234566"))))
}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
index 5192dd0e47..9189cb017c 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
@@ -17,7 +17,7 @@
package org.apache.gluten.execution
import org.apache.spark.SparkConf
-import org.apache.spark.sql.{DataFrame, Row, TestUtils}
+import org.apache.spark.sql.{DataFrame, GlutenTestUtils, Row}
import org.apache.spark.sql.catalyst.expressions.{Expression, GetJsonObject,
Literal}
import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding,
NullPropagation}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan,
Project}
@@ -595,7 +595,7 @@ class GlutenFunctionValidateSuite extends
GlutenClickHouseWholeStageTransformerS
// check the result
val result = df.collect()
assert(result.length === exceptedResult.size)
- TestUtils.compareAnswers(result, exceptedResult)
+ GlutenTestUtils.compareAnswers(result, exceptedResult)
}
runSql("select round(0.41875d * id , 4) from range(10);")(
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala
index 01edf875d8..38e2bcc45e 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala
@@ -21,7 +21,7 @@ import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.spark.SparkConf
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
-import org.apache.spark.sql.TestUtils
+import org.apache.spark.sql.GlutenTestUtils
import org.apache.spark.sql.execution.{ColumnarInputAdapter,
CommandResultExec, InputIteratorTransformer}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper,
BroadcastQueryStageExec}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
@@ -219,7 +219,7 @@ class VeloxMetricsSuite extends
VeloxWholeStageTransformerSuite with AdaptiveSpa
}
}
- TestUtils.withListener(spark.sparkContext, inputMetricsListener) {
+ GlutenTestUtils.withListener(spark.sparkContext, inputMetricsListener) {
_ =>
val df = spark.sql("""
|select /*+ BROADCAST(part) */ * from part join
lineitem
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala
index 62f7c1f0a1..2d7645bedf 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala
@@ -20,7 +20,7 @@ import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.config.VeloxConfig
import org.apache.spark.SparkConf
-import org.apache.spark.sql.{DataFrame, Row, TestUtils}
+import org.apache.spark.sql.{DataFrame, GlutenTestUtils, Row}
import org.apache.spark.sql.execution.{ColumnarShuffleExchangeExec,
FormattedMode}
import org.apache.commons.io.FileUtils
@@ -265,7 +265,7 @@ class VeloxTPCHDistinctSpillSuite extends
VeloxTPCHTableSupport {
test("distinct spill") {
val df = spark.sql("select count(distinct *) from lineitem limit 1")
- TestUtils.compareAnswers(df.collect(), Seq(Row(60175)))
+ GlutenTestUtils.compareAnswers(df.collect(), Seq(Row(60175)))
}
}
@@ -287,7 +287,7 @@ class VeloxTPCHMiscSuite extends VeloxTPCHTableSupport {
val result = df.collect()
df.explain(true)
val expectedResult = Seq(Row(0), Row(1), Row(2), Row(3), Row(4))
- TestUtils.compareAnswers(result, expectedResult)
+ GlutenTestUtils.compareAnswers(result, expectedResult)
}
}
diff --git
a/gluten-substrait/src/test/scala/org/apache/spark/sql/TestUtils.scala
b/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenTestUtils.scala
similarity index 91%
rename from gluten-substrait/src/test/scala/org/apache/spark/sql/TestUtils.scala
rename to
gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenTestUtils.scala
index c87f594663..35fe9518ce 100644
--- a/gluten-substrait/src/test/scala/org/apache/spark/sql/TestUtils.scala
+++ b/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenTestUtils.scala
@@ -18,12 +18,12 @@ package org.apache.spark.sql
import org.apache.gluten.exception.GlutenException
-import org.apache.spark.{TestUtils => SparkTestUtils}
import org.apache.spark.SparkContext
+import org.apache.spark.TestUtils
import org.apache.spark.scheduler.SparkListener
import org.apache.spark.sql.test.SQLTestUtils
-object TestUtils {
+object GlutenTestUtils {
def compareAnswers(actual: Seq[Row], expected: Seq[Row], sort: Boolean =
false): Unit = {
val result = SQLTestUtils.compareAnswers(actual, expected, sort)
if (result.isDefined) {
@@ -32,6 +32,6 @@ object TestUtils {
}
def withListener[L <: SparkListener](sc: SparkContext, listener: L)(body: L
=> Unit): Unit = {
- SparkTestUtils.withListener(sc, listener)(body)
+ TestUtils.withListener(sc, listener)(body)
}
}
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseSQLQueryTestSettings.scala
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseSQLQueryTestSettings.scala
index 46167fa0b3..937e3494d0 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseSQLQueryTestSettings.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseSQLQueryTestSettings.scala
@@ -268,6 +268,9 @@ object ClickHouseSQLQueryTestSettings extends
SQLQueryTestSettings {
"window.sql",
"udf/udf-window.sql",
"group-by.sql",
- "udf/udf-group-by.sql - Scala UDF"
+ "udf/udf-group-by.sql - Scala UDF",
+ "udaf/udaf-group-analytics.sql",
+ "udaf/udaf-group-by-ordinal.sql",
+ "udaf/udaf-group-by.sql"
)
}
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxSQLQueryTestSettings.scala
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxSQLQueryTestSettings.scala
index 53368be729..0a86f68962 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxSQLQueryTestSettings.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxSQLQueryTestSettings.scala
@@ -220,7 +220,7 @@ object VeloxSQLQueryTestSettings extends
SQLQueryTestSettings {
"typeCoercion/native/promoteStrings.sql",
"typeCoercion/native/widenSetOperationTypes.sql",
"typeCoercion/native/windowFrameCoercion.sql",
- "udaf/udaf.sql",
+ "udaf/udaf.sql - Grouped Aggregate Pandas UDF",
"udf/udf-union.sql - Scala UDF",
"udf/udf-intersect-all.sql - Scala UDF",
"udf/udf-except-all.sql - Scala UDF",
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenSQLQueryTestSuite.scala
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenSQLQueryTestSuite.scala
index c2acedb463..6c439d756c 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenSQLQueryTestSuite.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenSQLQueryTestSuite.scala
@@ -19,12 +19,14 @@ package org.apache.spark.sql
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.utils.{BackendTestSettings, BackendTestUtils}
-import org.apache.spark.{SparkConf, SparkException, SparkThrowable}
+import org.apache.spark.{SparkConf, SparkException, SparkThrowable, TestUtils}
import org.apache.spark.ErrorMessageFormat.MINIMAL
import org.apache.spark.SparkThrowableHelper.getMessage
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator
import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding,
ConvertToLocalRelation, NullPropagation}
+import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_SECOND
@@ -41,91 +43,18 @@ import java.util.Locale
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import scala.sys.process.{Process, ProcessLogger}
-import scala.util.Try
import scala.util.control.NonFatal
/**
- * End-to-end test cases for SQL queries.
+ * Originated from org.apache.spark.sql.SQLQueryTestSuite, with the following
modifications:
+ * - Overwrite the generated golden files to remove failed queries.
+ * - Overwrite the generated golden files to update expected results for
exception message
+ * mismatches, result order mismatches in non-order-sensitive queries, and
minor precision scale
+ * mismatches.
+ * - Remove the AnalyzerTest as it's not within the scope of the Gluten
project.
*
- * Each case is loaded from a file in
"spark/sql/core/src/test/resources/sql-tests/inputs". Each
- * case has a golden result file in
"spark/sql/core/src/test/resources/sql-tests/results".
- *
- * To run the entire test suite:
- * {{{
- * build/sbt "sql/testOnly *SQLQueryTestSuite"
- * }}}
- *
- * To run a single test file upon change:
- * {{{
- * build/sbt "~sql/testOnly *SQLQueryTestSuite -- -z inline-table.sql"
- * }}}
- *
- * To re-generate golden files for entire suite, run:
- * {{{
- * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly *SQLQueryTestSuite"
- * }}}
- *
- * To re-generate golden file for a single test, run:
- * {{{
- * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly *SQLQueryTestSuite
-- -z describe.sql"
- * }}}
- *
- * The format for input files is simple:
- * 1. A list of SQL queries separated by semicolons by default. If the
semicolon cannot
- * effectively separate the SQL queries in the test file(e.g. bracketed
comments), please use
- * --QUERY-DELIMITER-START and --QUERY-DELIMITER-END. Lines starting with
--QUERY-DELIMITER-START
- * and --QUERY-DELIMITER-END represent the beginning and end of a query,
respectively. Code that is
- * not surrounded by lines that begin with --QUERY-DELIMITER-START and
--QUERY-DELIMITER-END is
- * still separated by semicolons. 2. Lines starting with -- are treated as
comments and ignored. 3.
- * Lines starting with --SET are used to specify the configs when running this
testing file. You can
- * set multiple configs in one --SET, using comma to separate them. Or you can
use multiple --SET
- * statements. 4. Lines starting with --IMPORT are used to load queries from
another test file. 5.
- * Lines starting with --CONFIG_DIM are used to specify config dimensions of
this testing file. The
- * dimension name is decided by the string after --CONFIG_DIM. For example,
--CONFIG_DIM1 belongs to
- * dimension 1. One dimension can have multiple lines, each line representing
one config set (one or
- * more configs, separated by comma). Spark will run this testing file many
times, each time picks
- * one config set from each dimension, until all the combinations are tried.
For example, if
- * dimension 1 has 2 lines, dimension 2 has 3 lines, this testing file will be
run 6 times
- * (cartesian product).
- *
- * For example:
- * {{{
- * -- this is a comment
- * select 1, -1;
- * select current_date;
- * }}}
- *
- * The format for golden result files look roughly like:
- * {{{
- * -- some header information
- *
- * -- !query
- * select 1, -1
- * -- !query schema
- * struct<...schema...>
- * -- !query output
- * ... data row 1 ...
- * ... data row 2 ...
- * ...
- *
- * -- !query
- * ...
- * }}}
- *
- * Note that UDF tests work differently. After the test files under
'inputs/udf' directory are
- * detected, it creates three test cases:
- *
- * - Scala UDF test case with a Scalar UDF registered as the name 'udf'.
- *
- * - Python UDF test case with a Python UDF registered as the name 'udf' iff
Python executable and
- * pyspark are available.
- *
- * - Scalar Pandas UDF test case with a Scalar Pandas UDF registered as the
name 'udf' iff Python
- * executable, pyspark, pandas and pyarrow are available.
- *
- * Therefore, UDF test cases should have single input and output files but
executed by three
- * different types of UDFs. See 'udf/udf-inner-join.sql' as an example.
+ * NOTE: DO NOT simply copy-paste this file for supporting new Spark versions.
SQLQueryTestSuite is
+ * actively modified in Spark, so compare the difference and apply the
necessary changes.
*/
@ExtendedSQLTest
class GlutenSQLQueryTestSuite
@@ -136,28 +65,10 @@ class GlutenSQLQueryTestSuite
import IntegratedUDFTestUtils._
- override protected val regenerateGoldenFiles: Boolean =
- System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1"
-
- // FIXME it's not needed to install Spark in testing since the following
code only fetchs
- // some resource files from source folder
-
- protected val baseResourcePath = {
- // We use a path based on Spark home for 2 reasons:
- // 1. Maven can't get correct resource directory when resources in other
jars.
- // 2. We test subclasses in the hive-thriftserver module.
- getWorkspaceFilePath("sql", "core", "src", "test", "resources",
"sql-tests").toFile
- }
-
+ // ==== Start of modifications for Gluten. ====
protected val resourcesPath = {
- // We use a path based on Spark home for 2 reasons:
- // 1. Maven can't get correct resource directory when resources in other
jars.
- // 2. We test subclasses in the hive-thriftserver module.
getWorkspaceFilePath("sql", "core", "src", "test", "resources").toFile
}
-
- protected val inputFilePath = new File(baseResourcePath,
"inputs").getAbsolutePath
- protected val goldenFilePath = new File(baseResourcePath,
"results").getAbsolutePath
protected val testDataPath = new File(resourcesPath,
"test-data").getAbsolutePath
protected val overwriteInputFilePath = new File(
@@ -168,24 +79,34 @@ class GlutenSQLQueryTestSuite
BackendTestSettings.instance.getSQLQueryTestSettings.getResourceFilePath,
"results").getAbsolutePath
- /** Test if a command is available. */
- def testCommandAvailable(command: String): Boolean = {
- val attempt = if (Utils.isWindows) {
- Try(Process(Seq("cmd.exe", "/C", s"where $command")).run(ProcessLogger(_
=> ())).exitValue())
- } else {
- Try(Process(Seq("sh", "-c", s"command -v $command")).run(ProcessLogger(_
=> ())).exitValue())
- }
- attempt.isSuccess && attempt.get == 0
- }
-
private val isCHBackend = BackendTestUtils.isCHBackendLoaded()
+ // List of supported cases to run with a certain backend, in lower case.
+ private val supportedList: Set[String] =
+
BackendTestSettings.instance.getSQLQueryTestSettings.getSupportedSQLQueryTests
++
+
BackendTestSettings.instance.getSQLQueryTestSettings.getOverwriteSQLQueryTests
+
+ private val normalizeRegex = "#\\d+L?".r
+ private val nodeNumberRegex = "[\\^*]\\(\\d+\\)".r
+ private def normalizeIds(plan: String): String = {
+ val normalizedPlan = nodeNumberRegex.replaceAllIn(plan, "")
+ val map = new mutable.HashMap[String, String]()
+ normalizeRegex
+ .findAllMatchIn(normalizedPlan)
+ .map(_.toString)
+ .foreach(map.getOrElseUpdate(_, (map.size + 1).toString))
+ normalizeRegex.replaceAllIn(normalizedPlan, regexMatch =>
s"#${map(regexMatch.toString)}")
+ }
+
override protected def sparkConf: SparkConf = {
val conf = super.sparkConf
// Fewer shuffle partitions to speed up testing.
.set(SQLConf.SHUFFLE_PARTITIONS, 4)
// use Java 8 time API to handle negative years properly
.set(SQLConf.DATETIME_JAVA8API_ENABLED, true)
+ // SPARK-39564: don't print out serde to avoid introducing complicated
and error-prone
+ // regex magic.
+ .set("spark.test.noSerdeInExplain", "true")
.setAppName("Gluten-UT")
.set("spark.driver.memory", "1G")
.set("spark.sql.adaptive.enabled", "true")
@@ -214,31 +135,132 @@ class GlutenSQLQueryTestSuite
conf
}
- // SPARK-32106 Since we add SQL test 'transform.sql' will use `cat` command,
- // here we need to ignore it.
- private val otherIgnoreList =
- if (testCommandAvailable("/bin/bash")) Nil else Set("transform.sql")
-
- // 3.4 inadvertently enabled with "group-by.sql" and "group-by-ordinal.sql"
- private val udafIgnoreList = Set(
- "udaf/udaf-group-analytics.sql",
- "udaf/udaf-group-by-ordinal.sql",
- "udaf/udaf-group-by.sql"
- )
-
/** List of test cases to ignore, in lower cases. */
protected def ignoreList: Set[String] = Set(
"ignored.sql", // Do NOT remove this one. It is here to test the ignore
functionality.
- "explain-aqe.sql", // explain plan is different
- "explain-cbo.sql", // explain
- "explain.sql" // explain
- ) ++ otherIgnoreList ++ udafIgnoreList ++
+ "explain-aqe.sql", // Explain is different in Gluten.
+ "explain-cbo.sql", // Explain is different in Gluten.
+ "explain.sql" // Explain is different in Gluten.
+ ) ++ otherIgnoreList ++
BackendTestSettings.instance.getSQLQueryTestSettings.getIgnoredSQLQueryTests
- // List of supported cases to run with a certain backend, in lower case.
- private val supportedList: Set[String] =
-
BackendTestSettings.instance.getSQLQueryTestSettings.getSupportedSQLQueryTests
++
-
BackendTestSettings.instance.getSQLQueryTestSettings.getOverwriteSQLQueryTests
+ /**
+ * This method handles exceptions occurred during query execution as they
may need special care to
+ * become comparable to the expected output.
+ *
+ * Modified for Gluten by truncating exception output to only include the
exception class and
+ * message.
+ *
+ * @param result
+ * a function that returns a pair of schema and output
+ */
+ override protected def handleExceptions(
+ result: => (String, Seq[String])): (String, Seq[String]) = {
+ val format = MINIMAL
+ try {
+ result
+ } catch {
+ case e: SparkThrowable with Throwable if e.getErrorClass != null =>
+ (emptySchema, Seq(e.getClass.getName, getMessage(e, format)))
+ case a: AnalysisException =>
+ // Do not output the logical plan tree which contains expression IDs.
+ // Also implement a crude way of masking expression IDs in the error
message
+ // with a generic pattern "###".
+ (emptySchema, Seq(a.getClass.getName,
a.getSimpleMessage.replaceAll("#\\d+", "#x")))
+ case s: SparkException if s.getCause != null =>
+ // For a runtime exception, it is hard to match because its message
contains
+ // information of stage, task ID, etc.
+ // To make result matching simpler, here we match the cause of the
exception if it exists.
+ s.getCause match {
+ case e: SparkThrowable with Throwable if e.getErrorClass != null =>
+ (emptySchema, Seq(e.getClass.getName, getMessage(e, format)))
+ case e: GlutenException =>
+ val reasonPattern = "Reason: (.*)".r
+ val reason =
reasonPattern.findFirstMatchIn(e.getMessage).map(_.group(1))
+
+ reason match {
+ case Some(r) =>
+ (emptySchema, Seq(e.getClass.getName, r))
+ case None => (emptySchema, Seq())
+ }
+
+ case cause =>
+ (emptySchema, Seq(cause.getClass.getName, cause.getMessage))
+ }
+ case NonFatal(e) =>
+ // If there is an exception, put the exception class followed by the
message.
+ (emptySchema, Seq(e.getClass.getName, e.getMessage))
+ }
+ }
+
+ protected lazy val listTestCases: Seq[TestCase] = {
+ val createTestCase = (file: File, parentDir: String, resultPath: String)
=> {
+ val resultFile = file.getAbsolutePath.replace(parentDir, resultPath) +
".out"
+ val absPath = file.getAbsolutePath
+ val testCaseName =
absPath.stripPrefix(parentDir).stripPrefix(File.separator)
+
+ if (
+ file.getAbsolutePath.startsWith(
+ s"$parentDir${File.separator}udf${File.separator}postgreSQL")
+ ) {
+ Seq(TestScalaUDF("udf"), TestPythonUDF("udf"),
TestScalarPandasUDF("udf")).map {
+ udf => UDFPgSQLTestCase(s"$testCaseName - ${udf.prettyName}",
absPath, resultFile, udf)
+ }
+ } else if
(file.getAbsolutePath.startsWith(s"$parentDir${File.separator}udf")) {
+ Seq(TestScalaUDF("udf"), TestPythonUDF("udf"),
TestScalarPandasUDF("udf")).map {
+ udf => UDFTestCase(s"$testCaseName - ${udf.prettyName}", absPath,
resultFile, udf)
+ }
+ } else if
(file.getAbsolutePath.startsWith(s"$parentDir${File.separator}udaf")) {
+ Seq(TestGroupedAggPandasUDF("udaf")).map {
+ udf => UDAFTestCase(s"$testCaseName - ${udf.prettyName}", absPath,
resultFile, udf)
+ }
+ } else if
(file.getAbsolutePath.startsWith(s"$parentDir${File.separator}udtf")) {
+ Seq(TestPythonUDTF("udtf")).map {
+ udtf =>
+ UDTFTestCase(
+ s"$testCaseName - ${udtf.prettyName}",
+ absPath,
+ resultFile,
+ udtf
+ )
+ }
+ } else if
(file.getAbsolutePath.startsWith(s"$parentDir${File.separator}postgreSQL")) {
+ PgSQLTestCase(testCaseName, absPath, resultFile) :: Nil
+ } else if
(file.getAbsolutePath.startsWith(s"$parentDir${File.separator}ansi")) {
+ AnsiTestCase(testCaseName, absPath, resultFile) :: Nil
+ } else if
(file.getAbsolutePath.startsWith(s"$parentDir${File.separator}timestampNTZ")) {
+ TimestampNTZTestCase(testCaseName, absPath, resultFile) :: Nil
+ } else if
(file.getAbsolutePath.startsWith(s"$parentDir${File.separator}cte.sql")) {
+ CTETestCase(testCaseName, absPath, resultFile) :: Nil
+ } else {
+ RegularTestCase(testCaseName, absPath, resultFile) :: Nil
+ }
+ }
+ val overwriteTestCases = listFilesRecursively(new
File(overwriteInputFilePath))
+ .flatMap(createTestCase(_, overwriteInputFilePath,
overwriteGoldenFilePath))
+ val overwriteTestCaseNames = overwriteTestCases.map(_.name)
+ (listFilesRecursively(new File(inputFilePath))
+ .flatMap(createTestCase(_, inputFilePath, goldenFilePath))
+ .filterNot(testCase => overwriteTestCaseNames.contains(testCase.name))
++ overwriteTestCases)
+ .sortBy(_.name)
+ }
+ // ==== End of modifications for Gluten. ====
+
+ protected val baseResourcePath = {
+ // We use a path based on Spark home for 2 reasons:
+ // 1. Maven can't get correct resource directory when resources in other
jars.
+ // 2. We test subclasses in the hive-thriftserver module.
+ getWorkspaceFilePath("sql", "core", "src", "test", "resources",
"sql-tests").toFile
+ }
+
+ protected val inputFilePath = new File(baseResourcePath,
"inputs").getAbsolutePath
+ protected val goldenFilePath = new File(baseResourcePath,
"results").getAbsolutePath
+
+ // SPARK-32106 Since we add SQL test 'transform.sql' will use `cat` command,
+ // here we need to ignore it.
+ private val otherIgnoreList =
+ if (TestUtils.testCommandAvailable("/bin/bash")) Nil else
Set("transform.sql")
+
// Create all the test cases.
listTestCases.foreach(createScalaTestCase)
@@ -260,6 +282,7 @@ class GlutenSQLQueryTestSuite
val name: String
val inputFile: String
val resultFile: String
+ def asAnalyzerTest(newName: String, newResultFile: String): TestCase
}
/**
@@ -268,24 +291,55 @@ class GlutenSQLQueryTestSuite
*/
protected trait PgSQLTest
- /** traits that indicate ANSI-related tests with the ANSI mode enabled. */
+ /** Trait that indicates ANSI-related tests with the ANSI mode enabled. */
protected trait AnsiTest
- /** traits that indicate the default timestamp type is TimestampNTZType. */
+ /** Trait that indicates an analyzer test that shows the analyzed plan
string as output. */
+ protected trait AnalyzerTest extends TestCase {
+ override def asAnalyzerTest(newName: String, newResultFile: String):
AnalyzerTest = this
+ }
+
+ /** Trait that indicates the default timestamp type is TimestampNTZType. */
protected trait TimestampNTZTest
+ /** Trait that indicates CTE test cases need their create view versions */
+ protected trait CTETest
+
protected trait UDFTest {
val udf: TestUDF
}
+ protected trait UDTFTest {
+ val udtf: TestUDTF
+ }
+
/** A regular test case. */
protected case class RegularTestCase(name: String, inputFile: String,
resultFile: String)
+ extends TestCase {
+ override def asAnalyzerTest(newName: String, newResultFile: String):
TestCase =
+ RegularAnalyzerTestCase(newName, inputFile, newResultFile)
+ }
+
+ /** An ANSI-related test case. */
+ protected case class AnsiTestCase(name: String, inputFile: String,
resultFile: String)
extends TestCase
+ with AnsiTest {
+ override def asAnalyzerTest(newName: String, newResultFile: String):
TestCase =
+ AnsiAnalyzerTestCase(newName, inputFile, newResultFile)
+ }
+
+ /** An analyzer test that shows the analyzed plan string as output. */
+ protected case class AnalyzerTestCase(name: String, inputFile: String,
resultFile: String)
+ extends TestCase
+ with AnalyzerTest
/** A PostgreSQL test case. */
protected case class PgSQLTestCase(name: String, inputFile: String,
resultFile: String)
extends TestCase
- with PgSQLTest
+ with PgSQLTest {
+ override def asAnalyzerTest(newName: String, newResultFile: String):
TestCase =
+ PgSQLAnalyzerTestCase(newName, inputFile, newResultFile)
+ }
/** A UDF test case. */
protected case class UDFTestCase(
@@ -294,7 +348,34 @@ class GlutenSQLQueryTestSuite
resultFile: String,
udf: TestUDF)
extends TestCase
- with UDFTest
+ with UDFTest {
+ override def asAnalyzerTest(newName: String, newResultFile: String):
TestCase =
+ UDFAnalyzerTestCase(newName, inputFile, newResultFile, udf)
+ }
+
+ protected case class UDTFTestCase(
+ name: String,
+ inputFile: String,
+ resultFile: String,
+ udtf: TestUDTF)
+ extends TestCase
+ with UDTFTest {
+
+ override def asAnalyzerTest(newName: String, newResultFile: String):
TestCase =
+ UDTFAnalyzerTestCase(newName, inputFile, newResultFile, udtf)
+ }
+
+ /** A UDAF test case. */
+ protected case class UDAFTestCase(
+ name: String,
+ inputFile: String,
+ resultFile: String,
+ udf: TestUDF)
+ extends TestCase
+ with UDFTest {
+ override def asAnalyzerTest(newName: String, newResultFile: String):
TestCase =
+ UDAFAnalyzerTestCase(newName, inputFile, newResultFile, udf)
+ }
/** A UDF PostgreSQL test case. */
protected case class UDFPgSQLTestCase(
@@ -304,21 +385,78 @@ class GlutenSQLQueryTestSuite
udf: TestUDF)
extends TestCase
with UDFTest
- with PgSQLTest
-
- /** An ANSI-related test case. */
- protected case class AnsiTestCase(name: String, inputFile: String,
resultFile: String)
- extends TestCase
- with AnsiTest
+ with PgSQLTest {
+ override def asAnalyzerTest(newName: String, newResultFile: String):
TestCase =
+ UDFPgSQLAnalyzerTestCase(newName, inputFile, newResultFile, udf)
+ }
/** An date time test case with default timestamp as TimestampNTZType */
protected case class TimestampNTZTestCase(name: String, inputFile: String,
resultFile: String)
extends TestCase
+ with TimestampNTZTest {
+ override def asAnalyzerTest(newName: String, newResultFile: String):
TestCase =
+ TimestampNTZAnalyzerTestCase(newName, inputFile, newResultFile)
+ }
+
+ /** A CTE test case with special handling */
+ protected case class CTETestCase(name: String, inputFile: String,
resultFile: String)
+ extends TestCase
+ with CTETest {
+ override def asAnalyzerTest(newName: String, newResultFile: String):
TestCase =
+ CTEAnalyzerTestCase(newName, inputFile, newResultFile)
+ }
+
+ /** These are versions of the above test cases, but only exercising
analysis. */
+ protected case class RegularAnalyzerTestCase(name: String, inputFile:
String, resultFile: String)
+ extends AnalyzerTest
+ protected case class AnsiAnalyzerTestCase(name: String, inputFile: String,
resultFile: String)
+ extends AnalyzerTest
+ with AnsiTest
+ protected case class PgSQLAnalyzerTestCase(name: String, inputFile: String,
resultFile: String)
+ extends AnalyzerTest
+ with PgSQLTest
+ protected case class UDFAnalyzerTestCase(
+ name: String,
+ inputFile: String,
+ resultFile: String,
+ udf: TestUDF)
+ extends AnalyzerTest
+ with UDFTest
+ protected case class UDTFAnalyzerTestCase(
+ name: String,
+ inputFile: String,
+ resultFile: String,
+ udtf: TestUDTF)
+ extends AnalyzerTest
+ with UDTFTest
+ protected case class UDAFAnalyzerTestCase(
+ name: String,
+ inputFile: String,
+ resultFile: String,
+ udf: TestUDF)
+ extends AnalyzerTest
+ with UDFTest
+ protected case class UDFPgSQLAnalyzerTestCase(
+ name: String,
+ inputFile: String,
+ resultFile: String,
+ udf: TestUDF)
+ extends AnalyzerTest
+ with UDFTest
+ with PgSQLTest
+ protected case class TimestampNTZAnalyzerTestCase(
+ name: String,
+ inputFile: String,
+ resultFile: String)
+ extends AnalyzerTest
with TimestampNTZTest
+ protected case class CTEAnalyzerTestCase(name: String, inputFile: String,
resultFile: String)
+ extends AnalyzerTest
+ with CTETest
protected def createScalaTestCase(testCase: TestCase): Unit = {
- // If a test case is not in the test list, or it is in the ignore list,
ignore this test case.
if (
+ // Modified for Gluten to use exact name matching.
!supportedList.exists(
t => testCase.name.toLowerCase(Locale.ROOT) ==
t.toLowerCase(Locale.ROOT)) ||
ignoreList.exists(t => testCase.name.toLowerCase(Locale.ROOT) ==
t.toLowerCase(Locale.ROOT))
@@ -341,17 +479,25 @@ class GlutenSQLQueryTestSuite
s"pandas and/or pyarrow were not available in [$pythonExec].") {
/* Do nothing */
}
+ case udfTestCase: UDFTest
+ if udfTestCase.udf.isInstanceOf[TestGroupedAggPandasUDF] &&
+ !shouldTestPandasUDFs =>
+ ignore(
+ s"${testCase.name} is skipped because pyspark," +
+ s"pandas and/or pyarrow were not available in [$pythonExec].") {
+ /* Do nothing */
+ }
case _ =>
// Create a test case to run this case.
test(testCase.name) {
- runTest(testCase)
+ runSqlTestCase(testCase, listTestCases)
}
}
}
}
/** Run a test case. */
- protected def runTest(testCase: TestCase): Unit = {
+ protected def runSqlTestCase(testCase: TestCase, listTestCases:
Seq[TestCase]): Unit = {
def splitWithSemicolon(seq: Seq[String]) = {
seq.mkString("\n").split("(?<=[^\\\\]);")
}
@@ -468,6 +614,50 @@ class GlutenSQLQueryTestSuite
}
}
+ def hasNoDuplicateColumns(schema: String): Boolean = {
+ val columnAndTypes = schema.replaceFirst("^struct<",
"").stripSuffix(">").split(",")
+ columnAndTypes.size == columnAndTypes.distinct.size
+ }
+
+ def expandCTEQueryAndCompareResult(
+ session: SparkSession,
+ query: String,
+ output: ExecutionOutput): Unit = {
+ val triggerCreateViewTest =
+ try {
+ val logicalPlan: LogicalPlan =
session.sessionState.sqlParser.parsePlan(query)
+ !logicalPlan.isInstanceOf[Command] &&
+ output.schema.get != emptySchema &&
+ hasNoDuplicateColumns(output.schema.get)
+ } catch {
+ case _: ParseException => return
+ }
+
+ // For non-command query with CTE, compare the results of selecting from
view created on the
+ // original query.
+ if (triggerCreateViewTest) {
+ val createView = s"CREATE temporary VIEW cte_view AS $query"
+ val selectFromView = "SELECT * FROM cte_view"
+ val dropViewIfExists = "DROP VIEW IF EXISTS cte_view"
+ session.sql(createView)
+ val (selectViewSchema, selectViewOutput) =
+ handleExceptions(getNormalizedQueryExecutionResult(session,
selectFromView))
+ // Compare results.
+ assertResult(
+ output.schema.get,
+ s"Schema did not match for CTE query and select from its view:
\n$output") {
+ selectViewSchema
+ }
+ assertResult(
+ output.output,
+ s"Result did not match for CTE query and select from its view:
\n${output.sql}") {
+ selectViewOutput.mkString("\n").replaceAll("\\s+$", "")
+ }
+ // Drop view.
+ session.sql(dropViewIfExists)
+ }
+ }
+
protected def runQueries(
queries: Seq[String],
testCase: TestCase,
@@ -479,6 +669,8 @@ class GlutenSQLQueryTestSuite
testCase match {
case udfTestCase: UDFTest =>
registerTestUDF(udfTestCase.udf, localSparkSession)
+ case udtfTestCase: UDTFTest =>
+ registerTestUDTF(udtfTestCase.udtf, localSparkSession)
case _ =>
}
@@ -498,6 +690,7 @@ class GlutenSQLQueryTestSuite
SQLConf.TIMESTAMP_TYPE.key,
TimestampTypes.TIMESTAMP_NTZ.toString)
case _ =>
+ localSparkSession.conf.set(SQLConf.ANSI_ENABLED.key, false)
}
if (configSet.nonEmpty) {
@@ -508,22 +701,38 @@ class GlutenSQLQueryTestSuite
}
// Run the SQL queries preparing them for comparison.
- val outputs: Seq[QueryOutput] = queries.map {
+ val outputs: Seq[QueryTestOutput] = queries.map {
sql =>
- val (schema, output) =
-
handleExceptions(getNormalizedQueryExecutionResult(localSparkSession, sql))
- // We might need to do some query canonicalization in the future.
- QueryOutput(
- sql = sql,
- schema = schema,
- output = normalizeIds(output.mkString("\n").replaceAll("\\s+$", "")))
+ testCase match {
+ case _: AnalyzerTest =>
+ val (_, output) =
+
handleExceptions(getNormalizedQueryAnalysisResult(localSparkSession, sql))
+ // We might need to do some query canonicalization in the future.
+ AnalyzerOutput(
+ sql = sql,
+ schema = None,
+ output = output.mkString("\n").replaceAll("\\s+$", ""))
+ case _ =>
+ val (schema, output) =
+
handleExceptions(getNormalizedQueryExecutionResult(localSparkSession, sql))
+ // We might need to do some query canonicalization in the future.
+ val executionOutput = ExecutionOutput(
+ sql = sql,
+ schema = Some(schema),
+ // GLUTEN-3559: Overwrite scalar-subquery-select.sql test
+ output = normalizeIds(output.mkString("\n").replaceAll("\\s+$",
""))
+ )
+ if (testCase.isInstanceOf[CTETest]) {
+ expandCTEQueryAndCompareResult(localSparkSession, sql,
executionOutput)
+ }
+ executionOutput
+ }
}
if (regenerateGoldenFiles) {
// Again, we are explicitly not using multi-line string due to
stripMargin removing "|".
val goldenOutput = {
s"-- Automatically generated by ${getClass.getSimpleName}\n" +
- s"-- Number of queries: ${outputs.size}\n\n\n" +
outputs.mkString("\n\n\n") + "\n"
}
val resultFile = new File(testCase.resultFile)
@@ -547,100 +756,30 @@ class GlutenSQLQueryTestSuite
if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] &&
shouldTestPandasUDFs =>
s"${testCase.name}${System.lineSeparator()}" +
s"Python: $pythonVer Pandas: $pandasVer PyArrow:
$pyarrowVer${System.lineSeparator()}"
+ case udfTestCase: UDFTest
+ if udfTestCase.udf.isInstanceOf[TestGroupedAggPandasUDF] &&
+ shouldTestPandasUDFs =>
+ s"${testCase.name}${System.lineSeparator()}" +
+ s"Python: $pythonVer Pandas: $pandasVer PyArrow:
$pyarrowVer${System.lineSeparator()}"
+ case udtfTestCase: UDTFTest
+ if udtfTestCase.udtf.isInstanceOf[TestPythonUDTF] &&
shouldTestPythonUDFs =>
+ s"${testCase.name}${System.lineSeparator()}Python:
$pythonVer${System.lineSeparator()}"
case _ =>
s"${testCase.name}${System.lineSeparator()}"
}
withClue(clue) {
- // Read back the golden file.
- val expectedOutputs: Seq[QueryOutput] = {
- val goldenOutput = fileToString(new File(testCase.resultFile))
- val segments = goldenOutput.split("-- !query.*\n")
-
- // each query has 3 segments, plus the header
- assert(
- segments.size == outputs.size * 3 + 1,
- s"Expected ${outputs.size * 3 + 1} blocks in result file but got
${segments.size}. " +
- s"Try regenerate the result files.")
- Seq.tabulate(outputs.size) {
- i =>
- QueryOutput(
- sql = segments(i * 3 + 1).trim,
- schema = segments(i * 3 + 2).trim,
- output = segments(i * 3 + 3).replaceAll("\\s+$", "")
- )
- }
- }
-
- // Compare results.
- assertResult(expectedOutputs.size, s"Number of queries should be
${expectedOutputs.size}") {
- outputs.size
- }
-
- outputs.zip(expectedOutputs).zipWithIndex.foreach {
- case ((output, expected), i) =>
- assertResult(expected.sql, s"SQL query did not match for query
#$i\n${expected.sql}") {
- output.sql
- }
- assertResult(
- expected.schema,
- s"Schema did not match for query #$i\n${expected.sql}: $output") {
- output.schema
- }
- assertResult(
- expected.output,
- s"Result did not match" +
- s" for query #$i\n${expected.sql}")(output.output)
+ testCase match {
+ case _: AnalyzerTest =>
+ readGoldenFileAndCompareResults(testCase.resultFile, outputs,
AnalyzerOutput)
+ case _ =>
+ readGoldenFileAndCompareResults(testCase.resultFile, outputs,
ExecutionOutput)
}
}
}
- protected val normalizeRegex = "#\\d+L?".r
- protected val nodeNumberRegex = "[\\^*]\\(\\d+\\)".r
- protected def normalizeIds(plan: String): String = {
- val normalizedPlan = nodeNumberRegex.replaceAllIn(plan, "")
- val map = new mutable.HashMap[String, String]()
- normalizeRegex
- .findAllMatchIn(normalizedPlan)
- .map(_.toString)
- .foreach(map.getOrElseUpdate(_, (map.size + 1).toString))
- normalizeRegex.replaceAllIn(normalizedPlan, regexMatch =>
s"#${map(regexMatch.toString)}")
- }
-
- protected lazy val listTestCases: Seq[TestCase] = {
- val createTestCase = (file: File, parentDir: String, resultPath: String)
=> {
- val resultFile = file.getAbsolutePath.replace(parentDir, resultPath) +
".out"
- val absPath = file.getAbsolutePath
- val testCaseName =
absPath.stripPrefix(parentDir).stripPrefix(File.separator)
-
- if (
- file.getAbsolutePath.startsWith(
- s"$parentDir${File.separator}udf${File.separator}postgreSQL")
- ) {
- Seq(TestScalaUDF("udf"), TestPythonUDF("udf"),
TestScalarPandasUDF("udf")).map {
- udf => UDFPgSQLTestCase(s"$testCaseName - ${udf.prettyName}",
absPath, resultFile, udf)
- }
- } else if
(file.getAbsolutePath.startsWith(s"$parentDir${File.separator}udf")) {
- Seq(TestScalaUDF("udf"), TestPythonUDF("udf"),
TestScalarPandasUDF("udf")).map {
- udf => UDFTestCase(s"$testCaseName - ${udf.prettyName}", absPath,
resultFile, udf)
- }
- } else if
(file.getAbsolutePath.startsWith(s"$parentDir${File.separator}postgreSQL")) {
- PgSQLTestCase(testCaseName, absPath, resultFile) :: Nil
- } else if
(file.getAbsolutePath.startsWith(s"$parentDir${File.separator}ansi")) {
- AnsiTestCase(testCaseName, absPath, resultFile) :: Nil
- } else if
(file.getAbsolutePath.startsWith(s"$parentDir${File.separator}timestampNTZ")) {
- TimestampNTZTestCase(testCaseName, absPath, resultFile) :: Nil
- } else {
- RegularTestCase(testCaseName, absPath, resultFile) :: Nil
- }
- }
- val overwriteTestCases = listFilesRecursively(new
File(overwriteInputFilePath))
- .flatMap(createTestCase(_, overwriteInputFilePath,
overwriteGoldenFilePath))
- val overwriteTestCaseNames = overwriteTestCases.map(_.name)
- listFilesRecursively(new File(inputFilePath))
- .flatMap(createTestCase(_, inputFilePath, goldenFilePath))
- .filterNot(testCase => overwriteTestCaseNames.contains(testCase.name))
++ overwriteTestCases
- }
+ // ==== Start of modifications for Gluten. ====
+ // ===- End of modifications for Gluten. ====
/** Returns all the files (not directories) in a directory, recursively. */
protected def listFilesRecursively(path: File): Seq[File] = {
@@ -750,7 +889,7 @@ class GlutenSQLQueryTestSuite
.saveAsTable("tenk1")
}
- private def removeTestTables(session: SparkSession): Unit = {
+ protected def removeTestTables(session: SparkSession): Unit = {
session.sql("DROP TABLE IF EXISTS testdata")
session.sql("DROP TABLE IF EXISTS arraydata")
session.sql("DROP TABLE IF EXISTS mapdata")
@@ -789,48 +928,101 @@ class GlutenSQLQueryTestSuite
}
/**
- * This method handles exceptions occurred during query execution as they
may need special care to
- * become comparable to the expected output.
- *
- * @param result
- * a function that returns a pair of schema and output
+ * Consumes contents from a single golden file and compares the expected
results against the
+ * output of running a query.
*/
- override protected def handleExceptions(
- result: => (String, Seq[String])): (String, Seq[String]) = {
- val format = MINIMAL
- try {
- result
- } catch {
- case e: SparkThrowable with Throwable if e.getErrorClass != null =>
- (emptySchema, Seq(e.getClass.getName, getMessage(e, format)))
- case a: AnalysisException =>
- // Do not output the logical plan tree which contains expression IDs.
- // Also implement a crude way of masking expression IDs in the error
message
- // with a generic pattern "###".
- (emptySchema, Seq(a.getClass.getName,
a.getSimpleMessage.replaceAll("#\\d+", "#x")))
- case s: SparkException if s.getCause != null =>
- // For a runtime exception, it is hard to match because its message
contains
- // information of stage, task ID, etc.
- // To make result matching simpler, here we match the cause of the
exception if it exists.
- s.getCause match {
- case e: SparkThrowable with Throwable if e.getErrorClass != null =>
- (emptySchema, Seq(e.getClass.getName, getMessage(e, format)))
- case e: GlutenException =>
- val reasonPattern = "Reason: (.*)".r
- val reason =
reasonPattern.findFirstMatchIn(e.getMessage).map(_.group(1))
+ def readGoldenFileAndCompareResults(
+ resultFile: String,
+ outputs: Seq[QueryTestOutput],
+ makeOutput: (String, Option[String], String) => QueryTestOutput): Unit =
{
+ // Read back the golden file.
+ val expectedOutputs: Seq[QueryTestOutput] = {
+ val goldenOutput = fileToString(new File(resultFile))
+ val segments = goldenOutput.split("-- !query.*\n")
+
+ val numSegments = outputs.map(_.numSegments).sum + 1
+ assert(
+ segments.size == numSegments,
+ s"Expected $numSegments blocks in result file but got " +
+ s"${segments.size}. Try regenerate the result files.")
+ var curSegment = 0
+ outputs.map {
+ output =>
+ val result = if (output.numSegments == 3) {
+ makeOutput(
+ segments(curSegment + 1).trim, // SQL
+ Some(segments(curSegment + 2).trim), // Schema
+ segments(curSegment + 3).replaceAll("\\s+$", "")
+ ) // Output
+ } else {
+ makeOutput(
+ segments(curSegment + 1).trim, // SQL
+ None, // Schema
+ segments(curSegment + 2).replaceAll("\\s+$", "")
+ ) // Output
+ }
+ curSegment += output.numSegments
+ result
+ }
+ }
- reason match {
- case Some(r) =>
- (emptySchema, Seq(e.getClass.getName, r))
- case None => (emptySchema, Seq())
- }
+ // Compare results.
+ assertResult(expectedOutputs.size, s"Number of queries should be
${expectedOutputs.size}") {
+ outputs.size
+ }
- case cause =>
- (emptySchema, Seq(cause.getClass.getName, cause.getMessage))
+ outputs.zip(expectedOutputs).zipWithIndex.foreach {
+ case ((output, expected), i) =>
+ assertResult(expected.sql, s"SQL query did not match for query
#$i\n${expected.sql}") {
+ output.sql
}
- case NonFatal(e) =>
- // If there is an exception, put the exception class followed by the
message.
- (emptySchema, Seq(e.getClass.getName, e.getMessage))
+ assertResult(
+ expected.schema,
+ s"Schema did not match for query #$i\n${expected.sql}: $output") {
+ output.schema
+ }
+ assertResult(
+ expected.output,
+ s"Result did not match" +
+ s" for query #$i\n${expected.sql}") {
+ output.output
+ }
+ }
+ }
+
+ /** A single SQL query's output. */
+ trait QueryTestOutput {
+ def sql: String
+ def schema: Option[String]
+ def output: String
+ def numSegments: Int
+ }
+
+ /** A single SQL query's execution output. */
+ case class ExecutionOutput(sql: String, schema: Option[String], output:
String)
+ extends QueryTestOutput {
+ override def toString: String = {
+ // We are explicitly not using multi-line string due to stripMargin
removing "|" in output.
+ s"-- !query\n" +
+ sql + "\n" +
+ s"-- !query schema\n" +
+ schema.get + "\n" +
+ s"-- !query output\n" +
+ output
+ }
+ override def numSegments: Int = 3
+ }
+
+ /** A single SQL query's analysis results. */
+ case class AnalyzerOutput(sql: String, schema: Option[String], output:
String)
+ extends QueryTestOutput {
+ override def toString: String = {
+ // We are explicitly not using multi-line string due to stripMargin
removing "|" in output.
+ s"-- !query\n" +
+ sql + "\n" +
+ s"-- !query analysis\n" +
+ output
}
+ override def numSegments: Int = 2
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]