This is an automated email from the ASF dual-hosted git repository.

blaginin pushed a commit to branch annarose/dict-coercion
in repository https://gitbox.apache.org/repos/asf/datafusion-sandbox.git

commit 2f90194129d8a18eeb020b225fbf5f187e2da84d
Author: lyne <[email protected]>
AuthorDate: Thu Feb 5 09:12:17 2026 +0800

    Fix `array_repeat` handling of null count values (#20102)
    
    ## Which issue does this PR close?
    
    <!--
    We generally require a GitHub issue to be filed for all bug fixes and
    enhancements and this helps us generate change logs for our releases.
    You can link an issue to this PR using the GitHub syntax. For example
    `Closes #123` indicates that this PR will close issue #123.
    -->
    
    - Closes https://github.com/apache/datafusion/issues/20075.
    
    ## Rationale for this change
    
    The previous implementation of `array_repeat` relied on Arrow defaults
    when handling null and negative count values. As a result, null counts
    were implicitly treated as zero and returned empty arrays, which is a
    correctness issue.
    
    This PR makes the handling of these edge cases explicit and aligns the
    function with SQL null semantics.
    
    <!--
    Why are you proposing this change? If this is already explained clearly
    in the issue then this section is not needed.
    Explaining clearly why changes are proposed helps reviewers understand
    your changes and offer better suggestions for fixes.
    -->
    
    ## What changes are included in this PR?
    
    - Explicit handling of null and negative count values
    - Planner-time coercion of the count argument to `Int64`
    
    <!--
    There is no need to duplicate the description in the issue here but it
    is sometimes worth providing a summary of the individual changes in this
    PR.
    -->
    
    ## Are these changes tested?
    
    <!--
    We typically require tests for all PRs in order to:
    1. Prevent the code from being accidentally broken by subsequent changes
    2. Serve as another way to document the expected behavior of the code
    
    If tests are not included in your PR, please explain why (for example,
    are they covered by existing tests)?
    -->
    
    Yes, SLTs added and pass.
    
    ## Are there any user-facing changes?
    
    Yes. When the count value is null, `array_repeat` now returns a null
    array instead of an empty array.
    
    <!--
    If there are user-facing changes then we may require documentation to be
    updated before approving the PR.
    -->
    
    <!--
    If there are any breaking changes to public APIs, please add the `api
    change` label.
    -->
    
    ---------
    
    Co-authored-by: Martin Grigorov <[email protected]>
    Co-authored-by: Jeffrey Vo <[email protected]>
---
 datafusion/functions-nested/src/repeat.rs    | 148 ++++++++++++++++-----------
 datafusion/sqllogictest/test_files/array.slt |  81 ++++++++++++++-
 2 files changed, 164 insertions(+), 65 deletions(-)

diff --git a/datafusion/functions-nested/src/repeat.rs 
b/datafusion/functions-nested/src/repeat.rs
index 28ec827cc..5e78a4d0f 100644
--- a/datafusion/functions-nested/src/repeat.rs
+++ b/datafusion/functions-nested/src/repeat.rs
@@ -19,21 +19,23 @@
 
 use crate::utils::make_scalar_function;
 use arrow::array::{
-    Array, ArrayRef, BooleanBufferBuilder, GenericListArray, OffsetSizeTrait, 
UInt64Array,
+    Array, ArrayRef, BooleanBufferBuilder, GenericListArray, Int64Array, 
OffsetSizeTrait,
+    UInt64Array,
 };
 use arrow::buffer::{NullBuffer, OffsetBuffer};
 use arrow::compute;
-use arrow::compute::cast;
 use arrow::datatypes::DataType;
 use arrow::datatypes::{
     DataType::{LargeList, List},
     Field,
 };
-use datafusion_common::cast::{as_large_list_array, as_list_array, 
as_uint64_array};
-use datafusion_common::{Result, exec_err, utils::take_function_args};
+use datafusion_common::cast::{as_int64_array, as_large_list_array, 
as_list_array};
+use datafusion_common::types::{NativeType, logical_int64};
+use datafusion_common::{DataFusionError, Result};
 use datafusion_expr::{
     ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
 };
+use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
 use datafusion_macros::user_doc;
 use std::any::Any;
 use std::sync::Arc;
@@ -88,7 +90,17 @@ impl Default for ArrayRepeat {
 impl ArrayRepeat {
     pub fn new() -> Self {
         Self {
-            signature: Signature::user_defined(Volatility::Immutable),
+            signature: Signature::coercible(
+                vec![
+                    Coercion::new_exact(TypeSignatureClass::Any),
+                    Coercion::new_implicit(
+                        TypeSignatureClass::Native(logical_int64()),
+                        vec![TypeSignatureClass::Integer],
+                        NativeType::Int64,
+                    ),
+                ],
+                Volatility::Immutable,
+            ),
             aliases: vec![String::from("list_repeat")],
         }
     }
@@ -132,23 +144,6 @@ impl ScalarUDFImpl for ArrayRepeat {
         &self.aliases
     }
 
-    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
-        let [first_type, second_type] = take_function_args(self.name(), 
arg_types)?;
-
-        // Coerce the second argument to Int64/UInt64 if it's a numeric type
-        let second = match second_type {
-            DataType::Int8 | DataType::Int16 | DataType::Int32 | 
DataType::Int64 => {
-                DataType::Int64
-            }
-            DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | 
DataType::UInt64 => {
-                DataType::UInt64
-            }
-            _ => return exec_err!("count must be an integer type"),
-        };
-
-        Ok(vec![first_type.clone(), second])
-    }
-
     fn documentation(&self) -> Option<&Documentation> {
         self.doc()
     }
@@ -156,15 +151,7 @@ impl ScalarUDFImpl for ArrayRepeat {
 
 fn array_repeat_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
     let element = &args[0];
-    let count_array = &args[1];
-
-    let count_array = match count_array.data_type() {
-        DataType::Int64 => &cast(count_array, &DataType::UInt64)?,
-        DataType::UInt64 => count_array,
-        _ => return exec_err!("count must be an integer type"),
-    };
-
-    let count_array = as_uint64_array(count_array)?;
+    let count_array = as_int64_array(&args[1])?;
 
     match element.data_type() {
         List(_) => {
@@ -193,21 +180,31 @@ fn array_repeat_inner(args: &[ArrayRef]) -> 
Result<ArrayRef> {
 /// ```
 fn general_repeat<O: OffsetSizeTrait>(
     array: &ArrayRef,
-    count_array: &UInt64Array,
+    count_array: &Int64Array,
 ) -> Result<ArrayRef> {
-    // Build offsets and take_indices
-    let total_repeated_values: usize =
-        count_array.values().iter().map(|&c| c as usize).sum();
+    let total_repeated_values: usize = (0..count_array.len())
+        .map(|i| get_count_with_validity(count_array, i))
+        .sum();
+
     let mut take_indices = Vec::with_capacity(total_repeated_values);
     let mut offsets = Vec::with_capacity(count_array.len() + 1);
     offsets.push(O::zero());
     let mut running_offset = 0usize;
 
-    for (idx, &count) in count_array.values().iter().enumerate() {
-        let count = count as usize;
-        running_offset += count;
-        offsets.push(O::from_usize(running_offset).unwrap());
-        take_indices.extend(std::iter::repeat_n(idx as u64, count))
+    for idx in 0..count_array.len() {
+        let count = get_count_with_validity(count_array, idx);
+        running_offset = running_offset.checked_add(count).ok_or_else(|| {
+            DataFusionError::Execution(
+                "array_repeat: running_offset overflowed usize".to_string(),
+            )
+        })?;
+        let offset = O::from_usize(running_offset).ok_or_else(|| {
+            DataFusionError::Execution(format!(
+                "array_repeat: offset {running_offset} exceeds the maximum 
value for offset type"
+            ))
+        })?;
+        offsets.push(offset);
+        take_indices.extend(std::iter::repeat_n(idx as u64, count));
     }
 
     // Build the flattened values
@@ -222,7 +219,7 @@ fn general_repeat<O: OffsetSizeTrait>(
         Arc::new(Field::new_list_field(array.data_type().to_owned(), true)),
         OffsetBuffer::new(offsets.into()),
         repeated_values,
-        None,
+        count_array.nulls().cloned(),
     )?))
 }
 
@@ -238,23 +235,24 @@ fn general_repeat<O: OffsetSizeTrait>(
 /// ```
 fn general_list_repeat<O: OffsetSizeTrait>(
     list_array: &GenericListArray<O>,
-    count_array: &UInt64Array,
+    count_array: &Int64Array,
 ) -> Result<ArrayRef> {
-    let counts = count_array.values();
     let list_offsets = list_array.value_offsets();
 
     // calculate capacities for pre-allocation
-    let outer_total = counts.iter().map(|&c| c as usize).sum();
-    let inner_total = counts
-        .iter()
-        .enumerate()
-        .filter(|&(i, _)| !list_array.is_null(i))
-        .map(|(i, &c)| {
-            let len = list_offsets[i + 1].to_usize().unwrap()
-                - list_offsets[i].to_usize().unwrap();
-            len * (c as usize)
-        })
-        .sum();
+    let mut outer_total = 0usize;
+    let mut inner_total = 0usize;
+    for i in 0..count_array.len() {
+        let count = get_count_with_validity(count_array, i);
+        if count > 0 {
+            outer_total += count;
+            if list_array.is_valid(i) {
+                let len = list_offsets[i + 1].to_usize().unwrap()
+                    - list_offsets[i].to_usize().unwrap();
+                inner_total += len * count;
+            }
+        }
+    }
 
     // Build inner structures
     let mut inner_offsets = Vec::with_capacity(outer_total + 1);
@@ -263,17 +261,27 @@ fn general_list_repeat<O: OffsetSizeTrait>(
     let mut inner_running = 0usize;
     inner_offsets.push(O::zero());
 
-    for (row_idx, &count) in counts.iter().enumerate() {
-        let is_valid = !list_array.is_null(row_idx);
+    for row_idx in 0..count_array.len() {
+        let count = get_count_with_validity(count_array, row_idx);
+        let list_is_valid = list_array.is_valid(row_idx);
         let start = list_offsets[row_idx].to_usize().unwrap();
         let end = list_offsets[row_idx + 1].to_usize().unwrap();
         let row_len = end - start;
 
         for _ in 0..count {
-            inner_running += row_len;
-            inner_offsets.push(O::from_usize(inner_running).unwrap());
-            inner_nulls.append(is_valid);
-            if is_valid {
+            inner_running = inner_running.checked_add(row_len).ok_or_else(|| {
+                DataFusionError::Execution(
+                    "array_repeat: inner offset overflowed usize".to_string(),
+                )
+            })?;
+            let offset = O::from_usize(inner_running).ok_or_else(|| {
+                DataFusionError::Execution(format!(
+                    "array_repeat: offset {inner_running} exceeds the maximum 
value for offset type"
+                ))
+            })?;
+            inner_offsets.push(offset);
+            inner_nulls.append(list_is_valid);
+            if list_is_valid {
                 take_indices.extend(start as u64..end as u64);
             }
         }
@@ -298,8 +306,24 @@ fn general_list_repeat<O: OffsetSizeTrait>(
             list_array.data_type().to_owned(),
             true,
         )),
-        OffsetBuffer::<O>::from_lengths(counts.iter().map(|&c| c as usize)),
+        OffsetBuffer::<O>::from_lengths(
+            count_array
+                .iter()
+                .map(|c| c.map(|v| if v > 0 { v as usize } else { 0 
}).unwrap_or(0)),
+        ),
         Arc::new(inner_list),
-        None,
+        count_array.nulls().cloned(),
     )?))
 }
+
+/// Helper function to get count from count_array at given index
+/// Return 0 for null values or non-positive count.
+#[inline]
+fn get_count_with_validity(count_array: &Int64Array, idx: usize) -> usize {
+    if count_array.is_null(idx) {
+        0
+    } else {
+        let c = count_array.value(idx);
+        if c > 0 { c as usize } else { 0 }
+    }
+}
diff --git a/datafusion/sqllogictest/test_files/array.slt 
b/datafusion/sqllogictest/test_files/array.slt
index c27433e7e..2b98ae14d 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -3256,24 +3256,99 @@ drop table array_repeat_table;
 statement ok
 drop table large_array_repeat_table;
 
-
+# array_repeat: arrays with NULL counts
 statement ok
 create table array_repeat_null_count_table
 as values
 (1, 2),
 (2, null),
-(3, 1);
+(3, 1),
+(4, -1),
+(null, null);
 
 query I?
 select column1, array_repeat(column1, column2) from 
array_repeat_null_count_table;
 ----
 1 [1, 1]
-2 []
+2 NULL
 3 [3]
+4 []
+NULL NULL
 
 statement ok
 drop table array_repeat_null_count_table
 
+# array_repeat: nested arrays with NULL counts
+statement ok
+create table array_repeat_nested_null_count_table
+as values
+([[1, 2], [3, 4]], 2),
+([[5, 6], [7, 8]], null),
+([[null, null], [9, 10]], 1),
+(null, 3),
+([[11, 12]], -1);
+
+query ??
+select column1, array_repeat(column1, column2) from 
array_repeat_nested_null_count_table;
+----
+[[1, 2], [3, 4]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]
+[[5, 6], [7, 8]] NULL
+[[NULL, NULL], [9, 10]] [[[NULL, NULL], [9, 10]]]
+NULL [NULL, NULL, NULL]
+[[11, 12]] []
+
+statement ok
+drop table array_repeat_nested_null_count_table
+
+# array_repeat edge cases: empty arrays
+query ???
+select array_repeat([], 3), array_repeat([], 0), array_repeat([], null);
+----
+[[], [], []] [] NULL
+
+query ??
+select array_repeat(null::int, 0), array_repeat(null::int, null);
+----
+[] NULL
+
+# array_repeat LargeList with NULL count
+statement ok
+create table array_repeat_large_list_null_table
+as values
+(arrow_cast([1, 2, 3], 'LargeList(Int64)'), 2),
+(arrow_cast([4, 5], 'LargeList(Int64)'), null),
+(arrow_cast(null, 'LargeList(Int64)'), 3);
+
+query ??
+select column1, array_repeat(column1, column2) from 
array_repeat_large_list_null_table;
+----
+[1, 2, 3] [[1, 2, 3], [1, 2, 3]]
+[4, 5] NULL
+NULL [NULL, NULL, NULL]
+
+statement ok
+drop table array_repeat_large_list_null_table
+
+# array_repeat edge cases: LargeList nested with NULL count
+statement ok
+create table array_repeat_large_nested_null_table
+as values
+(arrow_cast([[1, 2], [3, 4]], 'LargeList(List(Int64))'), 2),
+(arrow_cast([[5, 6], [7, 8]], 'LargeList(List(Int64))'), null),
+(arrow_cast([[null, null]], 'LargeList(List(Int64))'), 1),
+(null, 3);
+
+query ??
+select column1, array_repeat(column1, column2) from 
array_repeat_large_nested_null_table;
+----
+[[1, 2], [3, 4]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]
+[[5, 6], [7, 8]] NULL
+[[NULL, NULL]] [[[NULL, NULL]]]
+NULL [NULL, NULL, NULL]
+
+statement ok
+drop table array_repeat_large_nested_null_table
+
 ## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`)
 
 # test with empty array


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to