This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 87c1c173c8 Support string concat `||` for StringViewArray (#12063)
87c1c173c8 is described below
commit 87c1c173c8adb781d02e9907af297734f8a981ed
Author: Dharan Aditya <[email protected]>
AuthorDate: Thu Aug 22 21:25:26 2024 +0530
Support string concat `||` for StringViewArray (#12063)
* naive impl
* calc capacity
* cleanup
* Update test
* simplify coercion logic
* write some more tests
* Update tests
* Improve implementation and do the right thing for null
* add ticket reference
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/expr-common/src/type_coercion/binary.rs | 31 +++++----
datafusion/physical-expr/src/expressions/binary.rs | 53 +++++++---------
.../src/expressions/binary/kernels.rs | 33 ++++++++++
datafusion/sqllogictest/test_files/string_view.slt | 74 ++++++++++++++++++++--
4 files changed, 141 insertions(+), 50 deletions(-)
diff --git a/datafusion/expr-common/src/type_coercion/binary.rs
b/datafusion/expr-common/src/type_coercion/binary.rs
index f811d3e20d..401762ad4d 100644
--- a/datafusion/expr-common/src/type_coercion/binary.rs
+++ b/datafusion/expr-common/src/type_coercion/binary.rs
@@ -922,26 +922,22 @@ fn dictionary_comparison_coercion(
/// Coercion rules for string concat.
/// This is a union of string coercion rules and specified rules:
-/// 1. At lease one side of lhs and rhs should be string type (Utf8 /
LargeUtf8)
+/// 1. At least one side of lhs and rhs should be string type (Utf8 /
LargeUtf8)
/// 2. Data type of the other side should be able to cast to string type
fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) ->
Option<DataType> {
use arrow::datatypes::DataType::*;
- match (lhs_type, rhs_type) {
- // If Utf8View is in any side, we coerce to Utf8.
- // Ref: https://github.com/apache/datafusion/pull/11796
- (Utf8View, Utf8View | Utf8 | LargeUtf8) | (Utf8 | LargeUtf8, Utf8View)
=> {
- Some(Utf8)
+ string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) {
+ (Utf8View, from_type) | (from_type, Utf8View) => {
+ string_concat_internal_coercion(from_type, &Utf8View)
}
- _ => string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type)
{
- (Utf8, from_type) | (from_type, Utf8) => {
- string_concat_internal_coercion(from_type, &Utf8)
- }
- (LargeUtf8, from_type) | (from_type, LargeUtf8) => {
- string_concat_internal_coercion(from_type, &LargeUtf8)
- }
- _ => None,
- }),
- }
+ (Utf8, from_type) | (from_type, Utf8) => {
+ string_concat_internal_coercion(from_type, &Utf8)
+ }
+ (LargeUtf8, from_type) | (from_type, LargeUtf8) => {
+ string_concat_internal_coercion(from_type, &LargeUtf8)
+ }
+ _ => None,
+ })
}
fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) ->
Option<DataType> {
@@ -952,6 +948,8 @@ fn array_coercion(lhs_type: &DataType, rhs_type: &DataType)
-> Option<DataType>
}
}
+/// If `from_type` can be casted to `to_type`, return `to_type`, otherwise
+/// return `None`.
fn string_concat_internal_coercion(
from_type: &DataType,
to_type: &DataType,
@@ -977,6 +975,7 @@ fn string_coercion(lhs_type: &DataType, rhs_type:
&DataType) -> Option<DataType>
}
// Then, if LargeUtf8 is in any side, we coerce to LargeUtf8.
(LargeUtf8, Utf8 | LargeUtf8) | (Utf8, LargeUtf8) => Some(LargeUtf8),
+ // Utf8 coerces to Utf8
(Utf8, Utf8) => Some(Utf8),
_ => None,
}
diff --git a/datafusion/physical-expr/src/expressions/binary.rs
b/datafusion/physical-expr/src/expressions/binary.rs
index 26885ae135..b663d86142 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -41,6 +41,7 @@ use datafusion_expr::type_coercion::binary::get_result_type;
use datafusion_expr::{ColumnarValue, Operator};
use datafusion_physical_expr_common::datum::{apply, apply_cmp,
apply_cmp_for_nested};
+use crate::expressions::binary::kernels::concat_elements_utf8view;
use kernels::{
bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn,
bitwise_or_dyn_scalar,
bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar,
bitwise_shift_right_dyn,
@@ -131,34 +132,6 @@ impl std::fmt::Display for BinaryExpr {
}
}
-/// Invoke a compute kernel on a pair of binary data arrays
-macro_rules! compute_utf8_op {
- ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
- let ll = $LEFT
- .as_any()
- .downcast_ref::<$DT>()
- .expect("compute_op failed to downcast left side array");
- let rr = $RIGHT
- .as_any()
- .downcast_ref::<$DT>()
- .expect("compute_op failed to downcast right side array");
- Ok(Arc::new(paste::expr! {[<$OP _utf8>]}(&ll, &rr)?))
- }};
-}
-
-macro_rules! binary_string_array_op {
- ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
- match $LEFT.data_type() {
- DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP,
StringArray),
- DataType::LargeUtf8 => compute_utf8_op!($LEFT, $RIGHT, $OP,
LargeStringArray),
- other => internal_err!(
- "Data type {:?} not supported for binary operation '{}' on
string arrays",
- other, stringify!($OP)
- ),
- }
- }};
-}
-
/// Invoke a boolean kernel on a pair of arrays
macro_rules! boolean_op {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
@@ -662,7 +635,7 @@ impl BinaryExpr {
BitwiseXor => bitwise_xor_dyn(left, right),
BitwiseShiftRight => bitwise_shift_right_dyn(left, right),
BitwiseShiftLeft => bitwise_shift_left_dyn(left, right),
- StringConcat => binary_string_array_op!(left, right,
concat_elements),
+ StringConcat => concat_elements(left, right),
AtArrow | ArrowAt => {
unreachable!("ArrowAt and AtArrow should be rewritten to
function")
}
@@ -670,6 +643,28 @@ impl BinaryExpr {
}
}
+fn concat_elements(left: Arc<dyn Array>, right: Arc<dyn Array>) ->
Result<ArrayRef> {
+ Ok(match left.data_type() {
+ DataType::Utf8 => Arc::new(concat_elements_utf8(
+ left.as_string::<i32>(),
+ right.as_string::<i32>(),
+ )?),
+ DataType::LargeUtf8 => Arc::new(concat_elements_utf8(
+ left.as_string::<i64>(),
+ right.as_string::<i64>(),
+ )?),
+ DataType::Utf8View => Arc::new(concat_elements_utf8view(
+ left.as_string_view(),
+ right.as_string_view(),
+ )?),
+ other => {
+ return internal_err!(
+ "Data type {other:?} not supported for binary operation
'concat_elements' on string arrays"
+ );
+ }
+ })
+}
+
/// Create a binary expression whose arguments are correctly coerced.
/// This function errors if it is not possible to coerce the arguments
/// to computational types supported by the operator.
diff --git a/datafusion/physical-expr/src/expressions/binary/kernels.rs
b/datafusion/physical-expr/src/expressions/binary/kernels.rs
index b0736e140f..1f9cfed1a4 100644
--- a/datafusion/physical-expr/src/expressions/binary/kernels.rs
+++ b/datafusion/physical-expr/src/expressions/binary/kernels.rs
@@ -27,6 +27,7 @@ use arrow::datatypes::DataType;
use datafusion_common::internal_err;
use datafusion_common::{Result, ScalarValue};
+use arrow_schema::ArrowError;
use std::sync::Arc;
/// Downcasts $LEFT and $RIGHT to $ARRAY_TYPE and then calls $KERNEL($LEFT,
$RIGHT)
@@ -131,3 +132,35 @@ create_dyn_scalar_kernel!(bitwise_or_dyn_scalar,
bitwise_or_scalar);
create_dyn_scalar_kernel!(bitwise_xor_dyn_scalar, bitwise_xor_scalar);
create_dyn_scalar_kernel!(bitwise_shift_right_dyn_scalar,
bitwise_shift_right_scalar);
create_dyn_scalar_kernel!(bitwise_shift_left_dyn_scalar,
bitwise_shift_left_scalar);
+
+pub fn concat_elements_utf8view(
+ left: &StringViewArray,
+ right: &StringViewArray,
+) -> std::result::Result<StringViewArray, ArrowError> {
+ let capacity = left
+ .data_buffers()
+ .iter()
+ .zip(right.data_buffers().iter())
+ .map(|(b1, b2)| b1.len() + b2.len())
+ .sum();
+ let mut result = StringViewBuilder::with_capacity(capacity);
+
+ // Avoid reallocations by writing to a reused buffer (note we
+ // could be even more efficient r by creating the view directly
+ // here and avoid the buffer but that would be more complex)
+ let mut buffer = String::new();
+
+ for (left, right) in left.iter().zip(right.iter()) {
+ if let (Some(left), Some(right)) = (left, right) {
+ use std::fmt::Write;
+ buffer.clear();
+ write!(&mut buffer, "{left}{right}")
+ .expect("writing into string buffer failed");
+ result.append_value(&buffer);
+ } else {
+ // at least one of the values is null, so the output is also null
+ result.append_null()
+ }
+ }
+ Ok(result.finish())
+}
diff --git a/datafusion/sqllogictest/test_files/string_view.slt
b/datafusion/sqllogictest/test_files/string_view.slt
index 4b4eba0522..3b3d7b88a4 100644
--- a/datafusion/sqllogictest/test_files/string_view.slt
+++ b/datafusion/sqllogictest/test_files/string_view.slt
@@ -1144,6 +1144,63 @@ FROM test;
0
NULL
+# || mixed types
+# expect all results to be the same for each row as they all have the same
values
+query TTTTTTTT
+SELECT
+ column1_utf8view || column2_utf8view,
+ column1_utf8 || column2_utf8view,
+ column1_large_utf8 || column2_utf8view,
+ column1_dict || column2_utf8view,
+ -- reverse argument order
+ column2_utf8view || column1_utf8view,
+ column2_utf8view || column1_utf8,
+ column2_utf8view || column1_large_utf8,
+ column2_utf8view || column1_dict
+FROM test;
+----
+AndrewX AndrewX AndrewX AndrewX XAndrew XAndrew XAndrew XAndrew
+XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng
XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng
+RaphaelR RaphaelR RaphaelR RaphaelR RRaphael RRaphael RRaphael RRaphael
+NULL NULL NULL NULL NULL NULL NULL NULL
+
+# || constants
+# expect all results to be the same for each row as they all have the same
values
+query TTTTTTTT
+SELECT
+ column1_utf8view || 'foo',
+ column1_utf8 || 'foo',
+ column1_large_utf8 || 'foo',
+ column1_dict || 'foo',
+ -- reverse argument order
+ 'foo' || column1_utf8view,
+ 'foo' || column1_utf8,
+ 'foo' || column1_large_utf8,
+ 'foo' || column1_dict
+FROM test;
+----
+Andrewfoo Andrewfoo Andrewfoo Andrewfoo fooAndrew fooAndrew fooAndrew fooAndrew
+Xiangpengfoo Xiangpengfoo Xiangpengfoo Xiangpengfoo fooXiangpeng fooXiangpeng
fooXiangpeng fooXiangpeng
+Raphaelfoo Raphaelfoo Raphaelfoo Raphaelfoo fooRaphael fooRaphael fooRaphael
fooRaphael
+NULL NULL NULL NULL NULL NULL NULL NULL
+
+# || same type (column1 has null, so also tests NULL || NULL)
+# expect all results to be the same for each row as they all have the same
values
+query TTT
+SELECT
+ column1_utf8view || column1_utf8view,
+ column1_utf8 || column1_utf8,
+ column1_large_utf8 || column1_large_utf8
+ -- Dictionary/Dictionary coercion doesn't work
+ -- https://github.com/apache/datafusion/issues/12101
+ --column1_dict || column1_dict
+FROM test;
+----
+AndrewAndrew AndrewAndrew AndrewAndrew
+XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng
+RaphaelRaphael RaphaelRaphael RaphaelRaphael
+NULL NULL NULL
+
statement ok
drop table test;
@@ -1167,18 +1224,25 @@ select t.dt from dates t where arrow_cast('2024-01-01',
'Utf8View') < t.dt;
statement ok
drop table dates;
+### Tests for `||` with Utf8View specifically
+
statement ok
create table temp as values
('value1', arrow_cast('rust', 'Utf8View'), arrow_cast('fast', 'Utf8View')),
('value2', arrow_cast('datafusion', 'Utf8View'), arrow_cast('cool',
'Utf8View'));
+query TTT
+select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3)
from temp;
+----
+Utf8 Utf8View Utf8View
+Utf8 Utf8View Utf8View
+
query T
select column2||' is fast' from temp;
----
rust is fast
datafusion is fast
-
query T
select column2 || ' is ' || column3 from temp;
----
@@ -1189,15 +1253,15 @@ query TT
explain select column2 || 'is' || column3 from temp;
----
logical_plan
-01)Projection: CAST(temp.column2 AS Utf8) || Utf8("is") || CAST(temp.column3
AS Utf8)
+01)Projection: temp.column2 || Utf8View("is") || temp.column3 AS temp.column2
|| Utf8("is") || temp.column3
02)--TableScan: temp projection=[column2, column3]
-
+# should not cast the column2 to utf8
query TT
explain select column2||' is fast' from temp;
----
logical_plan
-01)Projection: CAST(temp.column2 AS Utf8) || Utf8(" is fast")
+01)Projection: temp.column2 || Utf8View(" is fast") AS temp.column2 || Utf8("
is fast")
02)--TableScan: temp projection=[column2]
@@ -1211,7 +1275,7 @@ query TT
explain select column2||column3 from temp;
----
logical_plan
-01)Projection: CAST(temp.column2 AS Utf8) || CAST(temp.column3 AS Utf8)
+01)Projection: temp.column2 || temp.column3
02)--TableScan: temp projection=[column2, column3]
query T
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]