This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 87c1c173c8 Support string concat `||` for StringViewArray  (#12063)
87c1c173c8 is described below

commit 87c1c173c8adb781d02e9907af297734f8a981ed
Author: Dharan Aditya <[email protected]>
AuthorDate: Thu Aug 22 21:25:26 2024 +0530

    Support string concat `||` for StringViewArray  (#12063)
    
    * naive impl
    
    * calc capacity
    
    * cleanup
    
    * Update test
    
    * simplify coercion logic
    
    * write some more tests
    
    * Update tests
    
    * Improve implementation and do the right thing for null
    
    * add ticket reference
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/expr-common/src/type_coercion/binary.rs | 31 +++++----
 datafusion/physical-expr/src/expressions/binary.rs | 53 +++++++---------
 .../src/expressions/binary/kernels.rs              | 33 ++++++++++
 datafusion/sqllogictest/test_files/string_view.slt | 74 ++++++++++++++++++++--
 4 files changed, 141 insertions(+), 50 deletions(-)

diff --git a/datafusion/expr-common/src/type_coercion/binary.rs 
b/datafusion/expr-common/src/type_coercion/binary.rs
index f811d3e20d..401762ad4d 100644
--- a/datafusion/expr-common/src/type_coercion/binary.rs
+++ b/datafusion/expr-common/src/type_coercion/binary.rs
@@ -922,26 +922,22 @@ fn dictionary_comparison_coercion(
 
 /// Coercion rules for string concat.
 /// This is a union of string coercion rules and specified rules:
-/// 1. At lease one side of lhs and rhs should be string type (Utf8 / 
LargeUtf8)
+/// 1. At least one side of lhs and rhs should be string type (Utf8 / 
LargeUtf8)
 /// 2. Data type of the other side should be able to cast to string type
 fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> 
Option<DataType> {
     use arrow::datatypes::DataType::*;
-    match (lhs_type, rhs_type) {
-        // If Utf8View is in any side, we coerce to Utf8.
-        // Ref: https://github.com/apache/datafusion/pull/11796
-        (Utf8View, Utf8View | Utf8 | LargeUtf8) | (Utf8 | LargeUtf8, Utf8View) 
=> {
-            Some(Utf8)
+    string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) {
+        (Utf8View, from_type) | (from_type, Utf8View) => {
+            string_concat_internal_coercion(from_type, &Utf8View)
         }
-        _ => string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) 
{
-            (Utf8, from_type) | (from_type, Utf8) => {
-                string_concat_internal_coercion(from_type, &Utf8)
-            }
-            (LargeUtf8, from_type) | (from_type, LargeUtf8) => {
-                string_concat_internal_coercion(from_type, &LargeUtf8)
-            }
-            _ => None,
-        }),
-    }
+        (Utf8, from_type) | (from_type, Utf8) => {
+            string_concat_internal_coercion(from_type, &Utf8)
+        }
+        (LargeUtf8, from_type) | (from_type, LargeUtf8) => {
+            string_concat_internal_coercion(from_type, &LargeUtf8)
+        }
+        _ => None,
+    })
 }
 
 fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> 
Option<DataType> {
@@ -952,6 +948,8 @@ fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) 
-> Option<DataType>
     }
 }
 
+/// If `from_type` can be casted to `to_type`, return `to_type`, otherwise
+/// return `None`.
 fn string_concat_internal_coercion(
     from_type: &DataType,
     to_type: &DataType,
@@ -977,6 +975,7 @@ fn string_coercion(lhs_type: &DataType, rhs_type: 
&DataType) -> Option<DataType>
         }
         // Then, if LargeUtf8 is in any side, we coerce to LargeUtf8.
         (LargeUtf8, Utf8 | LargeUtf8) | (Utf8, LargeUtf8) => Some(LargeUtf8),
+        // Utf8 coerces to Utf8
         (Utf8, Utf8) => Some(Utf8),
         _ => None,
     }
diff --git a/datafusion/physical-expr/src/expressions/binary.rs 
b/datafusion/physical-expr/src/expressions/binary.rs
index 26885ae135..b663d86142 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -41,6 +41,7 @@ use datafusion_expr::type_coercion::binary::get_result_type;
 use datafusion_expr::{ColumnarValue, Operator};
 use datafusion_physical_expr_common::datum::{apply, apply_cmp, 
apply_cmp_for_nested};
 
+use crate::expressions::binary::kernels::concat_elements_utf8view;
 use kernels::{
     bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, 
bitwise_or_dyn_scalar,
     bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, 
bitwise_shift_right_dyn,
@@ -131,34 +132,6 @@ impl std::fmt::Display for BinaryExpr {
     }
 }
 
-/// Invoke a compute kernel on a pair of binary data arrays
-macro_rules! compute_utf8_op {
-    ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
-        let ll = $LEFT
-            .as_any()
-            .downcast_ref::<$DT>()
-            .expect("compute_op failed to downcast left side array");
-        let rr = $RIGHT
-            .as_any()
-            .downcast_ref::<$DT>()
-            .expect("compute_op failed to downcast right side array");
-        Ok(Arc::new(paste::expr! {[<$OP _utf8>]}(&ll, &rr)?))
-    }};
-}
-
-macro_rules! binary_string_array_op {
-    ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
-        match $LEFT.data_type() {
-            DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, 
StringArray),
-            DataType::LargeUtf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, 
LargeStringArray),
-            other => internal_err!(
-                "Data type {:?} not supported for binary operation '{}' on 
string arrays",
-                other, stringify!($OP)
-            ),
-        }
-    }};
-}
-
 /// Invoke a boolean kernel on a pair of arrays
 macro_rules! boolean_op {
     ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
@@ -662,7 +635,7 @@ impl BinaryExpr {
             BitwiseXor => bitwise_xor_dyn(left, right),
             BitwiseShiftRight => bitwise_shift_right_dyn(left, right),
             BitwiseShiftLeft => bitwise_shift_left_dyn(left, right),
-            StringConcat => binary_string_array_op!(left, right, 
concat_elements),
+            StringConcat => concat_elements(left, right),
             AtArrow | ArrowAt => {
                 unreachable!("ArrowAt and AtArrow should be rewritten to 
function")
             }
@@ -670,6 +643,28 @@ impl BinaryExpr {
     }
 }
 
+fn concat_elements(left: Arc<dyn Array>, right: Arc<dyn Array>) -> 
Result<ArrayRef> {
+    Ok(match left.data_type() {
+        DataType::Utf8 => Arc::new(concat_elements_utf8(
+            left.as_string::<i32>(),
+            right.as_string::<i32>(),
+        )?),
+        DataType::LargeUtf8 => Arc::new(concat_elements_utf8(
+            left.as_string::<i64>(),
+            right.as_string::<i64>(),
+        )?),
+        DataType::Utf8View => Arc::new(concat_elements_utf8view(
+            left.as_string_view(),
+            right.as_string_view(),
+        )?),
+        other => {
+            return internal_err!(
+                "Data type {other:?} not supported for binary operation 
'concat_elements' on string arrays"
+            );
+        }
+    })
+}
+
 /// Create a binary expression whose arguments are correctly coerced.
 /// This function errors if it is not possible to coerce the arguments
 /// to computational types supported by the operator.
diff --git a/datafusion/physical-expr/src/expressions/binary/kernels.rs 
b/datafusion/physical-expr/src/expressions/binary/kernels.rs
index b0736e140f..1f9cfed1a4 100644
--- a/datafusion/physical-expr/src/expressions/binary/kernels.rs
+++ b/datafusion/physical-expr/src/expressions/binary/kernels.rs
@@ -27,6 +27,7 @@ use arrow::datatypes::DataType;
 use datafusion_common::internal_err;
 use datafusion_common::{Result, ScalarValue};
 
+use arrow_schema::ArrowError;
 use std::sync::Arc;
 
 /// Downcasts $LEFT and $RIGHT to $ARRAY_TYPE and then calls $KERNEL($LEFT, 
$RIGHT)
@@ -131,3 +132,35 @@ create_dyn_scalar_kernel!(bitwise_or_dyn_scalar, 
bitwise_or_scalar);
 create_dyn_scalar_kernel!(bitwise_xor_dyn_scalar, bitwise_xor_scalar);
 create_dyn_scalar_kernel!(bitwise_shift_right_dyn_scalar, 
bitwise_shift_right_scalar);
 create_dyn_scalar_kernel!(bitwise_shift_left_dyn_scalar, 
bitwise_shift_left_scalar);
+
+pub fn concat_elements_utf8view(
+    left: &StringViewArray,
+    right: &StringViewArray,
+) -> std::result::Result<StringViewArray, ArrowError> {
+    let capacity = left
+        .data_buffers()
+        .iter()
+        .zip(right.data_buffers().iter())
+        .map(|(b1, b2)| b1.len() + b2.len())
+        .sum();
+    let mut result = StringViewBuilder::with_capacity(capacity);
+
+    // Avoid reallocations by writing to a reused buffer (note we
+    // could be even more efficient r by creating the view directly
+    // here and avoid the buffer but that would be more complex)
+    let mut buffer = String::new();
+
+    for (left, right) in left.iter().zip(right.iter()) {
+        if let (Some(left), Some(right)) = (left, right) {
+            use std::fmt::Write;
+            buffer.clear();
+            write!(&mut buffer, "{left}{right}")
+                .expect("writing into string buffer failed");
+            result.append_value(&buffer);
+        } else {
+            // at least one of the values is null, so the output is also null
+            result.append_null()
+        }
+    }
+    Ok(result.finish())
+}
diff --git a/datafusion/sqllogictest/test_files/string_view.slt 
b/datafusion/sqllogictest/test_files/string_view.slt
index 4b4eba0522..3b3d7b88a4 100644
--- a/datafusion/sqllogictest/test_files/string_view.slt
+++ b/datafusion/sqllogictest/test_files/string_view.slt
@@ -1144,6 +1144,63 @@ FROM test;
 0
 NULL
 
+# || mixed types
+# expect all results to be the same for each row as they all have the same 
values
+query TTTTTTTT
+SELECT
+  column1_utf8view || column2_utf8view,
+  column1_utf8 || column2_utf8view,
+  column1_large_utf8 || column2_utf8view,
+  column1_dict || column2_utf8view,
+  -- reverse argument order
+  column2_utf8view || column1_utf8view,
+  column2_utf8view || column1_utf8,
+  column2_utf8view || column1_large_utf8,
+  column2_utf8view || column1_dict
+FROM test;
+----
+AndrewX AndrewX AndrewX AndrewX XAndrew XAndrew XAndrew XAndrew
+XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng 
XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng
+RaphaelR RaphaelR RaphaelR RaphaelR RRaphael RRaphael RRaphael RRaphael
+NULL NULL NULL NULL NULL NULL NULL NULL
+
+# || constants
+# expect all results to be the same for each row as they all have the same 
values
+query TTTTTTTT
+SELECT
+  column1_utf8view || 'foo',
+  column1_utf8 || 'foo',
+  column1_large_utf8 || 'foo',
+  column1_dict || 'foo',
+  -- reverse argument order
+  'foo' || column1_utf8view,
+  'foo' || column1_utf8,
+  'foo' || column1_large_utf8,
+  'foo' || column1_dict
+FROM test;
+----
+Andrewfoo Andrewfoo Andrewfoo Andrewfoo fooAndrew fooAndrew fooAndrew fooAndrew
+Xiangpengfoo Xiangpengfoo Xiangpengfoo Xiangpengfoo fooXiangpeng fooXiangpeng 
fooXiangpeng fooXiangpeng
+Raphaelfoo Raphaelfoo Raphaelfoo Raphaelfoo fooRaphael fooRaphael fooRaphael 
fooRaphael
+NULL NULL NULL NULL NULL NULL NULL NULL
+
+# || same type (column1 has null, so also tests NULL || NULL)
+# expect all results to be the same for each row as they all have the same 
values
+query TTT
+SELECT
+  column1_utf8view || column1_utf8view,
+  column1_utf8 || column1_utf8,
+  column1_large_utf8 || column1_large_utf8
+  -- Dictionary/Dictionary coercion doesn't work
+  -- https://github.com/apache/datafusion/issues/12101
+  --column1_dict || column1_dict
+FROM test;
+----
+AndrewAndrew AndrewAndrew AndrewAndrew
+XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng
+RaphaelRaphael RaphaelRaphael RaphaelRaphael
+NULL NULL NULL
+
 statement ok
 drop table test;
 
@@ -1167,18 +1224,25 @@ select t.dt from dates t where arrow_cast('2024-01-01', 
'Utf8View') < t.dt;
 statement ok
 drop table dates;
 
+### Tests for `||` with Utf8View specifically
+
 statement ok
 create table temp as values
 ('value1', arrow_cast('rust', 'Utf8View'), arrow_cast('fast', 'Utf8View')),
 ('value2', arrow_cast('datafusion', 'Utf8View'), arrow_cast('cool', 
'Utf8View'));
 
+query TTT
+select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3) 
from temp;
+----
+Utf8 Utf8View Utf8View
+Utf8 Utf8View Utf8View
+
 query T
 select column2||' is fast' from temp;
 ----
 rust is fast
 datafusion is fast
 
-
 query T
 select column2 || ' is ' || column3 from temp;
 ----
@@ -1189,15 +1253,15 @@ query TT
 explain select column2 || 'is' || column3 from temp;
 ----
 logical_plan
-01)Projection: CAST(temp.column2 AS Utf8) || Utf8("is") || CAST(temp.column3 
AS Utf8)
+01)Projection: temp.column2 || Utf8View("is") || temp.column3 AS temp.column2 
|| Utf8("is") || temp.column3
 02)--TableScan: temp projection=[column2, column3]
 
-
+# should not cast the column2 to utf8
 query TT
 explain select column2||' is fast' from temp;
 ----
 logical_plan
-01)Projection: CAST(temp.column2 AS Utf8) || Utf8(" is fast")
+01)Projection: temp.column2 || Utf8View(" is fast") AS temp.column2 || Utf8(" 
is fast")
 02)--TableScan: temp projection=[column2]
 
 
@@ -1211,7 +1275,7 @@ query TT
 explain select column2||column3 from temp;
 ----
 logical_plan
-01)Projection: CAST(temp.column2 AS Utf8) || CAST(temp.column3 AS Utf8)
+01)Projection: temp.column2 || temp.column3
 02)--TableScan: temp projection=[column2, column3]
 
 query T


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to