This is an automated email from the ASF dual-hosted git repository.
richox pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git
The following commit(s) were added to refs/heads/master by this push:
new e98d9d35 [AURON #1888] Implement spark_partition_id() function (#1928)
e98d9d35 is described below
commit e98d9d356be3736db4594b615a71dff42a79d82e
Author: Shreyesh <[email protected]>
AuthorDate: Wed Jan 21 05:46:23 2026 -0800
[AURON #1888] Implement spark_partition_id() function (#1928)
<!--
- Start the PR title with the related issue ID, e.g. '[AURON #XXXX]
Short summary...'.
-->
# Which issue does this PR close?
Closes ##1888
# Rationale for this change
This is part of the effort to add non-deterministic expressions to
Auron.
# What changes are included in this PR?
This PR adds support for spark_partition_id using native
# Are there any user-facing changes?
N/A
# How was this patch tested?
Unit tests
---
native-engine/auron-planner/proto/auron.proto | 5 +
native-engine/auron-planner/src/planner.rs | 5 +-
native-engine/datafusion-ext-exprs/src/lib.rs | 1 +
.../datafusion-ext-exprs/src/spark_partition_id.rs | 189 +++++++++++++++++++++
.../org/apache/spark/sql/auron/ShimsImpl.scala | 8 +
.../apache/spark/sql/auron/NativeConverters.scala | 5 +
6 files changed, 212 insertions(+), 1 deletion(-)
diff --git a/native-engine/auron-planner/proto/auron.proto
b/native-engine/auron-planner/proto/auron.proto
index 788be352..99f8078b 100644
--- a/native-engine/auron-planner/proto/auron.proto
+++ b/native-engine/auron-planner/proto/auron.proto
@@ -113,6 +113,9 @@ message PhysicalExprNode {
// RowNum
RowNumExprNode row_num_expr = 20100;
+ // SparkPartitionID
+ SparkPartitionIdExprNode spark_partition_id_expr = 20101;
+
// BloomFilterMightContain
BloomFilterMightContainExprNode bloom_filter_might_contain_expr = 20200;
}
@@ -914,3 +917,5 @@ message ArrowType {
// }
//}
message EmptyMessage{}
+
+message SparkPartitionIdExprNode {}
diff --git a/native-engine/auron-planner/src/planner.rs
b/native-engine/auron-planner/src/planner.rs
index 8e13312b..cfab99e1 100644
--- a/native-engine/auron-planner/src/planner.rs
+++ b/native-engine/auron-planner/src/planner.rs
@@ -52,7 +52,7 @@ use datafusion::{
use datafusion_ext_exprs::{
bloom_filter_might_contain::BloomFilterMightContainExpr, cast::TryCastExpr,
get_indexed_field::GetIndexedFieldExpr, get_map_value::GetMapValueExpr,
- named_struct::NamedStructExpr, row_num::RowNumExpr,
+ named_struct::NamedStructExpr, row_num::RowNumExpr,
spark_partition_id::SparkPartitionIdExpr,
spark_scalar_subquery_wrapper::SparkScalarSubqueryWrapperExpr,
spark_udf_wrapper::SparkUDFWrapperExpr,
string_contains::StringContainsExpr,
string_ends_with::StringEndsWithExpr,
string_starts_with::StringStartsWithExpr,
@@ -962,6 +962,9 @@ impl PhysicalPlanner {
Arc::new(StringContainsExpr::new(expr, e.infix.clone()))
}
ExprType::RowNumExpr(_) => Arc::new(RowNumExpr::default()),
+ ExprType::SparkPartitionIdExpr(_) => {
+ Arc::new(SparkPartitionIdExpr::new(self.partition_id))
+ }
ExprType::BloomFilterMightContainExpr(e) =>
Arc::new(BloomFilterMightContainExpr::new(
e.uuid.clone(),
self.try_parse_physical_expr_box_required(&e.bloom_filter_expr, input_schema)?,
diff --git a/native-engine/datafusion-ext-exprs/src/lib.rs
b/native-engine/datafusion-ext-exprs/src/lib.rs
index 3a685a41..bb2757f0 100644
--- a/native-engine/datafusion-ext-exprs/src/lib.rs
+++ b/native-engine/datafusion-ext-exprs/src/lib.rs
@@ -23,6 +23,7 @@ pub mod get_indexed_field;
pub mod get_map_value;
pub mod named_struct;
pub mod row_num;
+pub mod spark_partition_id;
pub mod spark_scalar_subquery_wrapper;
pub mod spark_udf_wrapper;
pub mod string_contains;
diff --git a/native-engine/datafusion-ext-exprs/src/spark_partition_id.rs
b/native-engine/datafusion-ext-exprs/src/spark_partition_id.rs
new file mode 100644
index 00000000..d34150db
--- /dev/null
+++ b/native-engine/datafusion-ext-exprs/src/spark_partition_id.rs
@@ -0,0 +1,189 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::{
+ any::Any,
+ fmt::{Debug, Display, Formatter},
+ hash::{Hash, Hasher},
+ sync::Arc,
+};
+
+use arrow::{
+ array::{Int32Array, RecordBatch},
+ datatypes::{DataType, Schema},
+};
+use datafusion::{
+ common::Result,
+ logical_expr::ColumnarValue,
+ physical_expr::{PhysicalExpr, PhysicalExprRef},
+};
+
+pub struct SparkPartitionIdExpr {
+ partition_id: i32,
+}
+
+impl SparkPartitionIdExpr {
+ pub fn new(partition_id: usize) -> Self {
+ Self {
+ partition_id: partition_id as i32,
+ }
+ }
+}
+
+impl Display for SparkPartitionIdExpr {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(f, "SparkPartitionID")
+ }
+}
+
+impl Debug for SparkPartitionIdExpr {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(f, "SparkPartitionID")
+ }
+}
+
+impl PartialEq for SparkPartitionIdExpr {
+ fn eq(&self, other: &Self) -> bool {
+ self.partition_id == other.partition_id
+ }
+}
+
+impl Eq for SparkPartitionIdExpr {}
+
+impl Hash for SparkPartitionIdExpr {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ self.partition_id.hash(state)
+ }
+}
+
+impl PhysicalExpr for SparkPartitionIdExpr {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
+ Ok(DataType::Int32)
+ }
+
+ fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
+ Ok(false)
+ }
+
+ fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
+ let num_rows = batch.num_rows();
+ let array = Int32Array::from_value(self.partition_id, num_rows);
+ Ok(ColumnarValue::Array(Arc::new(array)))
+ }
+
+ fn children(&self) -> Vec<&PhysicalExprRef> {
+ vec![]
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ _children: Vec<PhysicalExprRef>,
+ ) -> Result<PhysicalExprRef> {
+ Ok(self)
+ }
+
+ fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(f, "fmt_sql not used")
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use arrow::{
+ array::Int32Array,
+ datatypes::{Field, Schema},
+ record_batch::RecordBatch,
+ };
+
+ use super::*;
+
+ #[test]
+ fn test_data_type_and_nullable() {
+ let expr = SparkPartitionIdExpr::new(0);
+ let schema = Schema::new(vec![] as Vec<Field>);
+ assert_eq!(
+ expr.data_type(&schema).expect("data_type failed"),
+ DataType::Int32
+ );
+ assert!(!expr.nullable(&schema).expect("nullable failed"));
+ }
+
+ #[test]
+ fn test_evaluate_returns_constant_partition_id() {
+ let expr = SparkPartitionIdExpr::new(5);
+ let schema = Schema::new(vec![Field::new("col", DataType::Int32,
false)]);
+ let batch = RecordBatch::try_new(
+ Arc::new(schema),
+ vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
+ )
+ .expect("RecordBatch creation failed");
+
+ let result = expr.evaluate(&batch).expect("evaluate failed");
+ match result {
+ ColumnarValue::Array(arr) => {
+ let int_arr = arr
+ .as_any()
+ .downcast_ref::<Int32Array>()
+ .expect("downcast failed");
+ assert_eq!(int_arr.len(), 3);
+ for i in 0..3 {
+ assert_eq!(int_arr.value(i), 5);
+ }
+ }
+ _ => panic!("Expected Array result"),
+ }
+ }
+
+ #[test]
+ fn test_evaluate_different_partition_ids() {
+ let schema = Schema::new(vec![Field::new("col", DataType::Int32,
false)]);
+ let batch = RecordBatch::try_new(
+ Arc::new(schema),
+ vec![Arc::new(Int32Array::from(vec![1, 2]))],
+ )
+ .expect("RecordBatch creation failed");
+
+ for partition_id in [0, 1, 100, 999] {
+ let expr = SparkPartitionIdExpr::new(partition_id);
+ let result = expr.evaluate(&batch).expect("evaluate failed");
+ match result {
+ ColumnarValue::Array(arr) => {
+ let int_arr = arr
+ .as_any()
+ .downcast_ref::<Int32Array>()
+ .expect("downcast failed");
+ for i in 0..int_arr.len() {
+ assert_eq!(int_arr.value(i), partition_id as i32);
+ }
+ }
+ _ => panic!("Expected Array result"),
+ }
+ }
+ }
+
+ #[test]
+ fn test_equality() {
+ let expr1 = SparkPartitionIdExpr::new(5);
+ let expr2 = SparkPartitionIdExpr::new(5);
+ let expr3 = SparkPartitionIdExpr::new(3);
+
+ assert_eq!(expr1, expr2);
+ assert_ne!(expr1, expr3);
+ }
+}
diff --git
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
index cb9492c9..1427e01d 100644
---
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
+++
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
@@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.expressions.Like
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.catalyst.expressions.SparkPartitionID
import org.apache.spark.sql.catalyst.expressions.StringSplit
import org.apache.spark.sql.catalyst.expressions.TaggingExpression
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
@@ -521,6 +522,13 @@ class ShimsImpl extends Shims with Logging {
isPruningExpr: Boolean,
fallback: Expression => pb.PhysicalExprNode):
Option[pb.PhysicalExprNode] = {
e match {
+ case _: SparkPartitionID =>
+ Some(
+ pb.PhysicalExprNode
+ .newBuilder()
+ .setSparkPartitionIdExpr(pb.SparkPartitionIdExprNode.newBuilder())
+ .build())
+
case StringSplit(str, pat @ Literal(_, StringType), Literal(-1,
IntegerType))
// native StringSplit implementation does not support regex, so only
most frequently
// used cases without regex are supported
diff --git
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
index 13a627f2..29b9386a 100644
---
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
+++
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
@@ -1125,6 +1125,11 @@ object NativeConverters extends Logging {
_.setRowNumExpr(pb.RowNumExprNode.newBuilder())
}
+ case StubExpr("SparkPartitionID", _, _) =>
+ buildExprNode {
+ _.setSparkPartitionIdExpr(pb.SparkPartitionIdExprNode.newBuilder())
+ }
+
// hive UDFJson
case e
if udfJsonEnabled && (