This is an automated email from the ASF dual-hosted git repository.

liurenjie1024 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-rust.git


The following commit(s) were added to refs/heads/main by this push:
     new 0a06c3e24 refactor(arrow,datafusion): Reuse PartitionValueCalculator 
in RecordBatchPartitionSplitter (#1781)
0a06c3e24 is described below

commit 0a06c3e241a297998ebdcdf7c20f22d03c4e23f2
Author: Shawn Chang <[email protected]>
AuthorDate: Tue Oct 28 02:28:38 2025 -0700

    refactor(arrow,datafusion): Reuse PartitionValueCalculator in 
RecordBatchPartitionSplitter (#1781)
    
    ## Which issue does this PR close?
    
    - Closes #1786
    - Covered some of changes from the previous draft: #1769
    
    ## What changes are included in this PR?
    - Move PartitionValueCalculator to core/arrow so it can be reused by
    RecordBatchPartitionSplitter
    - Allow skipping partition value calculation in partition splitter for
    projected batches
    - Return <PartitionKey, RecordBatch> rather than <Struct, RecordBatch>
    pairs in RecordBatchPartitionSplitter::split
    
    
    ## Are these changes tested?
    Added uts
---
 crates/iceberg/src/arrow/mod.rs                    |   7 +-
 .../src/arrow/partition_value_calculator.rs        | 254 ++++++++++++++
 .../src/arrow/record_batch_partition_splitter.rs   | 377 +++++++++++++++------
 .../datafusion/src/physical_plan/project.rs        | 190 +++--------
 4 files changed, 580 insertions(+), 248 deletions(-)

diff --git a/crates/iceberg/src/arrow/mod.rs b/crates/iceberg/src/arrow/mod.rs
index 28116a4b5..c091c4517 100644
--- a/crates/iceberg/src/arrow/mod.rs
+++ b/crates/iceberg/src/arrow/mod.rs
@@ -35,4 +35,9 @@ mod value;
 
 pub use reader::*;
 pub use value::*;
-pub(crate) mod record_batch_partition_splitter;
+/// Partition value calculator for computing partition values
+pub mod partition_value_calculator;
+pub use partition_value_calculator::*;
+/// Record batch partition splitter for partitioned tables
+pub mod record_batch_partition_splitter;
+pub use record_batch_partition_splitter::*;
diff --git a/crates/iceberg/src/arrow/partition_value_calculator.rs 
b/crates/iceberg/src/arrow/partition_value_calculator.rs
new file mode 100644
index 000000000..140950345
--- /dev/null
+++ b/crates/iceberg/src/arrow/partition_value_calculator.rs
@@ -0,0 +1,254 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Partition value calculation for Iceberg tables.
+//!
+//! This module provides utilities for calculating partition values from 
record batches
+//! based on a partition specification.
+
+use std::sync::Arc;
+
+use arrow_array::{ArrayRef, RecordBatch, StructArray};
+use arrow_schema::DataType;
+
+use super::record_batch_projector::RecordBatchProjector;
+use super::type_to_arrow_type;
+use crate::spec::{PartitionSpec, Schema, StructType, Type};
+use crate::transform::{BoxedTransformFunction, create_transform_function};
+use crate::{Error, ErrorKind, Result};
+
+/// Calculator for partition values in Iceberg tables.
+///
+/// This struct handles the projection of source columns and application of
+/// partition transforms to compute partition values for a given record batch.
+#[derive(Debug)]
+pub struct PartitionValueCalculator {
+    projector: RecordBatchProjector,
+    transform_functions: Vec<BoxedTransformFunction>,
+    partition_type: StructType,
+    partition_arrow_type: DataType,
+}
+
+impl PartitionValueCalculator {
+    /// Create a new PartitionValueCalculator.
+    ///
+    /// # Arguments
+    ///
+    /// * `partition_spec` - The partition specification
+    /// * `table_schema` - The Iceberg table schema
+    ///
+    /// # Returns
+    ///
+    /// Returns a new `PartitionValueCalculator` instance or an error if 
initialization fails.
+    ///
+    /// # Errors
+    ///
+    /// Returns an error if:
+    /// - The partition spec is unpartitioned
+    /// - Transform function creation fails
+    /// - Projector initialization fails
+    pub fn try_new(partition_spec: &PartitionSpec, table_schema: &Schema) -> 
Result<Self> {
+        if partition_spec.is_unpartitioned() {
+            return Err(Error::new(
+                ErrorKind::DataInvalid,
+                "Cannot create partition calculator for unpartitioned table",
+            ));
+        }
+
+        // Create transform functions for each partition field
+        let transform_functions: Vec<BoxedTransformFunction> = partition_spec
+            .fields()
+            .iter()
+            .map(|pf| create_transform_function(&pf.transform))
+            .collect::<Result<Vec<_>>>()?;
+
+        // Extract source field IDs for projection
+        let source_field_ids: Vec<i32> = partition_spec
+            .fields()
+            .iter()
+            .map(|pf| pf.source_id)
+            .collect();
+
+        // Create projector for extracting source columns
+        let projector = RecordBatchProjector::from_iceberg_schema(
+            Arc::new(table_schema.clone()),
+            &source_field_ids,
+        )?;
+
+        // Get partition type information
+        let partition_type = partition_spec.partition_type(table_schema)?;
+        let partition_arrow_type = 
type_to_arrow_type(&Type::Struct(partition_type.clone()))?;
+
+        Ok(Self {
+            projector,
+            transform_functions,
+            partition_type,
+            partition_arrow_type,
+        })
+    }
+
+    /// Get the partition type as an Iceberg StructType.
+    pub fn partition_type(&self) -> &StructType {
+        &self.partition_type
+    }
+
+    /// Get the partition type as an Arrow DataType.
+    pub fn partition_arrow_type(&self) -> &DataType {
+        &self.partition_arrow_type
+    }
+
+    /// Calculate partition values for a record batch.
+    ///
+    /// This method:
+    /// 1. Projects the source columns from the batch
+    /// 2. Applies partition transforms to each source column
+    /// 3. Constructs a StructArray containing the partition values
+    ///
+    /// # Arguments
+    ///
+    /// * `batch` - The record batch to calculate partition values for
+    ///
+    /// # Returns
+    ///
+    /// Returns an ArrayRef containing a StructArray of partition values, or 
an error if calculation fails.
+    ///
+    /// # Errors
+    ///
+    /// Returns an error if:
+    /// - Column projection fails
+    /// - Transform application fails
+    /// - StructArray construction fails
+    pub fn calculate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
+        // Project source columns from the batch
+        let source_columns = self.projector.project_column(batch.columns())?;
+
+        // Get expected struct fields for the result
+        let expected_struct_fields = match &self.partition_arrow_type {
+            DataType::Struct(fields) => fields.clone(),
+            _ => {
+                return Err(Error::new(
+                    ErrorKind::DataInvalid,
+                    "Expected partition type must be a struct",
+                ));
+            }
+        };
+
+        // Apply transforms to each source column
+        let mut partition_values = 
Vec::with_capacity(self.transform_functions.len());
+        for (source_column, transform_fn) in 
source_columns.iter().zip(&self.transform_functions) {
+            let partition_value = 
transform_fn.transform(source_column.clone())?;
+            partition_values.push(partition_value);
+        }
+
+        // Construct the StructArray
+        let struct_array = StructArray::try_new(expected_struct_fields, 
partition_values, None)
+            .map_err(|e| {
+                Error::new(
+                    ErrorKind::DataInvalid,
+                    format!("Failed to create partition struct array: {}", e),
+                )
+            })?;
+
+        Ok(Arc::new(struct_array))
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::sync::Arc;
+
+    use arrow_array::{Int32Array, RecordBatch, StringArray};
+    use arrow_schema::{Field, Schema as ArrowSchema};
+
+    use super::*;
+    use crate::spec::{NestedField, PartitionSpecBuilder, PrimitiveType, 
Transform};
+
+    #[test]
+    fn test_partition_calculator_identity_transform() {
+        let table_schema = Schema::builder()
+            .with_schema_id(0)
+            .with_fields(vec![
+                NestedField::required(1, "id", 
Type::Primitive(PrimitiveType::Int)).into(),
+                NestedField::required(2, "name", 
Type::Primitive(PrimitiveType::String)).into(),
+            ])
+            .build()
+            .unwrap();
+
+        let partition_spec = 
PartitionSpecBuilder::new(Arc::new(table_schema.clone()))
+            .add_partition_field("id", "id_partition", Transform::Identity)
+            .unwrap()
+            .build()
+            .unwrap();
+
+        let calculator = PartitionValueCalculator::try_new(&partition_spec, 
&table_schema).unwrap();
+
+        // Verify partition type
+        assert_eq!(calculator.partition_type().fields().len(), 1);
+        assert_eq!(calculator.partition_type().fields()[0].name, 
"id_partition");
+
+        // Create test batch
+        let arrow_schema = Arc::new(ArrowSchema::new(vec![
+            Field::new("id", DataType::Int32, false),
+            Field::new("name", DataType::Utf8, false),
+        ]));
+
+        let batch = RecordBatch::try_new(arrow_schema, vec![
+            Arc::new(Int32Array::from(vec![10, 20, 30])),
+            Arc::new(StringArray::from(vec!["a", "b", "c"])),
+        ])
+        .unwrap();
+
+        // Calculate partition values
+        let result = calculator.calculate(&batch).unwrap();
+        let struct_array = 
result.as_any().downcast_ref::<StructArray>().unwrap();
+
+        let id_partition = struct_array
+            .column_by_name("id_partition")
+            .unwrap()
+            .as_any()
+            .downcast_ref::<Int32Array>()
+            .unwrap();
+
+        assert_eq!(id_partition.value(0), 10);
+        assert_eq!(id_partition.value(1), 20);
+        assert_eq!(id_partition.value(2), 30);
+    }
+
+    #[test]
+    fn test_partition_calculator_unpartitioned_error() {
+        let table_schema = Schema::builder()
+            .with_schema_id(0)
+            .with_fields(vec![
+                NestedField::required(1, "id", 
Type::Primitive(PrimitiveType::Int)).into(),
+            ])
+            .build()
+            .unwrap();
+
+        let partition_spec = 
PartitionSpecBuilder::new(Arc::new(table_schema.clone()))
+            .build()
+            .unwrap();
+
+        let result = PartitionValueCalculator::try_new(&partition_spec, 
&table_schema);
+        assert!(result.is_err());
+        assert!(
+            result
+                .unwrap_err()
+                .to_string()
+                .contains("unpartitioned table")
+        );
+    }
+}
diff --git a/crates/iceberg/src/arrow/record_batch_partition_splitter.rs 
b/crates/iceberg/src/arrow/record_batch_partition_splitter.rs
index 704a4e9c1..66371fac1 100644
--- a/crates/iceberg/src/arrow/record_batch_partition_splitter.rs
+++ b/crates/iceberg/src/arrow/record_batch_partition_splitter.rs
@@ -19,137 +19,169 @@ use std::collections::HashMap;
 use std::sync::Arc;
 
 use arrow_array::{ArrayRef, BooleanArray, RecordBatch, StructArray};
-use arrow_schema::{DataType, SchemaRef as ArrowSchemaRef};
 use arrow_select::filter::filter_record_batch;
-use itertools::Itertools;
-use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
 
 use super::arrow_struct_to_literal;
-use super::record_batch_projector::RecordBatchProjector;
-use crate::arrow::type_to_arrow_type;
-use crate::spec::{Literal, PartitionSpecRef, SchemaRef, Struct, StructType, 
Type};
-use crate::transform::{BoxedTransformFunction, create_transform_function};
+use super::partition_value_calculator::PartitionValueCalculator;
+use crate::spec::{Literal, PartitionKey, PartitionSpecRef, SchemaRef, 
StructType};
 use crate::{Error, ErrorKind, Result};
 
+/// Column name for the projected partition values struct
+pub const PROJECTED_PARTITION_VALUE_COLUMN: &str = "_partition";
+
 /// The splitter used to split the record batch into multiple record batches 
by the partition spec.
 /// 1. It will project and transform the input record batch based on the 
partition spec, get the partitioned record batch.
 /// 2. Split the input record batch into multiple record batches based on the 
partitioned record batch.
+///
+/// # Partition Value Modes
+///
+/// The splitter supports two modes for obtaining partition values:
+/// - **Computed mode** (`calculator` is `Some`): Computes partition values 
from source columns using transforms
+/// - **Pre-computed mode** (`calculator` is `None`): Expects a `_partition` 
column in the input batch
 // # TODO
 // Remove this after partition writer supported.
 #[allow(dead_code)]
 pub struct RecordBatchPartitionSplitter {
     schema: SchemaRef,
     partition_spec: PartitionSpecRef,
-    projector: RecordBatchProjector,
-    transform_functions: Vec<BoxedTransformFunction>,
-
+    calculator: Option<PartitionValueCalculator>,
     partition_type: StructType,
-    partition_arrow_type: DataType,
 }
 
 // # TODO
 // Remove this after partition writer supported.
 #[allow(dead_code)]
 impl RecordBatchPartitionSplitter {
+    /// Create a new RecordBatchPartitionSplitter.
+    ///
+    /// # Arguments
+    ///
+    /// * `iceberg_schema` - The Iceberg schema reference
+    /// * `partition_spec` - The partition specification reference
+    /// * `calculator` - Optional calculator for computing partition values 
from source columns.
+    ///   - `Some(calculator)`: Compute partition values from source columns 
using transforms
+    ///   - `None`: Expect a pre-computed `_partition` column in the input 
batch
+    ///
+    /// # Returns
+    ///
+    /// Returns a new `RecordBatchPartitionSplitter` instance or an error if 
initialization fails.
     pub fn new(
-        input_schema: ArrowSchemaRef,
         iceberg_schema: SchemaRef,
         partition_spec: PartitionSpecRef,
+        calculator: Option<PartitionValueCalculator>,
     ) -> Result<Self> {
-        let projector = RecordBatchProjector::new(
-            input_schema,
-            &partition_spec
-                .fields()
-                .iter()
-                .map(|field| field.source_id)
-                .collect::<Vec<_>>(),
-            // The source columns, selected by ids, must be a primitive type 
and cannot be contained in a map or list, but may be nested in a struct.
-            // ref: https://iceberg.apache.org/spec/#partitioning
-            |field| {
-                if !field.data_type().is_primitive() {
-                    return Ok(None);
-                }
-                field
-                    .metadata()
-                    .get(PARQUET_FIELD_ID_META_KEY)
-                    .map(|s| {
-                        s.parse::<i64>()
-                            .map_err(|e| Error::new(ErrorKind::Unexpected, 
e.to_string()))
-                    })
-                    .transpose()
-            },
-            |_| true,
-        )?;
-        let transform_functions = partition_spec
-            .fields()
-            .iter()
-            .map(|field| create_transform_function(&field.transform))
-            .collect::<Result<Vec<_>>>()?;
-
         let partition_type = partition_spec.partition_type(&iceberg_schema)?;
-        let partition_arrow_type = 
type_to_arrow_type(&Type::Struct(partition_type.clone()))?;
 
         Ok(Self {
             schema: iceberg_schema,
             partition_spec,
-            projector,
-            transform_functions,
+            calculator,
             partition_type,
-            partition_arrow_type,
         })
     }
 
-    fn partition_columns_to_struct(&self, partition_columns: Vec<ArrayRef>) -> 
Result<Vec<Struct>> {
-        let arrow_struct_array = {
-            let partition_arrow_fields = {
-                let DataType::Struct(fields) = &self.partition_arrow_type else 
{
-                    return Err(Error::new(
-                        ErrorKind::DataInvalid,
-                        "The partition arrow type is not a struct type",
-                    ));
-                };
-                fields.clone()
-            };
-            Arc::new(StructArray::try_new(
-                partition_arrow_fields,
-                partition_columns,
-                None,
-            )?) as ArrayRef
-        };
-        let struct_array = {
-            let struct_array = arrow_struct_to_literal(&arrow_struct_array, 
&self.partition_type)?;
+    /// Create a new RecordBatchPartitionSplitter with computed partition 
values.
+    ///
+    /// This is a convenience method that creates a calculator and initializes 
the splitter
+    /// to compute partition values from source columns.
+    ///
+    /// # Arguments
+    ///
+    /// * `iceberg_schema` - The Iceberg schema reference
+    /// * `partition_spec` - The partition specification reference
+    ///
+    /// # Returns
+    ///
+    /// Returns a new `RecordBatchPartitionSplitter` instance or an error if 
initialization fails.
+    pub fn new_with_computed_values(
+        iceberg_schema: SchemaRef,
+        partition_spec: PartitionSpecRef,
+    ) -> Result<Self> {
+        let calculator = PartitionValueCalculator::try_new(&partition_spec, 
&iceberg_schema)?;
+        Self::new(iceberg_schema, partition_spec, Some(calculator))
+    }
+
+    /// Create a new RecordBatchPartitionSplitter expecting pre-computed 
partition values.
+    ///
+    /// This is a convenience method that initializes the splitter to expect a 
`_partition`
+    /// column in the input batches.
+    ///
+    /// # Arguments
+    ///
+    /// * `iceberg_schema` - The Iceberg schema reference
+    /// * `partition_spec` - The partition specification reference
+    ///
+    /// # Returns
+    ///
+    /// Returns a new `RecordBatchPartitionSplitter` instance or an error if 
initialization fails.
+    pub fn new_with_precomputed_values(
+        iceberg_schema: SchemaRef,
+        partition_spec: PartitionSpecRef,
+    ) -> Result<Self> {
+        Self::new(iceberg_schema, partition_spec, None)
+    }
+
+    /// Split the record batch into multiple record batches based on the 
partition spec.
+    pub fn split(&self, batch: &RecordBatch) -> Result<Vec<(PartitionKey, 
RecordBatch)>> {
+        let partition_structs = if let Some(calculator) = &self.calculator {
+            // Compute partition values from source columns using calculator
+            let partition_array = calculator.calculate(batch)?;
+            let struct_array = arrow_struct_to_literal(&partition_array, 
&self.partition_type)?;
+
             struct_array
                 .into_iter()
                 .map(|s| {
-                    if let Some(s) = s {
-                        if let Literal::Struct(s) = s {
-                            Ok(s)
-                        } else {
-                            Err(Error::new(
-                                ErrorKind::DataInvalid,
-                                "The struct is not a struct literal",
-                            ))
-                        }
+                    if let Some(Literal::Struct(s)) = s {
+                        Ok(s)
                     } else {
-                        Err(Error::new(ErrorKind::DataInvalid, "The struct is 
null"))
+                        Err(Error::new(
+                            ErrorKind::DataInvalid,
+                            "Partition value is not a struct literal or is 
null",
+                        ))
                     }
                 })
                 .collect::<Result<Vec<_>>>()?
-        };
+        } else {
+            // Extract partition values from pre-computed partition column
+            let partition_column = batch
+                .column_by_name(PROJECTED_PARTITION_VALUE_COLUMN)
+                .ok_or_else(|| {
+                    Error::new(
+                        ErrorKind::DataInvalid,
+                        format!(
+                            "Partition column '{}' not found in batch",
+                            PROJECTED_PARTITION_VALUE_COLUMN
+                        ),
+                    )
+                })?;
 
-        Ok(struct_array)
-    }
+            let partition_struct_array = partition_column
+                .as_any()
+                .downcast_ref::<StructArray>()
+                .ok_or_else(|| {
+                    Error::new(
+                        ErrorKind::DataInvalid,
+                        "Partition column is not a StructArray",
+                    )
+                })?;
 
-    /// Split the record batch into multiple record batches based on the 
partition spec.
-    pub fn split(&self, batch: &RecordBatch) -> Result<Vec<(Struct, 
RecordBatch)>> {
-        let source_columns = self.projector.project_column(batch.columns())?;
-        let partition_columns = source_columns
-            .into_iter()
-            .zip_eq(self.transform_functions.iter())
-            .map(|(source_column, transform_function)| 
transform_function.transform(source_column))
-            .collect::<Result<Vec<_>>>()?;
+            let arrow_struct_array = Arc::new(partition_struct_array.clone()) 
as ArrayRef;
+            let struct_array = arrow_struct_to_literal(&arrow_struct_array, 
&self.partition_type)?;
 
-        let partition_structs = 
self.partition_columns_to_struct(partition_columns)?;
+            struct_array
+                .into_iter()
+                .map(|s| {
+                    if let Some(Literal::Struct(s)) = s {
+                        Ok(s)
+                    } else {
+                        Err(Error::new(
+                            ErrorKind::DataInvalid,
+                            "Partition value is not a struct literal or is 
null",
+                        ))
+                    }
+                })
+                .collect::<Result<Vec<_>>>()?
+        };
 
         // Group the batch by row value.
         let mut group_ids = HashMap::new();
@@ -172,8 +204,15 @@ impl RecordBatchPartitionSplitter {
                 filter.into()
             };
 
+            // Create PartitionKey from the partition struct
+            let partition_key = PartitionKey::new(
+                self.partition_spec.as_ref().clone(),
+                self.schema.clone(),
+                row,
+            );
+
             // filter the RecordBatch
-            partition_batches.push((row, filter_record_batch(batch, 
&filter_array)?));
+            partition_batches.push((partition_key, filter_record_batch(batch, 
&filter_array)?));
         }
 
         Ok(partition_batches)
@@ -185,11 +224,13 @@ mod tests {
     use std::sync::Arc;
 
     use arrow_array::{Int32Array, RecordBatch, StringArray};
+    use arrow_schema::DataType;
+    use parquet::arrow::PARQUET_FIELD_ID_META_KEY;
 
     use super::*;
     use crate::arrow::schema_to_arrow_schema;
     use crate::spec::{
-        NestedField, PartitionSpecBuilder, PrimitiveLiteral, Schema, Transform,
+        NestedField, PartitionSpecBuilder, PrimitiveLiteral, Schema, Struct, 
Transform, Type,
         UnboundPartitionField,
     };
 
@@ -227,14 +268,14 @@ mod tests {
                 .build()
                 .unwrap(),
         );
-        let input_schema = Arc::new(schema_to_arrow_schema(&schema).unwrap());
         let partition_splitter =
-            RecordBatchPartitionSplitter::new(input_schema.clone(), 
schema.clone(), partition_spec)
+            
RecordBatchPartitionSplitter::new_with_computed_values(schema.clone(), 
partition_spec)
                 .expect("Failed to create splitter");
 
+        let arrow_schema = Arc::new(schema_to_arrow_schema(&schema).unwrap());
         let id_array = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]);
         let data_array = StringArray::from(vec!["a", "b", "c", "d", "e", "f", 
"g"]);
-        let batch = RecordBatch::try_new(input_schema.clone(), vec![
+        let batch = RecordBatch::try_new(arrow_schema.clone(), vec![
             Arc::new(id_array),
             Arc::new(data_array),
         ])
@@ -243,8 +284,8 @@ mod tests {
         let mut partitioned_batches = partition_splitter
             .split(&batch)
             .expect("Failed to split RecordBatch");
-        partitioned_batches.sort_by_key(|(row, _)| {
-            if let PrimitiveLiteral::Int(i) = row.fields()[0]
+        partitioned_batches.sort_by_key(|(partition_key, _)| {
+            if let PrimitiveLiteral::Int(i) = partition_key.data().fields()[0]
                 .as_ref()
                 .unwrap()
                 .as_primitive_literal()
@@ -260,7 +301,7 @@ mod tests {
             // check the first partition
             let expected_id_array = Int32Array::from(vec![1, 1, 1]);
             let expected_data_array = StringArray::from(vec!["a", "c", "g"]);
-            let expected_batch = RecordBatch::try_new(input_schema.clone(), 
vec![
+            let expected_batch = RecordBatch::try_new(arrow_schema.clone(), 
vec![
                 Arc::new(expected_id_array),
                 Arc::new(expected_data_array),
             ])
@@ -271,7 +312,7 @@ mod tests {
             // check the second partition
             let expected_id_array = Int32Array::from(vec![2, 2]);
             let expected_data_array = StringArray::from(vec!["b", "e"]);
-            let expected_batch = RecordBatch::try_new(input_schema.clone(), 
vec![
+            let expected_batch = RecordBatch::try_new(arrow_schema.clone(), 
vec![
                 Arc::new(expected_id_array),
                 Arc::new(expected_data_array),
             ])
@@ -282,7 +323,7 @@ mod tests {
             // check the third partition
             let expected_id_array = Int32Array::from(vec![3, 3]);
             let expected_data_array = StringArray::from(vec!["d", "f"]);
-            let expected_batch = RecordBatch::try_new(input_schema.clone(), 
vec![
+            let expected_batch = RecordBatch::try_new(arrow_schema.clone(), 
vec![
                 Arc::new(expected_id_array),
                 Arc::new(expected_data_array),
             ])
@@ -292,7 +333,7 @@ mod tests {
 
         let partition_values = partitioned_batches
             .iter()
-            .map(|(row, _)| row.clone())
+            .map(|(partition_key, _)| partition_key.data().clone())
             .collect::<Vec<_>>();
         // check partition value is struct(1), struct(2), struct(3)
         assert_eq!(partition_values, vec![
@@ -301,4 +342,144 @@ mod tests {
             Struct::from_iter(vec![Some(Literal::int(3))]),
         ]);
     }
+
+    #[test]
+    fn test_record_batch_partition_split_with_partition_column() {
+        use arrow_array::StructArray;
+        use arrow_schema::{Field, Schema as ArrowSchema};
+
+        let schema = Arc::new(
+            Schema::builder()
+                .with_fields(vec![
+                    NestedField::required(
+                        1,
+                        "id",
+                        Type::Primitive(crate::spec::PrimitiveType::Int),
+                    )
+                    .into(),
+                    NestedField::required(
+                        2,
+                        "name",
+                        Type::Primitive(crate::spec::PrimitiveType::String),
+                    )
+                    .into(),
+                ])
+                .build()
+                .unwrap(),
+        );
+        let partition_spec = Arc::new(
+            PartitionSpecBuilder::new(schema.clone())
+                .with_spec_id(1)
+                .add_unbound_field(UnboundPartitionField {
+                    source_id: 1,
+                    field_id: None,
+                    name: "id_bucket".to_string(),
+                    transform: Transform::Identity,
+                })
+                .unwrap()
+                .build()
+                .unwrap(),
+        );
+
+        // Create input schema with _partition column
+        // Note: partition field IDs start from 1000 by default
+        let partition_field = Field::new("id_bucket", DataType::Int32, 
false).with_metadata(
+            HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), 
"1000".to_string())]),
+        );
+        let partition_struct_field = Field::new(
+            PROJECTED_PARTITION_VALUE_COLUMN,
+            DataType::Struct(vec![partition_field.clone()].into()),
+            false,
+        );
+
+        let input_schema = Arc::new(ArrowSchema::new(vec![
+            Field::new("id", DataType::Int32, false),
+            Field::new("name", DataType::Utf8, false),
+            partition_struct_field,
+        ]));
+
+        // Create splitter expecting pre-computed partition column
+        let partition_splitter = 
RecordBatchPartitionSplitter::new_with_precomputed_values(
+            schema.clone(),
+            partition_spec,
+        )
+        .expect("Failed to create splitter");
+
+        // Create test data with pre-computed partition column
+        let id_array = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]);
+        let data_array = StringArray::from(vec!["a", "b", "c", "d", "e", "f", 
"g"]);
+
+        // Create partition column (same values as id for Identity transform)
+        let partition_values = Int32Array::from(vec![1, 2, 1, 3, 2, 3, 1]);
+        let partition_struct = StructArray::from(vec![(
+            Arc::new(partition_field),
+            Arc::new(partition_values) as ArrayRef,
+        )]);
+
+        let batch = RecordBatch::try_new(input_schema.clone(), vec![
+            Arc::new(id_array),
+            Arc::new(data_array),
+            Arc::new(partition_struct),
+        ])
+        .expect("Failed to create RecordBatch");
+
+        // Split using the pre-computed partition column
+        let mut partitioned_batches = partition_splitter
+            .split(&batch)
+            .expect("Failed to split RecordBatch");
+
+        partitioned_batches.sort_by_key(|(partition_key, _)| {
+            if let PrimitiveLiteral::Int(i) = partition_key.data().fields()[0]
+                .as_ref()
+                .unwrap()
+                .as_primitive_literal()
+                .unwrap()
+            {
+                i
+            } else {
+                panic!("The partition value is not a int");
+            }
+        });
+
+        assert_eq!(partitioned_batches.len(), 3);
+
+        // Helper to extract id and name values from a batch
+        let extract_values = |batch: &RecordBatch| -> (Vec<i32>, Vec<String>) {
+            let id_col = batch
+                .column(0)
+                .as_any()
+                .downcast_ref::<Int32Array>()
+                .unwrap();
+            let name_col = batch
+                .column(1)
+                .as_any()
+                .downcast_ref::<StringArray>()
+                .unwrap();
+            (
+                id_col.values().to_vec(),
+                name_col.iter().map(|s| s.unwrap().to_string()).collect(),
+            )
+        };
+
+        // Verify partition 1: id=1, names=["a", "c", "g"]
+        let (key, batch) = &partitioned_batches[0];
+        assert_eq!(key.data(), 
&Struct::from_iter(vec![Some(Literal::int(1))]));
+        let (ids, names) = extract_values(batch);
+        assert_eq!(ids, vec![1, 1, 1]);
+        assert_eq!(names, vec!["a", "c", "g"]);
+
+        // Verify partition 2: id=2, names=["b", "e"]
+        let (key, batch) = &partitioned_batches[1];
+        assert_eq!(key.data(), 
&Struct::from_iter(vec![Some(Literal::int(2))]));
+        let (ids, names) = extract_values(batch);
+        assert_eq!(ids, vec![2, 2]);
+        assert_eq!(names, vec!["b", "e"]);
+
+        // Verify partition 3: id=3, names=["d", "f"]
+        let (key, batch) = &partitioned_batches[2];
+        assert_eq!(key.data(), 
&Struct::from_iter(vec![Some(Literal::int(3))]));
+        let (ids, names) = extract_values(batch);
+        assert_eq!(ids, vec![3, 3]);
+        assert_eq!(names, vec!["d", "f"]);
+    }
 }
diff --git a/crates/integrations/datafusion/src/physical_plan/project.rs 
b/crates/integrations/datafusion/src/physical_plan/project.rs
index 4bfe8192b..17492176a 100644
--- a/crates/integrations/datafusion/src/physical_plan/project.rs
+++ b/crates/integrations/datafusion/src/physical_plan/project.rs
@@ -19,24 +19,19 @@
 
 use std::sync::Arc;
 
-use datafusion::arrow::array::{ArrayRef, RecordBatch, StructArray};
+use datafusion::arrow::array::RecordBatch;
 use datafusion::arrow::datatypes::{DataType, Schema as ArrowSchema};
 use datafusion::common::Result as DFResult;
-use datafusion::error::DataFusionError;
 use datafusion::physical_expr::PhysicalExpr;
 use datafusion::physical_expr::expressions::Column;
 use datafusion::physical_plan::projection::ProjectionExec;
 use datafusion::physical_plan::{ColumnarValue, ExecutionPlan};
-use iceberg::arrow::record_batch_projector::RecordBatchProjector;
-use iceberg::spec::{PartitionSpec, Schema};
+use iceberg::arrow::{PROJECTED_PARTITION_VALUE_COLUMN, 
PartitionValueCalculator};
+use iceberg::spec::PartitionSpec;
 use iceberg::table::Table;
-use iceberg::transform::BoxedTransformFunction;
 
 use crate::to_datafusion_error;
 
-/// Column name for the combined partition values struct
-const PARTITION_VALUES_COLUMN: &str = "_partition";
-
 /// Extends an ExecutionPlan with partition value calculations for Iceberg 
tables.
 ///
 /// This function takes an input ExecutionPlan and extends it with an 
additional column
@@ -65,12 +60,9 @@ pub fn project_with_partition(
     let input_schema = input.schema();
     // TODO: Validate that input_schema matches the Iceberg table schema.
     // See: https://github.com/apache/iceberg-rust/issues/1752
-    let partition_type = build_partition_type(partition_spec, 
table_schema.as_ref())?;
-    let calculator = PartitionValueCalculator::new(
-        partition_spec.as_ref().clone(),
-        table_schema.as_ref().clone(),
-        partition_type,
-    )?;
+    let calculator =
+        PartitionValueCalculator::try_new(partition_spec.as_ref(), 
table_schema.as_ref())
+            .map_err(to_datafusion_error)?;
 
     let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
         Vec::with_capacity(input_schema.fields().len() + 1);
@@ -80,8 +72,8 @@ pub fn project_with_partition(
         projection_exprs.push((column_expr, field.name().clone()));
     }
 
-    let partition_expr = Arc::new(PartitionExpr::new(calculator));
-    projection_exprs.push((partition_expr, 
PARTITION_VALUES_COLUMN.to_string()));
+    let partition_expr = Arc::new(PartitionExpr::new(calculator, 
partition_spec.clone()));
+    projection_exprs.push((partition_expr, 
PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
 
     let projection = ProjectionExec::try_new(projection_exprs, input)?;
     Ok(Arc::new(projection))
@@ -91,21 +83,24 @@ pub fn project_with_partition(
 #[derive(Debug, Clone)]
 struct PartitionExpr {
     calculator: Arc<PartitionValueCalculator>,
+    partition_spec: Arc<PartitionSpec>,
 }
 
 impl PartitionExpr {
-    fn new(calculator: PartitionValueCalculator) -> Self {
+    fn new(calculator: PartitionValueCalculator, partition_spec: 
Arc<PartitionSpec>) -> Self {
         Self {
             calculator: Arc::new(calculator),
+            partition_spec,
         }
     }
 }
 
 // Manual PartialEq/Eq implementations for pointer-based equality
-// (two PartitionExpr are equal if they share the same calculator instance)
+// (two PartitionExpr are equal if they share the same calculator and 
partition_spec instances)
 impl PartialEq for PartitionExpr {
     fn eq(&self, other: &Self) -> bool {
         Arc::ptr_eq(&self.calculator, &other.calculator)
+            && Arc::ptr_eq(&self.partition_spec, &other.partition_spec)
     }
 }
 
@@ -117,7 +112,7 @@ impl PhysicalExpr for PartitionExpr {
     }
 
     fn data_type(&self, _input_schema: &ArrowSchema) -> DFResult<DataType> {
-        Ok(self.calculator.partition_type.clone())
+        Ok(self.calculator.partition_arrow_type().clone())
     }
 
     fn nullable(&self, _input_schema: &ArrowSchema) -> DFResult<bool> {
@@ -125,7 +120,10 @@ impl PhysicalExpr for PartitionExpr {
     }
 
     fn evaluate(&self, batch: &RecordBatch) -> DFResult<ColumnarValue> {
-        let array = self.calculator.calculate(batch)?;
+        let array = self
+            .calculator
+            .calculate(batch)
+            .map_err(to_datafusion_error)?;
         Ok(ColumnarValue::Array(array))
     }
 
@@ -142,7 +140,6 @@ impl PhysicalExpr for PartitionExpr {
 
     fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         let field_names: Vec<String> = self
-            .calculator
             .partition_spec
             .fields()
             .iter()
@@ -155,7 +152,6 @@ impl PhysicalExpr for PartitionExpr {
 impl std::fmt::Display for PartitionExpr {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         let field_names: Vec<&str> = self
-            .calculator
             .partition_spec
             .fields()
             .iter()
@@ -167,110 +163,18 @@ impl std::fmt::Display for PartitionExpr {
 
 impl std::hash::Hash for PartitionExpr {
     fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
-        // Two PartitionExpr are equal if they share the same calculator Arc
+        // Two PartitionExpr are equal if they share the same calculator and 
partition_spec Arcs
         Arc::as_ptr(&self.calculator).hash(state);
+        Arc::as_ptr(&self.partition_spec).hash(state);
     }
 }
 
-/// Calculator for partition values in Iceberg tables
-#[derive(Debug)]
-struct PartitionValueCalculator {
-    partition_spec: PartitionSpec,
-    partition_type: DataType,
-    projector: RecordBatchProjector,
-    transform_functions: Vec<BoxedTransformFunction>,
-}
-
-impl PartitionValueCalculator {
-    fn new(
-        partition_spec: PartitionSpec,
-        table_schema: Schema,
-        partition_type: DataType,
-    ) -> DFResult<Self> {
-        if partition_spec.is_unpartitioned() {
-            return Err(DataFusionError::Internal(
-                "Cannot create partition calculator for unpartitioned 
table".to_string(),
-            ));
-        }
-
-        let transform_functions: Result<Vec<BoxedTransformFunction>, _> = 
partition_spec
-            .fields()
-            .iter()
-            .map(|pf| 
iceberg::transform::create_transform_function(&pf.transform))
-            .collect();
-
-        let transform_functions = 
transform_functions.map_err(to_datafusion_error)?;
-
-        let source_field_ids: Vec<i32> = partition_spec
-            .fields()
-            .iter()
-            .map(|pf| pf.source_id)
-            .collect();
-
-        let projector = RecordBatchProjector::from_iceberg_schema(
-            Arc::new(table_schema.clone()),
-            &source_field_ids,
-        )
-        .map_err(to_datafusion_error)?;
-
-        Ok(Self {
-            partition_spec,
-            partition_type,
-            projector,
-            transform_functions,
-        })
-    }
-
-    fn calculate(&self, batch: &RecordBatch) -> DFResult<ArrayRef> {
-        let source_columns = self
-            .projector
-            .project_column(batch.columns())
-            .map_err(to_datafusion_error)?;
-
-        let expected_struct_fields = match &self.partition_type {
-            DataType::Struct(fields) => fields.clone(),
-            _ => {
-                return Err(DataFusionError::Internal(
-                    "Expected partition type must be a struct".to_string(),
-                ));
-            }
-        };
-
-        let mut partition_values = 
Vec::with_capacity(self.partition_spec.fields().len());
-
-        for (source_column, transform_fn) in 
source_columns.iter().zip(&self.transform_functions) {
-            let partition_value = transform_fn
-                .transform(source_column.clone())
-                .map_err(to_datafusion_error)?;
-
-            partition_values.push(partition_value);
-        }
-
-        let struct_array = StructArray::try_new(expected_struct_fields, 
partition_values, None)
-            .map_err(|e| DataFusionError::ArrowError(e, None))?;
-
-        Ok(Arc::new(struct_array))
-    }
-}
-
-fn build_partition_type(
-    partition_spec: &PartitionSpec,
-    table_schema: &Schema,
-) -> DFResult<DataType> {
-    let partition_struct_type = partition_spec
-        .partition_type(table_schema)
-        .map_err(to_datafusion_error)?;
-
-    
iceberg::arrow::type_to_arrow_type(&iceberg::spec::Type::Struct(partition_struct_type))
-        .map_err(to_datafusion_error)
-}
-
 #[cfg(test)]
 mod tests {
-    use datafusion::arrow::array::Int32Array;
+    use datafusion::arrow::array::{ArrayRef, Int32Array, StructArray};
     use datafusion::arrow::datatypes::{Field, Fields};
     use datafusion::physical_plan::empty::EmptyExec;
-    use iceberg::spec::{NestedField, PrimitiveType, StructType, Transform, 
Type};
+    use iceberg::spec::{NestedField, PrimitiveType, Schema, StructType, 
Transform, Type};
 
     use super::*;
 
@@ -291,20 +195,11 @@ mod tests {
             .build()
             .unwrap();
 
-        let _arrow_schema = Arc::new(ArrowSchema::new(vec![
-            Field::new("id", DataType::Int32, false),
-            Field::new("name", DataType::Utf8, false),
-        ]));
-
-        let partition_type = build_partition_type(&partition_spec, 
&table_schema).unwrap();
-        let calculator = PartitionValueCalculator::new(
-            partition_spec.clone(),
-            table_schema,
-            partition_type.clone(),
-        )
-        .unwrap();
+        let calculator = PartitionValueCalculator::try_new(&partition_spec, 
&table_schema).unwrap();
 
-        assert_eq!(calculator.partition_type, partition_type);
+        // Verify partition type
+        assert_eq!(calculator.partition_type().fields().len(), 1);
+        assert_eq!(calculator.partition_type().fields()[0].name, 
"id_partition");
     }
 
     #[test]
@@ -318,11 +213,13 @@ mod tests {
             .build()
             .unwrap();
 
-        let partition_spec = 
iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
-            .add_partition_field("id", "id_partition", Transform::Identity)
-            .unwrap()
-            .build()
-            .unwrap();
+        let partition_spec = Arc::new(
+            
iceberg::spec::PartitionSpec::builder(Arc::new(table_schema.clone()))
+                .add_partition_field("id", "id_partition", Transform::Identity)
+                .unwrap()
+                .build()
+                .unwrap(),
+        );
 
         let arrow_schema = Arc::new(ArrowSchema::new(vec![
             Field::new("id", DataType::Int32, false),
@@ -331,9 +228,7 @@ mod tests {
 
         let input = Arc::new(EmptyExec::new(arrow_schema.clone()));
 
-        let partition_type = build_partition_type(&partition_spec, 
&table_schema).unwrap();
-        let calculator =
-            PartitionValueCalculator::new(partition_spec, table_schema, 
partition_type).unwrap();
+        let calculator = PartitionValueCalculator::try_new(&partition_spec, 
&table_schema).unwrap();
 
         let mut projection_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> =
             Vec::with_capacity(arrow_schema.fields().len() + 1);
@@ -342,8 +237,8 @@ mod tests {
             projection_exprs.push((column_expr, field.name().clone()));
         }
 
-        let partition_expr = Arc::new(PartitionExpr::new(calculator));
-        projection_exprs.push((partition_expr, 
PARTITION_VALUES_COLUMN.to_string()));
+        let partition_expr = Arc::new(PartitionExpr::new(calculator, 
partition_spec));
+        projection_exprs.push((partition_expr, 
PROJECTED_PARTITION_VALUE_COLUMN.to_string()));
 
         let projection = ProjectionExec::try_new(projection_exprs, 
input).unwrap();
         let result = Arc::new(projection);
@@ -384,11 +279,10 @@ mod tests {
         ])
         .unwrap();
 
-        let partition_type = build_partition_type(&partition_spec, 
&table_schema).unwrap();
-        let calculator =
-            PartitionValueCalculator::new(partition_spec, table_schema, 
partition_type.clone())
-                .unwrap();
-        let expr = PartitionExpr::new(calculator);
+        let partition_spec = Arc::new(partition_spec);
+        let calculator = PartitionValueCalculator::try_new(&partition_spec, 
&table_schema).unwrap();
+        let partition_type = calculator.partition_arrow_type().clone();
+        let expr = PartitionExpr::new(calculator, partition_spec);
 
         assert_eq!(expr.data_type(&arrow_schema).unwrap(), partition_type);
         assert!(!expr.nullable(&arrow_schema).unwrap());
@@ -469,9 +363,7 @@ mod tests {
         ])
         .unwrap();
 
-        let partition_type = build_partition_type(&partition_spec, 
&table_schema).unwrap();
-        let calculator =
-            PartitionValueCalculator::new(partition_spec, table_schema, 
partition_type).unwrap();
+        let calculator = PartitionValueCalculator::try_new(&partition_spec, 
&table_schema).unwrap();
         let array = calculator.calculate(&batch).unwrap();
 
         let struct_array = 
array.as_any().downcast_ref::<StructArray>().unwrap();


Reply via email to