This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new e4a94243b5 Update the CONCAT scalar function to support Utf8View
(#12224)
e4a94243b5 is described below
commit e4a94243b502da2ad07a358b4401052651952eea
Author: WeblWabl <[email protected]>
AuthorDate: Tue Sep 3 15:12:47 2024 -0500
Update the CONCAT scalar function to support Utf8View (#12224)
* wip
* feat: Update the CONCAT scalar function to support Utf8View
* fmt
* fmt and add default return type for concat
* fix clippy lint
Signed-off-by: Devan <[email protected]>
* fmt
Signed-off-by: Devan <[email protected]>
* add more tests for sqllogic
Signed-off-by: Devan <[email protected]>
* make sure no casting with LargeUtf8
* fixing utf8large
* fix large utf8
Signed-off-by: Devan <[email protected]>
* fix large utf8
Signed-off-by: Devan <[email protected]>
* add test
Signed-off-by: Devan <[email protected]>
* fmt
Signed-off-by: Devan <[email protected]>
* make it so Utf8View just returns Utf8
Signed-off-by: Devan <[email protected]>
* wip -- trying to build a stringview with columnar refs
Signed-off-by: Devan <[email protected]>
* built stringview builder but it does allocate a new String each iter :(
Signed-off-by: Devan <[email protected]>
* add some testing
Signed-off-by: Devan <[email protected]>
* clippy
Signed-off-by: Devan <[email protected]>
---------
Signed-off-by: Devan <[email protected]>
---
datafusion/functions/src/string/common.rs | 195 ++++++++++++++++++++-
datafusion/functions/src/string/concat.rs | 184 ++++++++++++++++---
datafusion/sqllogictest/test_files/string_view.slt | 71 +++++++-
3 files changed, 416 insertions(+), 34 deletions(-)
diff --git a/datafusion/functions/src/string/common.rs
b/datafusion/functions/src/string/common.rs
index 9738cb812f..6ebcc4ee6c 100644
--- a/datafusion/functions/src/string/common.rs
+++ b/datafusion/functions/src/string/common.rs
@@ -22,12 +22,11 @@ use std::sync::Arc;
use arrow::array::{
new_null_array, Array, ArrayAccessor, ArrayDataBuilder, ArrayIter,
ArrayRef,
- GenericStringArray, GenericStringBuilder, OffsetSizeTrait, StringArray,
- StringBuilder, StringViewArray,
+ GenericStringArray, GenericStringBuilder, LargeStringArray,
OffsetSizeTrait,
+ StringArray, StringBuilder, StringViewArray, StringViewBuilder,
};
use arrow::buffer::{Buffer, MutableBuffer, NullBuffer};
use arrow::datatypes::DataType;
-
use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
use datafusion_common::Result;
use datafusion_common::{exec_err, ScalarValue};
@@ -249,26 +248,41 @@ where
}
}
+#[derive(Debug)]
pub(crate) enum ColumnarValueRef<'a> {
Scalar(&'a [u8]),
NullableArray(&'a StringArray),
NonNullableArray(&'a StringArray),
+ NullableLargeStringArray(&'a LargeStringArray),
+ NonNullableLargeStringArray(&'a LargeStringArray),
+ NullableStringViewArray(&'a StringViewArray),
+ NonNullableStringViewArray(&'a StringViewArray),
}
impl<'a> ColumnarValueRef<'a> {
#[inline]
pub fn is_valid(&self, i: usize) -> bool {
match &self {
- Self::Scalar(_) | Self::NonNullableArray(_) => true,
+ Self::Scalar(_)
+ | Self::NonNullableArray(_)
+ | Self::NonNullableLargeStringArray(_)
+ | Self::NonNullableStringViewArray(_) => true,
Self::NullableArray(array) => array.is_valid(i),
+ Self::NullableStringViewArray(array) => array.is_valid(i),
+ Self::NullableLargeStringArray(array) => array.is_valid(i),
}
}
#[inline]
pub fn nulls(&self) -> Option<NullBuffer> {
match &self {
- Self::Scalar(_) | Self::NonNullableArray(_) => None,
+ Self::Scalar(_)
+ | Self::NonNullableArray(_)
+ | Self::NonNullableStringViewArray(_)
+ | Self::NonNullableLargeStringArray(_) => None,
Self::NullableArray(array) => array.nulls().cloned(),
+ Self::NullableStringViewArray(array) => array.nulls().cloned(),
+ Self::NullableLargeStringArray(array) => array.nulls().cloned(),
}
}
}
@@ -387,10 +401,30 @@ impl StringArrayBuilder {
.extend_from_slice(array.value(i).as_bytes());
}
}
+ ColumnarValueRef::NullableLargeStringArray(array) => {
+ if !CHECK_VALID || array.is_valid(i) {
+ self.value_buffer
+ .extend_from_slice(array.value(i).as_bytes());
+ }
+ }
+ ColumnarValueRef::NullableStringViewArray(array) => {
+ if !CHECK_VALID || array.is_valid(i) {
+ self.value_buffer
+ .extend_from_slice(array.value(i).as_bytes());
+ }
+ }
ColumnarValueRef::NonNullableArray(array) => {
self.value_buffer
.extend_from_slice(array.value(i).as_bytes());
}
+ ColumnarValueRef::NonNullableLargeStringArray(array) => {
+ self.value_buffer
+ .extend_from_slice(array.value(i).as_bytes());
+ }
+ ColumnarValueRef::NonNullableStringViewArray(array) => {
+ self.value_buffer
+ .extend_from_slice(array.value(i).as_bytes());
+ }
}
}
@@ -416,6 +450,157 @@ impl StringArrayBuilder {
}
}
+pub(crate) struct StringViewArrayBuilder {
+ builder: StringViewBuilder,
+ block: String,
+}
+
+impl StringViewArrayBuilder {
+ pub fn with_capacity(_item_capacity: usize, data_capacity: usize) -> Self {
+ let builder = StringViewBuilder::with_capacity(data_capacity);
+ Self {
+ builder,
+ block: String::new(),
+ }
+ }
+
+ pub fn write<const CHECK_VALID: bool>(
+ &mut self,
+ column: &ColumnarValueRef,
+ i: usize,
+ ) {
+ match column {
+ ColumnarValueRef::Scalar(s) => {
+ self.block.push_str(std::str::from_utf8(s).unwrap());
+ }
+ ColumnarValueRef::NullableArray(array) => {
+ if !CHECK_VALID || array.is_valid(i) {
+ self.block.push_str(
+
std::str::from_utf8(array.value(i).as_bytes()).unwrap(),
+ );
+ }
+ }
+ ColumnarValueRef::NullableLargeStringArray(array) => {
+ if !CHECK_VALID || array.is_valid(i) {
+ self.block.push_str(
+
std::str::from_utf8(array.value(i).as_bytes()).unwrap(),
+ );
+ }
+ }
+ ColumnarValueRef::NullableStringViewArray(array) => {
+ if !CHECK_VALID || array.is_valid(i) {
+ self.block.push_str(
+
std::str::from_utf8(array.value(i).as_bytes()).unwrap(),
+ );
+ }
+ }
+ ColumnarValueRef::NonNullableArray(array) => {
+ self.block
+
.push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap());
+ }
+ ColumnarValueRef::NonNullableLargeStringArray(array) => {
+ self.block
+
.push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap());
+ }
+ ColumnarValueRef::NonNullableStringViewArray(array) => {
+ self.block
+
.push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap());
+ }
+ }
+ }
+
+ pub fn append_offset(&mut self) {
+ self.builder.append_value(&self.block);
+ self.block = String::new();
+ }
+
+ pub fn finish(mut self) -> StringViewArray {
+ self.builder.finish()
+ }
+}
+
+pub(crate) struct LargeStringArrayBuilder {
+ offsets_buffer: MutableBuffer,
+ value_buffer: MutableBuffer,
+}
+
+impl LargeStringArrayBuilder {
+ pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self {
+ let mut offsets_buffer = MutableBuffer::with_capacity(
+ (item_capacity + 1) * std::mem::size_of::<i64>(),
+ );
+ // SAFETY: the first offset value is definitely not going to exceed
the bounds.
+ unsafe { offsets_buffer.push_unchecked(0_i64) };
+ Self {
+ offsets_buffer,
+ value_buffer: MutableBuffer::with_capacity(data_capacity),
+ }
+ }
+
+ pub fn write<const CHECK_VALID: bool>(
+ &mut self,
+ column: &ColumnarValueRef,
+ i: usize,
+ ) {
+ match column {
+ ColumnarValueRef::Scalar(s) => {
+ self.value_buffer.extend_from_slice(s);
+ }
+ ColumnarValueRef::NullableArray(array) => {
+ if !CHECK_VALID || array.is_valid(i) {
+ self.value_buffer
+ .extend_from_slice(array.value(i).as_bytes());
+ }
+ }
+ ColumnarValueRef::NullableLargeStringArray(array) => {
+ if !CHECK_VALID || array.is_valid(i) {
+ self.value_buffer
+ .extend_from_slice(array.value(i).as_bytes());
+ }
+ }
+ ColumnarValueRef::NullableStringViewArray(array) => {
+ if !CHECK_VALID || array.is_valid(i) {
+ self.value_buffer
+ .extend_from_slice(array.value(i).as_bytes());
+ }
+ }
+ ColumnarValueRef::NonNullableArray(array) => {
+ self.value_buffer
+ .extend_from_slice(array.value(i).as_bytes());
+ }
+ ColumnarValueRef::NonNullableLargeStringArray(array) => {
+ self.value_buffer
+ .extend_from_slice(array.value(i).as_bytes());
+ }
+ ColumnarValueRef::NonNullableStringViewArray(array) => {
+ self.value_buffer
+ .extend_from_slice(array.value(i).as_bytes());
+ }
+ }
+ }
+
+ pub fn append_offset(&mut self) {
+ let next_offset: i64 = self
+ .value_buffer
+ .len()
+ .try_into()
+ .expect("byte array offset overflow");
+ unsafe { self.offsets_buffer.push_unchecked(next_offset) };
+ }
+
+ pub fn finish(self, null_buffer: Option<NullBuffer>) -> LargeStringArray {
+ let array_builder = ArrayDataBuilder::new(DataType::LargeUtf8)
+ .len(self.offsets_buffer.len() / std::mem::size_of::<i64>() - 1)
+ .add_buffer(self.offsets_buffer.into())
+ .add_buffer(self.value_buffer.into())
+ .nulls(null_buffer);
+ // SAFETY: all data that was appended was valid Large UTF8 and the
values
+ // and offsets were created correctly
+ let array_data = unsafe { array_builder.build_unchecked() };
+ LargeStringArray::from(array_data)
+ }
+}
+
fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) ->
Result<ArrayRef>
where
O: OffsetSizeTrait,
diff --git a/datafusion/functions/src/string/concat.rs
b/datafusion/functions/src/string/concat.rs
index 6d15e22067..00fe69b0bd 100644
--- a/datafusion/functions/src/string/concat.rs
+++ b/datafusion/functions/src/string/concat.rs
@@ -15,14 +15,13 @@
// specific language governing permissions and limitations
// under the License.
+use arrow::array::{as_largestring_array, Array};
+use arrow::datatypes::DataType;
use std::any::Any;
use std::sync::Arc;
-use arrow::datatypes::DataType;
-use arrow::datatypes::DataType::Utf8;
-
-use datafusion_common::cast::as_string_array;
-use datafusion_common::{internal_err, Result, ScalarValue};
+use datafusion_common::cast::{as_string_array, as_string_view_array};
+use datafusion_common::{internal_err, plan_err, Result, ScalarValue};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{lit, ColumnarValue, Expr, Volatility};
@@ -46,7 +45,10 @@ impl ConcatFunc {
pub fn new() -> Self {
use DataType::*;
Self {
- signature: Signature::variadic(vec![Utf8], Volatility::Immutable),
+ signature: Signature::variadic(
+ vec![Utf8, Utf8View, LargeUtf8],
+ Volatility::Immutable,
+ ),
}
}
}
@@ -64,13 +66,36 @@ impl ScalarUDFImpl for ConcatFunc {
&self.signature
}
- fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
- Ok(Utf8)
+ fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+ use DataType::*;
+ let mut dt = &Utf8;
+ arg_types.iter().for_each(|data_type| {
+ if data_type == &Utf8View {
+ dt = data_type;
+ }
+ if data_type == &LargeUtf8 && dt != &Utf8View {
+ dt = data_type;
+ }
+ });
+
+ Ok(dt.to_owned())
}
/// Concatenates the text representations of all the arguments. NULL
arguments are ignored.
/// concat('abcde', 2, NULL, 22) = 'abcde222'
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
+ let mut return_datatype = DataType::Utf8;
+ args.iter().for_each(|col| {
+ if col.data_type() == DataType::Utf8View {
+ return_datatype = col.data_type();
+ }
+ if col.data_type() == DataType::LargeUtf8
+ && return_datatype != DataType::Utf8View
+ {
+ return_datatype = col.data_type();
+ }
+ });
+
let array_len = args
.iter()
.filter_map(|x| match x {
@@ -87,7 +112,21 @@ impl ScalarUDFImpl for ConcatFunc {
result.push_str(v);
}
}
- return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))));
+
+ return match return_datatype {
+ DataType::Utf8View => {
+
Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result))))
+ }
+ DataType::Utf8 => {
+ Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))))
+ }
+ DataType::LargeUtf8 => {
+
Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result))))
+ }
+ other => {
+ plan_err!("Concat function does not support datatype of
{other}")
+ }
+ };
}
// Array
@@ -103,28 +142,95 @@ impl ScalarUDFImpl for ConcatFunc {
columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
}
}
+ ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
+ if let Some(s) = maybe_value {
+ data_size += s.len() * len;
+ columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
+ }
+ }
ColumnarValue::Array(array) => {
- let string_array = as_string_array(array)?;
- data_size += string_array.values().len();
- let column = if array.is_nullable() {
- ColumnarValueRef::NullableArray(string_array)
- } else {
- ColumnarValueRef::NonNullableArray(string_array)
+ match array.data_type() {
+ DataType::Utf8 => {
+ let string_array = as_string_array(array)?;
+
+ data_size += string_array.values().len();
+ let column = if array.is_nullable() {
+ ColumnarValueRef::NullableArray(string_array)
+ } else {
+
ColumnarValueRef::NonNullableArray(string_array)
+ };
+ columns.push(column);
+ },
+ DataType::LargeUtf8 => {
+ let string_array = as_largestring_array(array);
+
+ data_size += string_array.values().len();
+ let column = if array.is_nullable() {
+
ColumnarValueRef::NullableLargeStringArray(string_array)
+ } else {
+
ColumnarValueRef::NonNullableLargeStringArray(string_array)
+ };
+ columns.push(column);
+ },
+ DataType::Utf8View => {
+ let string_array = as_string_view_array(array)?;
+
+ data_size += string_array.len();
+ let column = if array.is_nullable() {
+
ColumnarValueRef::NullableStringViewArray(string_array)
+ } else {
+
ColumnarValueRef::NonNullableStringViewArray(string_array)
+ };
+ columns.push(column);
+ },
+ other => {
+ return plan_err!("Input was {other} which is not a
supported datatype for concat function")
+ }
};
- columns.push(column);
}
_ => unreachable!(),
}
}
- let mut builder = StringArrayBuilder::with_capacity(len, data_size);
- for i in 0..len {
- columns
- .iter()
- .for_each(|column| builder.write::<true>(column, i));
- builder.append_offset();
+ match return_datatype {
+ DataType::Utf8 => {
+ let mut builder = StringArrayBuilder::with_capacity(len,
data_size);
+ for i in 0..len {
+ columns
+ .iter()
+ .for_each(|column| builder.write::<true>(column, i));
+ builder.append_offset();
+ }
+
+ let string_array = builder.finish(None);
+ Ok(ColumnarValue::Array(Arc::new(string_array)))
+ }
+ DataType::Utf8View => {
+ let mut builder = StringViewArrayBuilder::with_capacity(len,
data_size);
+ for i in 0..len {
+ columns
+ .iter()
+ .for_each(|column| builder.write::<true>(column, i));
+ builder.append_offset();
+ }
+
+ let string_array = builder.finish();
+ Ok(ColumnarValue::Array(Arc::new(string_array)))
+ }
+ DataType::LargeUtf8 => {
+ let mut builder = LargeStringArrayBuilder::with_capacity(len,
data_size);
+ for i in 0..len {
+ columns
+ .iter()
+ .for_each(|column| builder.write::<true>(column, i));
+ builder.append_offset();
+ }
+
+ let string_array = builder.finish(None);
+ Ok(ColumnarValue::Array(Arc::new(string_array)))
+ }
+ _ => unreachable!(),
}
- Ok(ColumnarValue::Array(Arc::new(builder.finish(None))))
}
/// Simplify the `concat` function by
@@ -151,11 +257,11 @@ pub fn simplify_concat(args: Vec<Expr>) ->
Result<ExprSimplifyResult> {
for arg in args.clone() {
match arg {
// filter out `null` args
- Expr::Literal(ScalarValue::Utf8(None) |
ScalarValue::LargeUtf8(None)) => {}
+ Expr::Literal(ScalarValue::Utf8(None) |
ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {}
// All literals have been converted to Utf8 or LargeUtf8 in
type_coercion.
// Concatenate it with the `contiguous_scalar`.
Expr::Literal(
- ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)),
+ ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) |
ScalarValue::Utf8View(Some(v)),
) => contiguous_scalar += &v,
Expr::Literal(x) => {
return internal_err!(
@@ -195,8 +301,9 @@ pub fn simplify_concat(args: Vec<Expr>) ->
Result<ExprSimplifyResult> {
mod tests {
use super::*;
use crate::utils::test::test_function;
- use arrow::array::Array;
+ use arrow::array::{Array, LargeStringArray, StringViewArray};
use arrow::array::{ArrayRef, StringArray};
+ use DataType::*;
#[test]
fn test_functions() -> Result<()> {
@@ -232,6 +339,31 @@ mod tests {
Utf8,
StringArray
);
+ test_function!(
+ ConcatFunc::new(),
+ &[
+ ColumnarValue::Scalar(ScalarValue::from("aa")),
+ ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
+ ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
+ ColumnarValue::Scalar(ScalarValue::from("cc")),
+ ],
+ Ok(Some("aacc")),
+ &str,
+ Utf8View,
+ StringViewArray
+ );
+ test_function!(
+ ConcatFunc::new(),
+ &[
+ ColumnarValue::Scalar(ScalarValue::from("aa")),
+ ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
+ ColumnarValue::Scalar(ScalarValue::from("cc")),
+ ],
+ Ok(Some("aacc")),
+ &str,
+ LargeUtf8,
+ LargeStringArray
+ );
Ok(())
}
diff --git a/datafusion/sqllogictest/test_files/string_view.slt
b/datafusion/sqllogictest/test_files/string_view.slt
index 83c75b8df3..eb625e530b 100644
--- a/datafusion/sqllogictest/test_files/string_view.slt
+++ b/datafusion/sqllogictest/test_files/string_view.slt
@@ -768,17 +768,26 @@ logical_plan
01)Projection: character_length(test.column1_utf8view) AS l
02)--TableScan: test projection=[column1_utf8view]
-## Ensure no casts for CONCAT
-## TODO https://github.com/apache/datafusion/issues/11836
+## Ensure no casts for CONCAT Utf8View
query TT
EXPLAIN SELECT
concat(column1_utf8view, column2_utf8view) as c
FROM test;
----
logical_plan
-01)Projection: concat(CAST(test.column1_utf8view AS Utf8),
CAST(test.column2_utf8view AS Utf8)) AS c
+01)Projection: concat(test.column1_utf8view, test.column2_utf8view) AS c
02)--TableScan: test projection=[column1_utf8view, column2_utf8view]
+## Ensure no casts for CONCAT LargeUtf8
+query TT
+EXPLAIN SELECT
+ concat(column1_large_utf8, column2_large_utf8) as c
+FROM test;
+----
+logical_plan
+01)Projection: concat(test.column1_large_utf8, test.column2_large_utf8) AS c
+02)--TableScan: test projection=[column1_large_utf8, column2_large_utf8]
+
## Ensure no casts for CONCAT_WS
## TODO https://github.com/apache/datafusion/issues/11837
query TT
@@ -863,6 +872,61 @@ XIANGPENG
RAPHAEL
NULL
+## Should run CONCAT successfully with utf8view
+query T
+SELECT
+ concat(column1_utf8view, column2_utf8view) as c
+FROM test;
+----
+AndrewX
+XiangpengXiangpeng
+RaphaelR
+R
+
+## Should run CONCAT successfully with utf8
+query T
+SELECT
+ concat(column1_utf8, column2_utf8) as c
+FROM test;
+----
+AndrewX
+XiangpengXiangpeng
+RaphaelR
+R
+
+## Should run CONCAT successfully with utf8 and utf8view
+query T
+SELECT
+ concat(column1_utf8view, column2_utf8) as c
+FROM test;
+----
+AndrewX
+XiangpengXiangpeng
+RaphaelR
+R
+
+## Should run CONCAT successfully with utf8 utf8view and largeutf8
+query T
+SELECT
+ concat(column1_utf8view, column2_utf8, column2_large_utf8) as c
+FROM test;
+----
+AndrewXX
+XiangpengXiangpengXiangpeng
+RaphaelRR
+RR
+
+## Should run CONCAT successfully with utf8large
+query T
+SELECT
+ concat(column1_large_utf8, column2_large_utf8) as c
+FROM test;
+----
+AndrewX
+XiangpengXiangpeng
+RaphaelR
+R
+
## Ensure no casts for LPAD
query TT
EXPLAIN SELECT
@@ -1307,3 +1371,4 @@ select column2|| ' ' ||column3 from temp;
----
rust fast
datafusion cool
+
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]