This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 5c2b1236b5 feat(spark): implement array_repeat function (#19702)
5c2b1236b5 is described below

commit 5c2b1236b52ed410306ed0c2dea8bd9247127e6f
Author: cht42 <[email protected]>
AuthorDate: Sat Jan 10 06:17:01 2026 +0400

    feat(spark): implement array_repeat function (#19702)
    
    ## Which issue does this PR close?
    
    <!--
    We generally require a GitHub issue to be filed for all bug fixes and
    enhancements and this helps us generate change logs for our releases.
    You can link an issue to this PR using the GitHub syntax. For example
    `Closes #123` indicates that this PR will close issue #123.
    -->
    
    - Part of #15914
    - Closes #19701
    
    ## Rationale for this change
    
    <!--
    Why are you proposing this change? If this is already explained clearly
    in the issue then this section is not needed.
    Explaining clearly why changes are proposed helps reviewers understand
    your changes and offer better suggestions for fixes.
    -->
    
    ## What changes are included in this PR?
    
    Implementation of spark `array_repeat` function.
    
    
    
    ## Are these changes tested?
    
    yes
    
    ## Are there any user-facing changes?
    
    yes
---
 datafusion/spark/src/function/array/mod.rs         |   9 +-
 datafusion/spark/src/function/array/repeat.rs      | 128 +++++++++++++++++++++
 datafusion/spark/src/function/mod.rs               |   1 +
 datafusion/spark/src/function/null_utils.rs        | 122 ++++++++++++++++++++
 datafusion/spark/src/function/string/concat.rs     | 110 ++----------------
 .../test_files/spark/array/array_repeat.slt        |  77 +++++++++++--
 6 files changed, 333 insertions(+), 114 deletions(-)

diff --git a/datafusion/spark/src/function/array/mod.rs 
b/datafusion/spark/src/function/array/mod.rs
index 01056ba952..7140653510 100644
--- a/datafusion/spark/src/function/array/mod.rs
+++ b/datafusion/spark/src/function/array/mod.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+pub mod repeat;
 pub mod shuffle;
 pub mod spark_array;
 
@@ -24,6 +25,7 @@ use std::sync::Arc;
 
 make_udf_function!(spark_array::SparkArray, array);
 make_udf_function!(shuffle::SparkShuffle, shuffle);
+make_udf_function!(repeat::SparkArrayRepeat, array_repeat);
 
 pub mod expr_fn {
     use datafusion_functions::export_functions;
@@ -34,8 +36,13 @@ pub mod expr_fn {
         "Returns a random permutation of the given array.",
         args
     ));
+    export_functions!((
+        array_repeat,
+        "returns an array containing element count times.",
+        element count
+    ));
 }
 
 pub fn functions() -> Vec<Arc<ScalarUDF>> {
-    vec![array(), shuffle()]
+    vec![array(), shuffle(), array_repeat()]
 }
diff --git a/datafusion/spark/src/function/array/repeat.rs 
b/datafusion/spark/src/function/array/repeat.rs
new file mode 100644
index 0000000000..7543300a91
--- /dev/null
+++ b/datafusion/spark/src/function/array/repeat.rs
@@ -0,0 +1,128 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::datatypes::{DataType, Field};
+use datafusion_common::utils::take_function_args;
+use datafusion_common::{Result, ScalarValue, exec_err};
+use datafusion_expr::{
+    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
+};
+use datafusion_functions_nested::repeat::ArrayRepeat;
+use std::any::Any;
+use std::sync::Arc;
+
+use crate::function::null_utils::{
+    NullMaskResolution, apply_null_mask, compute_null_mask,
+};
+
+/// Spark-compatible `array_repeat` expression. The difference with 
DataFusion's `array_repeat` is the handling of NULL inputs: in spark if any 
input is NULL, the result is NULL.
+/// <https://spark.apache.org/docs/latest/api/sql/index.html#array_repeat>
+#[derive(Debug, PartialEq, Eq, Hash)]
+pub struct SparkArrayRepeat {
+    signature: Signature,
+}
+
+impl Default for SparkArrayRepeat {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl SparkArrayRepeat {
+    pub fn new() -> Self {
+        Self {
+            signature: Signature::user_defined(Volatility::Immutable),
+        }
+    }
+}
+
+impl ScalarUDFImpl for SparkArrayRepeat {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        "array_repeat"
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+        Ok(DataType::List(Arc::new(Field::new_list_field(
+            arg_types[0].clone(),
+            true,
+        ))))
+    }
+
+    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
+        spark_array_repeat(args)
+    }
+
+    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
+        let [first_type, second_type] = take_function_args(self.name(), 
arg_types)?;
+
+        // Coerce the second argument to Int64/UInt64 if it's a numeric type
+        let second = match second_type {
+            DataType::Int8 | DataType::Int16 | DataType::Int32 | 
DataType::Int64 => {
+                DataType::Int64
+            }
+            DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | 
DataType::UInt64 => {
+                DataType::UInt64
+            }
+            _ => return exec_err!("count must be an integer type"),
+        };
+
+        Ok(vec![first_type.clone(), second])
+    }
+}
+
+/// This is a Spark-specific wrapper around DataFusion's array_repeat that 
returns NULL
+/// if any argument is NULL (Spark behavior), whereas DataFusion's 
array_repeat ignores NULLs.
+fn spark_array_repeat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
+    let ScalarFunctionArgs {
+        args: arg_values,
+        arg_fields,
+        number_rows,
+        return_field,
+        config_options,
+    } = args;
+    let return_type = return_field.data_type().clone();
+
+    // Step 1: Check for NULL mask in incoming args
+    let null_mask = compute_null_mask(&arg_values, number_rows)?;
+
+    // If any argument is null then return NULL immediately
+    if matches!(null_mask, NullMaskResolution::ReturnNull) {
+        return Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?));
+    }
+
+    // Step 2: Delegate to DataFusion's array_repeat
+    let array_repeat_func = ArrayRepeat::new();
+    let func_args = ScalarFunctionArgs {
+        args: arg_values,
+        arg_fields,
+        number_rows,
+        return_field,
+        config_options,
+    };
+    let result = array_repeat_func.invoke_with_args(func_args)?;
+
+    // Step 3: Apply NULL mask to result
+    apply_null_mask(result, null_mask, &return_type)
+}
diff --git a/datafusion/spark/src/function/mod.rs 
b/datafusion/spark/src/function/mod.rs
index 3f4f94cfaa..d5dd60c354 100644
--- a/datafusion/spark/src/function/mod.rs
+++ b/datafusion/spark/src/function/mod.rs
@@ -33,6 +33,7 @@ pub mod lambda;
 pub mod map;
 pub mod math;
 pub mod misc;
+mod null_utils;
 pub mod predicate;
 pub mod string;
 pub mod r#struct;
diff --git a/datafusion/spark/src/function/null_utils.rs 
b/datafusion/spark/src/function/null_utils.rs
new file mode 100644
index 0000000000..b25dc07d0e
--- /dev/null
+++ b/datafusion/spark/src/function/null_utils.rs
@@ -0,0 +1,122 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::Array;
+use arrow::buffer::NullBuffer;
+use arrow::datatypes::DataType;
+use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::ColumnarValue;
+use std::sync::Arc;
+
+pub(crate) enum NullMaskResolution {
+    /// Return NULL as the result (e.g., scalar inputs with at least one NULL)
+    ReturnNull,
+    /// No null mask needed (e.g., all scalar inputs are non-NULL)
+    NoMask,
+    /// Null mask to apply for arrays
+    Apply(NullBuffer),
+}
+
+/// Compute NULL mask for the arguments using NullBuffer::union
+pub(crate) fn compute_null_mask(
+    args: &[ColumnarValue],
+    number_rows: usize,
+) -> Result<NullMaskResolution> {
+    // Check if all arguments are scalars
+    let all_scalars = args
+        .iter()
+        .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
+
+    if all_scalars {
+        // For scalars, check if any is NULL
+        for arg in args {
+            if let ColumnarValue::Scalar(scalar) = arg
+                && scalar.is_null()
+            {
+                return Ok(NullMaskResolution::ReturnNull);
+            }
+        }
+        // No NULLs in scalars
+        Ok(NullMaskResolution::NoMask)
+    } else {
+        // For arrays, compute NULL mask for each row using NullBuffer::union
+        let array_len = args
+            .iter()
+            .find_map(|arg| match arg {
+                ColumnarValue::Array(array) => Some(array.len()),
+                _ => None,
+            })
+            .unwrap_or(number_rows);
+
+        // Convert all scalars to arrays for uniform processing
+        let arrays: Result<Vec<_>> = args
+            .iter()
+            .map(|arg| match arg {
+                ColumnarValue::Array(array) => Ok(Arc::clone(array)),
+                ColumnarValue::Scalar(scalar) => 
scalar.to_array_of_size(array_len),
+            })
+            .collect();
+        let arrays = arrays?;
+
+        // Use NullBuffer::union to combine all null buffers
+        let combined_nulls = arrays
+            .iter()
+            .map(|arr| arr.nulls())
+            .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls));
+
+        match combined_nulls {
+            Some(nulls) => Ok(NullMaskResolution::Apply(nulls)),
+            None => Ok(NullMaskResolution::NoMask),
+        }
+    }
+}
+
+/// Apply NULL mask to the result using NullBuffer::union
+pub(crate) fn apply_null_mask(
+    result: ColumnarValue,
+    null_mask: NullMaskResolution,
+    return_type: &DataType,
+) -> Result<ColumnarValue> {
+    match (result, null_mask) {
+        // Scalar with ReturnNull mask means return NULL of the correct type
+        (ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => {
+            Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?))
+        }
+        // Scalar without mask, return as-is
+        (scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => 
Ok(scalar),
+        // Array with NULL mask - use NullBuffer::union to combine nulls
+        (ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => 
{
+            // Combine the result's existing nulls with our computed null mask
+            let combined_nulls = NullBuffer::union(array.nulls(), 
Some(&null_mask));
+
+            // Create new array with combined nulls
+            let new_array = array
+                .into_data()
+                .into_builder()
+                .nulls(combined_nulls)
+                .build()?;
+
+            Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array(
+                new_array,
+            ))))
+        }
+        // Array without NULL mask, return as-is
+        (array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => 
Ok(array),
+        // Edge cases that shouldn't happen in practice
+        (scalar, _) => Ok(scalar),
+    }
+}
diff --git a/datafusion/spark/src/function/string/concat.rs 
b/datafusion/spark/src/function/string/concat.rs
index 8e97e591fc..f3dae22866 100644
--- a/datafusion/spark/src/function/string/concat.rs
+++ b/datafusion/spark/src/function/string/concat.rs
@@ -15,8 +15,6 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use arrow::array::Array;
-use arrow::buffer::NullBuffer;
 use arrow::datatypes::{DataType, Field};
 use datafusion_common::arrow::datatypes::FieldRef;
 use datafusion_common::{Result, ScalarValue};
@@ -29,6 +27,10 @@ use datafusion_functions::string::concat::ConcatFunc;
 use std::any::Any;
 use std::sync::Arc;
 
+use crate::function::null_utils::{
+    NullMaskResolution, apply_null_mask, compute_null_mask,
+};
+
 /// Spark-compatible `concat` expression
 /// <https://spark.apache.org/docs/latest/api/sql/index.html#concat>
 ///
@@ -94,16 +96,6 @@ impl ScalarUDFImpl for SparkConcat {
     }
 }
 
-/// Represents the null state for Spark concat
-enum NullMaskResolution {
-    /// Return NULL as the result (e.g., scalar inputs with at least one NULL)
-    ReturnNull,
-    /// No null mask needed (e.g., all scalar inputs are non-NULL)
-    NoMask,
-    /// Null mask to apply for arrays
-    Apply(NullBuffer),
-}
-
 /// Concatenates strings, returning NULL if any input is NULL
 /// This is a Spark-specific wrapper around DataFusion's concat that returns 
NULL
 /// if any argument is NULL (Spark behavior), whereas DataFusion's concat 
ignores NULLs.
@@ -133,6 +125,7 @@ fn spark_concat(args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
 
     // Step 2: Delegate to DataFusion's concat
     let concat_func = ConcatFunc::new();
+    let return_type = return_field.data_type().clone();
     let func_args = ScalarFunctionArgs {
         args: arg_values,
         arg_fields,
@@ -143,103 +136,14 @@ fn spark_concat(args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
     let result = concat_func.invoke_with_args(func_args)?;
 
     // Step 3: Apply NULL mask to result
-    apply_null_mask(result, null_mask)
-}
-
-/// Compute NULL mask for the arguments using NullBuffer::union
-fn compute_null_mask(
-    args: &[ColumnarValue],
-    number_rows: usize,
-) -> Result<NullMaskResolution> {
-    // Check if all arguments are scalars
-    let all_scalars = args
-        .iter()
-        .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
-
-    if all_scalars {
-        // For scalars, check if any is NULL
-        for arg in args {
-            if let ColumnarValue::Scalar(scalar) = arg
-                && scalar.is_null()
-            {
-                return Ok(NullMaskResolution::ReturnNull);
-            }
-        }
-        // No NULLs in scalars
-        Ok(NullMaskResolution::NoMask)
-    } else {
-        // For arrays, compute NULL mask for each row using NullBuffer::union
-        let array_len = args
-            .iter()
-            .find_map(|arg| match arg {
-                ColumnarValue::Array(array) => Some(array.len()),
-                _ => None,
-            })
-            .unwrap_or(number_rows);
-
-        // Convert all scalars to arrays for uniform processing
-        let arrays: Result<Vec<_>> = args
-            .iter()
-            .map(|arg| match arg {
-                ColumnarValue::Array(array) => Ok(Arc::clone(array)),
-                ColumnarValue::Scalar(scalar) => 
scalar.to_array_of_size(array_len),
-            })
-            .collect();
-        let arrays = arrays?;
-
-        // Use NullBuffer::union to combine all null buffers
-        let combined_nulls = arrays
-            .iter()
-            .map(|arr| arr.nulls())
-            .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls));
-
-        match combined_nulls {
-            Some(nulls) => Ok(NullMaskResolution::Apply(nulls)),
-            None => Ok(NullMaskResolution::NoMask),
-        }
-    }
-}
-
-/// Apply NULL mask to the result using NullBuffer::union
-fn apply_null_mask(
-    result: ColumnarValue,
-    null_mask: NullMaskResolution,
-) -> Result<ColumnarValue> {
-    match (result, null_mask) {
-        // Scalar with ReturnNull mask means return NULL
-        (ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => {
-            Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)))
-        }
-        // Scalar without mask, return as-is
-        (scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => 
Ok(scalar),
-        // Array with NULL mask - use NullBuffer::union to combine nulls
-        (ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => 
{
-            // Combine the result's existing nulls with our computed null mask
-            let combined_nulls = NullBuffer::union(array.nulls(), 
Some(&null_mask));
-
-            // Create new array with combined nulls
-            let new_array = array
-                .into_data()
-                .into_builder()
-                .nulls(combined_nulls)
-                .build()?;
-
-            Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array(
-                new_array,
-            ))))
-        }
-        // Array without NULL mask, return as-is
-        (array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => 
Ok(array),
-        // Edge cases that shouldn't happen in practice
-        (scalar, _) => Ok(scalar),
-    }
+    apply_null_mask(result, null_mask, &return_type)
 }
 
 #[cfg(test)]
 mod tests {
     use super::*;
     use crate::function::utils::test::test_scalar_function;
-    use arrow::array::StringArray;
+    use arrow::array::{Array, StringArray};
     use arrow::datatypes::{DataType, Field};
     use datafusion_common::Result;
     use datafusion_expr::ReturnFieldArgs;
diff --git a/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt 
b/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt
index 544c39608f..04926e4c11 100644
--- a/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt
+++ b/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt
@@ -15,13 +15,70 @@
 # specific language governing permissions and limitations
 # under the License.
 
-# This file was originally created by a porting script from:
-#   
https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function
-# This file is part of the implementation of the datafusion-spark function 
library.
-# For more information, please see:
-#   https://github.com/apache/datafusion/issues/15914
-
-## Original Query: SELECT array_repeat('123', 2);
-## PySpark 3.5.5 Result: {'array_repeat(123, 2)': ['123', '123'], 
'typeof(array_repeat(123, 2))': 'array<string>', 'typeof(123)': 'string', 
'typeof(2)': 'int'}
-#query
-#SELECT array_repeat('123'::string, 2::int);
+
+query ?
+SELECT array_repeat('123', 2);
+----
+[123, 123]
+
+query ?
+SELECT array_repeat('123', 0);
+----
+[]
+
+query ?
+SELECT array_repeat('123', -1);
+----
+[]
+
+query ?
+SELECT array_repeat(['123'], 2);
+----
+[[123], [123]]
+
+query ?
+SELECT array_repeat(NULL, 2);
+----
+NULL
+
+query ?
+SELECT array_repeat([NULL], 2);
+----
+[[NULL], [NULL]]
+
+query ?
+SELECT array_repeat(['123', NULL], 2);
+----
+[[123, NULL], [123, NULL]]
+
+query ?
+SELECT array_repeat('123', CAST(NULL AS INT));
+----
+NULL
+
+query ?
+SELECT array_repeat(column1, column2)
+FROM VALUES
+('123', 2),
+('123', 0),
+('123', -1),
+(NULL, 1),
+('123', NULL);
+----
+[123, 123]
+[]
+[]
+NULL
+NULL
+
+
+query ?
+SELECT array_repeat(column1, column2)
+FROM VALUES
+(['123'], 2),
+([], 2),
+([NULL], 2);
+----
+[[123], [123]]
+[[], []]
+[[NULL], [NULL]]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to