This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 81a641f fix: CometExecRule should handle ShuffleQueryStage and
ReusedExchange (#186)
81a641f is described below
commit 81a641f30844d76b417a600392a0cc97a74a919b
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Mar 13 11:14:45 2024 -0700
fix: CometExecRule should handle ShuffleQueryStage and ReusedExchange (#186)
* fix: CometExecRule should handle ShuffleQueryStage and ReusedExchange
* fix
* Add comment and move tests
* Remove unused table in test.
---
.../apache/comet/CometSparkSessionExtensions.scala | 44 ++++++++++++++++----
.../org/apache/comet/serde/QueryPlanSerde.scala | 6 ++-
.../comet/exec/CometColumnarShuffleSuite.scala | 48 +++++++++++++++++++++-
.../comet/exec/CometNativeShuffleSuite.scala | 5 ++-
4 files changed, 92 insertions(+), 11 deletions(-)
diff --git
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 5720b69..39c83ae 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -31,12 +31,13 @@ import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometNativeShuffle}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -221,12 +222,16 @@ class CometSparkSessionExtensions
*/
// spotless:on
private def transform(plan: SparkPlan): SparkPlan = {
- def transform1(op: UnaryExecNode): Option[Operator] = {
- op.child match {
- case childNativeOp: CometNativeExec =>
- QueryPlanSerde.operator2Proto(op, childNativeOp.nativeOp)
- case _ =>
- None
+ def transform1(op: SparkPlan): Option[Operator] = {
+ val allNativeExec = op.children.map {
+ case childNativeOp: CometNativeExec => Some(childNativeOp.nativeOp)
+ case _ => None
+ }
+
+ if (allNativeExec.forall(_.isDefined)) {
+ QueryPlanSerde.operator2Proto(op, allNativeExec.map(_.get): _*)
+ } else {
+ None
}
}
@@ -377,6 +382,31 @@ class CometSparkSessionExtensions
case None => b
}
+ // For AQE shuffle stage on a Comet shuffle exchange
+ case s @ ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) =>
+ val newOp = transform1(s)
+ newOp match {
+ case Some(nativeOp) =>
+ CometSinkPlaceHolder(nativeOp, s, s)
+ case None =>
+ s
+ }
+
+ // For AQE shuffle stage on a reused Comet shuffle exchange
+ // Note that we don't need to handle `ReusedExchangeExec` for non-AQE
case, because
+ // the query plan won't be re-optimized/planned in non-AQE mode.
+ case s @ ShuffleQueryStageExec(
+ _,
+ ReusedExchangeExec(_, _: CometShuffleExchangeExec),
+ _) =>
+ val newOp = transform1(s)
+ newOp match {
+ case Some(nativeOp) =>
+ CometSinkPlaceHolder(nativeOp, s, s)
+ case None =>
+ s
+ }
+
// Native shuffle for Comet operators
case s: ShuffleExchangeExec
if isCometShuffleEnabled(conf) &&
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 902f703..5da926e 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -29,10 +29,12 @@ import
org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.comet.{CometSinkPlaceHolder, DecimalPrecision}
+import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ShuffleExchangeExec}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -1883,6 +1885,8 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
case _: CollectLimitExec => true
case _: UnionExec => true
case _: ShuffleExchangeExec => true
+ case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true
+ case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _:
CometShuffleExchangeExec), _) => true
case _: TakeOrderedAndProjectExec => true
case _: BroadcastExchangeExec => true
case _ => false
diff --git
a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
index 1a92f71..216b690 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala
@@ -26,7 +26,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.{Partitioner, SparkConf}
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleDependency,
CometShuffleExchangeExec, CometShuffleManager}
-import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper,
AQEShuffleReadExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.functions.col
@@ -933,6 +933,52 @@ class CometShuffleSuite extends CometColumnarShuffleSuite {
override protected val asyncShuffleEnable: Boolean = false
protected val adaptiveExecutionEnabled: Boolean = true
+
+ import testImplicits._
+
+ test("Comet native operator after ShuffleQueryStage") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ CometConf.COMET_EXEC_ENABLED.key -> "true",
+ CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
+ CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
+ withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
+ val df = sql("SELECT * FROM tbl_a")
+ val shuffled = df
+ .select($"_1" + 1 as ("a"))
+ .filter($"a" > 4)
+ .repartition(10)
+ .sortWithinPartitions($"a")
+ checkSparkAnswerAndOperator(shuffled, classOf[ShuffleQueryStageExec])
+ }
+ }
+ }
+
+ test("Comet native operator after ShuffleQueryStage + ReusedExchange") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ CometConf.COMET_EXEC_ENABLED.key -> "true",
+ CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
+ CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
+ withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
+ withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") {
+ val df = sql("SELECT * FROM tbl_a")
+ val left = df
+ .select($"_1" + 1 as ("a"))
+ .filter($"a" > 4)
+ val right = left.select($"a" as ("b"))
+ val join = left.join(right, $"a" === $"b")
+ checkSparkAnswerAndOperator(
+ join,
+ classOf[ShuffleQueryStageExec],
+ classOf[SortMergeJoinExec],
+ classOf[AQEShuffleReadExec])
+ }
+ }
+ }
+ }
}
class DisableAQECometShuffleSuite extends CometColumnarShuffleSuite {
diff --git
a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala
index c35763c..59e27fd 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala
@@ -64,8 +64,9 @@ class CometNativeShuffleSuite extends CometTestBase with
AdaptiveSparkPlanHelper
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled,
1000)
var allTypes: Seq[Int] = (1 to 20)
- if (isSpark34Plus) {
- allTypes = allTypes.filterNot(Set(14, 17).contains)
+ if (!isSpark34Plus) {
+ // TODO: Remove this once after
https://github.com/apache/arrow/issues/40038 is fixed
+ allTypes = allTypes.filterNot(Set(14).contains)
}
allTypes.map(i => s"_$i").foreach { c =>
withSQLConf(