This is an automated email from the ASF dual-hosted git repository.

chengchengjin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 72db36e59e [Minor] Refactor test utility to let users compare the 
query result (#10565)
72db36e59e is described below

commit 72db36e59e6ea554507d8e543425a93d06e9dede
Author: Jin Chengcheng <[email protected]>
AuthorDate: Tue Sep 2 11:44:33 2025 +0100

    [Minor] Refactor test utility to let users compare the query result (#10565)
---
 .../GlutenClickHouseExcelFormatSuite.scala         |   6 +-
 .../GlutenClickHouseSyntheticDataSuite.scala       |   6 +-
 .../GlutenClickHouseTPCDSAbstractSuite.scala       |   2 +-
 .../GlutenClickhouseCountDistinctSuite.scala       |   4 +-
 .../extension/GlutenCustomAggExpressionSuite.scala |   2 +-
 .../execution/GlutenQueryComparisonTest.scala      | 155 +++++++++++++++++++++
 .../execution/WholeStageTransformerSuite.scala     | 129 +----------------
 7 files changed, 167 insertions(+), 137 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseExcelFormatSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseExcelFormatSuite.scala
index 4606b36d18..91a1419b2d 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseExcelFormatSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseExcelFormatSuite.scala
@@ -79,7 +79,7 @@ class GlutenClickHouseExcelFormatSuite extends 
GlutenClickHouseWholeStageTransfo
            |""".stripMargin
       val df2 = spark.sql(sql)
       df2.collect()
-      WholeStageTransformerSuite.checkFallBack(df2)
+      GlutenQueryComparisonTest.checkFallBack(df2)
       checkAnswer(df2, df1)
     }
   }
@@ -103,7 +103,7 @@ class GlutenClickHouseExcelFormatSuite extends 
GlutenClickHouseWholeStageTransfo
            |""".stripMargin
       val df2 = spark.sql(sql)
       df2.collect()
-      WholeStageTransformerSuite.checkFallBack(df2)
+      GlutenQueryComparisonTest.checkFallBack(df2)
       checkAnswer(df2, df1)
     }
   }
@@ -136,7 +136,7 @@ class GlutenClickHouseExcelFormatSuite extends 
GlutenClickHouseWholeStageTransfo
            |""".stripMargin
       val df2 = spark.sql(sql)
       df2.collect()
-      WholeStageTransformerSuite.checkFallBack(df2)
+      GlutenQueryComparisonTest.checkFallBack(df2)
       checkAnswer(df2, df1)
     }
   }
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseSyntheticDataSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseSyntheticDataSuite.scala
index b68409324b..7f3aa5eb1c 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseSyntheticDataSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseSyntheticDataSuite.scala
@@ -163,7 +163,7 @@ class GlutenClickHouseSyntheticDataSuite
     }
     val df = spark.sql(sqlStr)
     df.collect()
-    WholeStageTransformerSuite.checkFallBack(df)
+    GlutenQueryComparisonTest.checkFallBack(df)
     checkAnswer(df, expected)
   }
 
@@ -195,7 +195,7 @@ class GlutenClickHouseSyntheticDataSuite
     }
     val df = spark.sql(sqlStr)
     df.collect()
-    WholeStageTransformerSuite.checkFallBack(df)
+    GlutenQueryComparisonTest.checkFallBack(df)
     checkAnswer(df, expected)
   }
 
@@ -225,7 +225,7 @@ class GlutenClickHouseSyntheticDataSuite
     }
     val df = spark.sql(sqlStr)
     df.collect()
-    WholeStageTransformerSuite.checkFallBack(df)
+    GlutenQueryComparisonTest.checkFallBack(df)
     checkAnswer(df, expected)
   }
 
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
index 08b84b7428..5f70d35c56 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala
@@ -205,7 +205,7 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
       log.warn(s"query: $queryNum skipped comparing, time cost to collect: 
${System
           .currentTimeMillis() - start} ms, ret size: ${ret.length}")
     }
-    WholeStageTransformerSuite.checkFallBack(df, noFallBack)
+    GlutenQueryComparisonTest.checkFallBack(df, noFallBack)
     customCheck(df)
   }
 }
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseCountDistinctSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseCountDistinctSuite.scala
index bad39ef09b..22a82a9439 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseCountDistinctSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseCountDistinctSuite.scala
@@ -56,7 +56,7 @@ class GlutenClickhouseCountDistinctSuite extends 
GlutenClickHouseWholeStageTrans
         "values (0, null,1), (1, 1,1), (2, 2,1), (1, 2,1) ,(2,2,2) as 
data(a,b,c) group by c"
 
     val df = spark.sql(sql)
-    WholeStageTransformerSuite.checkFallBack(df)
+    GlutenQueryComparisonTest.checkFallBack(df)
 
     val planExecs = df.queryExecution.executedPlan.collect {
       case aggTransformer: HashAggregateExecBaseTransformer => aggTransformer
@@ -115,7 +115,7 @@ class GlutenClickhouseCountDistinctSuite extends 
GlutenClickHouseWholeStageTrans
       values (0, null,1), (0,null,2), (1, 1,4) as data(a,b,c) group by 
try_add(c,b)
       """
     val df = spark.sql(sql)
-    WholeStageTransformerSuite.checkFallBack(df, noFallback = 
isSparkVersionGE("3.5"))
+    GlutenQueryComparisonTest.checkFallBack(df, noFallback = 
isSparkVersionGE("3.5"))
   }
 
   test("check count distinct with filter") {
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/extension/GlutenCustomAggExpressionSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/extension/GlutenCustomAggExpressionSuite.scala
index 900ab11ea2..f0d371f1aa 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/extension/GlutenCustomAggExpressionSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/extension/GlutenCustomAggExpressionSuite.scala
@@ -70,7 +70,7 @@ class GlutenCustomAggExpressionSuite extends ParquetSuite {
          |""".stripMargin
     val df = spark.sql(sql)
     // Final stage is not supported, it will be fallback
-    WholeStageTransformerSuite.checkFallBack(df, noFallback = false)
+    GlutenQueryComparisonTest.checkFallBack(df, noFallback = false)
 
     val planExecs = df.queryExecution.executedPlan.collect {
       case agg: HashAggregateExec => agg
diff --git 
a/gluten-substrait/src/test/scala/org/apache/gluten/execution/GlutenQueryComparisonTest.scala
 
b/gluten-substrait/src/test/scala/org/apache/gluten/execution/GlutenQueryComparisonTest.scala
new file mode 100644
index 0000000000..8a22b706db
--- /dev/null
+++ 
b/gluten-substrait/src/test/scala/org/apache/gluten/execution/GlutenQueryComparisonTest.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.execution
+
+/**
+ * This test utility allows developer compares the test result with vanilla 
Spark easily, and can
+ * check the fallback status.
+ */
+import org.apache.gluten.config.GlutenConfig
+import org.apache.gluten.test.FallbackUtil
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{DataFrame, GlutenQueryTest, Row}
+
+import java.util.concurrent.atomic.AtomicBoolean
+
+abstract class GlutenQueryComparisonTest extends GlutenQueryTest {
+
+  private val isFallbackCheckDisabled0 = new AtomicBoolean(false)
+
+  final protected def disableFallbackCheck: Boolean =
+    isFallbackCheckDisabled0.compareAndSet(false, true)
+
+  protected def needCheckFallback: Boolean = !isFallbackCheckDisabled0.get()
+
+  protected def vanillaSparkConfs(): Seq[(String, String)] = {
+    List((GlutenConfig.GLUTEN_ENABLED.key, "false"))
+  }
+
+  /**
+   * Some rule on LogicalPlan will not only apply in select query, the total 
df.load() should in
+   * spark environment with gluten disabled config.
+   *
+   * @param sql
+   * @return
+   */
+  protected def runAndCompare(sql: String): DataFrame = {
+    var expected: Seq[Row] = null
+    withSQLConf(vanillaSparkConfs(): _*) {
+      expected = spark.sql(sql).collect()
+    }
+    val df = spark.sql(sql)
+    checkAnswer(df, expected)
+    df
+  }
+
+  protected def runQueryAndCompare(
+      sqlStr: String,
+      compareResult: Boolean = true,
+      noFallBack: Boolean = true,
+      cache: Boolean = false)(customCheck: DataFrame => Unit): DataFrame = {
+
+    compareDfResultsAgainstVanillaSpark(
+      () => spark.sql(sqlStr),
+      compareResult,
+      customCheck,
+      noFallBack,
+      cache)
+  }
+
+  protected def compareResultsAgainstVanillaSpark(
+      sql: String,
+      compareResult: Boolean = true,
+      customCheck: DataFrame => Unit,
+      noFallBack: Boolean = true,
+      cache: Boolean = false): DataFrame = {
+    compareDfResultsAgainstVanillaSpark(
+      () => spark.sql(sql),
+      compareResult,
+      customCheck,
+      noFallBack,
+      cache)
+  }
+
+  /**
+   * run a query with native engine as well as vanilla spark then compare the 
result set for
+   * correctness check
+   */
+  protected def compareDfResultsAgainstVanillaSpark(
+      dataframe: () => DataFrame,
+      compareResult: Boolean = true,
+      customCheck: DataFrame => Unit,
+      noFallBack: Boolean = true,
+      cache: Boolean = false): DataFrame = {
+    var expected: Seq[Row] = null
+    withSQLConf(vanillaSparkConfs(): _*) {
+      val df = dataframe()
+      expected = df.collect()
+    }
+    // By default, we will fallback complex type scan but here we should allow
+    // to test support of complex type
+    spark.conf.set("spark.gluten.sql.complexType.scan.fallback.enabled", 
"false");
+    val df = dataframe()
+    if (cache) {
+      df.cache()
+    }
+    try {
+      if (compareResult) {
+        checkAnswer(df, expected)
+      } else {
+        df.collect()
+      }
+    } finally {
+      if (cache) {
+        df.unpersist()
+      }
+    }
+    checkDataFrame(noFallBack, customCheck, df)
+    df
+  }
+
+  protected def checkDataFrame(
+      noFallBack: Boolean,
+      customCheck: DataFrame => Unit,
+      df: DataFrame): Unit = {
+    if (needCheckFallback) {
+      GlutenQueryComparisonTest.checkFallBack(df, noFallBack)
+    }
+    customCheck(df)
+  }
+
+}
+
+object GlutenQueryComparisonTest extends Logging {
+
+  def checkFallBack(
+      df: DataFrame,
+      noFallback: Boolean = true,
+      skipAssert: Boolean = false): Unit = {
+    // When noFallBack is true, it means there is no fallback plan,
+    // otherwise there must be some fallback plans.
+    val hasFallbacks = FallbackUtil.hasFallback(df.queryExecution.executedPlan)
+    if (!skipAssert) {
+      assert(
+        !hasFallbacks == noFallback,
+        s"FallBack $noFallback check error: ${df.queryExecution.executedPlan}")
+    } else {
+      logWarning(s"FallBack $noFallback check error: 
${df.queryExecution.executedPlan}")
+    }
+  }
+}
diff --git 
a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
 
b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
index af25b13555..aa86fecf0f 100644
--- 
a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
+++ 
b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala
@@ -17,25 +17,22 @@
 package org.apache.gluten.execution
 
 import org.apache.gluten.config.GlutenConfig
-import org.apache.gluten.test.FallbackUtil
 import org.apache.gluten.utils.Arm
 
 import org.apache.spark.SparkConf
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{DataFrame, GlutenQueryTest, Row}
+import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.DoubleType
 
 import java.io.File
-import java.util.concurrent.atomic.AtomicBoolean
 
 import scala.io.Source
 
 case class Table(name: String, partitionColumns: Seq[String])
 
 abstract class WholeStageTransformerSuite
-  extends GlutenQueryTest
+  extends GlutenQueryComparisonTest
   with SharedSparkSession
   with AdaptiveSparkPlanHelper {
 
@@ -56,13 +53,6 @@ abstract class WholeStageTransformerSuite
 
   protected var TPCHTableDataFrames: Map[String, DataFrame] = _
 
-  private val isFallbackCheckDisabled0 = new AtomicBoolean(false)
-
-  final protected def disableFallbackCheck: Boolean =
-    isFallbackCheckDisabled0.compareAndSet(false, true)
-
-  protected def needCheckFallback: Boolean = !isFallbackCheckDisabled0.get()
-
   override def beforeAll(): Unit = {
     super.beforeAll()
     sparkContext.setLogLevel(logLevel)
@@ -176,88 +166,6 @@ abstract class WholeStageTransformerSuite
       result
   }
 
-  protected def compareResultsAgainstVanillaSpark(
-      sql: String,
-      compareResult: Boolean = true,
-      customCheck: DataFrame => Unit,
-      noFallBack: Boolean = true,
-      cache: Boolean = false): DataFrame = {
-    compareDfResultsAgainstVanillaSpark(
-      () => spark.sql(sql),
-      compareResult,
-      customCheck,
-      noFallBack,
-      cache)
-  }
-
-  /**
-   * run a query with native engine as well as vanilla spark then compare the 
result set for
-   * correctness check
-   */
-  protected def compareDfResultsAgainstVanillaSpark(
-      dataframe: () => DataFrame,
-      compareResult: Boolean = true,
-      customCheck: DataFrame => Unit,
-      noFallBack: Boolean = true,
-      cache: Boolean = false): DataFrame = {
-    var expected: Seq[Row] = null
-    withSQLConf(vanillaSparkConfs(): _*) {
-      val df = dataframe()
-      expected = df.collect()
-    }
-    // By default, we will fallback complex type scan but here we should allow
-    // to test support of complex type
-    spark.conf.set("spark.gluten.sql.complexType.scan.fallback.enabled", 
"false");
-    val df = dataframe()
-    if (cache) {
-      df.cache()
-    }
-    try {
-      if (compareResult) {
-        checkAnswer(df, expected)
-      } else {
-        df.collect()
-      }
-    } finally {
-      if (cache) {
-        df.unpersist()
-      }
-    }
-    checkDataFrame(noFallBack, customCheck, df)
-    df
-  }
-
-  /**
-   * Some rule on LogicalPlan will not only apply in select query, the total 
df.load() should in
-   * spark environment with gluten disabled config.
-   *
-   * @param sql
-   * @return
-   */
-  protected def runAndCompare(sql: String): DataFrame = {
-    var expected: Seq[Row] = null
-    withSQLConf(vanillaSparkConfs(): _*) {
-      expected = spark.sql(sql).collect()
-    }
-    val df = spark.sql(sql)
-    checkAnswer(df, expected)
-    df
-  }
-
-  protected def runQueryAndCompare(
-      sqlStr: String,
-      compareResult: Boolean = true,
-      noFallBack: Boolean = true,
-      cache: Boolean = false)(customCheck: DataFrame => Unit): DataFrame = {
-
-    compareDfResultsAgainstVanillaSpark(
-      () => spark.sql(sqlStr),
-      compareResult,
-      customCheck,
-      noFallBack,
-      cache)
-  }
-
   /**
    * run a query with native engine as well as vanilla spark then compare the 
result set for
    * correctness check
@@ -278,19 +186,6 @@ abstract class WholeStageTransformerSuite
       customCheck,
       noFallBack)
 
-  protected def vanillaSparkConfs(): Seq[(String, String)] = {
-    List((GlutenConfig.GLUTEN_ENABLED.key, "false"))
-  }
-
-  protected def checkDataFrame(
-      noFallBack: Boolean,
-      customCheck: DataFrame => Unit,
-      df: DataFrame): Unit = {
-    if (needCheckFallback) {
-      WholeStageTransformerSuite.checkFallBack(df, noFallBack)
-    }
-    customCheck(df)
-  }
   protected def withDataFrame[R](sql: String)(f: DataFrame => R): R = 
f(spark.sql(sql))
 
   protected def tpchSQL(queryNum: Int, tpchQueries: String): String =
@@ -305,23 +200,3 @@ abstract class WholeStageTransformerSuite
     }
   }
 }
-
-object WholeStageTransformerSuite extends Logging {
-
-  /** Check whether the sql is fallback */
-  def checkFallBack(
-      df: DataFrame,
-      noFallback: Boolean = true,
-      skipAssert: Boolean = false): Unit = {
-    // When noFallBack is true, it means there is no fallback plan,
-    // otherwise there must be some fallback plans.
-    val hasFallbacks = FallbackUtil.hasFallback(df.queryExecution.executedPlan)
-    if (!skipAssert) {
-      assert(
-        !hasFallbacks == noFallback,
-        s"FallBack $noFallback check error: ${df.queryExecution.executedPlan}")
-    } else {
-      logWarning(s"FallBack $noFallback check error: 
${df.queryExecution.executedPlan}")
-    }
-  }
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to