This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new fe49e4074 feat: CometNativeWriteExec support with native scan as a
child (#2839)
fe49e4074 is described below
commit fe49e4074857ff398093516dbbeb551cbc5d3d07
Author: Matt Butrovich <[email protected]>
AuthorDate: Thu Dec 4 11:32:44 2025 -0500
feat: CometNativeWriteExec support with native scan as a child (#2839)
---
.../org/apache/comet/rules/CometExecRule.scala | 11 +
.../comet/parquet/CometParquetWriterSuite.scala | 221 +++++++++++++--------
2 files changed, 144 insertions(+), 88 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
index a92082ae1..9152b9f78 100644
--- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
+++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
@@ -536,6 +536,17 @@ case class CometExecRule(session: SparkSession) extends
Rule[SparkPlan] {
firstNativeOp = true
}
+ // CometNativeWriteExec is special: it has two separate plans:
+ // 1. A protobuf plan (nativeOp) describing the write operation
+ // 2. A Spark plan (child) that produces the data to write
+ // The serializedPlanOpt is a def that always returns Some(...) by
serializing
+ // nativeOp on-demand, so it doesn't need convertBlock(). However,
its child
+ // (e.g., CometNativeScanExec) may need its own serialization. Reset
the flag
+ // so children can start their own native execution blocks.
+ if (op.isInstanceOf[CometNativeWriteExec]) {
+ firstNativeOp = true
+ }
+
newPlan
case op =>
firstNativeOp = true
diff --git
a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala
b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala
index e4b8b5385..2ea697fd4 100644
---
a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala
+++
b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala
@@ -24,7 +24,7 @@ import java.io.File
import scala.util.Random
import org.apache.spark.sql.{CometTestBase, DataFrame}
-import org.apache.spark.sql.comet.CometNativeWriteExec
+import org.apache.spark.sql.comet.{CometNativeScanExec, CometNativeWriteExec}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.internal.SQLConf
@@ -34,122 +34,167 @@ import org.apache.comet.testing.{DataGenOptions,
FuzzDataGenerator, SchemaGenOpt
class CometParquetWriterSuite extends CometTestBase {
- test("basic parquet write") {
- // no support for fully native scan as input yet
- assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() !=
CometConf.SCAN_NATIVE_DATAFUSION)
+ private def createTestData(inputDir: File): String = {
+ val inputPath = new File(inputDir, "input.parquet").getAbsolutePath
+ val schema = FuzzDataGenerator.generateSchema(
+ SchemaGenOptions(generateArray = false, generateStruct = false,
generateMap = false))
+ val df = FuzzDataGenerator.generateDataFrame(
+ new Random(42),
+ spark,
+ schema,
+ 1000,
+ DataGenOptions(generateNegativeZero = false))
+ withSQLConf(
+ CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "false",
+ SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Denver") {
+ df.write.parquet(inputPath)
+ }
+ inputPath
+ }
+
+ private def writeWithCometNativeWriteExec(
+ inputPath: String,
+ outputPath: String): Option[QueryExecution] = {
+ val df = spark.read.parquet(inputPath)
+
+ // Use a listener to capture the execution plan during write
+ var capturedPlan: Option[QueryExecution] = None
+
+ val listener = new org.apache.spark.sql.util.QueryExecutionListener {
+ override def onSuccess(funcName: String, qe: QueryExecution, durationNs:
Long): Unit = {
+ // Capture plans from write operations
+ if (funcName == "save" || funcName.contains("command")) {
+ capturedPlan = Some(qe)
+ }
+ }
+
+ override def onFailure(
+ funcName: String,
+ qe: QueryExecution,
+ exception: Exception): Unit = {}
+ }
+
+ spark.listenerManager.register(listener)
+
+ try {
+ // Perform native write
+ df.write.parquet(outputPath)
+
+ // Wait for listener to be called with timeout
+ val maxWaitTimeMs = 15000
+ val checkIntervalMs = 100
+ val maxIterations = maxWaitTimeMs / checkIntervalMs
+ var iterations = 0
+
+ while (capturedPlan.isEmpty && iterations < maxIterations) {
+ Thread.sleep(checkIntervalMs)
+ iterations += 1
+ }
+
+ // Verify that CometNativeWriteExec was used
+ assert(
+ capturedPlan.isDefined,
+ s"Listener was not called within ${maxWaitTimeMs}ms - no execution
plan captured")
+
+ capturedPlan.foreach { qe =>
+ val executedPlan = qe.executedPlan
+ val hasNativeWrite = executedPlan.exists {
+ case _: CometNativeWriteExec => true
+ case d: DataWritingCommandExec =>
+ d.child.exists {
+ case _: CometNativeWriteExec => true
+ case _ => false
+ }
+ case _ => false
+ }
+
+ assert(
+ hasNativeWrite,
+ s"Expected CometNativeWriteExec in the plan, but
got:\n${executedPlan.treeString}")
+ }
+ } finally {
+ spark.listenerManager.unregister(listener)
+ }
+ capturedPlan
+ }
+
+ private def verifyWrittenFile(outputPath: String): Unit = {
+ // Verify the data was written correctly
+ val resultDf = spark.read.parquet(outputPath)
+ assert(resultDf.count() == 1000, "Expected 1000 rows to be written")
+
+ // Verify multiple part files were created
+ val outputDir = new File(outputPath)
+ val partFiles = outputDir.listFiles().filter(_.getName.startsWith("part-"))
+ // With 1000 rows and default parallelism, we should get multiple
partitions
+ assert(partFiles.length > 1, "Expected multiple part files to be created")
+
+ // read with and without Comet and compare
+ var sparkDf: DataFrame = null
+ var cometDf: DataFrame = null
+ withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false") {
+ sparkDf = spark.read.parquet(outputPath)
+ }
+ withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true") {
+ cometDf = spark.read.parquet(outputPath)
+ }
+ checkAnswer(sparkDf, cometDf)
+ }
+ test("basic parquet write") {
withTempPath { dir =>
val outputPath = new File(dir, "output.parquet").getAbsolutePath
// Create test data and write it to a temp parquet file first
withTempPath { inputDir =>
- val inputPath = new File(inputDir, "input.parquet").getAbsolutePath
- val schema = FuzzDataGenerator.generateSchema(
- SchemaGenOptions(generateArray = false, generateStruct = false,
generateMap = false))
- val df = FuzzDataGenerator.generateDataFrame(
- new Random(42),
- spark,
- schema,
- 1000,
- DataGenOptions(generateNegativeZero = false))
- withSQLConf(
- CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "false",
- SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Denver") {
- df.write.parquet(inputPath)
- }
+ val inputPath = createTestData(inputDir)
withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax",
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) ->
"true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {
- val df = spark.read.parquet(inputPath)
-
- // Use a listener to capture the execution plan during write
- var capturedPlan: Option[QueryExecution] = None
-
- val listener = new org.apache.spark.sql.util.QueryExecutionListener {
- override def onSuccess(
- funcName: String,
- qe: QueryExecution,
- durationNs: Long): Unit = {
- // Capture plans from write operations
- if (funcName == "save" || funcName.contains("command")) {
- capturedPlan = Some(qe)
- }
- }
- override def onFailure(
- funcName: String,
- qe: QueryExecution,
- exception: Exception): Unit = {}
- }
+ writeWithCometNativeWriteExec(inputPath, outputPath)
- spark.listenerManager.register(listener)
-
- try {
- // Perform native write
- df.write.parquet(outputPath)
+ verifyWrittenFile(outputPath)
+ }
+ }
+ }
+ }
- // Wait for listener to be called with timeout
- val maxWaitTimeMs = 15000
- val checkIntervalMs = 100
- val maxIterations = maxWaitTimeMs / checkIntervalMs
- var iterations = 0
+ test("basic parquet write with native scan child") {
+ withTempPath { dir =>
+ val outputPath = new File(dir, "output.parquet").getAbsolutePath
- while (capturedPlan.isEmpty && iterations < maxIterations) {
- Thread.sleep(checkIntervalMs)
- iterations += 1
- }
+ // Create test data and write it to a temp parquet file first
+ withTempPath { inputDir =>
+ val inputPath = createTestData(inputDir)
- // Verify that CometNativeWriteExec was used
- assert(
- capturedPlan.isDefined,
- s"Listener was not called within ${maxWaitTimeMs}ms - no
execution plan captured")
+ withSQLConf(
+ CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
+ SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax",
+
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) ->
"true",
+ CometConf.COMET_EXEC_ENABLED.key -> "true") {
+ withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key ->
"native_datafusion") {
+ val capturedPlan = writeWithCometNativeWriteExec(inputPath,
outputPath)
capturedPlan.foreach { qe =>
val executedPlan = qe.executedPlan
- val hasNativeWrite = executedPlan.exists {
- case _: CometNativeWriteExec => true
- case d: DataWritingCommandExec =>
- d.child.exists {
- case _: CometNativeWriteExec => true
- case _ => false
- }
+ val hasNativeScan = executedPlan.exists {
+ case _: CometNativeScanExec => true
case _ => false
}
assert(
- hasNativeWrite,
- s"Expected CometNativeWriteExec in the plan, but
got:\n${executedPlan.treeString}")
+ hasNativeScan,
+ s"Expected CometNativeScanExec in the plan, but
got:\n${executedPlan.treeString}")
}
- } finally {
- spark.listenerManager.unregister(listener)
- }
- // Verify the data was written correctly
- val resultDf = spark.read.parquet(outputPath)
- assert(resultDf.count() == 1000, "Expected 1000 rows to be written")
-
- // Verify multiple part files were created
- val outputDir = new File(outputPath)
- val partFiles =
outputDir.listFiles().filter(_.getName.startsWith("part-"))
- // With 1000 rows and default parallelism, we should get multiple
partitions
- assert(partFiles.length > 1, "Expected multiple part files to be
created")
-
- // read with and without Comet and compare
- var sparkDf: DataFrame = null
- var cometDf: DataFrame = null
- withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false") {
- sparkDf = spark.read.parquet(outputPath)
- }
- withSQLConf(CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true") {
- cometDf = spark.read.parquet(outputPath)
+ verifyWrittenFile(outputPath)
}
- checkAnswer(sparkDf, cometDf)
}
}
}
}
-
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]