This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 4d103b8 fix: Fix corrupted AggregateMode when transforming plan
parameters (#118)
4d103b8 is described below
commit 4d103b88bf9d0165954e04a98f3eb928fdda2291
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Feb 28 00:53:15 2024 -0800
fix: Fix corrupted AggregateMode when transforming plan parameters (#118)
---
.../apache/comet/CometSparkSessionExtensions.scala | 20 +++++++----
.../org/apache/spark/sql/comet/operators.scala | 40 ++++++++++++++--------
.../org/apache/comet/exec/CometExecSuite.scala | 16 ++++++++-
3 files changed, 53 insertions(+), 23 deletions(-)
diff --git
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index f2aba74..10c3328 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -237,7 +237,13 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
- CometProjectExec(nativeOp, op, op.projectList, op.output,
op.child, None)
+ CometProjectExec(
+ nativeOp,
+ op,
+ op.projectList,
+ op.output,
+ op.child,
+ SerializedPlan(None))
case None =>
op
}
@@ -246,7 +252,7 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
- CometFilterExec(nativeOp, op, op.condition, op.child, None)
+ CometFilterExec(nativeOp, op, op.condition, op.child,
SerializedPlan(None))
case None =>
op
}
@@ -255,7 +261,7 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
- CometSortExec(nativeOp, op, op.sortOrder, op.child, None)
+ CometSortExec(nativeOp, op, op.sortOrder, op.child,
SerializedPlan(None))
case None =>
op
}
@@ -264,7 +270,7 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
- CometLocalLimitExec(nativeOp, op, op.limit, op.child, None)
+ CometLocalLimitExec(nativeOp, op, op.limit, op.child,
SerializedPlan(None))
case None =>
op
}
@@ -273,7 +279,7 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
- CometGlobalLimitExec(nativeOp, op, op.limit, op.child, None)
+ CometGlobalLimitExec(nativeOp, op, op.limit, op.child,
SerializedPlan(None))
case None =>
op
}
@@ -282,7 +288,7 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
- CometExpandExec(nativeOp, op, op.projections, op.child, None)
+ CometExpandExec(nativeOp, op, op.projections, op.child,
SerializedPlan(None))
case None =>
op
}
@@ -305,7 +311,7 @@ class CometSparkSessionExtensions
child.output,
if (modes.nonEmpty) Some(modes.head) else None,
child,
- None)
+ SerializedPlan(None))
case None =>
op
}
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index 0298bc6..e75f9a4 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -150,7 +150,7 @@ abstract class CometNativeExec extends CometExec {
* The serialized native query plan, optional. This is only defined when the
current node is the
* "boundary" node between native and Spark.
*/
- def serializedPlanOpt: Option[Array[Byte]]
+ def serializedPlanOpt: SerializedPlan
/** The Comet native operator */
def nativeOp: Operator
@@ -200,7 +200,7 @@ abstract class CometNativeExec extends CometExec {
}
override def doExecuteColumnar(): RDD[ColumnarBatch] = {
- serializedPlanOpt match {
+ serializedPlanOpt.plan match {
case None =>
// This is in the middle of a native execution, it should not be
executed directly.
throw new CometRuntimeException(
@@ -282,11 +282,11 @@ abstract class CometNativeExec extends CometExec {
*/
def convertBlock(): CometNativeExec = {
def transform(arg: Any): AnyRef = arg match {
- case serializedPlan: Option[Array[Byte]] if serializedPlan.isEmpty =>
+ case serializedPlan: SerializedPlan if serializedPlan.isEmpty =>
val out = new ByteArrayOutputStream()
nativeOp.writeTo(out)
out.close()
- Some(out.toByteArray)
+ SerializedPlan(Some(out.toByteArray))
case other: AnyRef => other
case null => null
}
@@ -300,8 +300,8 @@ abstract class CometNativeExec extends CometExec {
*/
def cleanBlock(): CometNativeExec = {
def transform(arg: Any): AnyRef = arg match {
- case serializedPlan: Option[Array[Byte]] if serializedPlan.isDefined =>
- None
+ case serializedPlan: SerializedPlan if serializedPlan.isDefined =>
+ SerializedPlan(None)
case other: AnyRef => other
case null => null
}
@@ -323,13 +323,23 @@ abstract class CometNativeExec extends CometExec {
abstract class CometUnaryExec extends CometNativeExec with UnaryExecNode
+/**
+ * Represents the serialized plan of Comet native operators. Only the first
operator in a block of
+ * continuous Comet native operators has defined plan bytes which contains the
serialization of
+ * the plan tree of the block.
+ */
+case class SerializedPlan(plan: Option[Array[Byte]]) {
+ def isDefined: Boolean = plan.isDefined
+ def isEmpty: Boolean = plan.isEmpty
+}
+
case class CometProjectExec(
override val nativeOp: Operator,
override val originalPlan: SparkPlan,
projectList: Seq[NamedExpression],
override val output: Seq[Attribute],
child: SparkPlan,
- override val serializedPlanOpt: Option[Array[Byte]])
+ override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {
override def producedAttributes: AttributeSet = outputSet
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
@@ -356,7 +366,7 @@ case class CometFilterExec(
override val originalPlan: SparkPlan,
condition: Expression,
child: SparkPlan,
- override val serializedPlanOpt: Option[Array[Byte]])
+ override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)
@@ -390,7 +400,7 @@ case class CometSortExec(
override val originalPlan: SparkPlan,
sortOrder: Seq[SortOrder],
child: SparkPlan,
- override val serializedPlanOpt: Option[Array[Byte]])
+ override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)
@@ -422,7 +432,7 @@ case class CometLocalLimitExec(
override val originalPlan: SparkPlan,
limit: Int,
child: SparkPlan,
- override val serializedPlanOpt: Option[Array[Byte]])
+ override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)
@@ -449,7 +459,7 @@ case class CometGlobalLimitExec(
override val originalPlan: SparkPlan,
limit: Int,
child: SparkPlan,
- override val serializedPlanOpt: Option[Array[Byte]])
+ override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)
@@ -474,7 +484,7 @@ case class CometExpandExec(
override val originalPlan: SparkPlan,
projections: Seq[Seq[Expression]],
child: SparkPlan,
- override val serializedPlanOpt: Option[Array[Byte]])
+ override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {
override def producedAttributes: AttributeSet = outputSet
@@ -538,7 +548,7 @@ case class CometHashAggregateExec(
input: Seq[Attribute],
mode: Option[AggregateMode],
child: SparkPlan,
- override val serializedPlanOpt: Option[Array[Byte]])
+ override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)
@@ -576,7 +586,7 @@ case class CometHashAggregateExec(
case class CometScanWrapper(override val nativeOp: Operator, override val
originalPlan: SparkPlan)
extends CometNativeExec
with LeafExecNode {
- override val serializedPlanOpt: Option[Array[Byte]] = None
+ override val serializedPlanOpt: SerializedPlan = SerializedPlan(None)
override def stringArgs: Iterator[Any] = Iterator(originalPlan.output,
originalPlan)
}
@@ -592,7 +602,7 @@ case class CometSinkPlaceHolder(
override val originalPlan: SparkPlan,
child: SparkPlan)
extends CometUnaryExec {
- override val serializedPlanOpt: Option[Array[Byte]] = None
+ override val serializedPlanOpt: SerializedPlan = SerializedPlan(None)
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan
= {
this.copy(child = newChild)
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index 29b6e12..05be34c 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -31,13 +31,14 @@ import org.apache.spark.sql.{AnalysisException, Column,
CometTestBase, DataFrame
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics,
CatalogTable}
import org.apache.spark.sql.catalyst.expressions.Hex
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec,
CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec,
CometTakeOrderedAndProjectExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle,
CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec,
UnionExec}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec,
CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
-import org.apache.spark.sql.functions.{date_add, expr}
+import org.apache.spark.sql.functions.{date_add, expr, sum}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
import org.apache.spark.unsafe.types.UTF8String
@@ -57,6 +58,19 @@ class CometExecSuite extends CometTestBase {
}
}
+ test("Fix corrupted AggregateMode when transforming plan parameters") {
+ withParquetTable((0 until 5).map(i => (i, i + 1)), "table") {
+ val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2"))
+ val agg = stripAQEPlan(df.queryExecution.executedPlan).collectFirst {
+ case s: CometHashAggregateExec => s
+ }.get
+
+ assert(agg.mode.isDefined && agg.mode.get.isInstanceOf[AggregateMode])
+ val newAgg = agg.cleanBlock().asInstanceOf[CometHashAggregateExec]
+ assert(newAgg.mode.isDefined &&
newAgg.mode.get.isInstanceOf[AggregateMode])
+ }
+ }
+
test("CometBroadcastExchangeExec") {
withSQLConf(CometConf.COMET_EXEC_BROADCAST_ENABLED.key -> "true") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_a") {