This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 4d103b8  fix: Fix corrupted AggregateMode when transforming plan 
parameters (#118)
4d103b8 is described below

commit 4d103b88bf9d0165954e04a98f3eb928fdda2291
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Feb 28 00:53:15 2024 -0800

    fix: Fix corrupted AggregateMode when transforming plan parameters (#118)
---
 .../apache/comet/CometSparkSessionExtensions.scala | 20 +++++++----
 .../org/apache/spark/sql/comet/operators.scala     | 40 ++++++++++++++--------
 .../org/apache/comet/exec/CometExecSuite.scala     | 16 ++++++++-
 3 files changed, 53 insertions(+), 23 deletions(-)

diff --git 
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala 
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index f2aba74..10c3328 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -237,7 +237,13 @@ class CometSparkSessionExtensions
           val newOp = transform1(op)
           newOp match {
             case Some(nativeOp) =>
-              CometProjectExec(nativeOp, op, op.projectList, op.output, 
op.child, None)
+              CometProjectExec(
+                nativeOp,
+                op,
+                op.projectList,
+                op.output,
+                op.child,
+                SerializedPlan(None))
             case None =>
               op
           }
@@ -246,7 +252,7 @@ class CometSparkSessionExtensions
           val newOp = transform1(op)
           newOp match {
             case Some(nativeOp) =>
-              CometFilterExec(nativeOp, op, op.condition, op.child, None)
+              CometFilterExec(nativeOp, op, op.condition, op.child, 
SerializedPlan(None))
             case None =>
               op
           }
@@ -255,7 +261,7 @@ class CometSparkSessionExtensions
           val newOp = transform1(op)
           newOp match {
             case Some(nativeOp) =>
-              CometSortExec(nativeOp, op, op.sortOrder, op.child, None)
+              CometSortExec(nativeOp, op, op.sortOrder, op.child, 
SerializedPlan(None))
             case None =>
               op
           }
@@ -264,7 +270,7 @@ class CometSparkSessionExtensions
           val newOp = transform1(op)
           newOp match {
             case Some(nativeOp) =>
-              CometLocalLimitExec(nativeOp, op, op.limit, op.child, None)
+              CometLocalLimitExec(nativeOp, op, op.limit, op.child, 
SerializedPlan(None))
             case None =>
               op
           }
@@ -273,7 +279,7 @@ class CometSparkSessionExtensions
           val newOp = transform1(op)
           newOp match {
             case Some(nativeOp) =>
-              CometGlobalLimitExec(nativeOp, op, op.limit, op.child, None)
+              CometGlobalLimitExec(nativeOp, op, op.limit, op.child, 
SerializedPlan(None))
             case None =>
               op
           }
@@ -282,7 +288,7 @@ class CometSparkSessionExtensions
           val newOp = transform1(op)
           newOp match {
             case Some(nativeOp) =>
-              CometExpandExec(nativeOp, op, op.projections, op.child, None)
+              CometExpandExec(nativeOp, op, op.projections, op.child, 
SerializedPlan(None))
             case None =>
               op
           }
@@ -305,7 +311,7 @@ class CometSparkSessionExtensions
                 child.output,
                 if (modes.nonEmpty) Some(modes.head) else None,
                 child,
-                None)
+                SerializedPlan(None))
             case None =>
               op
           }
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index 0298bc6..e75f9a4 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -150,7 +150,7 @@ abstract class CometNativeExec extends CometExec {
    * The serialized native query plan, optional. This is only defined when the 
current node is the
    * "boundary" node between native and Spark.
    */
-  def serializedPlanOpt: Option[Array[Byte]]
+  def serializedPlanOpt: SerializedPlan
 
   /** The Comet native operator */
   def nativeOp: Operator
@@ -200,7 +200,7 @@ abstract class CometNativeExec extends CometExec {
   }
 
   override def doExecuteColumnar(): RDD[ColumnarBatch] = {
-    serializedPlanOpt match {
+    serializedPlanOpt.plan match {
       case None =>
         // This is in the middle of a native execution, it should not be 
executed directly.
         throw new CometRuntimeException(
@@ -282,11 +282,11 @@ abstract class CometNativeExec extends CometExec {
    */
   def convertBlock(): CometNativeExec = {
     def transform(arg: Any): AnyRef = arg match {
-      case serializedPlan: Option[Array[Byte]] if serializedPlan.isEmpty =>
+      case serializedPlan: SerializedPlan if serializedPlan.isEmpty =>
         val out = new ByteArrayOutputStream()
         nativeOp.writeTo(out)
         out.close()
-        Some(out.toByteArray)
+        SerializedPlan(Some(out.toByteArray))
       case other: AnyRef => other
       case null => null
     }
@@ -300,8 +300,8 @@ abstract class CometNativeExec extends CometExec {
    */
   def cleanBlock(): CometNativeExec = {
     def transform(arg: Any): AnyRef = arg match {
-      case serializedPlan: Option[Array[Byte]] if serializedPlan.isDefined =>
-        None
+      case serializedPlan: SerializedPlan if serializedPlan.isDefined =>
+        SerializedPlan(None)
       case other: AnyRef => other
       case null => null
     }
@@ -323,13 +323,23 @@ abstract class CometNativeExec extends CometExec {
 
 abstract class CometUnaryExec extends CometNativeExec with UnaryExecNode
 
+/**
+ * Represents the serialized plan of Comet native operators. Only the first 
operator in a block of
+ * continuous Comet native operators has defined plan bytes which contains the 
serialization of
+ * the plan tree of the block.
+ */
+case class SerializedPlan(plan: Option[Array[Byte]]) {
+  def isDefined: Boolean = plan.isDefined
+  def isEmpty: Boolean = plan.isEmpty
+}
+
 case class CometProjectExec(
     override val nativeOp: Operator,
     override val originalPlan: SparkPlan,
     projectList: Seq[NamedExpression],
     override val output: Seq[Attribute],
     child: SparkPlan,
-    override val serializedPlanOpt: Option[Array[Byte]])
+    override val serializedPlanOpt: SerializedPlan)
     extends CometUnaryExec {
   override def producedAttributes: AttributeSet = outputSet
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
@@ -356,7 +366,7 @@ case class CometFilterExec(
     override val originalPlan: SparkPlan,
     condition: Expression,
     child: SparkPlan,
-    override val serializedPlanOpt: Option[Array[Byte]])
+    override val serializedPlanOpt: SerializedPlan)
     extends CometUnaryExec {
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
     this.copy(child = newChild)
@@ -390,7 +400,7 @@ case class CometSortExec(
     override val originalPlan: SparkPlan,
     sortOrder: Seq[SortOrder],
     child: SparkPlan,
-    override val serializedPlanOpt: Option[Array[Byte]])
+    override val serializedPlanOpt: SerializedPlan)
     extends CometUnaryExec {
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
     this.copy(child = newChild)
@@ -422,7 +432,7 @@ case class CometLocalLimitExec(
     override val originalPlan: SparkPlan,
     limit: Int,
     child: SparkPlan,
-    override val serializedPlanOpt: Option[Array[Byte]])
+    override val serializedPlanOpt: SerializedPlan)
     extends CometUnaryExec {
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
     this.copy(child = newChild)
@@ -449,7 +459,7 @@ case class CometGlobalLimitExec(
     override val originalPlan: SparkPlan,
     limit: Int,
     child: SparkPlan,
-    override val serializedPlanOpt: Option[Array[Byte]])
+    override val serializedPlanOpt: SerializedPlan)
     extends CometUnaryExec {
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
     this.copy(child = newChild)
@@ -474,7 +484,7 @@ case class CometExpandExec(
     override val originalPlan: SparkPlan,
     projections: Seq[Seq[Expression]],
     child: SparkPlan,
-    override val serializedPlanOpt: Option[Array[Byte]])
+    override val serializedPlanOpt: SerializedPlan)
     extends CometUnaryExec {
   override def producedAttributes: AttributeSet = outputSet
 
@@ -538,7 +548,7 @@ case class CometHashAggregateExec(
     input: Seq[Attribute],
     mode: Option[AggregateMode],
     child: SparkPlan,
-    override val serializedPlanOpt: Option[Array[Byte]])
+    override val serializedPlanOpt: SerializedPlan)
     extends CometUnaryExec {
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
     this.copy(child = newChild)
@@ -576,7 +586,7 @@ case class CometHashAggregateExec(
 case class CometScanWrapper(override val nativeOp: Operator, override val 
originalPlan: SparkPlan)
     extends CometNativeExec
     with LeafExecNode {
-  override val serializedPlanOpt: Option[Array[Byte]] = None
+  override val serializedPlanOpt: SerializedPlan = SerializedPlan(None)
   override def stringArgs: Iterator[Any] = Iterator(originalPlan.output, 
originalPlan)
 }
 
@@ -592,7 +602,7 @@ case class CometSinkPlaceHolder(
     override val originalPlan: SparkPlan,
     child: SparkPlan)
     extends CometUnaryExec {
-  override val serializedPlanOpt: Option[Array[Byte]] = None
+  override val serializedPlanOpt: SerializedPlan = SerializedPlan(None)
 
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan 
= {
     this.copy(child = newChild)
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index 29b6e12..05be34c 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -31,13 +31,14 @@ import org.apache.spark.sql.{AnalysisException, Column, 
CometTestBase, DataFrame
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, 
CatalogTable}
 import org.apache.spark.sql.catalyst.expressions.Hex
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
 import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, 
CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, 
CometTakeOrderedAndProjectExec}
 import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, 
CometShuffleExchangeExec}
 import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, 
UnionExec}
 import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
 import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, 
CartesianProductExec, SortMergeJoinExec}
 import org.apache.spark.sql.execution.window.WindowExec
-import org.apache.spark.sql.functions.{date_add, expr}
+import org.apache.spark.sql.functions.{date_add, expr, sum}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
 import org.apache.spark.unsafe.types.UTF8String
@@ -57,6 +58,19 @@ class CometExecSuite extends CometTestBase {
     }
   }
 
+  test("Fix corrupted AggregateMode when transforming plan parameters") {
+    withParquetTable((0 until 5).map(i => (i, i + 1)), "table") {
+      val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2"))
+      val agg = stripAQEPlan(df.queryExecution.executedPlan).collectFirst {
+        case s: CometHashAggregateExec => s
+      }.get
+
+      assert(agg.mode.isDefined && agg.mode.get.isInstanceOf[AggregateMode])
+      val newAgg = agg.cleanBlock().asInstanceOf[CometHashAggregateExec]
+      assert(newAgg.mode.isDefined && 
newAgg.mode.get.isInstanceOf[AggregateMode])
+    }
+  }
+
   test("CometBroadcastExchangeExec") {
     withSQLConf(CometConf.COMET_EXEC_BROADCAST_ENABLED.key -> "true") {
       withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_a") {

Reply via email to