This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 88f654819 Use code points instead of grapheme clusters for string 
functions (#3054)
88f654819 is described below

commit 88f65481906a99fba4f5a325756827ba62ff6a2e
Author: Daniël Heres <[email protected]>
AuthorDate: Mon Aug 8 15:06:00 2022 +0200

    Use code points instead of grapheme clusters for string functions (#3054)
---
 datafusion/core/tests/sql/unicode.rs               |   2 +-
 datafusion/physical-expr/src/functions.rs          |   4 +-
 .../physical-expr/src/unicode_expressions.rs       | 136 ++++++---------------
 integration-tests/sqls/character_length.sql        |  17 +++
 integration-tests/test_psql_parity.py              |   2 +-
 5 files changed, 59 insertions(+), 102 deletions(-)

diff --git a/datafusion/core/tests/sql/unicode.rs 
b/datafusion/core/tests/sql/unicode.rs
index b9c9cbd1c..fbba4a328 100644
--- a/datafusion/core/tests/sql/unicode.rs
+++ b/datafusion/core/tests/sql/unicode.rs
@@ -61,7 +61,7 @@ async fn test_unicode_expressions() -> Result<()> {
     test_expression!("lpad(NULL, 0)", "NULL");
     test_expression!("lpad(NULL, 5, 'xy')", "NULL");
     test_expression!("reverse('abcde')", "edcba");
-    test_expression!("reverse('loẅks')", "skẅol");
+    test_expression!("reverse('loẅks')", "sk̈wol"); // Compatible with 
PostgreSQL
     test_expression!("reverse(NULL)", "NULL");
     test_expression!("right('abcde', -2)", "cde");
     test_expression!("right('abcde', -200)", "");
diff --git a/datafusion/physical-expr/src/functions.rs 
b/datafusion/physical-expr/src/functions.rs
index 72fcce4f9..913a2c384 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -1645,7 +1645,7 @@ mod tests {
         test_function!(
             Reverse,
             &[lit("loẅks")],
-            Ok(Some("skẅol")),
+            Ok(Some("sk̈wol")),
             &str,
             Utf8,
             StringArray
@@ -1654,7 +1654,7 @@ mod tests {
         test_function!(
             Reverse,
             &[lit("loẅks")],
-            Ok(Some("skẅol")),
+            Ok(Some("sk̈wol")),
             &str,
             Utf8,
             StringArray
diff --git a/datafusion/physical-expr/src/unicode_expressions.rs 
b/datafusion/physical-expr/src/unicode_expressions.rs
index 0730d24f5..5ef7029e7 100644
--- a/datafusion/physical-expr/src/unicode_expressions.rs
+++ b/datafusion/physical-expr/src/unicode_expressions.rs
@@ -27,9 +27,9 @@ use arrow::{
 };
 use datafusion_common::{DataFusionError, Result};
 use hashbrown::HashMap;
-use std::any::type_name;
 use std::cmp::Ordering;
 use std::sync::Arc;
+use std::{any::type_name, cmp::max};
 use unicode_segmentation::UnicodeSegmentation;
 
 macro_rules! downcast_string_arg {
@@ -60,6 +60,7 @@ macro_rules! downcast_arg {
 
 /// Returns number of characters in the string.
 /// character_length('josé') = 4
+/// The implementation counts UTF-8 code points to count the number of 
characters
 pub fn character_length<T: ArrowPrimitiveType>(args: &[ArrayRef]) -> 
Result<ArrayRef>
 where
     T::Native: OffsetSizeTrait,
@@ -75,9 +76,8 @@ where
         .iter()
         .map(|string| {
             string.map(|string: &str| {
-                T::Native::from_usize(string.graphemes(true).count()).expect(
-                    "should not fail as graphemes.count will always return 
integer",
-                )
+                T::Native::from_usize(string.chars().count())
+                    .expect("should not fail as string.chars will always 
return integer")
             })
         })
         .collect::<PrimitiveArray<T>>();
@@ -87,6 +87,7 @@ where
 
 /// Returns first n characters in the string, or when n is negative, returns 
all but last |n| characters.
 /// left('abcde', 2) = 'ab'
+/// The implementation uses UTF-8 code points as characters
 pub fn left<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
     let string_array = downcast_string_arg!(args[0], "string", T);
     let n_array = downcast_arg!(args[1], "n", Int64Array);
@@ -96,19 +97,16 @@ pub fn left<T: OffsetSizeTrait>(args: &[ArrayRef]) -> 
Result<ArrayRef> {
         .map(|(string, n)| match (string, n) {
             (Some(string), Some(n)) => match n.cmp(&0) {
                 Ordering::Less => {
-                    let graphemes = string.graphemes(true);
-                    let len = graphemes.clone().count() as i64;
-                    match n.abs().cmp(&len) {
-                        Ordering::Less => {
-                            Some(graphemes.take((len + n) as 
usize).collect::<String>())
-                        }
-                        Ordering::Equal => Some("".to_string()),
-                        Ordering::Greater => Some("".to_string()),
-                    }
+                    let len = string.chars().count() as i64;
+                    Some(if n.abs() < len {
+                        string.chars().take((len + n) as 
usize).collect::<String>()
+                    } else {
+                        "".to_string()
+                    })
                 }
                 Ordering::Equal => Some("".to_string()),
                 Ordering::Greater => {
-                    Some(string.graphemes(true).take(n as 
usize).collect::<String>())
+                    Some(string.chars().take(n as usize).collect::<String>())
                 }
             },
             _ => None,
@@ -139,11 +137,8 @@ pub fn lpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> 
Result<ArrayRef> {
                             if length < graphemes.len() {
                                 Some(graphemes[..length].concat())
                             } else {
-                                let mut s = string.to_string();
-                                s.insert_str(
-                                    0,
-                                    " ".repeat(length - 
graphemes.len()).as_str(),
-                                );
+                                let mut s: String = " ".repeat(length - 
graphemes.len());
+                                s.push_str(string);
                                 Some(s)
                             }
                         }
@@ -209,14 +204,13 @@ pub fn lpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> 
Result<ArrayRef> {
 
 /// Reverses the order of the characters in the string.
 /// reverse('abcde') = 'edcba'
+/// The implementation uses UTF-8 code points as characters
 pub fn reverse<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
     let string_array = downcast_string_arg!(args[0], "string", T);
 
     let result = string_array
         .iter()
-        .map(|string| {
-            string.map(|string: &str| 
string.graphemes(true).rev().collect::<String>())
-        })
+        .map(|string| string.map(|string: &str| 
string.chars().rev().collect::<String>()))
         .collect::<GenericStringArray<T>>();
 
     Ok(Arc::new(result) as ArrayRef)
@@ -224,6 +218,7 @@ pub fn reverse<T: OffsetSizeTrait>(args: &[ArrayRef]) -> 
Result<ArrayRef> {
 
 /// Returns last n characters in the string, or when n is negative, returns 
all but first |n| characters.
 /// right('abcde', 2) = 'de'
+/// The implementation uses UTF-8 code points as characters
 pub fn right<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
     let string_array = downcast_string_arg!(args[0], "string", T);
     let n_array = downcast_arg!(args[1], "n", Int64Array);
@@ -233,33 +228,17 @@ pub fn right<T: OffsetSizeTrait>(args: &[ArrayRef]) -> 
Result<ArrayRef> {
         .zip(n_array.iter())
         .map(|(string, n)| match (string, n) {
             (Some(string), Some(n)) => match n.cmp(&0) {
-                Ordering::Less => {
-                    let graphemes = string.graphemes(true).rev();
-                    let len = graphemes.clone().count() as i64;
-                    match n.abs().cmp(&len) {
-                        Ordering::Less => Some(
-                            graphemes
-                                .take((len + n) as usize)
-                                .collect::<Vec<&str>>()
-                                .iter()
-                                .rev()
-                                .copied()
-                                .collect::<String>(),
-                        ),
-                        Ordering::Equal => Some("".to_string()),
-                        Ordering::Greater => Some("".to_string()),
-                    }
-                }
+                Ordering::Less => Some(
+                    string
+                        .chars()
+                        .skip(n.unsigned_abs() as usize)
+                        .collect::<String>(),
+                ),
                 Ordering::Equal => Some("".to_string()),
                 Ordering::Greater => Some(
                     string
-                        .graphemes(true)
-                        .rev()
-                        .take(n as usize)
-                        .collect::<Vec<&str>>()
-                        .iter()
-                        .rev()
-                        .copied()
+                        .chars()
+                        .skip(max(string.chars().count() as i64 - n, 0) as 
usize)
                         .collect::<String>(),
                 ),
             },
@@ -349,6 +328,7 @@ pub fn rpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> 
Result<ArrayRef> {
 
 /// Returns starting index of specified substring within string, or zero if 
it's not present. (Same as position(substring in string), but note the reversed 
argument order.)
 /// strpos('high', 'ig') = 2
+/// The implementation uses UTF-8 code points as characters
 pub fn strpos<T: ArrowPrimitiveType>(args: &[ArrayRef]) -> Result<ArrayRef>
 where
     T::Native: OffsetSizeTrait,
@@ -374,28 +354,13 @@ where
         .zip(substring_array.iter())
         .map(|(string, substring)| match (string, substring) {
             (Some(string), Some(substring)) => {
-                // the rfind method returns the byte index of the substring 
which may or may not be the same as the character index due to UTF8 encoding
-                // this method first finds the matching byte using rfind
-                // then maps that to the character index by matching on the 
grapheme_index of the byte_index
-                Some(
-                    
T::Native::from_usize(string.to_string().rfind(substring).map_or(
-                        0,
-                        |byte_offset| {
-                            string
-                                .grapheme_indices(true)
-                                .collect::<Vec<(usize, &str)>>()
-                                .iter()
-                                .enumerate()
-                                .filter(|(_, (offset, _))| *offset == 
byte_offset)
-                                .map(|(index, _)| index)
-                                .collect::<Vec<usize>>()
-                                .first()
-                                .expect("should not fail as grapheme_indices 
and byte offsets are tightly coupled")
-                                .to_owned()
-                                + 1
-                        },
-                    ))
-                    .expect("should not fail due to map_or default value")
+                // the find method returns the byte index of the substring
+                // Next, we count the number of the chars until that byte
+                T::Native::from_usize(
+                    string
+                        .find(substring)
+                        .map(|x| string[..x].chars().count() + 1)
+                        .unwrap_or(0),
                 )
             }
             _ => None,
@@ -408,6 +373,7 @@ where
 /// Extracts the substring of string starting at the start'th character, and 
extending for count characters if that is specified. (Same as substring(string 
from start for count).)
 /// substr('alphabet', 3) = 'phabet'
 /// substr('alphabet', 3, 2) = 'ph'
+/// The implementation uses UTF-8 code points as characters
 pub fn substr<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
     match args.len() {
         2 => {
@@ -422,13 +388,7 @@ pub fn substr<T: OffsetSizeTrait>(args: &[ArrayRef]) -> 
Result<ArrayRef> {
                         if start <= 0 {
                             Some(string.to_string())
                         } else {
-                            let graphemes = 
string.graphemes(true).collect::<Vec<&str>>();
-                            let start_pos = start as usize - 1;
-                            if graphemes.len() < start_pos {
-                                Some("".to_string())
-                            } else {
-                                Some(graphemes[start_pos..].concat())
-                            }
+                            Some(string.chars().skip(start as usize - 
1).collect())
                         }
                     }
                     _ => None,
@@ -455,29 +415,9 @@ pub fn substr<T: OffsetSizeTrait>(args: &[ArrayRef]) -> 
Result<ArrayRef> {
                                 count
                             )))
                         } else {
-                            let graphemes = 
string.graphemes(true).collect::<Vec<&str>>();
-                            let (start_pos, end_pos) = if start <= 0 {
-                                let end_pos = start + count - 1;
-                                (
-                                    0_usize,
-                                    if end_pos < 0 {
-                                        // we use 0 as workaround for usize to 
return empty string
-                                        0
-                                    } else {
-                                        end_pos as usize
-                                    },
-                                )
-                            } else {
-                                ((start - 1) as usize, (start + count - 1) as 
usize)
-                            };
-
-                            if end_pos == 0 || graphemes.len() < start_pos {
-                                Ok(Some("".to_string()))
-                            } else if graphemes.len() < end_pos {
-                                Ok(Some(graphemes[start_pos..].concat()))
-                            } else {
-                                
Ok(Some(graphemes[start_pos..end_pos].concat()))
-                            }
+                            let skip = max(0, start - 1);
+                            let count = max(0, count + (if start < 1 {start - 
1} else {0}));
+                            Ok(Some(string.chars().skip(skip as 
usize).take(count as usize).collect::<String>()))
                         }
                     }
                     _ => Ok(None),
diff --git a/integration-tests/sqls/character_length.sql 
b/integration-tests/sqls/character_length.sql
new file mode 100644
index 000000000..18ef487f1
--- /dev/null
+++ b/integration-tests/sqls/character_length.sql
@@ -0,0 +1,17 @@
+-- Licensed to the Apache Software Foundation (ASF) under one
+-- or more contributor license agreements.  See the NOTICE file
+-- distributed with this work for additional information
+-- regarding copyright ownership.  The ASF licenses this file
+-- to you under the Apache License, Version 2.0 (the
+-- "License"); you may not use this file except in compliance
+-- with the License.  You may obtain a copy of the License at
+
+-- http://www.apache.org/licenses/LICENSE-2.0
+
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+
+select length('ä');
diff --git a/integration-tests/test_psql_parity.py 
b/integration-tests/test_psql_parity.py
index 506100bbc..3629645a6 100644
--- a/integration-tests/test_psql_parity.py
+++ b/integration-tests/test_psql_parity.py
@@ -82,7 +82,7 @@ test_files = set(root.glob("*.sql"))
 
 class TestPsqlParity:
     def test_tests_count(self):
-        assert len(test_files) == 21, "tests are missed"
+        assert len(test_files) == 22, "tests are missed"
 
     @pytest.mark.parametrize("fname", test_files)
     def test_sql_file(self, fname):

Reply via email to