This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 88f654819 Use code points instead of grapheme clusters for string
functions (#3054)
88f654819 is described below
commit 88f65481906a99fba4f5a325756827ba62ff6a2e
Author: Daniël Heres <[email protected]>
AuthorDate: Mon Aug 8 15:06:00 2022 +0200
Use code points instead of grapheme clusters for string functions (#3054)
---
datafusion/core/tests/sql/unicode.rs | 2 +-
datafusion/physical-expr/src/functions.rs | 4 +-
.../physical-expr/src/unicode_expressions.rs | 136 ++++++---------------
integration-tests/sqls/character_length.sql | 17 +++
integration-tests/test_psql_parity.py | 2 +-
5 files changed, 59 insertions(+), 102 deletions(-)
diff --git a/datafusion/core/tests/sql/unicode.rs
b/datafusion/core/tests/sql/unicode.rs
index b9c9cbd1c..fbba4a328 100644
--- a/datafusion/core/tests/sql/unicode.rs
+++ b/datafusion/core/tests/sql/unicode.rs
@@ -61,7 +61,7 @@ async fn test_unicode_expressions() -> Result<()> {
test_expression!("lpad(NULL, 0)", "NULL");
test_expression!("lpad(NULL, 5, 'xy')", "NULL");
test_expression!("reverse('abcde')", "edcba");
- test_expression!("reverse('loẅks')", "skẅol");
+ test_expression!("reverse('loẅks')", "sk̈wol"); // Compatible with
PostgreSQL
test_expression!("reverse(NULL)", "NULL");
test_expression!("right('abcde', -2)", "cde");
test_expression!("right('abcde', -200)", "");
diff --git a/datafusion/physical-expr/src/functions.rs
b/datafusion/physical-expr/src/functions.rs
index 72fcce4f9..913a2c384 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -1645,7 +1645,7 @@ mod tests {
test_function!(
Reverse,
&[lit("loẅks")],
- Ok(Some("skẅol")),
+ Ok(Some("sk̈wol")),
&str,
Utf8,
StringArray
@@ -1654,7 +1654,7 @@ mod tests {
test_function!(
Reverse,
&[lit("loẅks")],
- Ok(Some("skẅol")),
+ Ok(Some("sk̈wol")),
&str,
Utf8,
StringArray
diff --git a/datafusion/physical-expr/src/unicode_expressions.rs
b/datafusion/physical-expr/src/unicode_expressions.rs
index 0730d24f5..5ef7029e7 100644
--- a/datafusion/physical-expr/src/unicode_expressions.rs
+++ b/datafusion/physical-expr/src/unicode_expressions.rs
@@ -27,9 +27,9 @@ use arrow::{
};
use datafusion_common::{DataFusionError, Result};
use hashbrown::HashMap;
-use std::any::type_name;
use std::cmp::Ordering;
use std::sync::Arc;
+use std::{any::type_name, cmp::max};
use unicode_segmentation::UnicodeSegmentation;
macro_rules! downcast_string_arg {
@@ -60,6 +60,7 @@ macro_rules! downcast_arg {
/// Returns number of characters in the string.
/// character_length('josé') = 4
+/// The implementation counts UTF-8 code points to count the number of
characters
pub fn character_length<T: ArrowPrimitiveType>(args: &[ArrayRef]) ->
Result<ArrayRef>
where
T::Native: OffsetSizeTrait,
@@ -75,9 +76,8 @@ where
.iter()
.map(|string| {
string.map(|string: &str| {
- T::Native::from_usize(string.graphemes(true).count()).expect(
- "should not fail as graphemes.count will always return
integer",
- )
+ T::Native::from_usize(string.chars().count())
+ .expect("should not fail as string.chars will always
return integer")
})
})
.collect::<PrimitiveArray<T>>();
@@ -87,6 +87,7 @@ where
/// Returns first n characters in the string, or when n is negative, returns
all but last |n| characters.
/// left('abcde', 2) = 'ab'
+/// The implementation uses UTF-8 code points as characters
pub fn left<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = downcast_string_arg!(args[0], "string", T);
let n_array = downcast_arg!(args[1], "n", Int64Array);
@@ -96,19 +97,16 @@ pub fn left<T: OffsetSizeTrait>(args: &[ArrayRef]) ->
Result<ArrayRef> {
.map(|(string, n)| match (string, n) {
(Some(string), Some(n)) => match n.cmp(&0) {
Ordering::Less => {
- let graphemes = string.graphemes(true);
- let len = graphemes.clone().count() as i64;
- match n.abs().cmp(&len) {
- Ordering::Less => {
- Some(graphemes.take((len + n) as
usize).collect::<String>())
- }
- Ordering::Equal => Some("".to_string()),
- Ordering::Greater => Some("".to_string()),
- }
+ let len = string.chars().count() as i64;
+ Some(if n.abs() < len {
+ string.chars().take((len + n) as
usize).collect::<String>()
+ } else {
+ "".to_string()
+ })
}
Ordering::Equal => Some("".to_string()),
Ordering::Greater => {
- Some(string.graphemes(true).take(n as
usize).collect::<String>())
+ Some(string.chars().take(n as usize).collect::<String>())
}
},
_ => None,
@@ -139,11 +137,8 @@ pub fn lpad<T: OffsetSizeTrait>(args: &[ArrayRef]) ->
Result<ArrayRef> {
if length < graphemes.len() {
Some(graphemes[..length].concat())
} else {
- let mut s = string.to_string();
- s.insert_str(
- 0,
- " ".repeat(length -
graphemes.len()).as_str(),
- );
+ let mut s: String = " ".repeat(length -
graphemes.len());
+ s.push_str(string);
Some(s)
}
}
@@ -209,14 +204,13 @@ pub fn lpad<T: OffsetSizeTrait>(args: &[ArrayRef]) ->
Result<ArrayRef> {
/// Reverses the order of the characters in the string.
/// reverse('abcde') = 'edcba'
+/// The implementation uses UTF-8 code points as characters
pub fn reverse<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = downcast_string_arg!(args[0], "string", T);
let result = string_array
.iter()
- .map(|string| {
- string.map(|string: &str|
string.graphemes(true).rev().collect::<String>())
- })
+ .map(|string| string.map(|string: &str|
string.chars().rev().collect::<String>()))
.collect::<GenericStringArray<T>>();
Ok(Arc::new(result) as ArrayRef)
@@ -224,6 +218,7 @@ pub fn reverse<T: OffsetSizeTrait>(args: &[ArrayRef]) ->
Result<ArrayRef> {
/// Returns last n characters in the string, or when n is negative, returns
all but first |n| characters.
/// right('abcde', 2) = 'de'
+/// The implementation uses UTF-8 code points as characters
pub fn right<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = downcast_string_arg!(args[0], "string", T);
let n_array = downcast_arg!(args[1], "n", Int64Array);
@@ -233,33 +228,17 @@ pub fn right<T: OffsetSizeTrait>(args: &[ArrayRef]) ->
Result<ArrayRef> {
.zip(n_array.iter())
.map(|(string, n)| match (string, n) {
(Some(string), Some(n)) => match n.cmp(&0) {
- Ordering::Less => {
- let graphemes = string.graphemes(true).rev();
- let len = graphemes.clone().count() as i64;
- match n.abs().cmp(&len) {
- Ordering::Less => Some(
- graphemes
- .take((len + n) as usize)
- .collect::<Vec<&str>>()
- .iter()
- .rev()
- .copied()
- .collect::<String>(),
- ),
- Ordering::Equal => Some("".to_string()),
- Ordering::Greater => Some("".to_string()),
- }
- }
+ Ordering::Less => Some(
+ string
+ .chars()
+ .skip(n.unsigned_abs() as usize)
+ .collect::<String>(),
+ ),
Ordering::Equal => Some("".to_string()),
Ordering::Greater => Some(
string
- .graphemes(true)
- .rev()
- .take(n as usize)
- .collect::<Vec<&str>>()
- .iter()
- .rev()
- .copied()
+ .chars()
+ .skip(max(string.chars().count() as i64 - n, 0) as
usize)
.collect::<String>(),
),
},
@@ -349,6 +328,7 @@ pub fn rpad<T: OffsetSizeTrait>(args: &[ArrayRef]) ->
Result<ArrayRef> {
/// Returns starting index of specified substring within string, or zero if
it's not present. (Same as position(substring in string), but note the reversed
argument order.)
/// strpos('high', 'ig') = 2
+/// The implementation uses UTF-8 code points as characters
pub fn strpos<T: ArrowPrimitiveType>(args: &[ArrayRef]) -> Result<ArrayRef>
where
T::Native: OffsetSizeTrait,
@@ -374,28 +354,13 @@ where
.zip(substring_array.iter())
.map(|(string, substring)| match (string, substring) {
(Some(string), Some(substring)) => {
- // the rfind method returns the byte index of the substring
which may or may not be the same as the character index due to UTF8 encoding
- // this method first finds the matching byte using rfind
- // then maps that to the character index by matching on the
grapheme_index of the byte_index
- Some(
-
T::Native::from_usize(string.to_string().rfind(substring).map_or(
- 0,
- |byte_offset| {
- string
- .grapheme_indices(true)
- .collect::<Vec<(usize, &str)>>()
- .iter()
- .enumerate()
- .filter(|(_, (offset, _))| *offset ==
byte_offset)
- .map(|(index, _)| index)
- .collect::<Vec<usize>>()
- .first()
- .expect("should not fail as grapheme_indices
and byte offsets are tightly coupled")
- .to_owned()
- + 1
- },
- ))
- .expect("should not fail due to map_or default value")
+ // the find method returns the byte index of the substring
+ // Next, we count the number of the chars until that byte
+ T::Native::from_usize(
+ string
+ .find(substring)
+ .map(|x| string[..x].chars().count() + 1)
+ .unwrap_or(0),
)
}
_ => None,
@@ -408,6 +373,7 @@ where
/// Extracts the substring of string starting at the start'th character, and
extending for count characters if that is specified. (Same as substring(string
from start for count).)
/// substr('alphabet', 3) = 'phabet'
/// substr('alphabet', 3, 2) = 'ph'
+/// The implementation uses UTF-8 code points as characters
pub fn substr<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
2 => {
@@ -422,13 +388,7 @@ pub fn substr<T: OffsetSizeTrait>(args: &[ArrayRef]) ->
Result<ArrayRef> {
if start <= 0 {
Some(string.to_string())
} else {
- let graphemes =
string.graphemes(true).collect::<Vec<&str>>();
- let start_pos = start as usize - 1;
- if graphemes.len() < start_pos {
- Some("".to_string())
- } else {
- Some(graphemes[start_pos..].concat())
- }
+ Some(string.chars().skip(start as usize -
1).collect())
}
}
_ => None,
@@ -455,29 +415,9 @@ pub fn substr<T: OffsetSizeTrait>(args: &[ArrayRef]) ->
Result<ArrayRef> {
count
)))
} else {
- let graphemes =
string.graphemes(true).collect::<Vec<&str>>();
- let (start_pos, end_pos) = if start <= 0 {
- let end_pos = start + count - 1;
- (
- 0_usize,
- if end_pos < 0 {
- // we use 0 as workaround for usize to
return empty string
- 0
- } else {
- end_pos as usize
- },
- )
- } else {
- ((start - 1) as usize, (start + count - 1) as
usize)
- };
-
- if end_pos == 0 || graphemes.len() < start_pos {
- Ok(Some("".to_string()))
- } else if graphemes.len() < end_pos {
- Ok(Some(graphemes[start_pos..].concat()))
- } else {
-
Ok(Some(graphemes[start_pos..end_pos].concat()))
- }
+ let skip = max(0, start - 1);
+ let count = max(0, count + (if start < 1 {start -
1} else {0}));
+ Ok(Some(string.chars().skip(skip as
usize).take(count as usize).collect::<String>()))
}
}
_ => Ok(None),
diff --git a/integration-tests/sqls/character_length.sql
b/integration-tests/sqls/character_length.sql
new file mode 100644
index 000000000..18ef487f1
--- /dev/null
+++ b/integration-tests/sqls/character_length.sql
@@ -0,0 +1,17 @@
+-- Licensed to the Apache Software Foundation (ASF) under one
+-- or more contributor license agreements. See the NOTICE file
+-- distributed with this work for additional information
+-- regarding copyright ownership. The ASF licenses this file
+-- to you under the Apache License, Version 2.0 (the
+-- "License"); you may not use this file except in compliance
+-- with the License. You may obtain a copy of the License at
+
+-- http://www.apache.org/licenses/LICENSE-2.0
+
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+
+select length('ä');
diff --git a/integration-tests/test_psql_parity.py
b/integration-tests/test_psql_parity.py
index 506100bbc..3629645a6 100644
--- a/integration-tests/test_psql_parity.py
+++ b/integration-tests/test_psql_parity.py
@@ -82,7 +82,7 @@ test_files = set(root.glob("*.sql"))
class TestPsqlParity:
def test_tests_count(self):
- assert len(test_files) == 21, "tests are missed"
+ assert len(test_files) == 22, "tests are missed"
@pytest.mark.parametrize("fname", test_files)
def test_sql_file(self, fname):