This is an automated email from the ASF dual-hosted git repository.
comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 0ae651554 Avoid duplicated writer nodes when AQE enabled (#2982)
0ae651554 is described below
commit 0ae65155432b95f00c4ccbd58609eb271c591ba2
Author: Oleks V <[email protected]>
AuthorDate: Tue Dec 23 18:28:14 2025 -0800
Avoid duplicated writer nodes when AQE enabled (#2982)
* feat: Avoid duplicated write nodes for AQE execution
---
.../org/apache/comet/rules/CometExecRule.scala | 9 ++++
.../comet/parquet/CometParquetWriterSuite.scala | 55 +++++++++++++++++-----
2 files changed, 52 insertions(+), 12 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
index ed48e36f0..bb4ce879d 100644
--- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
+++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec,
AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec,
ObjectHashAggregateExec}
import org.apache.spark.sql.execution.command.{DataWritingCommandExec,
ExecutedCommandExec}
+import org.apache.spark.sql.execution.datasources.WriteFilesExec
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
@@ -197,6 +198,14 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
case op if shouldApplySparkToColumnar(conf, op) =>
convertToComet(op, CometSparkToColumnarExec).getOrElse(op)
+ // AQE reoptimization looks for `DataWritingCommandExec` or
`WriteFilesExec`
+ // if there is none it would reinsert write nodes, and since Comet remap
those nodes
+ // to Comet counterparties the write nodes are twice to the plan.
+ // Checking if AQE inserted another write Command on top of existing
write command
+ case _ @DataWritingCommandExec(_, w: WriteFilesExec)
+ if w.child.isInstanceOf[CometNativeWriteExec] =>
+ w.child
+
case op: DataWritingCommandExec =>
convertToComet(op, CometDataWritingCommand).getOrElse(op)
diff --git
a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala
b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala
index 2ea697fd4..3ae7f949a 100644
---
a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala
+++
b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala
@@ -54,7 +54,8 @@ class CometParquetWriterSuite extends CometTestBase {
private def writeWithCometNativeWriteExec(
inputPath: String,
- outputPath: String): Option[QueryExecution] = {
+ outputPath: String,
+ num_partitions: Option[Int] = None): Option[QueryExecution] = {
val df = spark.read.parquet(inputPath)
// Use a listener to capture the execution plan during write
@@ -77,8 +78,8 @@ class CometParquetWriterSuite extends CometTestBase {
spark.listenerManager.register(listener)
try {
- // Perform native write
- df.write.parquet(outputPath)
+ // Perform native write with optional partitioning
+ num_partitions.fold(df)(n => df.repartition(n)).write.parquet(outputPath)
// Wait for listener to be called with timeout
val maxWaitTimeMs = 15000
@@ -97,20 +98,25 @@ class CometParquetWriterSuite extends CometTestBase {
s"Listener was not called within ${maxWaitTimeMs}ms - no execution
plan captured")
capturedPlan.foreach { qe =>
- val executedPlan = qe.executedPlan
- val hasNativeWrite = executedPlan.exists {
- case _: CometNativeWriteExec => true
+ val executedPlan = stripAQEPlan(qe.executedPlan)
+
+ // Count CometNativeWriteExec instances in the plan
+ var nativeWriteCount = 0
+ executedPlan.foreach {
+ case _: CometNativeWriteExec =>
+ nativeWriteCount += 1
case d: DataWritingCommandExec =>
- d.child.exists {
- case _: CometNativeWriteExec => true
- case _ => false
+ d.child.foreach {
+ case _: CometNativeWriteExec =>
+ nativeWriteCount += 1
+ case _ =>
}
- case _ => false
+ case _ =>
}
assert(
- hasNativeWrite,
- s"Expected CometNativeWriteExec in the plan, but
got:\n${executedPlan.treeString}")
+ nativeWriteCount == 1,
+ s"Expected exactly one CometNativeWriteExec in the plan, but found
$nativeWriteCount:\n${executedPlan.treeString}")
}
} finally {
spark.listenerManager.unregister(listener)
@@ -197,4 +203,29 @@ class CometParquetWriterSuite extends CometTestBase {
}
}
}
+
+ test("basic parquet write with repartition") {
+ withTempPath { dir =>
+ // Create test data and write it to a temp parquet file first
+ withTempPath { inputDir =>
+ val inputPath = createTestData(inputDir)
+ Seq(true, false).foreach(adaptive => {
+ // Create a new output path for each AQE value
+ val outputPath = new File(dir,
s"output_aqe_$adaptive.parquet").getAbsolutePath
+
+ withSQLConf(
+ CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
+ "spark.sql.adaptive.enabled" -> adaptive.toString,
+ SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax",
+ CometConf.getOperatorAllowIncompatConfigKey(
+ classOf[DataWritingCommandExec]) -> "true",
+ CometConf.COMET_EXEC_ENABLED.key -> "true") {
+
+ writeWithCometNativeWriteExec(inputPath, outputPath, Some(10))
+ verifyWrittenFile(outputPath)
+ }
+ })
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]