This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new ce63ff8  test: Enable Comet shuffle in Spark SQL tests (#210)
ce63ff8 is described below

commit ce63ff8ba00a4717c8c9244385b15839ee934ce1
Author: Chao Sun <[email protected]>
AuthorDate: Tue Mar 26 10:08:52 2024 -0700

    test: Enable Comet shuffle in Spark SQL tests (#210)
    
    * test: Enable Comet shuffle in Spark SQL tests
    
    * disable some tests
    
    * disable another test
    
    * update
    
    * update
---
 dev/diffs/3.4.2.diff | 202 +++++++++++++++++++++++++++++++++++++++++++++++----
 1 file changed, 188 insertions(+), 14 deletions(-)

diff --git a/dev/diffs/3.4.2.diff b/dev/diffs/3.4.2.diff
index 94e7139..98c5f42 100644
--- a/dev/diffs/3.4.2.diff
+++ b/dev/diffs/3.4.2.diff
@@ -234,6 +234,20 @@ index 56e9520fdab..917932336df 100644
            spark.range(50).write.saveAsTable(s"$dbName.$table1Name")
            spark.range(100).write.saveAsTable(s"$dbName.$table2Name")
  
+diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+index 9ddb4abe98b..1bebe99f1cc 100644
+--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
++++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+@@ -3311,7 +3311,8 @@ class DataFrameSuite extends QueryTest
+     assert(df2.isLocal)
+   }
+ 
+-  test("SPARK-35886: PromotePrecision should be subexpr replaced") {
++  test("SPARK-35886: PromotePrecision should be subexpr replaced",
++    IgnoreComet("TODO: fix Comet for this test")) {
+     withTable("tbl") {
+       sql(
+         """
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
 index f33432ddb6f..fe9f74ff8f1 100644
 --- 
a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
@@ -365,7 +379,7 @@ index 00000000000..4b31bea33de
 +  }
 +}
 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
-index 5125708be32..e274a497996 100644
+index 5125708be32..a1f1ae90796 100644
 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
 +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
 @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
@@ -376,18 +390,30 @@ index 5125708be32..e274a497996 100644
  import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, 
ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec}
  import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
  import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, 
ShuffleExchangeLike}
-@@ -1371,7 +1372,7 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
+@@ -1369,9 +1370,12 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
+           }
+           val plan = sql(getJoinQuery(selectExpr, 
joinType)).queryExecution.executedPlan
            assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true 
}.size === 1)
-           assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 
3)
+-          assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 
3)
++          assert(collect(plan) {
++            case _: SortMergeJoinExec => true
++            case _: CometSortMergeJoinExec => true
++          }.size === 3)
            // No extra sort on left side before last sort merge join
 -          assert(collect(plan) { case _: SortExec => true }.size === 5)
 +          assert(collect(plan) { case _: SortExec | _: CometSortExec => true 
}.size === 5)
        }
  
        // Test output ordering is not preserved
-@@ -1382,7 +1383,7 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
+@@ -1380,9 +1384,12 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
+           val selectExpr = "/*+ BROADCAST(left_t) */ k1 as k0"
+           val plan = sql(getJoinQuery(selectExpr, 
joinType)).queryExecution.executedPlan
            assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true 
}.size === 1)
-           assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 
3)
+-          assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 
3)
++          assert(collect(plan) {
++            case _: SortMergeJoinExec => true
++            case _: CometSortMergeJoinExec => true
++          }.size === 3)
            // Have sort on left side before last sort merge join
 -          assert(collect(plan) { case _: SortExec => true }.size === 6)
 +          assert(collect(plan) { case _: SortExec | _: CometSortExec => true 
}.size === 6)
@@ -408,18 +434,19 @@ index b5b34922694..a72403780c4 100644
    protected val baseResourcePath = {
      // use the same way as `SQLQueryTestSuite` to get the resource path
 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
-index 3cfda19134a..afcfba37c6f 100644
+index 3cfda19134a..278bb1060c4 100644
 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
 +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
-@@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer
+@@ -21,6 +21,8 @@ import scala.collection.mutable.ArrayBuffer
  
  import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
  import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join, 
LogicalPlan, Project, Sort, Union}
 +import org.apache.spark.sql.comet.CometScanExec
++import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
  import org.apache.spark.sql.execution._
  import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, 
DisableAdaptiveExecution}
  import org.apache.spark.sql.execution.datasources.FileScanRDD
-@@ -1543,6 +1544,12 @@ class SubquerySuite extends QueryTest
+@@ -1543,6 +1545,12 @@ class SubquerySuite extends QueryTest
              fs.inputRDDs().forall(
                _.asInstanceOf[FileScanRDD].filePartitions.forall(
                  _.files.forall(_.urlEncodedPath.contains("p=0"))))
@@ -432,6 +459,14 @@ index 3cfda19134a..afcfba37c6f 100644
          case _ => false
        })
      }
+@@ -2109,6 +2117,7 @@ class SubquerySuite extends QueryTest
+       df.collect()
+       val exchanges = collect(df.queryExecution.executedPlan) {
+         case s: ShuffleExchangeExec => s
++        case s: CometShuffleExchangeExec => s
+       }
+       assert(exchanges.size === 1)
+     }
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
 index cfc8b2cc845..c6fcfd7bd08 100644
 --- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
@@ -541,6 +576,28 @@ index ac710c32296..37746bd470d 100644
            val projection = Seq.tabulate(columnNum)(i => s"c$i + c$i as 
newC$i")
            val df = spark.read.parquet(path).selectExpr(projection: _*)
  
+diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+index 593bd7bb4ba..be1b82d0030 100644
+--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
++++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+@@ -29,6 +29,7 @@ import org.apache.spark.scheduler.{SparkListener, 
SparkListenerEvent, SparkListe
+ import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
+ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
+ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
++import org.apache.spark.sql.comet._
+ import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, 
PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, 
ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec}
+ import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
+ import org.apache.spark.sql.execution.command.DataWritingCommandExec
+@@ -116,6 +117,9 @@ class AdaptiveQueryExecSuite
+   private def findTopLevelSortMergeJoin(plan: SparkPlan): 
Seq[SortMergeJoinExec] = {
+     collect(plan) {
+       case j: SortMergeJoinExec => j
++      case j: CometSortMergeJoinExec =>
++        assert(j.originalPlan.isInstanceOf[SortMergeJoinExec])
++        j.originalPlan.asInstanceOf[SortMergeJoinExec]
+     }
+   }
+ 
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
 index bd9c79e5b96..ab7584e768e 100644
 --- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -935,7 +992,7 @@ index d083cac48ff..3c11bcde807 100644
    import testImplicits._
  
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
-index 266bb343526..f393606997c 100644
+index 266bb343526..b33bb677f0d 100644
 --- 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
 +++ 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
 @@ -24,7 +24,9 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
@@ -1051,7 +1108,16 @@ index 266bb343526..f393606997c 100644
        checkAnswer(aggDF, df1.groupBy("j").agg(max("k")))
      }
    }
-@@ -1031,10 +1060,16 @@ abstract class BucketedReadSuite extends QueryTest 
with SQLTestUtils with Adapti
+@@ -1026,15 +1055,24 @@ abstract class BucketedReadSuite extends QueryTest 
with SQLTestUtils with Adapti
+             expectedNumShuffles: Int,
+             expectedCoalescedNumBuckets: Option[Int]): Unit = {
+           val plan = sql(query).queryExecution.executedPlan
+-          val shuffles = plan.collect { case s: ShuffleExchangeExec => s }
++          val shuffles = plan.collect {
++            case s: ShuffleExchangeExec => s
++            case s: CometShuffleExchangeExec => s
++          }
+           assert(shuffles.length == expectedNumShuffles)
  
            val scans = plan.collect {
              case f: FileSourceScanExec if 
f.optionalNumCoalescedBuckets.isDefined => f
@@ -1139,6 +1205,94 @@ index 75f440caefc..36b1146bc3a 100644
        }.headOption.getOrElse {
          fail(s"No FileScan in query\n${df.queryExecution}")
        }
+diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala
+index b597a244710..b2e8be41065 100644
+--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala
++++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala
+@@ -21,6 +21,7 @@ import java.io.File
+ 
+ import org.apache.commons.io.FileUtils
+ 
++import org.apache.spark.sql.IgnoreComet
+ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update
+ import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, 
MemoryStream}
+ import org.apache.spark.sql.internal.SQLConf
+@@ -91,7 +92,7 @@ class FlatMapGroupsWithStateDistributionSuite extends 
StreamTest
+   }
+ 
+   test("SPARK-38204: flatMapGroupsWithState should require 
StatefulOpClusteredDistribution " +
+-    "from children - without initial state") {
++    "from children - without initial state", IgnoreComet("TODO: fix Comet for 
this test")) {
+     // function will return -1 on timeout and returns count of the state 
otherwise
+     val stateFunc =
+       (key: (String, String), values: Iterator[(String, String, Long)],
+@@ -243,7 +244,8 @@ class FlatMapGroupsWithStateDistributionSuite extends 
StreamTest
+   }
+ 
+   test("SPARK-38204: flatMapGroupsWithState should require 
ClusteredDistribution " +
+-    "from children if the query starts from checkpoint in 3.2.x - without 
initial state") {
++    "from children if the query starts from checkpoint in 3.2.x - without 
initial state",
++    IgnoreComet("TODO: fix Comet for this test")) {
+     // function will return -1 on timeout and returns count of the state 
otherwise
+     val stateFunc =
+       (key: (String, String), values: Iterator[(String, String, Long)],
+@@ -335,7 +337,8 @@ class FlatMapGroupsWithStateDistributionSuite extends 
StreamTest
+   }
+ 
+   test("SPARK-38204: flatMapGroupsWithState should require 
ClusteredDistribution " +
+-    "from children if the query starts from checkpoint in prior to 3.2") {
++    "from children if the query starts from checkpoint in prior to 3.2",
++    IgnoreComet("TODO: fix Comet for this test")) {
+     // function will return -1 on timeout and returns count of the state 
otherwise
+     val stateFunc =
+       (key: (String, String), values: Iterator[(String, String, Long)],
+diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+index 6aa7d0945c7..38523536154 100644
+--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
++++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+@@ -25,7 +25,7 @@ import org.scalatest.exceptions.TestFailedException
+ 
+ import org.apache.spark.SparkException
+ import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction
+-import org.apache.spark.sql.{DataFrame, Encoder}
++import org.apache.spark.sql.{DataFrame, Encoder, IgnoreCometSuite}
+ import org.apache.spark.sql.catalyst.InternalRow
+ import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeProjection, UnsafeRow}
+ import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState
+@@ -46,8 +46,9 @@ case class RunningCount(count: Long)
+ 
+ case class Result(key: Long, count: Int)
+ 
++// TODO: fix Comet to enable this suite
+ @SlowSQLTest
+-class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
++class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with 
IgnoreCometSuite {
+ 
+   import testImplicits._
+ 
+diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala
+index 2a2a83d35e1..e3b7b290b3e 100644
+--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala
++++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala
+@@ -18,7 +18,7 @@
+ package org.apache.spark.sql.streaming
+ 
+ import org.apache.spark.SparkException
+-import org.apache.spark.sql.{AnalysisException, Dataset, 
KeyValueGroupedDataset}
++import org.apache.spark.sql.{AnalysisException, Dataset, IgnoreComet, 
KeyValueGroupedDataset}
+ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update
+ import org.apache.spark.sql.execution.streaming.MemoryStream
+ import 
org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper
+@@ -253,7 +253,8 @@ class FlatMapGroupsWithStateWithInitialStateSuite extends 
StateStoreMetricsTest
+     assert(e.message.contains(expectedError))
+   }
+ 
+-  test("flatMapGroupsWithState - initial state - initial state has 
flatMapGroupsWithState") {
++  test("flatMapGroupsWithState - initial state - initial state has 
flatMapGroupsWithState",
++    IgnoreComet("TODO: fix Comet for this test")) {
+     val initialStateDS = Seq(("keyInStateAndData", new 
RunningCount(1))).toDS()
+     val initialState: KeyValueGroupedDataset[String, RunningCount] =
+       initialStateDS.groupByKey(_._1).mapValues(_._2)
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
 index abe606ad9c1..2d930b64cca 100644
 --- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
@@ -1221,10 +1375,10 @@ index dd55fcfe42c..cc18147d17a 100644
  
      spark.internalCreateDataFrame(withoutFilters.execute(), schema)
 diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
-index ed2e309fa07..3767d4e7ca4 100644
+index ed2e309fa07..4cfe0093da7 100644
 --- 
a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
 +++ 
b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
-@@ -74,6 +74,18 @@ trait SharedSparkSessionBase
+@@ -74,6 +74,21 @@ trait SharedSparkSessionBase
        // this rule may potentially block testing of other optimization rules 
such as
        // ConstantPropagation etc.
        .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, 
ConvertToLocalRelation.ruleName)
@@ -1238,6 +1392,9 @@ index ed2e309fa07..3767d4e7ca4 100644
 +        conf
 +          .set("spark.comet.exec.enabled", "true")
 +          .set("spark.comet.exec.all.enabled", "true")
++          .set("spark.shuffle.manager",
++            
"org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager")
++          .set("spark.comet.exec.shuffle.enabled", "true")
 +      }
 +    }
      conf.set(
@@ -1265,11 +1422,25 @@ index 52abd248f3a..7a199931a08 100644
        case h: HiveTableScanExec => h.partitionPruningPred.collect {
          case d: DynamicPruningExpression => d.child
        }
+diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+index 1966e1e64fd..cde97a0aafe 100644
+--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
++++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+@@ -656,7 +656,8 @@ abstract class AggregationQuerySuite extends QueryTest 
with SQLTestUtils with Te
+         Row(3, 4, 4, 3, null) :: Nil)
+   }
+ 
+-  test("single distinct multiple columns set") {
++  test("single distinct multiple columns set",
++    IgnoreComet("TODO: fix Comet for this test")) {
+     checkAnswer(
+       spark.sql(
+         """
 diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
-index 07361cfdce9..545b3184c23 100644
+index 07361cfdce9..c5d94c92e32 100644
 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
 +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
-@@ -55,25 +55,43 @@ object TestHive
+@@ -55,25 +55,46 @@ object TestHive
      new SparkContext(
        System.getProperty("spark.sql.test.master", "local[1]"),
        "TestSQLContext",
@@ -1322,6 +1493,9 @@ index 07361cfdce9..545b3184c23 100644
 +            conf
 +              .set("spark.comet.exec.enabled", "true")
 +              .set("spark.comet.exec.all.enabled", "true")
++              .set("spark.shuffle.manager",
++                
"org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager")
++              .set("spark.comet.exec.shuffle.enabled", "true")
 +          }
 +        }
  

Reply via email to