This is an automated email from the ASF dual-hosted git repository.

richox pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git


The following commit(s) were added to refs/heads/master by this push:
     new e98d9d35 [AURON #1888] Implement spark_partition_id() function (#1928)
e98d9d35 is described below

commit e98d9d356be3736db4594b615a71dff42a79d82e
Author: Shreyesh <[email protected]>
AuthorDate: Wed Jan 21 05:46:23 2026 -0800

    [AURON #1888] Implement spark_partition_id() function (#1928)
    
    <!--
    - Start the PR title with the related issue ID, e.g. '[AURON #XXXX]
    Short summary...'.
    -->
    # Which issue does this PR close?
    
    Closes ##1888
    
    # Rationale for this change
    This is part of the effort to add non-deterministic expressions to
    Auron.
    
    # What changes are included in this PR?
    This PR adds support for spark_partition_id using native
    
    # Are there any user-facing changes?
    N/A
    
    # How was this patch tested?
    Unit tests
---
 native-engine/auron-planner/proto/auron.proto      |   5 +
 native-engine/auron-planner/src/planner.rs         |   5 +-
 native-engine/datafusion-ext-exprs/src/lib.rs      |   1 +
 .../datafusion-ext-exprs/src/spark_partition_id.rs | 189 +++++++++++++++++++++
 .../org/apache/spark/sql/auron/ShimsImpl.scala     |   8 +
 .../apache/spark/sql/auron/NativeConverters.scala  |   5 +
 6 files changed, 212 insertions(+), 1 deletion(-)

diff --git a/native-engine/auron-planner/proto/auron.proto 
b/native-engine/auron-planner/proto/auron.proto
index 788be352..99f8078b 100644
--- a/native-engine/auron-planner/proto/auron.proto
+++ b/native-engine/auron-planner/proto/auron.proto
@@ -113,6 +113,9 @@ message PhysicalExprNode {
     // RowNum
     RowNumExprNode row_num_expr = 20100;
 
+    // SparkPartitionID
+    SparkPartitionIdExprNode spark_partition_id_expr = 20101;
+
     // BloomFilterMightContain
     BloomFilterMightContainExprNode bloom_filter_might_contain_expr = 20200;
   }
@@ -914,3 +917,5 @@ message ArrowType {
 //   }
 //}
 message EmptyMessage{}
+
+message SparkPartitionIdExprNode {}
diff --git a/native-engine/auron-planner/src/planner.rs 
b/native-engine/auron-planner/src/planner.rs
index 8e13312b..cfab99e1 100644
--- a/native-engine/auron-planner/src/planner.rs
+++ b/native-engine/auron-planner/src/planner.rs
@@ -52,7 +52,7 @@ use datafusion::{
 use datafusion_ext_exprs::{
     bloom_filter_might_contain::BloomFilterMightContainExpr, cast::TryCastExpr,
     get_indexed_field::GetIndexedFieldExpr, get_map_value::GetMapValueExpr,
-    named_struct::NamedStructExpr, row_num::RowNumExpr,
+    named_struct::NamedStructExpr, row_num::RowNumExpr, 
spark_partition_id::SparkPartitionIdExpr,
     spark_scalar_subquery_wrapper::SparkScalarSubqueryWrapperExpr,
     spark_udf_wrapper::SparkUDFWrapperExpr, 
string_contains::StringContainsExpr,
     string_ends_with::StringEndsWithExpr, 
string_starts_with::StringStartsWithExpr,
@@ -962,6 +962,9 @@ impl PhysicalPlanner {
                 Arc::new(StringContainsExpr::new(expr, e.infix.clone()))
             }
             ExprType::RowNumExpr(_) => Arc::new(RowNumExpr::default()),
+            ExprType::SparkPartitionIdExpr(_) => {
+                Arc::new(SparkPartitionIdExpr::new(self.partition_id))
+            }
             ExprType::BloomFilterMightContainExpr(e) => 
Arc::new(BloomFilterMightContainExpr::new(
                 e.uuid.clone(),
                 
self.try_parse_physical_expr_box_required(&e.bloom_filter_expr, input_schema)?,
diff --git a/native-engine/datafusion-ext-exprs/src/lib.rs 
b/native-engine/datafusion-ext-exprs/src/lib.rs
index 3a685a41..bb2757f0 100644
--- a/native-engine/datafusion-ext-exprs/src/lib.rs
+++ b/native-engine/datafusion-ext-exprs/src/lib.rs
@@ -23,6 +23,7 @@ pub mod get_indexed_field;
 pub mod get_map_value;
 pub mod named_struct;
 pub mod row_num;
+pub mod spark_partition_id;
 pub mod spark_scalar_subquery_wrapper;
 pub mod spark_udf_wrapper;
 pub mod string_contains;
diff --git a/native-engine/datafusion-ext-exprs/src/spark_partition_id.rs 
b/native-engine/datafusion-ext-exprs/src/spark_partition_id.rs
new file mode 100644
index 00000000..d34150db
--- /dev/null
+++ b/native-engine/datafusion-ext-exprs/src/spark_partition_id.rs
@@ -0,0 +1,189 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::{
+    any::Any,
+    fmt::{Debug, Display, Formatter},
+    hash::{Hash, Hasher},
+    sync::Arc,
+};
+
+use arrow::{
+    array::{Int32Array, RecordBatch},
+    datatypes::{DataType, Schema},
+};
+use datafusion::{
+    common::Result,
+    logical_expr::ColumnarValue,
+    physical_expr::{PhysicalExpr, PhysicalExprRef},
+};
+
+pub struct SparkPartitionIdExpr {
+    partition_id: i32,
+}
+
+impl SparkPartitionIdExpr {
+    pub fn new(partition_id: usize) -> Self {
+        Self {
+            partition_id: partition_id as i32,
+        }
+    }
+}
+
+impl Display for SparkPartitionIdExpr {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        write!(f, "SparkPartitionID")
+    }
+}
+
+impl Debug for SparkPartitionIdExpr {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        write!(f, "SparkPartitionID")
+    }
+}
+
+impl PartialEq for SparkPartitionIdExpr {
+    fn eq(&self, other: &Self) -> bool {
+        self.partition_id == other.partition_id
+    }
+}
+
+impl Eq for SparkPartitionIdExpr {}
+
+impl Hash for SparkPartitionIdExpr {
+    fn hash<H: Hasher>(&self, state: &mut H) {
+        self.partition_id.hash(state)
+    }
+}
+
+impl PhysicalExpr for SparkPartitionIdExpr {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
+        Ok(DataType::Int32)
+    }
+
+    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
+        Ok(false)
+    }
+
+    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
+        let num_rows = batch.num_rows();
+        let array = Int32Array::from_value(self.partition_id, num_rows);
+        Ok(ColumnarValue::Array(Arc::new(array)))
+    }
+
+    fn children(&self) -> Vec<&PhysicalExprRef> {
+        vec![]
+    }
+
+    fn with_new_children(
+        self: Arc<Self>,
+        _children: Vec<PhysicalExprRef>,
+    ) -> Result<PhysicalExprRef> {
+        Ok(self)
+    }
+
+    fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        write!(f, "fmt_sql not used")
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use arrow::{
+        array::Int32Array,
+        datatypes::{Field, Schema},
+        record_batch::RecordBatch,
+    };
+
+    use super::*;
+
+    #[test]
+    fn test_data_type_and_nullable() {
+        let expr = SparkPartitionIdExpr::new(0);
+        let schema = Schema::new(vec![] as Vec<Field>);
+        assert_eq!(
+            expr.data_type(&schema).expect("data_type failed"),
+            DataType::Int32
+        );
+        assert!(!expr.nullable(&schema).expect("nullable failed"));
+    }
+
+    #[test]
+    fn test_evaluate_returns_constant_partition_id() {
+        let expr = SparkPartitionIdExpr::new(5);
+        let schema = Schema::new(vec![Field::new("col", DataType::Int32, 
false)]);
+        let batch = RecordBatch::try_new(
+            Arc::new(schema),
+            vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
+        )
+        .expect("RecordBatch creation failed");
+
+        let result = expr.evaluate(&batch).expect("evaluate failed");
+        match result {
+            ColumnarValue::Array(arr) => {
+                let int_arr = arr
+                    .as_any()
+                    .downcast_ref::<Int32Array>()
+                    .expect("downcast failed");
+                assert_eq!(int_arr.len(), 3);
+                for i in 0..3 {
+                    assert_eq!(int_arr.value(i), 5);
+                }
+            }
+            _ => panic!("Expected Array result"),
+        }
+    }
+
+    #[test]
+    fn test_evaluate_different_partition_ids() {
+        let schema = Schema::new(vec![Field::new("col", DataType::Int32, 
false)]);
+        let batch = RecordBatch::try_new(
+            Arc::new(schema),
+            vec![Arc::new(Int32Array::from(vec![1, 2]))],
+        )
+        .expect("RecordBatch creation failed");
+
+        for partition_id in [0, 1, 100, 999] {
+            let expr = SparkPartitionIdExpr::new(partition_id);
+            let result = expr.evaluate(&batch).expect("evaluate failed");
+            match result {
+                ColumnarValue::Array(arr) => {
+                    let int_arr = arr
+                        .as_any()
+                        .downcast_ref::<Int32Array>()
+                        .expect("downcast failed");
+                    for i in 0..int_arr.len() {
+                        assert_eq!(int_arr.value(i), partition_id as i32);
+                    }
+                }
+                _ => panic!("Expected Array result"),
+            }
+        }
+    }
+
+    #[test]
+    fn test_equality() {
+        let expr1 = SparkPartitionIdExpr::new(5);
+        let expr2 = SparkPartitionIdExpr::new(5);
+        let expr3 = SparkPartitionIdExpr::new(3);
+
+        assert_eq!(expr1, expr2);
+        assert_ne!(expr1, expr3);
+    }
+}
diff --git 
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
 
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
index cb9492c9..1427e01d 100644
--- 
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
+++ 
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
@@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.expressions.Like
 import org.apache.spark.sql.catalyst.expressions.Literal
 import org.apache.spark.sql.catalyst.expressions.NamedExpression
 import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.catalyst.expressions.SparkPartitionID
 import org.apache.spark.sql.catalyst.expressions.StringSplit
 import org.apache.spark.sql.catalyst.expressions.TaggingExpression
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
@@ -521,6 +522,13 @@ class ShimsImpl extends Shims with Logging {
       isPruningExpr: Boolean,
       fallback: Expression => pb.PhysicalExprNode): 
Option[pb.PhysicalExprNode] = {
     e match {
+      case _: SparkPartitionID =>
+        Some(
+          pb.PhysicalExprNode
+            .newBuilder()
+            .setSparkPartitionIdExpr(pb.SparkPartitionIdExprNode.newBuilder())
+            .build())
+
       case StringSplit(str, pat @ Literal(_, StringType), Literal(-1, 
IntegerType))
           // native StringSplit implementation does not support regex, so only 
most frequently
           // used cases without regex are supported
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
index 13a627f2..29b9386a 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
@@ -1125,6 +1125,11 @@ object NativeConverters extends Logging {
           _.setRowNumExpr(pb.RowNumExprNode.newBuilder())
         }
 
+      case StubExpr("SparkPartitionID", _, _) =>
+        buildExprNode {
+          _.setSparkPartitionIdExpr(pb.SparkPartitionIdExprNode.newBuilder())
+        }
+
       // hive UDFJson
       case e
           if udfJsonEnabled && (

Reply via email to