This is an automated email from the ASF dual-hosted git repository.
chengchengjin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 72db36e59e [Minor] Refactor test utility to let users compare the
query result (#10565)
72db36e59e is described below
commit 72db36e59e6ea554507d8e543425a93d06e9dede
Author: Jin Chengcheng <[email protected]>
AuthorDate: Tue Sep 2 11:44:33 2025 +0100
[Minor] Refactor test utility to let users compare the query result (#10565)
---
.../GlutenClickHouseExcelFormatSuite.scala | 6 +-
.../GlutenClickHouseSyntheticDataSuite.scala | 6 +-
.../GlutenClickHouseTPCDSAbstractSuite.scala | 2 +-
.../GlutenClickhouseCountDistinctSuite.scala | 4 +-
.../extension/GlutenCustomAggExpressionSuite.scala | 2 +-
.../execution/GlutenQueryComparisonTest.scala | 155 +++++++++++++++++++++
.../execution/WholeStageTransformerSuite.scala | 129 +----------------
7 files changed, 167 insertions(+), 137 deletions(-)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseExcelFormatSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseExcelFormatSuite.scala
index 4606b36d18..91a1419b2d 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseExcelFormatSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseExcelFormatSuite.scala
@@ -79,7 +79,7 @@ class GlutenClickHouseExcelFormatSuite extends
GlutenClickHouseWholeStageTransfo
|""".stripMargin
val df2 = spark.sql(sql)
df2.collect()
- WholeStageTransformerSuite.checkFallBack(df2)
+ GlutenQueryComparisonTest.checkFallBack(df2)
checkAnswer(df2, df1)
}
}
@@ -103,7 +103,7 @@ class GlutenClickHouseExcelFormatSuite extends
GlutenClickHouseWholeStageTransfo
|""".stripMargin
val df2 = spark.sql(sql)
df2.collect()
- WholeStageTransformerSuite.checkFallBack(df2)
+ GlutenQueryComparisonTest.checkFallBack(df2)
checkAnswer(df2, df1)
}
}
@@ -136,7 +136,7 @@ class GlutenClickHouseExcelFormatSuite extends
GlutenClickHouseWholeStageTransfo
|""".stripMargin
val df2 = spark.sql(sql)
df2.collect()
- WholeStageTransformerSuite.checkFallBack(df2)
+ GlutenQueryComparisonTest.checkFallBack(df2)
checkAnswer(df2, df1)
}
}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseSyntheticDataSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseSyntheticDataSuite.scala
index b68409324b..7f3aa5eb1c 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseSyntheticDataSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseSyntheticDataSuite.scala
@@ -163,7 +163,7 @@ class GlutenClickHouseSyntheticDataSuite
}
val df = spark.sql(sqlStr)
df.collect()
- WholeStageTransformerSuite.checkFallBack(df)
+ GlutenQueryComparisonTest.checkFallBack(df)
checkAnswer(df, expected)
}
@@ -195,7 +195,7 @@ class GlutenClickHouseSyntheticDataSuite
}
val df = spark.sql(sqlStr)
df.collect()
- WholeStageTransformerSuite.checkFallBack(df)
+ GlutenQueryComparisonTest.checkFallBack(df)
checkAnswer(df, expected)
}
@@ -225,7 +225,7 @@ class GlutenClickHouseSyntheticDataSuite
}
val df = spark.sql(sqlStr)
df.collect()
- WholeStageTransformerSuite.checkFallBack(df)
+ GlutenQueryComparisonTest.checkFallBack(df)
checkAnswer(df, expected)
}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
index 08b84b7428..5f70d35c56 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
@@ -205,7 +205,7 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
log.warn(s"query: $queryNum skipped comparing, time cost to collect:
${System
.currentTimeMillis() - start} ms, ret size: ${ret.length}")
}
- WholeStageTransformerSuite.checkFallBack(df, noFallBack)
+ GlutenQueryComparisonTest.checkFallBack(df, noFallBack)
customCheck(df)
}
}
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseCountDistinctSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseCountDistinctSuite.scala
index bad39ef09b..22a82a9439 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseCountDistinctSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseCountDistinctSuite.scala
@@ -56,7 +56,7 @@ class GlutenClickhouseCountDistinctSuite extends
GlutenClickHouseWholeStageTrans
"values (0, null,1), (1, 1,1), (2, 2,1), (1, 2,1) ,(2,2,2) as
data(a,b,c) group by c"
val df = spark.sql(sql)
- WholeStageTransformerSuite.checkFallBack(df)
+ GlutenQueryComparisonTest.checkFallBack(df)
val planExecs = df.queryExecution.executedPlan.collect {
case aggTransformer: HashAggregateExecBaseTransformer => aggTransformer
@@ -115,7 +115,7 @@ class GlutenClickhouseCountDistinctSuite extends
GlutenClickHouseWholeStageTrans
values (0, null,1), (0,null,2), (1, 1,4) as data(a,b,c) group by
try_add(c,b)
"""
val df = spark.sql(sql)
- WholeStageTransformerSuite.checkFallBack(df, noFallback =
isSparkVersionGE("3.5"))
+ GlutenQueryComparisonTest.checkFallBack(df, noFallback =
isSparkVersionGE("3.5"))
}
test("check count distinct with filter") {
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/extension/GlutenCustomAggExpressionSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/extension/GlutenCustomAggExpressionSuite.scala
index 900ab11ea2..f0d371f1aa 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/extension/GlutenCustomAggExpressionSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/extension/GlutenCustomAggExpressionSuite.scala
@@ -70,7 +70,7 @@ class GlutenCustomAggExpressionSuite extends ParquetSuite {
|""".stripMargin
val df = spark.sql(sql)
// Final stage is not supported, it will be fallback
- WholeStageTransformerSuite.checkFallBack(df, noFallback = false)
+ GlutenQueryComparisonTest.checkFallBack(df, noFallback = false)
val planExecs = df.queryExecution.executedPlan.collect {
case agg: HashAggregateExec => agg
diff --git
a/gluten-substrait/src/test/scala/org/apache/gluten/execution/GlutenQueryComparisonTest.scala
b/gluten-substrait/src/test/scala/org/apache/gluten/execution/GlutenQueryComparisonTest.scala
new file mode 100644
index 0000000000..8a22b706db
--- /dev/null
+++
b/gluten-substrait/src/test/scala/org/apache/gluten/execution/GlutenQueryComparisonTest.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution
+
+/**
+ * This test utility allows developer compares the test result with vanilla
Spark easily, and can
+ * check the fallback status.
+ */
+import org.apache.gluten.config.GlutenConfig
+import org.apache.gluten.test.FallbackUtil
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{DataFrame, GlutenQueryTest, Row}
+
+import java.util.concurrent.atomic.AtomicBoolean
+
+abstract class GlutenQueryComparisonTest extends GlutenQueryTest {
+
+ private val isFallbackCheckDisabled0 = new AtomicBoolean(false)
+
+ final protected def disableFallbackCheck: Boolean =
+ isFallbackCheckDisabled0.compareAndSet(false, true)
+
+ protected def needCheckFallback: Boolean = !isFallbackCheckDisabled0.get()
+
+ protected def vanillaSparkConfs(): Seq[(String, String)] = {
+ List((GlutenConfig.GLUTEN_ENABLED.key, "false"))
+ }
+
+ /**
+ * Some rule on LogicalPlan will not only apply in select query, the total
df.load() should in
+ * spark environment with gluten disabled config.
+ *
+ * @param sql
+ * @return
+ */
+ protected def runAndCompare(sql: String): DataFrame = {
+ var expected: Seq[Row] = null
+ withSQLConf(vanillaSparkConfs(): _*) {
+ expected = spark.sql(sql).collect()
+ }
+ val df = spark.sql(sql)
+ checkAnswer(df, expected)
+ df
+ }
+
+ protected def runQueryAndCompare(
+ sqlStr: String,
+ compareResult: Boolean = true,
+ noFallBack: Boolean = true,
+ cache: Boolean = false)(customCheck: DataFrame => Unit): DataFrame = {
+
+ compareDfResultsAgainstVanillaSpark(
+ () => spark.sql(sqlStr),
+ compareResult,
+ customCheck,
+ noFallBack,
+ cache)
+ }
+
+ protected def compareResultsAgainstVanillaSpark(
+ sql: String,
+ compareResult: Boolean = true,
+ customCheck: DataFrame => Unit,
+ noFallBack: Boolean = true,
+ cache: Boolean = false): DataFrame = {
+ compareDfResultsAgainstVanillaSpark(
+ () => spark.sql(sql),
+ compareResult,
+ customCheck,
+ noFallBack,
+ cache)
+ }
+
+ /**
+ * run a query with native engine as well as vanilla spark then compare the
result set for
+ * correctness check
+ */
+ protected def compareDfResultsAgainstVanillaSpark(
+ dataframe: () => DataFrame,
+ compareResult: Boolean = true,
+ customCheck: DataFrame => Unit,
+ noFallBack: Boolean = true,
+ cache: Boolean = false): DataFrame = {
+ var expected: Seq[Row] = null
+ withSQLConf(vanillaSparkConfs(): _*) {
+ val df = dataframe()
+ expected = df.collect()
+ }
+ // By default, we will fallback complex type scan but here we should allow
+ // to test support of complex type
+ spark.conf.set("spark.gluten.sql.complexType.scan.fallback.enabled",
"false");
+ val df = dataframe()
+ if (cache) {
+ df.cache()
+ }
+ try {
+ if (compareResult) {
+ checkAnswer(df, expected)
+ } else {
+ df.collect()
+ }
+ } finally {
+ if (cache) {
+ df.unpersist()
+ }
+ }
+ checkDataFrame(noFallBack, customCheck, df)
+ df
+ }
+
+ protected def checkDataFrame(
+ noFallBack: Boolean,
+ customCheck: DataFrame => Unit,
+ df: DataFrame): Unit = {
+ if (needCheckFallback) {
+ GlutenQueryComparisonTest.checkFallBack(df, noFallBack)
+ }
+ customCheck(df)
+ }
+
+}
+
+object GlutenQueryComparisonTest extends Logging {
+
+ def checkFallBack(
+ df: DataFrame,
+ noFallback: Boolean = true,
+ skipAssert: Boolean = false): Unit = {
+ // When noFallBack is true, it means there is no fallback plan,
+ // otherwise there must be some fallback plans.
+ val hasFallbacks = FallbackUtil.hasFallback(df.queryExecution.executedPlan)
+ if (!skipAssert) {
+ assert(
+ !hasFallbacks == noFallback,
+ s"FallBack $noFallback check error: ${df.queryExecution.executedPlan}")
+ } else {
+ logWarning(s"FallBack $noFallback check error:
${df.queryExecution.executedPlan}")
+ }
+ }
+}
diff --git
a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
index af25b13555..aa86fecf0f 100644
---
a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
+++
b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
@@ -17,25 +17,22 @@
package org.apache.gluten.execution
import org.apache.gluten.config.GlutenConfig
-import org.apache.gluten.test.FallbackUtil
import org.apache.gluten.utils.Arm
import org.apache.spark.SparkConf
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{DataFrame, GlutenQueryTest, Row}
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.DoubleType
import java.io.File
-import java.util.concurrent.atomic.AtomicBoolean
import scala.io.Source
case class Table(name: String, partitionColumns: Seq[String])
abstract class WholeStageTransformerSuite
- extends GlutenQueryTest
+ extends GlutenQueryComparisonTest
with SharedSparkSession
with AdaptiveSparkPlanHelper {
@@ -56,13 +53,6 @@ abstract class WholeStageTransformerSuite
protected var TPCHTableDataFrames: Map[String, DataFrame] = _
- private val isFallbackCheckDisabled0 = new AtomicBoolean(false)
-
- final protected def disableFallbackCheck: Boolean =
- isFallbackCheckDisabled0.compareAndSet(false, true)
-
- protected def needCheckFallback: Boolean = !isFallbackCheckDisabled0.get()
-
override def beforeAll(): Unit = {
super.beforeAll()
sparkContext.setLogLevel(logLevel)
@@ -176,88 +166,6 @@ abstract class WholeStageTransformerSuite
result
}
- protected def compareResultsAgainstVanillaSpark(
- sql: String,
- compareResult: Boolean = true,
- customCheck: DataFrame => Unit,
- noFallBack: Boolean = true,
- cache: Boolean = false): DataFrame = {
- compareDfResultsAgainstVanillaSpark(
- () => spark.sql(sql),
- compareResult,
- customCheck,
- noFallBack,
- cache)
- }
-
- /**
- * run a query with native engine as well as vanilla spark then compare the
result set for
- * correctness check
- */
- protected def compareDfResultsAgainstVanillaSpark(
- dataframe: () => DataFrame,
- compareResult: Boolean = true,
- customCheck: DataFrame => Unit,
- noFallBack: Boolean = true,
- cache: Boolean = false): DataFrame = {
- var expected: Seq[Row] = null
- withSQLConf(vanillaSparkConfs(): _*) {
- val df = dataframe()
- expected = df.collect()
- }
- // By default, we will fallback complex type scan but here we should allow
- // to test support of complex type
- spark.conf.set("spark.gluten.sql.complexType.scan.fallback.enabled",
"false");
- val df = dataframe()
- if (cache) {
- df.cache()
- }
- try {
- if (compareResult) {
- checkAnswer(df, expected)
- } else {
- df.collect()
- }
- } finally {
- if (cache) {
- df.unpersist()
- }
- }
- checkDataFrame(noFallBack, customCheck, df)
- df
- }
-
- /**
- * Some rule on LogicalPlan will not only apply in select query, the total
df.load() should in
- * spark environment with gluten disabled config.
- *
- * @param sql
- * @return
- */
- protected def runAndCompare(sql: String): DataFrame = {
- var expected: Seq[Row] = null
- withSQLConf(vanillaSparkConfs(): _*) {
- expected = spark.sql(sql).collect()
- }
- val df = spark.sql(sql)
- checkAnswer(df, expected)
- df
- }
-
- protected def runQueryAndCompare(
- sqlStr: String,
- compareResult: Boolean = true,
- noFallBack: Boolean = true,
- cache: Boolean = false)(customCheck: DataFrame => Unit): DataFrame = {
-
- compareDfResultsAgainstVanillaSpark(
- () => spark.sql(sqlStr),
- compareResult,
- customCheck,
- noFallBack,
- cache)
- }
-
/**
* run a query with native engine as well as vanilla spark then compare the
result set for
* correctness check
@@ -278,19 +186,6 @@ abstract class WholeStageTransformerSuite
customCheck,
noFallBack)
- protected def vanillaSparkConfs(): Seq[(String, String)] = {
- List((GlutenConfig.GLUTEN_ENABLED.key, "false"))
- }
-
- protected def checkDataFrame(
- noFallBack: Boolean,
- customCheck: DataFrame => Unit,
- df: DataFrame): Unit = {
- if (needCheckFallback) {
- WholeStageTransformerSuite.checkFallBack(df, noFallBack)
- }
- customCheck(df)
- }
protected def withDataFrame[R](sql: String)(f: DataFrame => R): R =
f(spark.sql(sql))
protected def tpchSQL(queryNum: Int, tpchQueries: String): String =
@@ -305,23 +200,3 @@ abstract class WholeStageTransformerSuite
}
}
}
-
-object WholeStageTransformerSuite extends Logging {
-
- /** Check whether the sql is fallback */
- def checkFallBack(
- df: DataFrame,
- noFallback: Boolean = true,
- skipAssert: Boolean = false): Unit = {
- // When noFallBack is true, it means there is no fallback plan,
- // otherwise there must be some fallback plans.
- val hasFallbacks = FallbackUtil.hasFallback(df.queryExecution.executedPlan)
- if (!skipAssert) {
- assert(
- !hasFallbacks == noFallback,
- s"FallBack $noFallback check error: ${df.queryExecution.executedPlan}")
- } else {
- logWarning(s"FallBack $noFallback check error:
${df.queryExecution.executedPlan}")
- }
- }
-}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]