This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new e4a94243b5 Update the CONCAT scalar function to support Utf8View 
(#12224)
e4a94243b5 is described below

commit e4a94243b502da2ad07a358b4401052651952eea
Author: WeblWabl <[email protected]>
AuthorDate: Tue Sep 3 15:12:47 2024 -0500

    Update the CONCAT scalar function to support Utf8View (#12224)
    
    * wip
    
    * feat: Update the CONCAT scalar function to support Utf8View
    
    * fmt
    
    * fmt and add default return type for concat
    
    * fix clippy lint
    
    Signed-off-by: Devan <[email protected]>
    
    * fmt
    
    Signed-off-by: Devan <[email protected]>
    
    * add more tests for sqllogic
    
    Signed-off-by: Devan <[email protected]>
    
    * make sure no casting with LargeUtf8
    
    * fixing utf8large
    
    * fix large utf8
    
    Signed-off-by: Devan <[email protected]>
    
    * fix large utf8
    
    Signed-off-by: Devan <[email protected]>
    
    * add test
    
    Signed-off-by: Devan <[email protected]>
    
    * fmt
    
    Signed-off-by: Devan <[email protected]>
    
    * make it so Utf8View just returns Utf8
    
    Signed-off-by: Devan <[email protected]>
    
    * wip -- trying to build a stringview with columnar refs
    
    Signed-off-by: Devan <[email protected]>
    
    * built stringview builder but it does allocate a new String each iter :(
    
    Signed-off-by: Devan <[email protected]>
    
    * add some testing
    
    Signed-off-by: Devan <[email protected]>
    
    * clippy
    
    Signed-off-by: Devan <[email protected]>
    
    ---------
    
    Signed-off-by: Devan <[email protected]>
---
 datafusion/functions/src/string/common.rs          | 195 ++++++++++++++++++++-
 datafusion/functions/src/string/concat.rs          | 184 ++++++++++++++++---
 datafusion/sqllogictest/test_files/string_view.slt |  71 +++++++-
 3 files changed, 416 insertions(+), 34 deletions(-)

diff --git a/datafusion/functions/src/string/common.rs 
b/datafusion/functions/src/string/common.rs
index 9738cb812f..6ebcc4ee6c 100644
--- a/datafusion/functions/src/string/common.rs
+++ b/datafusion/functions/src/string/common.rs
@@ -22,12 +22,11 @@ use std::sync::Arc;
 
 use arrow::array::{
     new_null_array, Array, ArrayAccessor, ArrayDataBuilder, ArrayIter, 
ArrayRef,
-    GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray,
-    StringBuilder, StringViewArray,
+    GenericStringArray, GenericStringBuilder, LargeStringArray, 
OffsetSizeTrait,
+    StringArray, StringBuilder, StringViewArray, StringViewBuilder,
 };
 use arrow::buffer::{Buffer, MutableBuffer, NullBuffer};
 use arrow::datatypes::DataType;
-
 use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
 use datafusion_common::Result;
 use datafusion_common::{exec_err, ScalarValue};
@@ -249,26 +248,41 @@ where
     }
 }
 
+#[derive(Debug)]
 pub(crate) enum ColumnarValueRef<'a> {
     Scalar(&'a [u8]),
     NullableArray(&'a StringArray),
     NonNullableArray(&'a StringArray),
+    NullableLargeStringArray(&'a LargeStringArray),
+    NonNullableLargeStringArray(&'a LargeStringArray),
+    NullableStringViewArray(&'a StringViewArray),
+    NonNullableStringViewArray(&'a StringViewArray),
 }
 
 impl<'a> ColumnarValueRef<'a> {
     #[inline]
     pub fn is_valid(&self, i: usize) -> bool {
         match &self {
-            Self::Scalar(_) | Self::NonNullableArray(_) => true,
+            Self::Scalar(_)
+            | Self::NonNullableArray(_)
+            | Self::NonNullableLargeStringArray(_)
+            | Self::NonNullableStringViewArray(_) => true,
             Self::NullableArray(array) => array.is_valid(i),
+            Self::NullableStringViewArray(array) => array.is_valid(i),
+            Self::NullableLargeStringArray(array) => array.is_valid(i),
         }
     }
 
     #[inline]
     pub fn nulls(&self) -> Option<NullBuffer> {
         match &self {
-            Self::Scalar(_) | Self::NonNullableArray(_) => None,
+            Self::Scalar(_)
+            | Self::NonNullableArray(_)
+            | Self::NonNullableStringViewArray(_)
+            | Self::NonNullableLargeStringArray(_) => None,
             Self::NullableArray(array) => array.nulls().cloned(),
+            Self::NullableStringViewArray(array) => array.nulls().cloned(),
+            Self::NullableLargeStringArray(array) => array.nulls().cloned(),
         }
     }
 }
@@ -387,10 +401,30 @@ impl StringArrayBuilder {
                         .extend_from_slice(array.value(i).as_bytes());
                 }
             }
+            ColumnarValueRef::NullableLargeStringArray(array) => {
+                if !CHECK_VALID || array.is_valid(i) {
+                    self.value_buffer
+                        .extend_from_slice(array.value(i).as_bytes());
+                }
+            }
+            ColumnarValueRef::NullableStringViewArray(array) => {
+                if !CHECK_VALID || array.is_valid(i) {
+                    self.value_buffer
+                        .extend_from_slice(array.value(i).as_bytes());
+                }
+            }
             ColumnarValueRef::NonNullableArray(array) => {
                 self.value_buffer
                     .extend_from_slice(array.value(i).as_bytes());
             }
+            ColumnarValueRef::NonNullableLargeStringArray(array) => {
+                self.value_buffer
+                    .extend_from_slice(array.value(i).as_bytes());
+            }
+            ColumnarValueRef::NonNullableStringViewArray(array) => {
+                self.value_buffer
+                    .extend_from_slice(array.value(i).as_bytes());
+            }
         }
     }
 
@@ -416,6 +450,157 @@ impl StringArrayBuilder {
     }
 }
 
+pub(crate) struct StringViewArrayBuilder {
+    builder: StringViewBuilder,
+    block: String,
+}
+
+impl StringViewArrayBuilder {
+    pub fn with_capacity(_item_capacity: usize, data_capacity: usize) -> Self {
+        let builder = StringViewBuilder::with_capacity(data_capacity);
+        Self {
+            builder,
+            block: String::new(),
+        }
+    }
+
+    pub fn write<const CHECK_VALID: bool>(
+        &mut self,
+        column: &ColumnarValueRef,
+        i: usize,
+    ) {
+        match column {
+            ColumnarValueRef::Scalar(s) => {
+                self.block.push_str(std::str::from_utf8(s).unwrap());
+            }
+            ColumnarValueRef::NullableArray(array) => {
+                if !CHECK_VALID || array.is_valid(i) {
+                    self.block.push_str(
+                        
std::str::from_utf8(array.value(i).as_bytes()).unwrap(),
+                    );
+                }
+            }
+            ColumnarValueRef::NullableLargeStringArray(array) => {
+                if !CHECK_VALID || array.is_valid(i) {
+                    self.block.push_str(
+                        
std::str::from_utf8(array.value(i).as_bytes()).unwrap(),
+                    );
+                }
+            }
+            ColumnarValueRef::NullableStringViewArray(array) => {
+                if !CHECK_VALID || array.is_valid(i) {
+                    self.block.push_str(
+                        
std::str::from_utf8(array.value(i).as_bytes()).unwrap(),
+                    );
+                }
+            }
+            ColumnarValueRef::NonNullableArray(array) => {
+                self.block
+                    
.push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap());
+            }
+            ColumnarValueRef::NonNullableLargeStringArray(array) => {
+                self.block
+                    
.push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap());
+            }
+            ColumnarValueRef::NonNullableStringViewArray(array) => {
+                self.block
+                    
.push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap());
+            }
+        }
+    }
+
+    pub fn append_offset(&mut self) {
+        self.builder.append_value(&self.block);
+        self.block = String::new();
+    }
+
+    pub fn finish(mut self) -> StringViewArray {
+        self.builder.finish()
+    }
+}
+
+pub(crate) struct LargeStringArrayBuilder {
+    offsets_buffer: MutableBuffer,
+    value_buffer: MutableBuffer,
+}
+
+impl LargeStringArrayBuilder {
+    pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self {
+        let mut offsets_buffer = MutableBuffer::with_capacity(
+            (item_capacity + 1) * std::mem::size_of::<i64>(),
+        );
+        // SAFETY: the first offset value is definitely not going to exceed 
the bounds.
+        unsafe { offsets_buffer.push_unchecked(0_i64) };
+        Self {
+            offsets_buffer,
+            value_buffer: MutableBuffer::with_capacity(data_capacity),
+        }
+    }
+
+    pub fn write<const CHECK_VALID: bool>(
+        &mut self,
+        column: &ColumnarValueRef,
+        i: usize,
+    ) {
+        match column {
+            ColumnarValueRef::Scalar(s) => {
+                self.value_buffer.extend_from_slice(s);
+            }
+            ColumnarValueRef::NullableArray(array) => {
+                if !CHECK_VALID || array.is_valid(i) {
+                    self.value_buffer
+                        .extend_from_slice(array.value(i).as_bytes());
+                }
+            }
+            ColumnarValueRef::NullableLargeStringArray(array) => {
+                if !CHECK_VALID || array.is_valid(i) {
+                    self.value_buffer
+                        .extend_from_slice(array.value(i).as_bytes());
+                }
+            }
+            ColumnarValueRef::NullableStringViewArray(array) => {
+                if !CHECK_VALID || array.is_valid(i) {
+                    self.value_buffer
+                        .extend_from_slice(array.value(i).as_bytes());
+                }
+            }
+            ColumnarValueRef::NonNullableArray(array) => {
+                self.value_buffer
+                    .extend_from_slice(array.value(i).as_bytes());
+            }
+            ColumnarValueRef::NonNullableLargeStringArray(array) => {
+                self.value_buffer
+                    .extend_from_slice(array.value(i).as_bytes());
+            }
+            ColumnarValueRef::NonNullableStringViewArray(array) => {
+                self.value_buffer
+                    .extend_from_slice(array.value(i).as_bytes());
+            }
+        }
+    }
+
+    pub fn append_offset(&mut self) {
+        let next_offset: i64 = self
+            .value_buffer
+            .len()
+            .try_into()
+            .expect("byte array offset overflow");
+        unsafe { self.offsets_buffer.push_unchecked(next_offset) };
+    }
+
+    pub fn finish(self, null_buffer: Option<NullBuffer>) -> LargeStringArray {
+        let array_builder = ArrayDataBuilder::new(DataType::LargeUtf8)
+            .len(self.offsets_buffer.len() / std::mem::size_of::<i64>() - 1)
+            .add_buffer(self.offsets_buffer.into())
+            .add_buffer(self.value_buffer.into())
+            .nulls(null_buffer);
+        // SAFETY: all data that was appended was valid Large UTF8 and the 
values
+        // and offsets were created correctly
+        let array_data = unsafe { array_builder.build_unchecked() };
+        LargeStringArray::from(array_data)
+    }
+}
+
 fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> 
Result<ArrayRef>
 where
     O: OffsetSizeTrait,
diff --git a/datafusion/functions/src/string/concat.rs 
b/datafusion/functions/src/string/concat.rs
index 6d15e22067..00fe69b0bd 100644
--- a/datafusion/functions/src/string/concat.rs
+++ b/datafusion/functions/src/string/concat.rs
@@ -15,14 +15,13 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use arrow::array::{as_largestring_array, Array};
+use arrow::datatypes::DataType;
 use std::any::Any;
 use std::sync::Arc;
 
-use arrow::datatypes::DataType;
-use arrow::datatypes::DataType::Utf8;
-
-use datafusion_common::cast::as_string_array;
-use datafusion_common::{internal_err, Result, ScalarValue};
+use datafusion_common::cast::{as_string_array, as_string_view_array};
+use datafusion_common::{internal_err, plan_err, Result, ScalarValue};
 use datafusion_expr::expr::ScalarFunction;
 use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
 use datafusion_expr::{lit, ColumnarValue, Expr, Volatility};
@@ -46,7 +45,10 @@ impl ConcatFunc {
     pub fn new() -> Self {
         use DataType::*;
         Self {
-            signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
+            signature: Signature::variadic(
+                vec![Utf8, Utf8View, LargeUtf8],
+                Volatility::Immutable,
+            ),
         }
     }
 }
@@ -64,13 +66,36 @@ impl ScalarUDFImpl for ConcatFunc {
         &self.signature
     }
 
-    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
-        Ok(Utf8)
+    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+        use DataType::*;
+        let mut dt = &Utf8;
+        arg_types.iter().for_each(|data_type| {
+            if data_type == &Utf8View {
+                dt = data_type;
+            }
+            if data_type == &LargeUtf8 && dt != &Utf8View {
+                dt = data_type;
+            }
+        });
+
+        Ok(dt.to_owned())
     }
 
     /// Concatenates the text representations of all the arguments. NULL 
arguments are ignored.
     /// concat('abcde', 2, NULL, 22) = 'abcde222'
     fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
+        let mut return_datatype = DataType::Utf8;
+        args.iter().for_each(|col| {
+            if col.data_type() == DataType::Utf8View {
+                return_datatype = col.data_type();
+            }
+            if col.data_type() == DataType::LargeUtf8
+                && return_datatype != DataType::Utf8View
+            {
+                return_datatype = col.data_type();
+            }
+        });
+
         let array_len = args
             .iter()
             .filter_map(|x| match x {
@@ -87,7 +112,21 @@ impl ScalarUDFImpl for ConcatFunc {
                     result.push_str(v);
                 }
             }
-            return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))));
+
+            return match return_datatype {
+                DataType::Utf8View => {
+                    
Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result))))
+                }
+                DataType::Utf8 => {
+                    Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))))
+                }
+                DataType::LargeUtf8 => {
+                    
Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result))))
+                }
+                other => {
+                    plan_err!("Concat function does not support datatype of 
{other}")
+                }
+            };
         }
 
         // Array
@@ -103,28 +142,95 @@ impl ScalarUDFImpl for ConcatFunc {
                         columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
                     }
                 }
+                ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
+                    if let Some(s) = maybe_value {
+                        data_size += s.len() * len;
+                        columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
+                    }
+                }
                 ColumnarValue::Array(array) => {
-                    let string_array = as_string_array(array)?;
-                    data_size += string_array.values().len();
-                    let column = if array.is_nullable() {
-                        ColumnarValueRef::NullableArray(string_array)
-                    } else {
-                        ColumnarValueRef::NonNullableArray(string_array)
+                    match array.data_type() {
+                        DataType::Utf8 => {
+                            let string_array = as_string_array(array)?;
+
+                            data_size += string_array.values().len();
+                            let column = if array.is_nullable() {
+                                ColumnarValueRef::NullableArray(string_array)
+                            } else {
+                                
ColumnarValueRef::NonNullableArray(string_array)
+                            };
+                            columns.push(column);
+                        },
+                        DataType::LargeUtf8 => {
+                            let string_array = as_largestring_array(array);
+
+                            data_size += string_array.values().len();
+                            let column = if array.is_nullable() {
+                                
ColumnarValueRef::NullableLargeStringArray(string_array)
+                            } else {
+                                
ColumnarValueRef::NonNullableLargeStringArray(string_array)
+                            };
+                            columns.push(column);
+                        },
+                        DataType::Utf8View => {
+                            let string_array = as_string_view_array(array)?;
+
+                            data_size += string_array.len();
+                            let column = if array.is_nullable() {
+                                
ColumnarValueRef::NullableStringViewArray(string_array)
+                            } else {
+                                
ColumnarValueRef::NonNullableStringViewArray(string_array)
+                            };
+                            columns.push(column);
+                        },
+                        other => {
+                            return plan_err!("Input was {other} which is not a 
supported datatype for concat function")
+                        }
                     };
-                    columns.push(column);
                 }
                 _ => unreachable!(),
             }
         }
 
-        let mut builder = StringArrayBuilder::with_capacity(len, data_size);
-        for i in 0..len {
-            columns
-                .iter()
-                .for_each(|column| builder.write::<true>(column, i));
-            builder.append_offset();
+        match return_datatype {
+            DataType::Utf8 => {
+                let mut builder = StringArrayBuilder::with_capacity(len, 
data_size);
+                for i in 0..len {
+                    columns
+                        .iter()
+                        .for_each(|column| builder.write::<true>(column, i));
+                    builder.append_offset();
+                }
+
+                let string_array = builder.finish(None);
+                Ok(ColumnarValue::Array(Arc::new(string_array)))
+            }
+            DataType::Utf8View => {
+                let mut builder = StringViewArrayBuilder::with_capacity(len, 
data_size);
+                for i in 0..len {
+                    columns
+                        .iter()
+                        .for_each(|column| builder.write::<true>(column, i));
+                    builder.append_offset();
+                }
+
+                let string_array = builder.finish();
+                Ok(ColumnarValue::Array(Arc::new(string_array)))
+            }
+            DataType::LargeUtf8 => {
+                let mut builder = LargeStringArrayBuilder::with_capacity(len, 
data_size);
+                for i in 0..len {
+                    columns
+                        .iter()
+                        .for_each(|column| builder.write::<true>(column, i));
+                    builder.append_offset();
+                }
+
+                let string_array = builder.finish(None);
+                Ok(ColumnarValue::Array(Arc::new(string_array)))
+            }
+            _ => unreachable!(),
         }
-        Ok(ColumnarValue::Array(Arc::new(builder.finish(None))))
     }
 
     /// Simplify the `concat` function by
@@ -151,11 +257,11 @@ pub fn simplify_concat(args: Vec<Expr>) -> 
Result<ExprSimplifyResult> {
     for arg in args.clone() {
         match arg {
             // filter out `null` args
-            Expr::Literal(ScalarValue::Utf8(None) | 
ScalarValue::LargeUtf8(None)) => {}
+            Expr::Literal(ScalarValue::Utf8(None) | 
ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {}
             // All literals have been converted to Utf8 or LargeUtf8 in 
type_coercion.
             // Concatenate it with the `contiguous_scalar`.
             Expr::Literal(
-                ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)),
+                ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | 
ScalarValue::Utf8View(Some(v)),
             ) => contiguous_scalar += &v,
             Expr::Literal(x) => {
                 return internal_err!(
@@ -195,8 +301,9 @@ pub fn simplify_concat(args: Vec<Expr>) -> 
Result<ExprSimplifyResult> {
 mod tests {
     use super::*;
     use crate::utils::test::test_function;
-    use arrow::array::Array;
+    use arrow::array::{Array, LargeStringArray, StringViewArray};
     use arrow::array::{ArrayRef, StringArray};
+    use DataType::*;
 
     #[test]
     fn test_functions() -> Result<()> {
@@ -232,6 +339,31 @@ mod tests {
             Utf8,
             StringArray
         );
+        test_function!(
+            ConcatFunc::new(),
+            &[
+                ColumnarValue::Scalar(ScalarValue::from("aa")),
+                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
+                ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
+                ColumnarValue::Scalar(ScalarValue::from("cc")),
+            ],
+            Ok(Some("aacc")),
+            &str,
+            Utf8View,
+            StringViewArray
+        );
+        test_function!(
+            ConcatFunc::new(),
+            &[
+                ColumnarValue::Scalar(ScalarValue::from("aa")),
+                ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
+                ColumnarValue::Scalar(ScalarValue::from("cc")),
+            ],
+            Ok(Some("aacc")),
+            &str,
+            LargeUtf8,
+            LargeStringArray
+        );
 
         Ok(())
     }
diff --git a/datafusion/sqllogictest/test_files/string_view.slt 
b/datafusion/sqllogictest/test_files/string_view.slt
index 83c75b8df3..eb625e530b 100644
--- a/datafusion/sqllogictest/test_files/string_view.slt
+++ b/datafusion/sqllogictest/test_files/string_view.slt
@@ -768,17 +768,26 @@ logical_plan
 01)Projection: character_length(test.column1_utf8view) AS l
 02)--TableScan: test projection=[column1_utf8view]
 
-## Ensure no casts for CONCAT
-## TODO https://github.com/apache/datafusion/issues/11836
+## Ensure no casts for CONCAT Utf8View
 query TT
 EXPLAIN SELECT
   concat(column1_utf8view, column2_utf8view) as c
 FROM test;
 ----
 logical_plan
-01)Projection: concat(CAST(test.column1_utf8view AS Utf8), 
CAST(test.column2_utf8view AS Utf8)) AS c
+01)Projection: concat(test.column1_utf8view, test.column2_utf8view) AS c
 02)--TableScan: test projection=[column1_utf8view, column2_utf8view]
 
+## Ensure no casts for CONCAT LargeUtf8
+query TT
+EXPLAIN SELECT
+  concat(column1_large_utf8, column2_large_utf8) as c
+FROM test;
+----
+logical_plan
+01)Projection: concat(test.column1_large_utf8, test.column2_large_utf8) AS c
+02)--TableScan: test projection=[column1_large_utf8, column2_large_utf8]
+
 ## Ensure no casts for CONCAT_WS
 ## TODO https://github.com/apache/datafusion/issues/11837
 query TT
@@ -863,6 +872,61 @@ XIANGPENG
 RAPHAEL
 NULL
 
+## Should run CONCAT successfully with utf8view
+query T
+SELECT
+  concat(column1_utf8view, column2_utf8view) as c
+FROM test;
+----
+AndrewX
+XiangpengXiangpeng
+RaphaelR
+R
+
+## Should run CONCAT successfully with utf8
+query T
+SELECT
+  concat(column1_utf8, column2_utf8) as c
+FROM test;
+----
+AndrewX
+XiangpengXiangpeng
+RaphaelR
+R
+
+## Should run CONCAT successfully with utf8 and utf8view
+query T
+SELECT
+  concat(column1_utf8view, column2_utf8) as c
+FROM test;
+----
+AndrewX
+XiangpengXiangpeng
+RaphaelR
+R
+
+## Should run CONCAT successfully with utf8 utf8view and largeutf8
+query T
+SELECT
+  concat(column1_utf8view, column2_utf8, column2_large_utf8) as c
+FROM test;
+----
+AndrewXX
+XiangpengXiangpengXiangpeng
+RaphaelRR
+RR
+
+## Should run CONCAT successfully with utf8large
+query T
+SELECT
+  concat(column1_large_utf8, column2_large_utf8) as c
+FROM test;
+----
+AndrewX
+XiangpengXiangpeng
+RaphaelR
+R
+
 ## Ensure no casts for LPAD
 query TT
 EXPLAIN SELECT
@@ -1307,3 +1371,4 @@ select column2|| ' ' ||column3 from temp;
 ----
 rust fast
 datafusion cool
+


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to