This is an automated email from the ASF dual-hosted git repository.
mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 28e13dde7 feat: CometExecRDD supports per-partition plan data, reduce
Iceberg native scan serialization, add DPP for Iceberg scans (#3349)
28e13dde7 is described below
commit 28e13dde722c44e492ab6d81c969fd6489dfa67d
Author: Matt Butrovich <[email protected]>
AuthorDate: Sun Feb 8 08:36:40 2026 -0500
feat: CometExecRDD supports per-partition plan data, reduce Iceberg native
scan serialization, add DPP for Iceberg scans (#3349)
---
.../core/src/execution/operators/iceberg_scan.rs | 40 +-
native/core/src/execution/planner.rs | 65 ++-
native/proto/src/lib.rs | 1 +
native/proto/src/proto/operator.proto | 47 +-
.../apache/comet/iceberg/IcebergReflection.scala | 2 +-
.../org/apache/comet/rules/CometScanRule.scala | 54 ++-
.../serde/operator/CometIcebergNativeScan.scala | 523 ++++++++++-----------
.../org/apache/spark/sql/comet/CometExecRDD.scala | 169 ++++++-
.../sql/comet/CometIcebergNativeScanExec.scala | 279 ++++++++---
.../spark/sql/comet/ZippedPartitionsRDD.scala | 67 ---
.../org/apache/spark/sql/comet/operators.scala | 234 +++++++--
.../apache/comet/shims/ShimSubqueryBroadcast.scala | 33 ++
.../apache/comet/shims/ShimSubqueryBroadcast.scala | 33 ++
.../apache/comet/shims/ShimSubqueryBroadcast.scala | 33 ++
.../org/apache/comet/CometIcebergNativeSuite.scala | 181 ++++++-
15 files changed, 1201 insertions(+), 560 deletions(-)
diff --git a/native/core/src/execution/operators/iceberg_scan.rs
b/native/core/src/execution/operators/iceberg_scan.rs
index 2f639e9f7..bc20592e9 100644
--- a/native/core/src/execution/operators/iceberg_scan.rs
+++ b/native/core/src/execution/operators/iceberg_scan.rs
@@ -44,6 +44,7 @@ use crate::parquet::parquet_support::SparkParquetOptions;
use crate::parquet::schema_adapter::SparkSchemaAdapterFactory;
use datafusion::datasource::schema_adapter::{SchemaAdapterFactory,
SchemaMapper};
use datafusion_comet_spark_expr::EvalMode;
+use iceberg::scan::FileScanTask;
/// Iceberg table scan operator that uses iceberg-rust to read Iceberg tables.
///
@@ -58,8 +59,8 @@ pub struct IcebergScanExec {
plan_properties: PlanProperties,
/// Catalog-specific configuration for FileIO
catalog_properties: HashMap<String, String>,
- /// Pre-planned file scan tasks, grouped by partition
- file_task_groups: Vec<Vec<iceberg::scan::FileScanTask>>,
+ /// Pre-planned file scan tasks
+ tasks: Vec<FileScanTask>,
/// Metrics
metrics: ExecutionPlanMetricsSet,
}
@@ -69,11 +70,10 @@ impl IcebergScanExec {
metadata_location: String,
schema: SchemaRef,
catalog_properties: HashMap<String, String>,
- file_task_groups: Vec<Vec<iceberg::scan::FileScanTask>>,
+ tasks: Vec<FileScanTask>,
) -> Result<Self, ExecutionError> {
let output_schema = schema;
- let num_partitions = file_task_groups.len();
- let plan_properties =
Self::compute_properties(Arc::clone(&output_schema), num_partitions);
+ let plan_properties =
Self::compute_properties(Arc::clone(&output_schema), 1);
let metrics = ExecutionPlanMetricsSet::new();
@@ -82,7 +82,7 @@ impl IcebergScanExec {
output_schema,
plan_properties,
catalog_properties,
- file_task_groups,
+ tasks,
metrics,
})
}
@@ -127,19 +127,10 @@ impl ExecutionPlan for IcebergScanExec {
fn execute(
&self,
- partition: usize,
+ _partition: usize,
context: Arc<TaskContext>,
) -> DFResult<SendableRecordBatchStream> {
- if partition < self.file_task_groups.len() {
- let tasks = &self.file_task_groups[partition];
- self.execute_with_tasks(tasks.clone(), partition, context)
- } else {
- Err(DataFusionError::Execution(format!(
- "IcebergScanExec: Partition index {} out of range (only {}
task groups available)",
- partition,
- self.file_task_groups.len()
- )))
- }
+ self.execute_with_tasks(self.tasks.clone(), context)
}
fn metrics(&self) -> Option<MetricsSet> {
@@ -152,15 +143,14 @@ impl IcebergScanExec {
/// deletes via iceberg-rust's ArrowReader.
fn execute_with_tasks(
&self,
- tasks: Vec<iceberg::scan::FileScanTask>,
- partition: usize,
+ tasks: Vec<FileScanTask>,
context: Arc<TaskContext>,
) -> DFResult<SendableRecordBatchStream> {
let output_schema = Arc::clone(&self.output_schema);
let file_io = Self::load_file_io(&self.catalog_properties,
&self.metadata_location)?;
let batch_size = context.session_config().batch_size();
- let metrics = IcebergScanMetrics::new(&self.metrics, partition);
+ let metrics = IcebergScanMetrics::new(&self.metrics);
let num_tasks = tasks.len();
metrics.num_splits.add(num_tasks);
@@ -221,10 +211,10 @@ struct IcebergScanMetrics {
}
impl IcebergScanMetrics {
- fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
+ fn new(metrics: &ExecutionPlanMetricsSet) -> Self {
Self {
- baseline: BaselineMetrics::new(metrics, partition),
- num_splits: MetricBuilder::new(metrics).counter("num_splits",
partition),
+ baseline: BaselineMetrics::new(metrics, 0),
+ num_splits: MetricBuilder::new(metrics).counter("num_splits", 0),
}
}
}
@@ -311,11 +301,11 @@ where
impl DisplayAs for IcebergScanExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) ->
fmt::Result {
- let num_tasks: usize = self.file_task_groups.iter().map(|g|
g.len()).sum();
write!(
f,
"IcebergScanExec: metadata_location={}, num_tasks={}",
- self.metadata_location, num_tasks
+ self.metadata_location,
+ self.tasks.len()
)
}
}
diff --git a/native/core/src/execution/planner.rs
b/native/core/src/execution/planner.rs
index d7ab018ab..2c3d00a23 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -1151,33 +1151,28 @@ impl PhysicalPlanner {
))
}
OpStruct::IcebergScan(scan) => {
- let required_schema: SchemaRef =
-
convert_spark_types_to_arrow_schema(scan.required_schema.as_slice());
+ // Extract common data and single partition's file tasks
+ // Per-partition injection happens in Scala before sending to
native
+ let common = scan
+ .common
+ .as_ref()
+ .ok_or_else(|| GeneralError("IcebergScan missing common
data".into()))?;
- let catalog_properties: HashMap<String, String> = scan
+ let required_schema =
+
convert_spark_types_to_arrow_schema(common.required_schema.as_slice());
+ let catalog_properties: HashMap<String, String> = common
.catalog_properties
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
-
- let metadata_location = scan.metadata_location.clone();
-
- debug_assert!(
- !scan.file_partitions.is_empty(),
- "IcebergScan must have at least one file partition. This
indicates a bug in Scala serialization."
- );
-
- let tasks = parse_file_scan_tasks(
- scan,
- &scan.file_partitions[self.partition as
usize].file_scan_tasks,
- )?;
- let file_task_groups = vec![tasks];
+ let metadata_location = common.metadata_location.clone();
+ let tasks = parse_file_scan_tasks_from_common(common,
&scan.file_scan_tasks)?;
let iceberg_scan = IcebergScanExec::new(
metadata_location,
required_schema,
catalog_properties,
- file_task_groups,
+ tasks,
)?;
Ok((
@@ -2762,15 +2757,14 @@ fn partition_data_to_struct(
/// Each task contains a residual predicate that is used for row-group level
filtering
/// during Parquet scanning.
///
-/// This function uses deduplication pools from the IcebergScan to avoid
redundant parsing
-/// of schemas, partition specs, partition types, name mappings, and other
repeated data.
-fn parse_file_scan_tasks(
- proto_scan: &spark_operator::IcebergScan,
+/// This function uses deduplication pools from the IcebergScanCommon to avoid
redundant
+/// parsing of schemas, partition specs, partition types, name mappings, and
other repeated data.
+fn parse_file_scan_tasks_from_common(
+ proto_common: &spark_operator::IcebergScanCommon,
proto_tasks: &[spark_operator::IcebergFileScanTask],
) -> Result<Vec<iceberg::scan::FileScanTask>, ExecutionError> {
- // Build caches upfront: for 10K tasks with 1 schema, this parses the
schema
- // once instead of 10K times, eliminating redundant JSON deserialization
- let schema_cache: Vec<Arc<iceberg::spec::Schema>> = proto_scan
+ // Parse each unique schema once, not once per task
+ let schema_cache: Vec<Arc<iceberg::spec::Schema>> = proto_common
.schema_pool
.iter()
.map(|json| {
@@ -2783,7 +2777,7 @@ fn parse_file_scan_tasks(
})
.collect::<Result<Vec<_>, _>>()?;
- let partition_spec_cache: Vec<Option<Arc<iceberg::spec::PartitionSpec>>> =
proto_scan
+ let partition_spec_cache: Vec<Option<Arc<iceberg::spec::PartitionSpec>>> =
proto_common
.partition_spec_pool
.iter()
.map(|json| {
@@ -2793,7 +2787,7 @@ fn parse_file_scan_tasks(
})
.collect();
- let name_mapping_cache: Vec<Option<Arc<iceberg::spec::NameMapping>>> =
proto_scan
+ let name_mapping_cache: Vec<Option<Arc<iceberg::spec::NameMapping>>> =
proto_common
.name_mapping_pool
.iter()
.map(|json| {
@@ -2803,7 +2797,7 @@ fn parse_file_scan_tasks(
})
.collect();
- let delete_files_cache: Vec<Vec<iceberg::scan::FileScanTaskDeleteFile>> =
proto_scan
+ let delete_files_cache: Vec<Vec<iceberg::scan::FileScanTaskDeleteFile>> =
proto_common
.delete_files_pool
.iter()
.map(|list| {
@@ -2815,7 +2809,7 @@ fn parse_file_scan_tasks(
"EQUALITY_DELETES" =>
iceberg::spec::DataContentType::EqualityDeletes,
other => {
return Err(GeneralError(format!(
- "Invalid delete content type '{}'. This
indicates a bug in Scala serialization.",
+ "Invalid delete content type '{}'",
other
)))
}
@@ -2836,7 +2830,6 @@ fn parse_file_scan_tasks(
})
.collect::<Result<Vec<_>, _>>()?;
- // Partition data pool is in protobuf messages
let results: Result<Vec<_>, _> = proto_tasks
.iter()
.map(|proto_task| {
@@ -2870,7 +2863,7 @@ fn parse_file_scan_tasks(
};
let bound_predicate = if let Some(idx) = proto_task.residual_idx {
- proto_scan
+ proto_common
.residual_pool
.get(idx as usize)
.and_then(convert_spark_expr_to_predicate)
@@ -2890,24 +2883,22 @@ fn parse_file_scan_tasks(
};
let partition = if let Some(partition_data_idx) =
proto_task.partition_data_idx {
- // Get partition data from protobuf pool
- let partition_data_proto = proto_scan
+ let partition_data_proto = proto_common
.partition_data_pool
.get(partition_data_idx as usize)
.ok_or_else(|| {
ExecutionError::GeneralError(format!(
"Invalid partition_data_idx: {} (pool size: {})",
partition_data_idx,
- proto_scan.partition_data_pool.len()
+ proto_common.partition_data_pool.len()
))
})?;
- // Convert protobuf PartitionData to iceberg Struct
match partition_data_to_struct(partition_data_proto) {
Ok(s) => Some(s),
Err(e) => {
return Err(ExecutionError::GeneralError(format!(
- "Failed to deserialize partition data from
protobuf: {}",
+ "Failed to deserialize partition data: {}",
e
)))
}
@@ -2926,14 +2917,14 @@ fn parse_file_scan_tasks(
.and_then(|idx| name_mapping_cache.get(idx as usize))
.and_then(|opt| opt.clone());
- let project_field_ids = proto_scan
+ let project_field_ids = proto_common
.project_field_ids_pool
.get(proto_task.project_field_ids_idx as usize)
.ok_or_else(|| {
ExecutionError::GeneralError(format!(
"Invalid project_field_ids_idx: {} (pool size: {})",
proto_task.project_field_ids_idx,
- proto_scan.project_field_ids_pool.len()
+ proto_common.project_field_ids_pool.len()
))
})?
.field_ids
diff --git a/native/proto/src/lib.rs b/native/proto/src/lib.rs
index 6dfe546ac..a55657b7a 100644
--- a/native/proto/src/lib.rs
+++ b/native/proto/src/lib.rs
@@ -34,6 +34,7 @@ pub mod spark_partitioning {
// Include generated modules from .proto files.
#[allow(missing_docs)]
+#[allow(clippy::large_enum_variant)]
pub mod spark_operator {
include!(concat!("generated", "/spark.spark_operator.rs"));
}
diff --git a/native/proto/src/proto/operator.proto
b/native/proto/src/proto/operator.proto
index 73c087cf3..78f118e6d 100644
--- a/native/proto/src/proto/operator.proto
+++ b/native/proto/src/proto/operator.proto
@@ -156,28 +156,34 @@ message PartitionData {
repeated PartitionValue values = 1;
}
-message IcebergScan {
- // Schema to read
- repeated SparkStructField required_schema = 1;
-
+// Common data shared by all partitions in split mode (sent once, captured in
closure)
+message IcebergScanCommon {
// Catalog-specific configuration for FileIO (credentials, S3/GCS config,
etc.)
- map<string, string> catalog_properties = 2;
-
- // Pre-planned file scan tasks grouped by Spark partition
- repeated IcebergFilePartition file_partitions = 3;
+ map<string, string> catalog_properties = 1;
// Table metadata file path for FileIO initialization
- string metadata_location = 4;
+ string metadata_location = 2;
+
+ // Schema to read
+ repeated SparkStructField required_schema = 3;
- // Deduplication pools - shared data referenced by index from tasks
- repeated string schema_pool = 5;
- repeated string partition_type_pool = 6;
- repeated string partition_spec_pool = 7;
- repeated string name_mapping_pool = 8;
- repeated ProjectFieldIdList project_field_ids_pool = 9;
- repeated PartitionData partition_data_pool = 10;
- repeated DeleteFileList delete_files_pool = 11;
- repeated spark.spark_expression.Expr residual_pool = 12;
+ // Deduplication pools (must contain all entries for cross-partition
deduplication)
+ repeated string schema_pool = 4;
+ repeated string partition_type_pool = 5;
+ repeated string partition_spec_pool = 6;
+ repeated string name_mapping_pool = 7;
+ repeated ProjectFieldIdList project_field_ids_pool = 8;
+ repeated PartitionData partition_data_pool = 9;
+ repeated DeleteFileList delete_files_pool = 10;
+ repeated spark.spark_expression.Expr residual_pool = 11;
+}
+
+message IcebergScan {
+ // Common data shared across partitions (pools, metadata, catalog props)
+ IcebergScanCommon common = 1;
+
+ // Single partition's file scan tasks
+ repeated IcebergFileScanTask file_scan_tasks = 2;
}
// Helper message for deduplicating field ID lists
@@ -190,11 +196,6 @@ message DeleteFileList {
repeated IcebergDeleteFile delete_files = 1;
}
-// Groups FileScanTasks for a single Spark partition
-message IcebergFilePartition {
- repeated IcebergFileScanTask file_scan_tasks = 1;
-}
-
// Iceberg FileScanTask containing data file, delete files, and residual filter
message IcebergFileScanTask {
// Data file path (e.g.,
s3://bucket/warehouse/db/table/data/00000-0-abc.parquet)
diff --git
a/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala
b/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala
index 2d772063e..c5b655405 100644
--- a/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala
+++ b/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala
@@ -734,7 +734,7 @@ case class CometIcebergNativeScanMetadata(
table: Any,
metadataLocation: String,
nameMapping: Option[String],
- tasks: java.util.List[_],
+ @transient tasks: java.util.List[_],
scanSchema: Any,
tableSchema: Any,
globalFieldIdMapping: Map[String, Int],
diff --git a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala
b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala
index 29555a61e..ebb521730 100644
--- a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala
+++ b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala
@@ -28,12 +28,12 @@ import scala.jdk.CollectionConverters._
import org.apache.hadoop.conf.Configuration
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression,
GenericInternalRow, PlanExpression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
DynamicPruningExpression, Expression, GenericInternalRow, PlanExpression}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.{sideBySide, ArrayBasedMapData,
GenericArrayData, MetadataColumnHelper}
import
org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues
import org.apache.spark.sql.comet.{CometBatchScanExec, CometScanExec}
-import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
+import org.apache.spark.sql.execution.{FileSourceScanExec, InSubqueryExec,
SparkPlan, SubqueryAdaptiveBroadcastExec}
import org.apache.spark.sql.execution.datasources.HadoopFsRelation
import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
@@ -51,12 +51,15 @@ import org.apache.comet.objectstore.NativeConfig
import org.apache.comet.parquet.{CometParquetScan, Native, SupportsComet}
import org.apache.comet.parquet.CometParquetUtils.{encryptionEnabled,
isEncryptionConfigSupported}
import org.apache.comet.serde.operator.CometNativeScan
-import org.apache.comet.shims.{CometTypeShim, ShimFileFormat}
+import org.apache.comet.shims.{CometTypeShim, ShimFileFormat,
ShimSubqueryBroadcast}
/**
* Spark physical optimizer rule for replacing Spark scans with Comet scans.
*/
-case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with
CometTypeShim {
+case class CometScanRule(session: SparkSession)
+ extends Rule[SparkPlan]
+ with CometTypeShim
+ with ShimSubqueryBroadcast {
import CometScanRule._
@@ -341,10 +344,6 @@ case class CometScanRule(session: SparkSession) extends
Rule[SparkPlan] with Com
case _
if scanExec.scan.getClass.getName ==
"org.apache.iceberg.spark.source.SparkBatchQueryScan" =>
- if (scanExec.runtimeFilters.exists(isDynamicPruningFilter)) {
- return withInfo(scanExec, "Dynamic Partition Pruning is not
supported")
- }
-
val fallbackReasons = new ListBuffer[String]()
// Native Iceberg scan requires both configs to be enabled
@@ -635,10 +634,47 @@ case class CometScanRule(session: SparkSession) extends
Rule[SparkPlan] with Com
!hasUnsupportedDeletes
}
+ // Check that all DPP subqueries use InSubqueryExec which we know how
to handle.
+ // Future Spark versions might introduce new subquery types we haven't
tested.
+ val dppSubqueriesSupported = {
+ val unsupportedSubqueries = scanExec.runtimeFilters.collect {
+ case DynamicPruningExpression(e) if
!e.isInstanceOf[InSubqueryExec] =>
+ e.getClass.getSimpleName
+ }
+ // Check for multi-index DPP which we don't support yet.
+ // SPARK-46946 changed SubqueryAdaptiveBroadcastExec from index: Int
to indices: Seq[Int]
+ // as a preparatory refactor for future features (Null Safe Equality
DPP, multiple
+ // equality predicates). Currently indices always has one element,
but future Spark
+ // versions might use multiple indices.
+ val multiIndexDpp = scanExec.runtimeFilters.exists {
+ case DynamicPruningExpression(e: InSubqueryExec) =>
+ e.plan match {
+ case sab: SubqueryAdaptiveBroadcastExec =>
+ getSubqueryBroadcastIndices(sab).length > 1
+ case _ => false
+ }
+ case _ => false
+ }
+ if (unsupportedSubqueries.nonEmpty) {
+ fallbackReasons +=
+ s"Unsupported DPP subquery types:
${unsupportedSubqueries.mkString(", ")}. " +
+ "CometIcebergNativeScanExec only supports InSubqueryExec for
DPP"
+ false
+ } else if (multiIndexDpp) {
+ // See SPARK-46946 for context on multi-index DPP
+ fallbackReasons +=
+ "Multi-index DPP (indices.length > 1) is not yet supported. " +
+ "See SPARK-46946 for context."
+ false
+ } else {
+ true
+ }
+ }
+
if (schemaSupported && fileIOCompatible && formatVersionSupported &&
allParquetFiles &&
allSupportedFilesystems && partitionTypesSupported &&
complexTypePredicatesSupported && transformFunctionsSupported &&
- deleteFileTypesSupported) {
+ deleteFileTypesSupported && dppSubqueriesSupported) {
CometBatchScanExec(
scanExec.clone().asInstanceOf[BatchScanExec],
runtimeFilters = scanExec.runtimeFilters,
diff --git
a/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala
b/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala
index 0ad82af8f..957f62103 100644
---
a/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala
+++
b/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala
@@ -28,10 +28,11 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.comet.{CometBatchScanExec, CometNativeExec}
+import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec,
DataSourceRDD, DataSourceRDDPartition}
import org.apache.spark.sql.types._
import org.apache.comet.ConfigEntry
-import org.apache.comet.iceberg.IcebergReflection
+import org.apache.comet.iceberg.{CometIcebergNativeScanMetadata,
IcebergReflection}
import org.apache.comet.serde.{CometOperatorSerde, OperatorOuterClass}
import org.apache.comet.serde.ExprOuterClass.Expr
import org.apache.comet.serde.OperatorOuterClass.{Operator, SparkStructField}
@@ -309,7 +310,7 @@ object CometIcebergNativeScan extends
CometOperatorSerde[CometBatchScanExec] wit
contentScanTaskClass: Class[_],
fileScanTaskClass: Class[_],
taskBuilder: OperatorOuterClass.IcebergFileScanTask.Builder,
- icebergScanBuilder: OperatorOuterClass.IcebergScan.Builder,
+ commonBuilder: OperatorOuterClass.IcebergScanCommon.Builder,
partitionTypeToPoolIndex: mutable.HashMap[String, Int],
partitionSpecToPoolIndex: mutable.HashMap[String, Int],
partitionDataToPoolIndex: mutable.HashMap[String, Int]): Unit = {
@@ -334,7 +335,7 @@ object CometIcebergNativeScan extends
CometOperatorSerde[CometBatchScanExec] wit
val specIdx = partitionSpecToPoolIndex.getOrElseUpdate(
partitionSpecJson, {
val idx = partitionSpecToPoolIndex.size
- icebergScanBuilder.addPartitionSpecPool(partitionSpecJson)
+ commonBuilder.addPartitionSpecPool(partitionSpecJson)
idx
})
taskBuilder.setPartitionSpecIdx(specIdx)
@@ -415,7 +416,7 @@ object CometIcebergNativeScan extends
CometOperatorSerde[CometBatchScanExec] wit
val typeIdx = partitionTypeToPoolIndex.getOrElseUpdate(
partitionTypeJson, {
val idx = partitionTypeToPoolIndex.size
- icebergScanBuilder.addPartitionTypePool(partitionTypeJson)
+ commonBuilder.addPartitionTypePool(partitionTypeJson)
idx
})
taskBuilder.setPartitionTypeIdx(typeIdx)
@@ -470,7 +471,7 @@ object CometIcebergNativeScan extends
CometOperatorSerde[CometBatchScanExec] wit
val partitionDataIdx = partitionDataToPoolIndex.getOrElseUpdate(
partitionDataKey, {
val idx = partitionDataToPoolIndex.size
- icebergScanBuilder.addPartitionDataPool(partitionDataProto)
+ commonBuilder.addPartitionDataPool(partitionDataProto)
idx
})
taskBuilder.setPartitionDataIdx(partitionDataIdx)
@@ -671,17 +672,59 @@ object CometIcebergNativeScan extends
CometOperatorSerde[CometBatchScanExec] wit
}
/**
- * Serializes a CometBatchScanExec wrapping an Iceberg SparkBatchQueryScan
to protobuf.
+ * Converts a CometBatchScanExec to a minimal placeholder IcebergScan
operator.
*
- * Uses pre-extracted metadata from CometScanRule to avoid redundant
reflection operations. All
- * reflection and validation was done during planning, so serialization
failures here would
- * indicate a programming error rather than an expected fallback condition.
+ * Returns a placeholder operator with only metadata_location for matching
during partition
+ * injection. All other fields (catalog properties, required schema, pools,
partition data) are
+ * set by serializePartitions() at execution time after DPP resolves.
*/
override def convert(
scan: CometBatchScanExec,
builder: Operator.Builder,
childOp: Operator*): Option[OperatorOuterClass.Operator] = {
+
+ val metadata = scan.nativeIcebergScanMetadata.getOrElse {
+ throw new IllegalStateException(
+ "Programming error: CometBatchScanExec.nativeIcebergScanMetadata is
None. " +
+ "Metadata should have been extracted in CometScanRule.")
+ }
+
val icebergScanBuilder = OperatorOuterClass.IcebergScan.newBuilder()
+ val commonBuilder = OperatorOuterClass.IcebergScanCommon.newBuilder()
+
+ // Only set metadata_location - used for matching in PlanDataInjector.
+ // All other fields (catalog_properties, required_schema, pools) are set by
+ // serializePartitions() at execution time, so setting them here would be
wasted work.
+ commonBuilder.setMetadataLocation(metadata.metadataLocation)
+
+ icebergScanBuilder.setCommon(commonBuilder.build())
+ // partition field intentionally empty - will be populated at execution
time
+
+ builder.clearChildren()
+ Some(builder.setIcebergScan(icebergScanBuilder).build())
+ }
+
+ /**
+ * Serializes partitions from inputRDD at execution time.
+ *
+ * Called after doPrepare() has resolved DPP subqueries. Builds pools and
per-partition data in
+ * one pass from the DPP-filtered partitions.
+ *
+ * @param scanExec
+ * The BatchScanExec whose inputRDD contains the DPP-filtered partitions
+ * @param output
+ * The output attributes for the scan
+ * @param metadata
+ * Pre-extracted Iceberg metadata from CometScanRule
+ * @return
+ * Tuple of (commonBytes, perPartitionBytes) for native execution
+ */
+ def serializePartitions(
+ scanExec: BatchScanExec,
+ output: Seq[Attribute],
+ metadata: CometIcebergNativeScanMetadata): (Array[Byte],
Array[Array[Byte]]) = {
+
+ val commonBuilder = OperatorOuterClass.IcebergScanCommon.newBuilder()
// Deduplication structures - map unique values to pool indices
val schemaToPoolIndex = mutable.HashMap[AnyRef, Int]()
@@ -689,300 +732,225 @@ object CometIcebergNativeScan extends
CometOperatorSerde[CometBatchScanExec] wit
val partitionSpecToPoolIndex = mutable.HashMap[String, Int]()
val nameMappingToPoolIndex = mutable.HashMap[String, Int]()
val projectFieldIdsToPoolIndex = mutable.HashMap[Seq[Int], Int]()
- val partitionDataToPoolIndex = mutable.HashMap[String, Int]() // Base64
bytes -> pool index
+ val partitionDataToPoolIndex = mutable.HashMap[String, Int]()
val deleteFilesToPoolIndex =
mutable.HashMap[Seq[OperatorOuterClass.IcebergDeleteFile], Int]()
val residualToPoolIndex = mutable.HashMap[Option[Expr], Int]()
- var totalTasks = 0
+ val perPartitionBuilders =
mutable.ArrayBuffer[OperatorOuterClass.IcebergScan]()
- // Get pre-extracted metadata from planning phase
- // If metadata is None, this is a programming error - metadata should have
been extracted
- // in CometScanRule before creating CometBatchScanExec
- val metadata = scan.nativeIcebergScanMetadata.getOrElse {
- throw new IllegalStateException(
- "Programming error: CometBatchScanExec.nativeIcebergScanMetadata is
None. " +
- "Metadata should have been extracted in CometScanRule.")
- }
-
- // Use pre-extracted metadata (no reflection needed)
- icebergScanBuilder.setMetadataLocation(metadata.metadataLocation)
+ var totalTasks = 0
+ commonBuilder.setMetadataLocation(metadata.metadataLocation)
metadata.catalogProperties.foreach { case (key, value) =>
- icebergScanBuilder.putCatalogProperties(key, value)
+ commonBuilder.putCatalogProperties(key, value)
}
- // Set required_schema from output
- scan.output.foreach { attr =>
+ output.foreach { attr =>
val field = SparkStructField
.newBuilder()
.setName(attr.name)
.setNullable(attr.nullable)
serializeDataType(attr.dataType).foreach(field.setDataType)
- icebergScanBuilder.addRequiredSchema(field.build())
+ commonBuilder.addRequiredSchema(field.build())
}
- // Extract FileScanTasks from the InputPartitions in the RDD
- try {
- scan.wrapped.inputRDD match {
- case rdd: org.apache.spark.sql.execution.datasources.v2.DataSourceRDD
=>
- val partitions = rdd.partitions
- partitions.foreach { partition =>
- val partitionBuilder =
OperatorOuterClass.IcebergFilePartition.newBuilder()
+ // Load Iceberg classes once (avoid repeated class loading in loop)
+ // scalastyle:off classforname
+ val contentScanTaskClass =
Class.forName(IcebergReflection.ClassNames.CONTENT_SCAN_TASK)
+ val fileScanTaskClass =
Class.forName(IcebergReflection.ClassNames.FILE_SCAN_TASK)
+ val contentFileClass =
Class.forName(IcebergReflection.ClassNames.CONTENT_FILE)
+ val schemaParserClass =
Class.forName(IcebergReflection.ClassNames.SCHEMA_PARSER)
+ val schemaClass = Class.forName(IcebergReflection.ClassNames.SCHEMA)
+ // scalastyle:on classforname
- val inputPartitions = partition
-
.asInstanceOf[org.apache.spark.sql.execution.datasources.v2.DataSourceRDDPartition]
- .inputPartitions
+ // Cache method lookups (avoid repeated getMethod in loop)
+ val fileMethod = contentScanTaskClass.getMethod("file")
+ val startMethod = contentScanTaskClass.getMethod("start")
+ val lengthMethod = contentScanTaskClass.getMethod("length")
+ val residualMethod = contentScanTaskClass.getMethod("residual")
+ val taskSchemaMethod = fileScanTaskClass.getMethod("schema")
+ val toJsonMethod = schemaParserClass.getMethod("toJson", schemaClass)
+ toJsonMethod.setAccessible(true)
+
+ // Access inputRDD - safe now, DPP is resolved
+ scanExec.inputRDD match {
+ case rdd: DataSourceRDD =>
+ val partitions = rdd.partitions
+ partitions.foreach { partition =>
+ val partitionBuilder = OperatorOuterClass.IcebergScan.newBuilder()
+
+ val inputPartitions = partition
+ .asInstanceOf[DataSourceRDDPartition]
+ .inputPartitions
+
+ inputPartitions.foreach { inputPartition =>
+ val inputPartClass = inputPartition.getClass
- inputPartitions.foreach { inputPartition =>
- val inputPartClass = inputPartition.getClass
+ try {
+ val taskGroupMethod =
inputPartClass.getDeclaredMethod("taskGroup")
+ taskGroupMethod.setAccessible(true)
+ val taskGroup = taskGroupMethod.invoke(inputPartition)
- try {
- val taskGroupMethod =
inputPartClass.getDeclaredMethod("taskGroup")
- taskGroupMethod.setAccessible(true)
- val taskGroup = taskGroupMethod.invoke(inputPartition)
+ val taskGroupClass = taskGroup.getClass
+ val tasksMethod = taskGroupClass.getMethod("tasks")
+ val tasksCollection =
+
tasksMethod.invoke(taskGroup).asInstanceOf[java.util.Collection[_]]
- val taskGroupClass = taskGroup.getClass
- val tasksMethod = taskGroupClass.getMethod("tasks")
- val tasksCollection =
-
tasksMethod.invoke(taskGroup).asInstanceOf[java.util.Collection[_]]
+ tasksCollection.asScala.foreach { task =>
+ totalTasks += 1
- tasksCollection.asScala.foreach { task =>
- totalTasks += 1
+ val taskBuilder =
OperatorOuterClass.IcebergFileScanTask.newBuilder()
- try {
- val taskBuilder =
OperatorOuterClass.IcebergFileScanTask.newBuilder()
-
- // scalastyle:off classforname
- val contentScanTaskClass =
-
Class.forName(IcebergReflection.ClassNames.CONTENT_SCAN_TASK)
- val fileScanTaskClass =
-
Class.forName(IcebergReflection.ClassNames.FILE_SCAN_TASK)
- val contentFileClass =
- Class.forName(IcebergReflection.ClassNames.CONTENT_FILE)
- // scalastyle:on classforname
-
- val fileMethod = contentScanTaskClass.getMethod("file")
- val dataFile = fileMethod.invoke(task)
-
- val filePathOpt =
- IcebergReflection.extractFileLocation(contentFileClass,
dataFile)
-
- filePathOpt match {
- case Some(filePath) =>
- taskBuilder.setDataFilePath(filePath)
- case None =>
- val msg =
- "Iceberg reflection failure: Cannot extract file
path from data file"
- logError(msg)
- throw new RuntimeException(msg)
- }
+ val dataFile = fileMethod.invoke(task)
- val startMethod = contentScanTaskClass.getMethod("start")
- val start = startMethod.invoke(task).asInstanceOf[Long]
- taskBuilder.setStart(start)
-
- val lengthMethod = contentScanTaskClass.getMethod("length")
- val length = lengthMethod.invoke(task).asInstanceOf[Long]
- taskBuilder.setLength(length)
-
- try {
- // Equality deletes require the full table schema to
resolve field IDs,
- // even for columns not in the projection. Schema
evolution requires
- // using the snapshot's schema to correctly read old
data files.
- // These requirements conflict, so we choose based on
delete presence.
-
- val taskSchemaMethod =
fileScanTaskClass.getMethod("schema")
- val taskSchema = taskSchemaMethod.invoke(task)
-
- val deletes =
- IcebergReflection.getDeleteFilesFromTask(task,
fileScanTaskClass)
- val hasDeletes = !deletes.isEmpty
-
- // Schema to pass to iceberg-rust's FileScanTask.
- // This is used by RecordBatchTransformer for field type
lookups (e.g., in
- // constants_map) and default value generation. The
actual projection is
- // controlled by project_field_ids.
- //
- // Schema selection logic:
- // 1. If hasDeletes=true: Use taskSchema (file-specific
schema) because
- // delete files reference specific schema versions and
we need exact schema
- // matching for MOR.
- // 2. Else if scanSchema contains columns not in
tableSchema: Use scanSchema
- // because this is a VERSION AS OF query reading a
historical snapshot with
- // different schema (e.g., after column drop, scanSchema
has old columns
- // that tableSchema doesn't)
- // 3. Else: Use tableSchema because scanSchema is the
query OUTPUT schema
- // (e.g., for aggregates like "SELECT count(*)",
scanSchema only has
- // aggregate fields and doesn't contain partition
columns needed by
- // constants_map)
- val schema: AnyRef =
- if (hasDeletes) {
- taskSchema
- } else {
- // Check if scanSchema has columns that tableSchema
doesn't have
- // (VERSION AS OF case)
- val scanSchemaFieldIds = IcebergReflection
- .buildFieldIdMapping(metadata.scanSchema)
- .values
- .toSet
- val tableSchemaFieldIds = IcebergReflection
- .buildFieldIdMapping(metadata.tableSchema)
- .values
- .toSet
- val hasHistoricalColumns =
- scanSchemaFieldIds.exists(id =>
!tableSchemaFieldIds.contains(id))
-
- if (hasHistoricalColumns) {
- // VERSION AS OF: scanSchema has columns that
current table doesn't have
- metadata.scanSchema.asInstanceOf[AnyRef]
- } else {
- // Regular query: use tableSchema for partition
field lookups
- metadata.tableSchema.asInstanceOf[AnyRef]
- }
- }
-
- // scalastyle:off classforname
- val schemaParserClass =
-
Class.forName(IcebergReflection.ClassNames.SCHEMA_PARSER)
- val schemaClass =
Class.forName(IcebergReflection.ClassNames.SCHEMA)
- // scalastyle:on classforname
- val toJsonMethod = schemaParserClass.getMethod("toJson",
schemaClass)
- toJsonMethod.setAccessible(true)
-
- // Use object identity for deduplication: Iceberg Schema
objects are immutable
- // and reused across tasks, making identity-based
deduplication safe
- val schemaIdx = schemaToPoolIndex.getOrElseUpdate(
- schema, {
- val idx = schemaToPoolIndex.size
- val schemaJson = toJsonMethod.invoke(null,
schema).asInstanceOf[String]
- icebergScanBuilder.addSchemaPool(schemaJson)
- idx
- })
- taskBuilder.setSchemaIdx(schemaIdx)
-
- // Build field ID mapping from the schema we're using
- val nameToFieldId =
IcebergReflection.buildFieldIdMapping(schema)
-
- // Extract project_field_ids for scan.output columns.
- // For schema evolution: try task schema first, then
fall back to
- // global scan schema (pre-extracted in metadata).
- val projectFieldIds = scan.output.flatMap { attr =>
- nameToFieldId
- .get(attr.name)
- .orElse(metadata.globalFieldIdMapping.get(attr.name))
- .orElse {
- logWarning(
- s"Column '${attr.name}' not found in task or
scan schema," +
- "skipping projection")
- None
- }
- }
-
- // Deduplicate project field IDs
- val projectFieldIdsIdx =
projectFieldIdsToPoolIndex.getOrElseUpdate(
- projectFieldIds, {
- val idx = projectFieldIdsToPoolIndex.size
- val listBuilder =
OperatorOuterClass.ProjectFieldIdList.newBuilder()
- projectFieldIds.foreach(id =>
listBuilder.addFieldIds(id))
-
icebergScanBuilder.addProjectFieldIdsPool(listBuilder.build())
- idx
- })
- taskBuilder.setProjectFieldIdsIdx(projectFieldIdsIdx)
- } catch {
- case e: Exception =>
- val msg =
- "Iceberg reflection failure: " +
- "Failed to extract schema from FileScanTask: " +
- s"${e.getMessage}"
- logError(msg)
- throw new RuntimeException(msg, e)
- }
+ val filePathOpt =
+ IcebergReflection.extractFileLocation(contentFileClass,
dataFile)
- // Deduplicate delete files
- val deleteFilesList =
- extractDeleteFilesList(task, contentFileClass,
fileScanTaskClass)
- if (deleteFilesList.nonEmpty) {
- val deleteFilesIdx =
deleteFilesToPoolIndex.getOrElseUpdate(
- deleteFilesList, {
- val idx = deleteFilesToPoolIndex.size
- val listBuilder =
OperatorOuterClass.DeleteFileList.newBuilder()
- deleteFilesList.foreach(df =>
listBuilder.addDeleteFiles(df))
-
icebergScanBuilder.addDeleteFilesPool(listBuilder.build())
- idx
- })
- taskBuilder.setDeleteFilesIdx(deleteFilesIdx)
- }
+ filePathOpt match {
+ case Some(filePath) =>
+ taskBuilder.setDataFilePath(filePath)
+ case None =>
+ val msg =
+ "Iceberg reflection failure: Cannot extract file path
from data file"
+ logError(msg)
+ throw new RuntimeException(msg)
+ }
- // Extract and deduplicate residual expression
- val residualExprOpt =
- try {
- val residualMethod =
contentScanTaskClass.getMethod("residual")
- val residualExpr = residualMethod.invoke(task)
-
- val catalystExpr =
convertIcebergExpression(residualExpr, scan.output)
-
- catalystExpr.flatMap { expr =>
- exprToProto(expr, scan.output, binding = false)
- }
- } catch {
- case e: Exception =>
- logWarning(
- "Failed to extract residual expression from
FileScanTask: " +
- s"${e.getMessage}")
- None
- }
-
- residualExprOpt.foreach { residualExpr =>
- val residualIdx = residualToPoolIndex.getOrElseUpdate(
- Some(residualExpr), {
- val idx = residualToPoolIndex.size
- icebergScanBuilder.addResidualPool(residualExpr)
- idx
- })
- taskBuilder.setResidualIdx(residualIdx)
+ val start = startMethod.invoke(task).asInstanceOf[Long]
+ taskBuilder.setStart(start)
+
+ val length = lengthMethod.invoke(task).asInstanceOf[Long]
+ taskBuilder.setLength(length)
+
+ val taskSchema = taskSchemaMethod.invoke(task)
+
+ val deletes =
+ IcebergReflection.getDeleteFilesFromTask(task,
fileScanTaskClass)
+ val hasDeletes = !deletes.isEmpty
+
+ val schema: AnyRef =
+ if (hasDeletes) {
+ taskSchema
+ } else {
+ val scanSchemaFieldIds = IcebergReflection
+ .buildFieldIdMapping(metadata.scanSchema)
+ .values
+ .toSet
+ val tableSchemaFieldIds = IcebergReflection
+ .buildFieldIdMapping(metadata.tableSchema)
+ .values
+ .toSet
+ val hasHistoricalColumns =
+ scanSchemaFieldIds.exists(id =>
!tableSchemaFieldIds.contains(id))
+
+ if (hasHistoricalColumns) {
+ metadata.scanSchema.asInstanceOf[AnyRef]
+ } else {
+ metadata.tableSchema.asInstanceOf[AnyRef]
}
+ }
- // Serialize partition spec and data (field definitions,
transforms, values)
- serializePartitionData(
- task,
- contentScanTaskClass,
- fileScanTaskClass,
- taskBuilder,
- icebergScanBuilder,
- partitionTypeToPoolIndex,
- partitionSpecToPoolIndex,
- partitionDataToPoolIndex)
-
- // Deduplicate name mapping
- metadata.nameMapping.foreach { nm =>
- val nmIdx = nameMappingToPoolIndex.getOrElseUpdate(
- nm, {
- val idx = nameMappingToPoolIndex.size
- icebergScanBuilder.addNameMappingPool(nm)
- idx
- })
- taskBuilder.setNameMappingIdx(nmIdx)
+ val schemaIdx = schemaToPoolIndex.getOrElseUpdate(
+ schema, {
+ val idx = schemaToPoolIndex.size
+ val schemaJson = toJsonMethod.invoke(null,
schema).asInstanceOf[String]
+ commonBuilder.addSchemaPool(schemaJson)
+ idx
+ })
+ taskBuilder.setSchemaIdx(schemaIdx)
+
+ val nameToFieldId =
IcebergReflection.buildFieldIdMapping(schema)
+
+ val projectFieldIds = output.flatMap { attr =>
+ nameToFieldId
+ .get(attr.name)
+ .orElse(metadata.globalFieldIdMapping.get(attr.name))
+ .orElse {
+ logWarning(s"Column '${attr.name}' not found in task or
scan schema, " +
+ "skipping projection")
+ None
}
+ }
- partitionBuilder.addFileScanTasks(taskBuilder.build())
+ val projectFieldIdsIdx =
projectFieldIdsToPoolIndex.getOrElseUpdate(
+ projectFieldIds, {
+ val idx = projectFieldIdsToPoolIndex.size
+ val listBuilder =
OperatorOuterClass.ProjectFieldIdList.newBuilder()
+ projectFieldIds.foreach(id => listBuilder.addFieldIds(id))
+ commonBuilder.addProjectFieldIdsPool(listBuilder.build())
+ idx
+ })
+ taskBuilder.setProjectFieldIdsIdx(projectFieldIdsIdx)
+
+ val deleteFilesList =
+ extractDeleteFilesList(task, contentFileClass,
fileScanTaskClass)
+ if (deleteFilesList.nonEmpty) {
+ val deleteFilesIdx = deleteFilesToPoolIndex.getOrElseUpdate(
+ deleteFilesList, {
+ val idx = deleteFilesToPoolIndex.size
+ val listBuilder =
OperatorOuterClass.DeleteFileList.newBuilder()
+ deleteFilesList.foreach(df =>
listBuilder.addDeleteFiles(df))
+ commonBuilder.addDeleteFilesPool(listBuilder.build())
+ idx
+ })
+ taskBuilder.setDeleteFilesIdx(deleteFilesIdx)
+ }
+
+ val residualExprOpt =
+ try {
+ val residualExpr = residualMethod.invoke(task)
+ val catalystExpr = convertIcebergExpression(residualExpr,
output)
+ catalystExpr.flatMap { expr =>
+ exprToProto(expr, output, binding = false)
+ }
+ } catch {
+ case e: Exception =>
+ logWarning(
+ "Failed to extract residual expression from
FileScanTask: " +
+ s"${e.getMessage}")
+ None
}
+
+ residualExprOpt.foreach { residualExpr =>
+ val residualIdx = residualToPoolIndex.getOrElseUpdate(
+ Some(residualExpr), {
+ val idx = residualToPoolIndex.size
+ commonBuilder.addResidualPool(residualExpr)
+ idx
+ })
+ taskBuilder.setResidualIdx(residualIdx)
+ }
+
+ serializePartitionData(
+ task,
+ contentScanTaskClass,
+ fileScanTaskClass,
+ taskBuilder,
+ commonBuilder,
+ partitionTypeToPoolIndex,
+ partitionSpecToPoolIndex,
+ partitionDataToPoolIndex)
+
+ metadata.nameMapping.foreach { nm =>
+ val nmIdx = nameMappingToPoolIndex.getOrElseUpdate(
+ nm, {
+ val idx = nameMappingToPoolIndex.size
+ commonBuilder.addNameMappingPool(nm)
+ idx
+ })
+ taskBuilder.setNameMappingIdx(nmIdx)
}
+
+ partitionBuilder.addFileScanTasks(taskBuilder.build())
}
}
-
- val builtPartition = partitionBuilder.build()
- icebergScanBuilder.addFilePartitions(builtPartition)
}
- case _ =>
- }
- } catch {
- case e: Exception =>
- // CometScanRule already validated this scan should use native
execution.
- // Failure here is a programming error, not a graceful fallback
scenario.
- throw new IllegalStateException(
- s"Native Iceberg scan serialization failed unexpectedly:
${e.getMessage}",
- e)
+
+ perPartitionBuilders += partitionBuilder.build()
+ }
+ case _ =>
+ throw new IllegalStateException("Expected DataSourceRDD from
BatchScanExec")
}
// Log deduplication summary
@@ -999,7 +967,6 @@ object CometIcebergNativeScan extends
CometOperatorSerde[CometBatchScanExec] wit
val avgDedup = if (totalTasks == 0) {
"0.0"
} else {
- // Filter out empty pools - they shouldn't count as 100% dedup
val nonEmptyPools = allPoolSizes.filter(_ > 0)
if (nonEmptyPools.isEmpty) {
"0.0"
@@ -1009,8 +976,7 @@ object CometIcebergNativeScan extends
CometOperatorSerde[CometBatchScanExec] wit
}
}
- // Calculate partition data pool size in bytes (protobuf format)
- val partitionDataPoolBytes =
icebergScanBuilder.getPartitionDataPoolList.asScala
+ val partitionDataPoolBytes = commonBuilder.getPartitionDataPoolList.asScala
.map(_.getSerializedSize)
.sum
@@ -1021,8 +987,10 @@ object CometIcebergNativeScan extends
CometOperatorSerde[CometBatchScanExec] wit
s"$partitionDataPoolBytes bytes (protobuf)")
}
- builder.clearChildren()
- Some(builder.setIcebergScan(icebergScanBuilder).build())
+ val commonBytes = commonBuilder.build().toByteArray
+ val perPartitionBytes = perPartitionBuilders.map(_.toByteArray).toArray
+
+ (commonBytes, perPartitionBytes)
}
override def createExec(nativeOp: Operator, op: CometBatchScanExec):
CometNativeExec = {
@@ -1035,10 +1003,11 @@ object CometIcebergNativeScan extends
CometOperatorSerde[CometBatchScanExec] wit
"Metadata should have been extracted in CometScanRule.")
}
- // Extract metadataLocation from the native operator
- val metadataLocation = nativeOp.getIcebergScan.getMetadataLocation
+ // Extract metadataLocation from the native operator's common data
+ val metadataLocation =
nativeOp.getIcebergScan.getCommon.getMetadataLocation
- // Create the CometIcebergNativeScanExec using the companion object's
apply method
+ // Pass BatchScanExec reference for deferred serialization (DPP support)
+ // Serialization happens at execution time after doPrepare() resolves DPP
subqueries
CometIcebergNativeScanExec(nativeOp, op.wrapped, op.session,
metadataLocation, metadata)
}
}
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala
index 2fd7f12c2..ad0c4f2af 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala
@@ -19,39 +19,168 @@
package org.apache.spark.sql.comet
-import org.apache.spark.{Partition, SparkContext, TaskContext}
-import org.apache.spark.rdd.{RDD, RDDOperationScope}
+import org.apache.spark._
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.execution.ScalarSubquery
import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.SerializableConfiguration
+
+import org.apache.comet.CometExecIterator
+import org.apache.comet.serde.OperatorOuterClass
+
+/**
+ * Partition that carries per-partition planning data, avoiding closure
capture of all partitions.
+ */
+private[spark] class CometExecPartition(
+ override val index: Int,
+ val inputPartitions: Array[Partition],
+ val planDataByKey: Map[String, Array[Byte]])
+ extends Partition
/**
- * A RDD that executes Spark SQL query in Comet native execution to generate
ColumnarBatch.
+ * Unified RDD for Comet native execution.
+ *
+ * Solves the closure capture problem: instead of capturing all partitions'
data in the closure
+ * (which gets serialized to every task), each Partition object carries only
its own data.
+ *
+ * Handles three cases:
+ * - With inputs + per-partition data: injects planning data into operator
tree
+ * - With inputs + no per-partition data: just zips inputs (no injection
overhead)
+ * - No inputs: uses numPartitions to create partitions
+ *
+ * NOTE: This RDD does not handle DPP (InSubqueryExec), which is resolved in
+ * CometIcebergNativeScanExec.serializedPartitionData before this RDD is
created. It also handles
+ * ScalarSubquery expressions by registering them with CometScalarSubquery
before execution.
*/
private[spark] class CometExecRDD(
sc: SparkContext,
- partitionNum: Int,
- var f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch])
- extends RDD[ColumnarBatch](sc, Nil) {
+ var inputRDDs: Seq[RDD[ColumnarBatch]],
+ commonByKey: Map[String, Array[Byte]],
+ @transient perPartitionByKey: Map[String, Array[Array[Byte]]],
+ serializedPlan: Array[Byte],
+ defaultNumPartitions: Int,
+ numOutputCols: Int,
+ nativeMetrics: CometMetricNode,
+ subqueries: Seq[ScalarSubquery],
+ broadcastedHadoopConfForEncryption:
Option[Broadcast[SerializableConfiguration]] = None,
+ encryptedFilePaths: Seq[String] = Seq.empty)
+ extends RDD[ColumnarBatch](sc, inputRDDs.map(rdd => new
OneToOneDependency(rdd))) {
- override def compute(s: Partition, context: TaskContext):
Iterator[ColumnarBatch] = {
- f(Seq.empty, partitionNum, s.index)
+ // Determine partition count: from inputs if available, otherwise from
parameter
+ private val numPartitions: Int = if (inputRDDs.nonEmpty) {
+ inputRDDs.head.partitions.length
+ } else if (perPartitionByKey.nonEmpty) {
+ perPartitionByKey.values.head.length
+ } else {
+ defaultNumPartitions
}
+ // Validate all per-partition arrays have the same length to prevent
+ // ArrayIndexOutOfBoundsException in getPartitions (e.g., from broadcast
scans with
+ // different partition counts after DPP filtering)
+ require(
+ perPartitionByKey.values.forall(_.length == numPartitions),
+ s"All per-partition arrays must have length $numPartitions, but found: " +
+ perPartitionByKey.map { case (key, arr) => s"$key -> ${arr.length}"
}.mkString(", "))
+
override protected def getPartitions: Array[Partition] = {
- Array.tabulate(partitionNum)(i =>
- new Partition {
- override def index: Int = i
- })
+ (0 until numPartitions).map { idx =>
+ val inputParts = inputRDDs.map(_.partitions(idx)).toArray
+ val planData = perPartitionByKey.map { case (key, arr) => key ->
arr(idx) }
+ new CometExecPartition(idx, inputParts, planData)
+ }.toArray
+ }
+
+ override def compute(split: Partition, context: TaskContext):
Iterator[ColumnarBatch] = {
+ val partition = split.asInstanceOf[CometExecPartition]
+
+ val inputs = inputRDDs.zip(partition.inputPartitions).map { case (rdd,
part) =>
+ rdd.iterator(part, context)
+ }
+
+ // Only inject if we have per-partition planning data
+ val actualPlan = if (commonByKey.nonEmpty) {
+ val basePlan = OperatorOuterClass.Operator.parseFrom(serializedPlan)
+ val injected =
+ PlanDataInjector.injectPlanData(basePlan, commonByKey,
partition.planDataByKey)
+ PlanDataInjector.serializeOperator(injected)
+ } else {
+ serializedPlan
+ }
+
+ val it = new CometExecIterator(
+ CometExec.newIterId,
+ inputs,
+ numOutputCols,
+ actualPlan,
+ nativeMetrics,
+ numPartitions,
+ partition.index,
+ broadcastedHadoopConfForEncryption,
+ encryptedFilePaths)
+
+ // Register ScalarSubqueries so native code can look them up
+ subqueries.foreach(sub => CometScalarSubquery.setSubquery(it.id, sub))
+
+ Option(context).foreach { ctx =>
+ ctx.addTaskCompletionListener[Unit] { _ =>
+ it.close()
+ subqueries.foreach(sub => CometScalarSubquery.removeSubquery(it.id,
sub))
+ }
+ }
+
+ it
+ }
+
+ // Duplicates logic from Spark's
ZippedPartitionsBaseRDD.getPreferredLocations
+ override def getPreferredLocations(split: Partition): Seq[String] = {
+ if (inputRDDs == null || inputRDDs.isEmpty) return Nil
+
+ val idx = split.index
+ val prefs = inputRDDs.map(rdd =>
rdd.preferredLocations(rdd.partitions(idx)))
+ // Prefer nodes where all inputs are local; fall back to any input's
preferred location
+ val intersection = prefs.reduce((a, b) => a.intersect(b))
+ if (intersection.nonEmpty) intersection else prefs.flatten.distinct
+ }
+
+ override def clearDependencies(): Unit = {
+ super.clearDependencies()
+ inputRDDs = null
}
}
object CometExecRDD {
- def apply(sc: SparkContext, partitionNum: Int)(
- f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch])
- : RDD[ColumnarBatch] =
- withScope(sc) {
- new CometExecRDD(sc, partitionNum, f)
- }
- private[spark] def withScope[U](sc: SparkContext)(body: => U): U =
- RDDOperationScope.withScope[U](sc)(body)
+ /**
+ * Creates an RDD for native execution with optional per-partition planning
data.
+ */
+ // scalastyle:off
+ def apply(
+ sc: SparkContext,
+ inputRDDs: Seq[RDD[ColumnarBatch]],
+ commonByKey: Map[String, Array[Byte]],
+ perPartitionByKey: Map[String, Array[Array[Byte]]],
+ serializedPlan: Array[Byte],
+ numPartitions: Int,
+ numOutputCols: Int,
+ nativeMetrics: CometMetricNode,
+ subqueries: Seq[ScalarSubquery],
+ broadcastedHadoopConfForEncryption:
Option[Broadcast[SerializableConfiguration]] = None,
+ encryptedFilePaths: Seq[String] = Seq.empty): CometExecRDD = {
+ // scalastyle:on
+
+ new CometExecRDD(
+ sc,
+ inputRDDs,
+ commonByKey,
+ perPartitionByKey,
+ serializedPlan,
+ numPartitions,
+ numOutputCols,
+ nativeMetrics,
+ subqueries,
+ broadcastedHadoopConfForEncryption,
+ encryptedFilePaths)
+ }
}
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala
index 223ae4fbb..36085b632 100644
---
a/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala
+++
b/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala
@@ -21,18 +21,23 @@ package org.apache.spark.sql.comet
import scala.jdk.CollectionConverters._
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast,
DynamicPruningExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning,
Partitioning, UnknownPartitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning,
UnknownPartitioning}
+import org.apache.spark.sql.execution.{InSubqueryExec,
SubqueryAdaptiveBroadcastExec}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.AccumulatorV2
import com.google.common.base.Objects
import org.apache.comet.iceberg.CometIcebergNativeScanMetadata
import org.apache.comet.serde.OperatorOuterClass.Operator
+import org.apache.comet.serde.operator.CometIcebergNativeScan
+import org.apache.comet.shims.ShimSubqueryBroadcast
/**
* Native Iceberg scan operator that delegates file reading to iceberg-rust.
@@ -41,6 +46,10 @@ import org.apache.comet.serde.OperatorOuterClass.Operator
* execution. Iceberg's catalog and planning run in Spark to produce
FileScanTasks, which are
* serialized to protobuf for the native side to execute using iceberg-rust's
FileIO and
* ArrowReader. This provides better performance than reading through Spark's
abstraction layers.
+ *
+ * Supports Dynamic Partition Pruning (DPP) by deferring partition
serialization to execution
+ * time. The doPrepare() method waits for DPP subqueries to resolve, then lazy
+ * serializedPartitionData serializes the DPP-filtered partitions from
inputRDD.
*/
case class CometIcebergNativeScanExec(
override val nativeOp: Operator,
@@ -48,16 +57,136 @@ case class CometIcebergNativeScanExec(
@transient override val originalPlan: BatchScanExec,
override val serializedPlanOpt: SerializedPlan,
metadataLocation: String,
- numPartitions: Int,
@transient nativeIcebergScanMetadata: CometIcebergNativeScanMetadata)
- extends CometLeafExec {
+ extends CometLeafExec
+ with ShimSubqueryBroadcast {
override val supportsColumnar: Boolean = true
override val nodeName: String = "CometIcebergNativeScan"
- override lazy val outputPartitioning: Partitioning =
- UnknownPartitioning(numPartitions)
+ /**
+ * Prepare DPP subquery plans. Called by Spark's prepare() before
doExecuteColumnar().
+ *
+ * This follows Spark's convention of preparing subqueries in doPrepare()
rather than
+ * doExecuteColumnar(). While the actual waiting for DPP results happens
later in
+ * serializedPartitionData, calling prepare() here ensures subquery plans
are set up before
+ * execution begins.
+ */
+ override protected def doPrepare(): Unit = {
+ originalPlan.runtimeFilters.foreach {
+ case DynamicPruningExpression(e: InSubqueryExec) =>
+ e.plan.prepare()
+ case _ =>
+ }
+ super.doPrepare()
+ }
+
+ /**
+ * Lazy partition serialization - deferred until execution time for DPP
support.
+ *
+ * Entry points: This lazy val may be triggered from either
doExecuteColumnar() (via
+ * commonData/perPartitionData) or capturedMetricValues (for Iceberg
metrics). Lazy val
+ * semantics ensure single evaluation regardless of entry point.
+ *
+ * DPP (Dynamic Partition Pruning) Flow:
+ *
+ * {{{
+ * Planning time:
+ * CometIcebergNativeScanExec created
+ * - serializedPartitionData not evaluated (lazy)
+ * - No partition serialization yet
+ *
+ * Execution time:
+ * 1. Spark calls prepare() on the plan tree
+ * - doPrepare() calls e.plan.prepare() for each DPP filter
+ * - Subquery plans are set up (but not yet executed)
+ *
+ * 2. Spark calls doExecuteColumnar() (or metrics are accessed)
+ * - Accesses perPartitionData (or capturedMetricValues)
+ * - Forces serializedPartitionData evaluation (here)
+ * - Waits for DPP values (updateResult or reflection)
+ * - Calls serializePartitions with DPP-filtered inputRDD
+ * - Only matching partitions are serialized
+ * }}}
+ */
+ @transient private lazy val serializedPartitionData: (Array[Byte],
Array[Array[Byte]]) = {
+ // Ensure DPP subqueries are resolved before accessing inputRDD.
+ originalPlan.runtimeFilters.foreach {
+ case DynamicPruningExpression(e: InSubqueryExec) if e.values().isEmpty =>
+ e.plan match {
+ case sab: SubqueryAdaptiveBroadcastExec =>
+ // SubqueryAdaptiveBroadcastExec.executeCollect() throws, so we
call
+ // child.executeCollect() directly. We use the index from SAB to
find the
+ // right buildKey, then locate that key's column in child.output.
+ val rows = sab.child.executeCollect()
+ val indices = getSubqueryBroadcastIndices(sab)
+
+ // SPARK-46946 changed index: Int to indices: Seq[Int] as a
preparatory refactor
+ // for future features (Null Safe Equality DPP, multiple equality
predicates).
+ // Currently indices always has one element. CometScanRule checks
for multi-index
+ // DPP and falls back, so this assertion should never fail.
+ assert(
+ indices.length == 1,
+ s"Multi-index DPP not supported: indices=$indices. See
SPARK-46946.")
+ val buildKeyIndex = indices.head
+ val buildKey = sab.buildKeys(buildKeyIndex)
+
+ // Find column index in child.output by matching buildKey's exprId
+ val colIndex = buildKey match {
+ case attr: Attribute =>
+ sab.child.output.indexWhere(_.exprId == attr.exprId)
+ // DPP may cast partition column to match join key type
+ case Cast(attr: Attribute, _, _, _) =>
+ sab.child.output.indexWhere(_.exprId == attr.exprId)
+ case _ => buildKeyIndex
+ }
+ if (colIndex < 0) {
+ throw new IllegalStateException(
+ s"DPP build key '$buildKey' not found in
${sab.child.output.map(_.name)}")
+ }
+
+ setInSubqueryResult(e, rows.map(_.get(colIndex, e.child.dataType)))
+ case _ =>
+ e.updateResult()
+ }
+ case _ =>
+ }
+
+ CometIcebergNativeScan.serializePartitions(originalPlan, output,
nativeIcebergScanMetadata)
+ }
+
+ /**
+ * Sets InSubqueryExec's private result field via reflection.
+ *
+ * Reflection is required because:
+ * - SubqueryAdaptiveBroadcastExec.executeCollect() throws
UnsupportedOperationException
+ * - InSubqueryExec has no public setter for result, only updateResult()
which calls
+ * executeCollect()
+ * - We can't replace e.plan since it's a val
+ */
+ private def setInSubqueryResult(e: InSubqueryExec, result: Array[_]): Unit =
{
+ val fields = e.getClass.getDeclaredFields
+ // Field name is mangled by Scala compiler, e.g.
"org$apache$...$InSubqueryExec$$result"
+ val resultField = fields
+ .find(f => f.getName.endsWith("$result") &&
!f.getName.contains("Broadcast"))
+ .getOrElse {
+ throw new IllegalStateException(
+ s"Cannot find 'result' field in ${e.getClass.getName}. " +
+ "Spark version may be incompatible with Comet's DPP
implementation.")
+ }
+ resultField.setAccessible(true)
+ resultField.set(e, result)
+ }
+
+ def commonData: Array[Byte] = serializedPartitionData._1
+ def perPartitionData: Array[Array[Byte]] = serializedPartitionData._2
+
+ // numPartitions for execution - derived from actual DPP-filtered partitions
+ // Only accessed during execution, not planning
+ def numPartitions: Int = perPartitionData.length
+
+ override lazy val outputPartitioning: Partitioning =
UnknownPartitioning(numPartitions)
override lazy val outputOrdering: Seq[SortOrder] = Nil
@@ -95,17 +224,34 @@ case class CometIcebergNativeScanExec(
}
}
- private val capturedMetricValues: Seq[MetricValue] = {
- originalPlan.metrics
- .filterNot { case (name, _) =>
- // Filter out metrics that are now runtime metrics incremented on the
native side
- name == "numOutputRows" || name == "numDeletes" || name == "numSplits"
- }
- .map { case (name, metric) =>
- val mappedType = mapMetricType(name, metric.metricType)
- MetricValue(name, metric.value, mappedType)
- }
- .toSeq
+ /**
+ * Captures Iceberg planning metrics for display in Spark UI.
+ *
+ * This lazy val intentionally triggers serializedPartitionData evaluation
because Iceberg
+ * populates metrics during planning (when inputRDD is accessed). Both this
and
+ * doExecuteColumnar() may trigger serializedPartitionData, but lazy val
semantics ensure it's
+ * evaluated only once.
+ */
+ @transient private lazy val capturedMetricValues: Seq[MetricValue] = {
+ // Guard against null originalPlan (from doCanonicalize)
+ if (originalPlan == null) {
+ Seq.empty
+ } else {
+ // Trigger serializedPartitionData to ensure Iceberg planning has run and
+ // metrics are populated
+ val _ = serializedPartitionData
+
+ originalPlan.metrics
+ .filterNot { case (name, _) =>
+ // Filter out metrics that are now runtime metrics incremented on
the native side
+ name == "numOutputRows" || name == "numDeletes" || name ==
"numSplits"
+ }
+ .map { case (name, metric) =>
+ val mappedType = mapMetricType(name, metric.metricType)
+ MetricValue(name, metric.value, mappedType)
+ }
+ .toSeq
+ }
}
/**
@@ -146,62 +292,88 @@ case class CometIcebergNativeScanExec(
baseMetrics ++ icebergMetrics + ("num_splits" -> numSplitsMetric)
}
+ /** Executes using CometExecRDD - planning data is computed lazily on first
access. */
+ override def doExecuteColumnar(): RDD[ColumnarBatch] = {
+ val nativeMetrics = CometMetricNode.fromCometPlan(this)
+ val serializedPlan = CometExec.serializeNativePlan(nativeOp)
+ CometExecRDD(
+ sparkContext,
+ inputRDDs = Seq.empty,
+ commonByKey = Map(metadataLocation -> commonData),
+ perPartitionByKey = Map(metadataLocation -> perPartitionData),
+ serializedPlan = serializedPlan,
+ numPartitions = perPartitionData.length,
+ numOutputCols = output.length,
+ nativeMetrics = nativeMetrics,
+ subqueries = Seq.empty)
+ }
+
+ /**
+ * Override convertBlock to preserve @transient fields. The parent
implementation uses
+ * makeCopy() which loses transient fields.
+ */
+ override def convertBlock(): CometIcebergNativeScanExec = {
+ // Serialize the native plan if not already done
+ val newSerializedPlan = if (serializedPlanOpt.isEmpty) {
+ val bytes = CometExec.serializeNativePlan(nativeOp)
+ SerializedPlan(Some(bytes))
+ } else {
+ serializedPlanOpt
+ }
+
+ // Create new instance preserving transient fields
+ CometIcebergNativeScanExec(
+ nativeOp,
+ output,
+ originalPlan,
+ newSerializedPlan,
+ metadataLocation,
+ nativeIcebergScanMetadata)
+ }
+
override protected def doCanonicalize(): CometIcebergNativeScanExec = {
CometIcebergNativeScanExec(
nativeOp,
output.map(QueryPlan.normalizeExpressions(_, output)),
- originalPlan.doCanonicalize(),
+ null, // Don't need originalPlan for canonicalization
SerializedPlan(None),
metadataLocation,
- numPartitions,
- nativeIcebergScanMetadata)
+ null
+ ) // Don't need metadata for canonicalization
}
- override def stringArgs: Iterator[Any] =
- Iterator(output, s"$metadataLocation, ${originalPlan.scan.description()}",
numPartitions)
+ override def stringArgs: Iterator[Any] = {
+ // Use metadata task count to avoid triggering serializedPartitionData
during planning
+ val hasMeta = nativeIcebergScanMetadata != null &&
nativeIcebergScanMetadata.tasks != null
+ val taskCount = if (hasMeta) nativeIcebergScanMetadata.tasks.size() else 0
+ val scanDesc = if (originalPlan != null) originalPlan.scan.description()
else "canonicalized"
+ // Include runtime filters (DPP) in string representation
+ val runtimeFiltersStr = if (originalPlan != null &&
originalPlan.runtimeFilters.nonEmpty) {
+ s", runtimeFilters=${originalPlan.runtimeFilters.mkString("[", ", ",
"]")}"
+ } else {
+ ""
+ }
+ Iterator(output, s"$metadataLocation, $scanDesc$runtimeFiltersStr",
taskCount)
+ }
override def equals(obj: Any): Boolean = {
obj match {
case other: CometIcebergNativeScanExec =>
this.metadataLocation == other.metadataLocation &&
this.output == other.output &&
- this.serializedPlanOpt == other.serializedPlanOpt &&
- this.numPartitions == other.numPartitions
+ this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}
override def hashCode(): Int =
- Objects.hashCode(
- metadataLocation,
- output.asJava,
- serializedPlanOpt,
- numPartitions: java.lang.Integer)
+ Objects.hashCode(metadataLocation, output.asJava, serializedPlanOpt)
}
object CometIcebergNativeScanExec {
- /**
- * Creates a CometIcebergNativeScanExec from a Spark BatchScanExec.
- *
- * Determines the number of partitions from Iceberg's output partitioning:
- * - KeyGroupedPartitioning: Use Iceberg's partition count
- * - Other cases: Use the number of InputPartitions from Iceberg's planning
- *
- * @param nativeOp
- * The serialized native operator
- * @param scanExec
- * The original Spark BatchScanExec
- * @param session
- * The SparkSession
- * @param metadataLocation
- * Path to table metadata file
- * @param nativeIcebergScanMetadata
- * Pre-extracted Iceberg metadata from planning phase
- * @return
- * A new CometIcebergNativeScanExec
- */
+ /** Creates a CometIcebergNativeScanExec with deferred partition
serialization. */
def apply(
nativeOp: Operator,
scanExec: BatchScanExec,
@@ -209,21 +381,12 @@ object CometIcebergNativeScanExec {
metadataLocation: String,
nativeIcebergScanMetadata: CometIcebergNativeScanMetadata):
CometIcebergNativeScanExec = {
- // Determine number of partitions from Iceberg's output partitioning
- val numParts = scanExec.outputPartitioning match {
- case p: KeyGroupedPartitioning =>
- p.numPartitions
- case _ =>
- scanExec.inputRDD.getNumPartitions
- }
-
val exec = CometIcebergNativeScanExec(
nativeOp,
scanExec.output,
scanExec,
SerializedPlan(None),
metadataLocation,
- numParts,
nativeIcebergScanMetadata)
scanExec.logicalLink.foreach(exec.setLogicalLink)
diff --git
a/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala
deleted file mode 100644
index fdf8bf393..000000000
--- a/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.comet
-
-import org.apache.spark.{Partition, SparkContext, TaskContext}
-import org.apache.spark.rdd.{RDD, RDDOperationScope, ZippedPartitionsBaseRDD,
ZippedPartitionsPartition}
-import org.apache.spark.sql.vectorized.ColumnarBatch
-
-/**
- * Similar to Spark `ZippedPartitionsRDD[1-4]` classes, this class is used to
zip partitions of
- * the multiple RDDs into a single RDD. Spark `ZippedPartitionsRDD[1-4]`
classes only support at
- * most 4 RDDs. This class is used to support more than 4 RDDs. This
ZipPartitionsRDD is used to
- * zip the input sources of the Comet physical plan. So it only zips
partitions of ColumnarBatch.
- */
-private[spark] class ZippedPartitionsRDD(
- sc: SparkContext,
- var f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch],
- var zipRdds: Seq[RDD[ColumnarBatch]],
- preservesPartitioning: Boolean = false)
- extends ZippedPartitionsBaseRDD[ColumnarBatch](sc, zipRdds,
preservesPartitioning) {
-
- // We need to get the number of partitions in `compute` but
`getNumPartitions` is not available
- // on the executors. So we need to capture it here.
- private val numParts: Int = this.getNumPartitions
-
- override def compute(s: Partition, context: TaskContext):
Iterator[ColumnarBatch] = {
- val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
- val iterators =
- zipRdds.zipWithIndex.map(pair => pair._1.iterator(partitions(pair._2),
context))
- f(iterators, numParts, s.index)
- }
-
- override def clearDependencies(): Unit = {
- super.clearDependencies()
- zipRdds = null
- f = null
- }
-}
-
-object ZippedPartitionsRDD {
- def apply(sc: SparkContext, rdds: Seq[RDD[ColumnarBatch]])(
- f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch])
- : RDD[ColumnarBatch] =
- withScope(sc) {
- new ZippedPartitionsRDD(sc, f, rdds)
- }
-
- private[spark] def withScope[U](sc: SparkContext)(body: => U): U =
- RDDOperationScope.withScope[U](sc)(body)
-}
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index 6f33467ef..eba74c9e2 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -25,7 +25,6 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
-import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -59,6 +58,139 @@ import
org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregat
import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto,
supportedSortType}
import org.apache.comet.serde.operator.CometSink
+/**
+ * Trait for injecting per-partition planning data into operator nodes.
+ *
+ * Implementations handle specific operator types (e.g., Iceberg scans, Delta
scans).
+ */
+private[comet] trait PlanDataInjector {
+
+ /** Check if this injector can handle the given operator. */
+ def canInject(op: Operator): Boolean
+
+ /** Extract the key used to look up planning data for this operator. */
+ def getKey(op: Operator): Option[String]
+
+ /** Inject common + partition data into the operator node. */
+ def inject(op: Operator, commonBytes: Array[Byte], partitionBytes:
Array[Byte]): Operator
+}
+
+/**
+ * Registry and utilities for injecting per-partition planning data into
operator trees.
+ */
+private[comet] object PlanDataInjector {
+
+ // Registry of injectors for different operator types
+ private val injectors: Seq[PlanDataInjector] = Seq(
+ IcebergPlanDataInjector
+ // Future: DeltaPlanDataInjector, HudiPlanDataInjector, etc.
+ )
+
+ /**
+ * Injects planning data into an Operator tree by finding nodes that need
injection and applying
+ * the appropriate injector.
+ *
+ * Supports joins over multiple tables by matching each operator with its
corresponding data
+ * based on a key (e.g., metadata_location for Iceberg).
+ */
+ def injectPlanData(
+ op: Operator,
+ commonByKey: Map[String, Array[Byte]],
+ partitionByKey: Map[String, Array[Byte]]): Operator = {
+ val builder = op.toBuilder
+
+ // Try each injector to see if it can handle this operator
+ for (injector <- injectors if injector.canInject(op)) {
+ injector.getKey(op) match {
+ case Some(key) =>
+ (commonByKey.get(key), partitionByKey.get(key)) match {
+ case (Some(commonBytes), Some(partitionBytes)) =>
+ val injectedOp = injector.inject(op, commonBytes, partitionBytes)
+ // Copy the injected operator's fields to our builder
+ builder.clear()
+ builder.mergeFrom(injectedOp)
+ case _ =>
+ throw new CometRuntimeException(s"Missing planning data for key:
$key")
+ }
+ case None => // No key, skip injection
+ }
+ }
+
+ // Recursively process children
+ builder.clearChildren()
+ op.getChildrenList.asScala.foreach { child =>
+ builder.addChildren(injectPlanData(child, commonByKey, partitionByKey))
+ }
+
+ builder.build()
+ }
+
+ def serializeOperator(op: Operator): Array[Byte] = {
+ val size = op.getSerializedSize
+ val bytes = new Array[Byte](size)
+ val codedOutput = CodedOutputStream.newInstance(bytes)
+ op.writeTo(codedOutput)
+ codedOutput.checkNoSpaceLeft()
+ bytes
+ }
+}
+
+/**
+ * Injector for Iceberg scan operators.
+ */
+private[comet] object IcebergPlanDataInjector extends PlanDataInjector {
+ import java.nio.ByteBuffer
+ import java.util.{LinkedHashMap, Map => JMap}
+
+ private final val maxCacheEntries = 16
+
+ // Cache parsed IcebergScanCommon to avoid reparsing for Iceberg tables with
large numbers of
+ // partitions (thousands or more) that may repeatedly parse the same
commonBytes.
+ // IcebergPlanDataInjector is a singleton, so we use an LRU cache to
eventually evict old
+ // IcebergScanCommon objects. 16 seems like a reasonable starting point
since these objects
+ // are not large. Thread-safe LinkedHashMap with accessOrder=true provides
LRU ordering.
+ private val commonCache = java.util.Collections.synchronizedMap(
+ new LinkedHashMap[ByteBuffer, OperatorOuterClass.IcebergScanCommon](4,
0.75f, true) {
+ override def removeEldestEntry(
+ eldest: JMap.Entry[ByteBuffer,
OperatorOuterClass.IcebergScanCommon]): Boolean = {
+ size() > maxCacheEntries
+ }
+ })
+
+ override def canInject(op: Operator): Boolean =
+ op.hasIcebergScan &&
+ op.getIcebergScan.getFileScanTasksCount == 0 &&
+ op.getIcebergScan.hasCommon
+
+ override def getKey(op: Operator): Option[String] =
+ Some(op.getIcebergScan.getCommon.getMetadataLocation)
+
+ override def inject(
+ op: Operator,
+ commonBytes: Array[Byte],
+ partitionBytes: Array[Byte]): Operator = {
+ val scan = op.getIcebergScan
+
+ // Cache the parsed common data to avoid deserializing on every partition
+ val cacheKey = ByteBuffer.wrap(commonBytes)
+ val common = commonCache.synchronized {
+ Option(commonCache.get(cacheKey)).getOrElse {
+ val parsed =
OperatorOuterClass.IcebergScanCommon.parseFrom(commonBytes)
+ commonCache.put(cacheKey, parsed)
+ parsed
+ }
+ }
+
+ val tasksOnly = OperatorOuterClass.IcebergScan.parseFrom(partitionBytes)
+
+ val scanBuilder = scan.toBuilder
+ scanBuilder.setCommon(common)
+ scanBuilder.addAllFileScanTasks(tasksOnly.getFileScanTasksList)
+
+ op.toBuilder.setIcebergScan(scanBuilder).build()
+ }
+}
+
/**
* A Comet physical operator
*/
@@ -105,6 +237,15 @@ abstract class CometExec extends CometPlan {
}
}
}
+
+ /** Collects all ScalarSubquery expressions from a plan tree. */
+ protected def collectSubqueries(sparkPlan: SparkPlan): Seq[ScalarSubquery] =
{
+ val childSubqueries = sparkPlan.children.flatMap(collectSubqueries)
+ val planSubqueries = sparkPlan.expressions.flatMap {
+ _.collect { case sub: ScalarSubquery => sub }
+ }
+ childSubqueries ++ planSubqueries
+ }
}
object CometExec {
@@ -290,32 +431,8 @@ abstract class CometNativeExec extends CometExec {
case None => (None, Seq.empty)
}
- def createCometExecIter(
- inputs: Seq[Iterator[ColumnarBatch]],
- numParts: Int,
- partitionIndex: Int): CometExecIterator = {
- val it = new CometExecIterator(
- CometExec.newIterId,
- inputs,
- output.length,
- serializedPlanCopy,
- nativeMetrics,
- numParts,
- partitionIndex,
- broadcastedHadoopConfForEncryption,
- encryptedFilePaths)
-
- setSubqueries(it.id, this)
-
- Option(TaskContext.get()).foreach { context =>
- context.addTaskCompletionListener[Unit] { _ =>
- it.close()
- cleanSubqueries(it.id, this)
- }
- }
-
- it
- }
+ // Find planning data within this stage (stops at shuffle boundaries).
+ val (commonByKey, perPartitionByKey) = findAllPlanData(this)
// Collect the input ColumnarBatches from the child operators and
create a CometExecIterator
// to execute the native plan.
@@ -395,12 +512,20 @@ abstract class CometNativeExec extends CometExec {
throw new CometRuntimeException(s"No input for CometNativeExec:\n
$this")
}
- if (inputs.nonEmpty) {
- ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter)
- } else {
- val partitionNum = firstNonBroadcastPlanNumPartitions
- CometExecRDD(sparkContext, partitionNum)(createCometExecIter)
- }
+ // Unified RDD creation - CometExecRDD handles all cases
+ val subqueries = collectSubqueries(this)
+ CometExecRDD(
+ sparkContext,
+ inputs.toSeq,
+ commonByKey,
+ perPartitionByKey,
+ serializedPlanCopy,
+ firstNonBroadcastPlanNumPartitions,
+ output.length,
+ nativeMetrics,
+ subqueries,
+ broadcastedHadoopConfForEncryption,
+ encryptedFilePaths)
}
}
@@ -440,6 +565,49 @@ abstract class CometNativeExec extends CometExec {
}
}
+ /**
+ * Find all plan nodes with per-partition planning data in the plan tree.
Returns two maps keyed
+ * by a unique identifier: one for common data (shared across partitions)
and one for
+ * per-partition data.
+ *
+ * Currently supports Iceberg scans (keyed by metadata_location). Additional
scan types can be
+ * added by extending this method.
+ *
+ * Stops at stage boundaries (shuffle exchanges, etc.) because partition
indices are only valid
+ * within the same stage.
+ *
+ * @return
+ * (commonByKey, perPartitionByKey) - common data is shared, per-partition
varies
+ */
+ private def findAllPlanData(
+ plan: SparkPlan): (Map[String, Array[Byte]], Map[String,
Array[Array[Byte]]]) = {
+ plan match {
+ // Found an Iceberg scan with planning data
+ case iceberg: CometIcebergNativeScanExec
+ if iceberg.commonData.nonEmpty && iceberg.perPartitionData.nonEmpty
=>
+ (
+ Map(iceberg.metadataLocation -> iceberg.commonData),
+ Map(iceberg.metadataLocation -> iceberg.perPartitionData))
+
+ // Broadcast stages are boundaries - don't collect per-partition data
from inside them.
+ // After DPP filtering, broadcast scans may have different partition
counts than the
+ // probe side, causing ArrayIndexOutOfBoundsException in
CometExecRDD.getPartitions.
+ case _: BroadcastQueryStageExec | _: CometBroadcastExchangeExec =>
+ (Map.empty, Map.empty)
+
+ // Stage boundaries - stop searching (partition indices won't align
after these)
+ case _: ShuffleQueryStageExec | _: AQEShuffleReadExec | _:
CometShuffleExchangeExec |
+ _: CometUnionExec | _: CometTakeOrderedAndProjectExec | _:
CometCoalesceExec |
+ _: ReusedExchangeExec | _: CometSparkToColumnarExec =>
+ (Map.empty, Map.empty)
+
+ // Continue searching through other operators, combining results from
all children
+ case _ =>
+ val results = plan.children.map(findAllPlanData)
+ (results.flatMap(_._1).toMap, results.flatMap(_._2).toMap)
+ }
+ }
+
/**
* Converts this native Comet operator and its children into a native block
which can be
* executed as a whole (i.e., in a single JNI call) from the native side.
diff --git
a/spark/src/main/spark-3.4/org/apache/comet/shims/ShimSubqueryBroadcast.scala
b/spark/src/main/spark-3.4/org/apache/comet/shims/ShimSubqueryBroadcast.scala
new file mode 100644
index 000000000..1ff093504
--- /dev/null
+++
b/spark/src/main/spark-3.4/org/apache/comet/shims/ShimSubqueryBroadcast.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.execution.SubqueryAdaptiveBroadcastExec
+
+trait ShimSubqueryBroadcast {
+
+ /**
+ * Gets the build key indices from SubqueryAdaptiveBroadcastExec. Spark 3.x
has `index: Int`,
+ * Spark 4.x has `indices: Seq[Int]`.
+ */
+ def getSubqueryBroadcastIndices(sab: SubqueryAdaptiveBroadcastExec):
Seq[Int] = {
+ Seq(sab.index)
+ }
+}
diff --git
a/spark/src/main/spark-3.5/org/apache/comet/shims/ShimSubqueryBroadcast.scala
b/spark/src/main/spark-3.5/org/apache/comet/shims/ShimSubqueryBroadcast.scala
new file mode 100644
index 000000000..1ff093504
--- /dev/null
+++
b/spark/src/main/spark-3.5/org/apache/comet/shims/ShimSubqueryBroadcast.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.execution.SubqueryAdaptiveBroadcastExec
+
+trait ShimSubqueryBroadcast {
+
+ /**
+ * Gets the build key indices from SubqueryAdaptiveBroadcastExec. Spark 3.x
has `index: Int`,
+ * Spark 4.x has `indices: Seq[Int]`.
+ */
+ def getSubqueryBroadcastIndices(sab: SubqueryAdaptiveBroadcastExec):
Seq[Int] = {
+ Seq(sab.index)
+ }
+}
diff --git
a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSubqueryBroadcast.scala
b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSubqueryBroadcast.scala
new file mode 100644
index 000000000..417dfd46b
--- /dev/null
+++
b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSubqueryBroadcast.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.execution.SubqueryAdaptiveBroadcastExec
+
+trait ShimSubqueryBroadcast {
+
+ /**
+ * Gets the build key indices from SubqueryAdaptiveBroadcastExec. Spark 3.x
has `index: Int`,
+ * Spark 4.x has `indices: Seq[Int]`.
+ */
+ def getSubqueryBroadcastIndices(sab: SubqueryAdaptiveBroadcastExec):
Seq[Int] = {
+ sab.indices
+ }
+}
diff --git
a/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala
b/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala
index f3c8a8b2a..033b634e0 100644
--- a/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala
@@ -25,8 +25,10 @@ import java.nio.file.Files
import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.comet.CometIcebergNativeScanExec
import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.types.{StringType, TimestampType}
import org.apache.comet.iceberg.RESTCatalogHelper
+import org.apache.comet.testing.{FuzzDataGenerator, SchemaGenOptions}
/**
* Test suite for native Iceberg scan using FileScanTasks and iceberg-rust.
@@ -2291,10 +2293,89 @@ class CometIcebergNativeSuite extends CometTestBase
with RESTCatalogHelper {
}
file.delete()
}
+
deleteRecursively(dir)
}
}
+ test("runtime filtering - multiple DPP filters on two partition columns") {
+ assume(icebergAvailable, "Iceberg not available")
+ withTempIcebergDir { warehouseDir =>
+ val dimDir = new File(warehouseDir, "dim_parquet")
+ withSQLConf(
+ "spark.sql.catalog.runtime_cat" ->
"org.apache.iceberg.spark.SparkCatalog",
+ "spark.sql.catalog.runtime_cat.type" -> "hadoop",
+ "spark.sql.catalog.runtime_cat.warehouse" ->
warehouseDir.getAbsolutePath,
+ "spark.sql.autoBroadcastJoinThreshold" -> "1KB",
+ CometConf.COMET_ENABLED.key -> "true",
+ CometConf.COMET_EXEC_ENABLED.key -> "true",
+ CometConf.COMET_ICEBERG_NATIVE_ENABLED.key -> "true") {
+
+ // Create table partitioned by TWO columns: (data, bucket(8, id))
+ // This mimics Iceberg's testMultipleRuntimeFilters
+ spark.sql("""
+ CREATE TABLE runtime_cat.db.multi_dpp_fact (
+ id BIGINT,
+ data STRING,
+ date DATE,
+ ts TIMESTAMP
+ ) USING iceberg
+ PARTITIONED BY (data, bucket(8, id))
+ """)
+
+ // Insert data - 99 rows with varying data and id values
+ val df = spark
+ .range(1, 100)
+ .selectExpr(
+ "id",
+ "CAST(DATE_ADD(DATE '1970-01-01', CAST(id % 4 AS INT)) AS STRING)
as data",
+ "DATE_ADD(DATE '1970-01-01', CAST(id % 4 AS INT)) as date",
+ "CAST(DATE_ADD(DATE '1970-01-01', CAST(id % 4 AS INT)) AS
TIMESTAMP) as ts")
+ df.coalesce(1)
+ .write
+ .format("iceberg")
+ .option("fanout-enabled", "true")
+ .mode("append")
+ .saveAsTable("runtime_cat.db.multi_dpp_fact")
+
+ // Create dimension table with specific id=1, data='1970-01-02'
+ spark
+ .createDataFrame(Seq((1L, java.sql.Date.valueOf("1970-01-02"),
"1970-01-02")))
+ .toDF("id", "date", "data")
+ .write
+ .parquet(dimDir.getAbsolutePath)
+
spark.read.parquet(dimDir.getAbsolutePath).createOrReplaceTempView("dim")
+
+ // Join on BOTH partition columns - this creates TWO DPP filters
+ val query =
+ """SELECT /*+ BROADCAST(d) */ f.*
+ |FROM runtime_cat.db.multi_dpp_fact f
+ |JOIN dim d ON f.id = d.id AND f.data = d.data
+ |WHERE d.date = DATE '1970-01-02'""".stripMargin
+
+ // Verify plan has 2 dynamic pruning expressions
+ val df2 = spark.sql(query)
+ val planStr = df2.queryExecution.executedPlan.toString
+ // Count "dynamicpruningexpression(" to avoid matching
"dynamicpruning#N" references
+ val dppCount =
"dynamicpruningexpression\\(".r.findAllIn(planStr).length
+ assert(dppCount == 2, s"Expected 2 DPP expressions but found $dppCount
in:\n$planStr")
+
+ // Verify native Iceberg scan is used and DPP actually pruned
partitions
+ val (_, cometPlan) = checkSparkAnswer(query)
+ val icebergScans = collectIcebergNativeScans(cometPlan)
+ assert(
+ icebergScans.nonEmpty,
+ s"Expected CometIcebergNativeScanExec but found none.
Plan:\n$cometPlan")
+ // With 4 data values x 8 buckets = up to 32 partitions total
+ // DPP on (data='1970-01-02', bucket(id=1)) should prune to 1
+ val numPartitions = icebergScans.head.numPartitions
+ assert(numPartitions == 1, s"Expected DPP to prune to 1 partition but
got $numPartitions")
+
+ spark.sql("DROP TABLE runtime_cat.db.multi_dpp_fact")
+ }
+ }
+ }
+
test("runtime filtering - join with dynamic partition pruning") {
assume(icebergAvailable, "Iceberg not available")
withTempIcebergDir { warehouseDir =>
@@ -2303,11 +2384,14 @@ class CometIcebergNativeSuite extends CometTestBase
with RESTCatalogHelper {
"spark.sql.catalog.runtime_cat" ->
"org.apache.iceberg.spark.SparkCatalog",
"spark.sql.catalog.runtime_cat.type" -> "hadoop",
"spark.sql.catalog.runtime_cat.warehouse" ->
warehouseDir.getAbsolutePath,
+ // Prevent fact table from being broadcast (force dimension to be
broadcast)
+ "spark.sql.autoBroadcastJoinThreshold" -> "1KB",
CometConf.COMET_ENABLED.key -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true",
CometConf.COMET_ICEBERG_NATIVE_ENABLED.key -> "true") {
- // Create partitioned Iceberg table (fact table)
+ // Create partitioned Iceberg table (fact table) with 3 partitions
+ // Add enough data to prevent broadcast
spark.sql("""
CREATE TABLE runtime_cat.db.fact_table (
id BIGINT,
@@ -2323,7 +2407,11 @@ class CometIcebergNativeSuite extends CometTestBase with
RESTCatalogHelper {
(1, 'a', DATE '1970-01-01'),
(2, 'b', DATE '1970-01-02'),
(3, 'c', DATE '1970-01-02'),
- (4, 'd', DATE '1970-01-03')
+ (4, 'd', DATE '1970-01-03'),
+ (5, 'e', DATE '1970-01-01'),
+ (6, 'f', DATE '1970-01-02'),
+ (7, 'g', DATE '1970-01-03'),
+ (8, 'h', DATE '1970-01-01')
""")
// Create dimension table (Parquet) in temp directory
@@ -2335,8 +2423,9 @@ class CometIcebergNativeSuite extends CometTestBase with
RESTCatalogHelper {
spark.read.parquet(dimDir.getAbsolutePath).createOrReplaceTempView("dim")
// This join should trigger dynamic partition pruning
+ // Use BROADCAST hint to force dimension table to be broadcast
val query =
- """SELECT f.* FROM runtime_cat.db.fact_table f
+ """SELECT /*+ BROADCAST(d) */ f.* FROM runtime_cat.db.fact_table f
|JOIN dim d ON f.date = d.date AND d.id = 1
|ORDER BY f.id""".stripMargin
@@ -2348,16 +2437,88 @@ class CometIcebergNativeSuite extends CometTestBase
with RESTCatalogHelper {
planStr.contains("dynamicpruning"),
s"Expected dynamic pruning in plan but got:\n$planStr")
- // Check results match Spark
- // Note: AQE re-plans after subquery executes, converting
dynamicpruningexpression(...)
- // to dynamicpruningexpression(true), which allows native Iceberg scan
to proceed.
- // This is correct behavior - no actual subquery to wait for after AQE
re-planning.
- // However, the rest of the still contains non-native operators
because CometExecRule
- // doesn't run again.
- checkSparkAnswer(df)
+ // Should now use native Iceberg scan with DPP
+ checkIcebergNativeScan(query)
+
+ // Verify DPP actually pruned partitions (should only scan 1 of 3
partitions)
+ val (_, cometPlan) = checkSparkAnswer(query)
+ val icebergScans = collectIcebergNativeScans(cometPlan)
+ assert(
+ icebergScans.nonEmpty,
+ s"Expected CometIcebergNativeScanExec but found none.
Plan:\n$cometPlan")
+ val numPartitions = icebergScans.head.numPartitions
+ assert(numPartitions == 1, s"Expected DPP to prune to 1 partition but
got $numPartitions")
spark.sql("DROP TABLE runtime_cat.db.fact_table")
}
}
}
+
+ // Regression test for a user reported issue
+ test("double partitioning with range filter on top-level partition") {
+ assume(icebergAvailable, "Iceberg not available")
+
+ // Generate Iceberg table without Comet enabled
+ withTempIcebergDir { warehouseDir =>
+ withSQLConf(
+ "spark.sql.catalog.test_cat" ->
"org.apache.iceberg.spark.SparkCatalog",
+ "spark.sql.catalog.test_cat.type" -> "hadoop",
+ "spark.sql.catalog.test_cat.warehouse" -> warehouseDir.getAbsolutePath,
+ "spark.sql.files.maxRecordsPerFile" -> "50") {
+
+ // timestamp + geohash with multi-column partitioning
+ spark.sql("""
+ CREATE TABLE test_cat.db.geolocation_trips (
+ outputTimestamp TIMESTAMP,
+ geohash7 STRING,
+ tripId STRING
+ ) USING iceberg
+ PARTITIONED BY (hours(outputTimestamp), truncate(3, geohash7))
+ TBLPROPERTIES (
+ 'format-version' = '2',
+ 'write.distribution-mode' = 'range',
+ 'write.target-file-size-bytes' = '1073741824'
+ )
+ """)
+ val schema = FuzzDataGenerator.generateSchema(
+ SchemaGenOptions(primitiveTypes = Seq(TimestampType, StringType,
StringType)))
+
+ val random = new scala.util.Random(42)
+ // Set baseDate to match our filter range (around 2024-01-01)
+ val options = testing.DataGenOptions(
+ allowNull = false,
+ baseDate = 1704067200000L
+ ) // 2024-01-01 00:00:00
+
+ val df = FuzzDataGenerator
+ .generateDataFrame(random, spark, schema, 1000, options)
+ .toDF("outputTimestamp", "geohash7", "tripId")
+
+ df.writeTo("test_cat.db.geolocation_trips").append()
+ }
+
+ // Query using Comet native Iceberg scan
+ withSQLConf(
+ "spark.sql.catalog.test_cat" ->
"org.apache.iceberg.spark.SparkCatalog",
+ "spark.sql.catalog.test_cat.type" -> "hadoop",
+ "spark.sql.catalog.test_cat.warehouse" -> warehouseDir.getAbsolutePath,
+ CometConf.COMET_ENABLED.key -> "true",
+ CometConf.COMET_EXEC_ENABLED.key -> "true",
+ CometConf.COMET_ICEBERG_NATIVE_ENABLED.key -> "true") {
+
+ // Filter for a range that does not align with hour boundaries
+ // Partitioning is hours(outputTimestamp), so filter in middle of
hours forces residual filter
+ val startMs = 1704067200000L + 30 * 60 * 1000L // 2024-01-01 01:30:00
(30 min into hour)
+ val endMs = 1704078000000L - 15 * 60 * 1000L // 2024-01-01 03:45:00
(15 min before hour)
+
+ checkIcebergNativeScan(s"""
+ SELECT COUNT(DISTINCT(tripId)) FROM test_cat.db.geolocation_trips
+ WHERE timestamp_millis($startMs) <= outputTimestamp
+ AND outputTimestamp < timestamp_millis($endMs)
+ """)
+
+ spark.sql("DROP TABLE test_cat.db.geolocation_trips")
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]