This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new ec7c9ab0af chore: Add `substr()` benchmarks, refactor (#20803)
ec7c9ab0af is described below
commit ec7c9ab0af56917470c83694e9fb0efbd76c7d91
Author: Neil Conway <[email protected]>
AuthorDate: Mon Mar 23 12:04:14 2026 -0400
chore: Add `substr()` benchmarks, refactor (#20803)
## Which issue does this PR close?
N/A
## Rationale for this change
I'd like to optimize `substr` for scalar `start`/`count` inputs, but the
code would benefit from some refactoring and cleanup first. I also added
benchmarks for `substr` with scalar args.
## What changes are included in this PR?
- Refactor `string_view_substr` and `string_substr` to use a single loop
- Change `get_true_start_end` to validate its own inputs, cleanup UTF8
path
- Add benchmark cases for scalar `start` and `count` arguments
- Improve docs
## Are these changes tested?
Yes.
## Are there any user-facing changes?
No, other than an error message wording change.
## AI usage
Multiple AI tools were used to iterate on this PR. I have reviewed and
understand the resulting code.
---
datafusion/functions/benches/substr.rs | 193 ++++++++-----
datafusion/functions/src/unicode/substr.rs | 309 +++++++--------------
datafusion/spark/src/function/string/substring.rs | 8 +-
.../test_files/string/string_literal.slt | 4 +-
docs/source/user-guide/sql/scalar_functions.md | 2 +-
5 files changed, 240 insertions(+), 276 deletions(-)
diff --git a/datafusion/functions/benches/substr.rs
b/datafusion/functions/benches/substr.rs
index 37a1e178f5..4ea90c3708 100644
--- a/datafusion/functions/benches/substr.rs
+++ b/datafusion/functions/benches/substr.rs
@@ -21,47 +21,42 @@ use arrow::util::bench_util::{
create_string_array_with_len, create_string_view_array_with_len,
};
use criterion::{Criterion, SamplingMode, criterion_group, criterion_main};
-use datafusion_common::DataFusionError;
use datafusion_common::config::ConfigOptions;
+use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
use datafusion_functions::unicode;
use std::hint::black_box;
use std::sync::Arc;
+fn make_i64_arg(value: i64, size: usize, as_scalar: bool) -> ColumnarValue {
+ if as_scalar {
+ ColumnarValue::Scalar(ScalarValue::from(value))
+ } else {
+ ColumnarValue::Array(Arc::new(Int64Array::from(vec![value; size])))
+ }
+}
+
fn create_args_without_count<O: OffsetSizeTrait>(
size: usize,
str_len: usize,
start_half_way: bool,
force_view_types: bool,
+ scalar_start: bool,
) -> Vec<ColumnarValue> {
- let start_array = Arc::new(Int64Array::from(
- (0..size)
- .map(|_| {
- if start_half_way {
- (str_len / 2) as i64
- } else {
- 1i64
- }
- })
- .collect::<Vec<_>>(),
- ));
-
- if force_view_types {
- let string_array =
- Arc::new(create_string_view_array_with_len(size, 0.1, str_len,
false));
- vec![
- ColumnarValue::Array(string_array),
- ColumnarValue::Array(start_array),
- ]
+ let start_val = if start_half_way {
+ (str_len / 2) as i64
} else {
- let string_array =
- Arc::new(create_string_array_with_len::<O>(size, 0.1, str_len));
+ 1i64
+ };
+ let start = make_i64_arg(start_val, size, scalar_start);
- vec![
- ColumnarValue::Array(string_array),
- ColumnarValue::Array(Arc::clone(&start_array) as ArrayRef),
- ]
- }
+ let string_array: ArrayRef = if force_view_types {
+ Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false))
+ } else {
+ Arc::new(create_string_array_with_len::<O>(size, 0.1, str_len))
+ };
+
+ vec![ColumnarValue::Array(string_array), start]
}
fn create_args_with_count<O: OffsetSizeTrait>(
@@ -69,32 +64,19 @@ fn create_args_with_count<O: OffsetSizeTrait>(
str_len: usize,
count_max: usize,
force_view_types: bool,
+ scalar_args: bool,
) -> Vec<ColumnarValue> {
- let start_array =
- Arc::new(Int64Array::from((0..size).map(|_| 1).collect::<Vec<_>>()));
let count = count_max.min(str_len) as i64;
- let count_array = Arc::new(Int64Array::from(
- (0..size).map(|_| count).collect::<Vec<_>>(),
- ));
-
- if force_view_types {
- let string_array =
- Arc::new(create_string_view_array_with_len(size, 0.1, str_len,
false));
- vec![
- ColumnarValue::Array(string_array),
- ColumnarValue::Array(start_array),
- ColumnarValue::Array(count_array),
- ]
+ let start = make_i64_arg(1i64, size, scalar_args);
+ let count = make_i64_arg(count, size, scalar_args);
+
+ let string_array: ArrayRef = if force_view_types {
+ Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false))
} else {
- let string_array =
- Arc::new(create_string_array_with_len::<O>(size, 0.1, str_len));
-
- vec![
- ColumnarValue::Array(string_array),
- ColumnarValue::Array(Arc::clone(&start_array) as ArrayRef),
- ColumnarValue::Array(Arc::clone(&count_array) as ArrayRef),
- ]
- }
+ Arc::new(create_string_array_with_len::<O>(size, 0.1, str_len))
+ };
+
+ vec![ColumnarValue::Array(string_array), start, count]
}
#[expect(clippy::needless_pass_by_value)]
@@ -122,22 +104,22 @@ fn criterion_benchmark(c: &mut Criterion) {
for size in [1024, 4096] {
// string_len = 12, substring_len=6 (see `create_args_without_count`)
let len = 12;
- let mut group = c.benchmark_group("SHORTER THAN 12");
+ let mut group = c.benchmark_group("substr, no count, short strings");
group.sampling_mode(SamplingMode::Flat);
group.sample_size(10);
- let args = create_args_without_count::<i32>(size, len, true, true);
+ let args = create_args_without_count::<i32>(size, len, true, true,
false);
group.bench_function(
format!("substr_string_view [size={size}, strlen={len}]"),
|b| b.iter(|| black_box(invoke_substr_with_args(args.clone(),
size))),
);
- let args = create_args_without_count::<i32>(size, len, false, false);
+ let args = create_args_without_count::<i32>(size, len, false, false,
false);
group.bench_function(format!("substr_string [size={size},
strlen={len}]"), |b| {
b.iter(|| black_box(invoke_substr_with_args(args.clone(), size)))
});
- let args = create_args_without_count::<i64>(size, len, true, false);
+ let args = create_args_without_count::<i64>(size, len, true, false,
false);
group.bench_function(
format!("substr_large_string [size={size}, strlen={len}]"),
|b| b.iter(|| black_box(invoke_substr_with_args(args.clone(),
size))),
@@ -148,23 +130,23 @@ fn criterion_benchmark(c: &mut Criterion) {
// string_len = 128, start=1, count=64, substring_len=64
let len = 128;
let count = 64;
- let mut group = c.benchmark_group("LONGER THAN 12");
+ let mut group = c.benchmark_group("substr, with count, long strings");
group.sampling_mode(SamplingMode::Flat);
group.sample_size(10);
- let args = create_args_with_count::<i32>(size, len, count, true);
+ let args = create_args_with_count::<i32>(size, len, count, true,
false);
group.bench_function(
format!("substr_string_view [size={size}, count={count},
strlen={len}]",),
|b| b.iter(|| black_box(invoke_substr_with_args(args.clone(),
size))),
);
- let args = create_args_with_count::<i32>(size, len, count, false);
+ let args = create_args_with_count::<i32>(size, len, count, false,
false);
group.bench_function(
format!("substr_string [size={size}, count={count},
strlen={len}]",),
|b| b.iter(|| black_box(invoke_substr_with_args(args.clone(),
size))),
);
- let args = create_args_with_count::<i64>(size, len, count, false);
+ let args = create_args_with_count::<i64>(size, len, count, false,
false);
group.bench_function(
format!("substr_large_string [size={size}, count={count},
strlen={len}]",),
|b| b.iter(|| black_box(invoke_substr_with_args(args.clone(),
size))),
@@ -175,29 +157,116 @@ fn criterion_benchmark(c: &mut Criterion) {
// string_len = 128, start=1, count=6, substring_len=6
let len = 128;
let count = 6;
- let mut group = c.benchmark_group("SRC_LEN > 12, SUB_LEN < 12");
+ let mut group = c.benchmark_group("substr, short count, long strings");
group.sampling_mode(SamplingMode::Flat);
group.sample_size(10);
- let args = create_args_with_count::<i32>(size, len, count, true);
+ let args = create_args_with_count::<i32>(size, len, count, true,
false);
group.bench_function(
format!("substr_string_view [size={size}, count={count},
strlen={len}]",),
|b| b.iter(|| black_box(invoke_substr_with_args(args.clone(),
size))),
);
- let args = create_args_with_count::<i32>(size, len, count, false);
+ let args = create_args_with_count::<i32>(size, len, count, false,
false);
group.bench_function(
format!("substr_string [size={size}, count={count},
strlen={len}]",),
|b| b.iter(|| black_box(invoke_substr_with_args(args.clone(),
size))),
);
- let args = create_args_with_count::<i64>(size, len, count, false);
+ let args = create_args_with_count::<i64>(size, len, count, false,
false);
group.bench_function(
format!("substr_large_string [size={size}, count={count},
strlen={len}]",),
|b| b.iter(|| black_box(invoke_substr_with_args(args.clone(),
size))),
);
group.finish();
+
+ // Scalar start, no count, short strings
+ let len = 12;
+ let mut group =
+ c.benchmark_group("substr, scalar start, no count, short strings");
+ group.sampling_mode(SamplingMode::Flat);
+ group.sample_size(10);
+
+ let args = create_args_without_count::<i32>(size, len, true, true,
true);
+ group.bench_function(
+ format!("substr_string_view [size={size}, strlen={len}]"),
+ |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(),
size))),
+ );
+
+ let args = create_args_without_count::<i32>(size, len, false, false,
true);
+ group.bench_function(format!("substr_string [size={size},
strlen={len}]"), |b| {
+ b.iter(|| black_box(invoke_substr_with_args(args.clone(), size)))
+ });
+
+ group.finish();
+
+ // Scalar start, no count, long strings
+ let len = 128;
+ let mut group = c.benchmark_group("substr, scalar start, no count,
long strings");
+ group.sampling_mode(SamplingMode::Flat);
+ group.sample_size(10);
+
+ let args = create_args_without_count::<i32>(size, len, true, true,
true);
+ group.bench_function(
+ format!("substr_string_view [size={size}, strlen={len}]"),
+ |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(),
size))),
+ );
+
+ let args = create_args_without_count::<i32>(size, len, false, false,
true);
+ group.bench_function(format!("substr_string [size={size},
strlen={len}]"), |b| {
+ b.iter(|| black_box(invoke_substr_with_args(args.clone(), size)))
+ });
+
+ group.finish();
+
+ // Scalar start and count, short strings
+ let len = 12;
+ let count = 6;
+ let mut group = c.benchmark_group("substr, scalar args, short
strings");
+ group.sampling_mode(SamplingMode::Flat);
+ group.sample_size(10);
+
+ let args = create_args_with_count::<i32>(size, len, count, true, true);
+ group.bench_function(
+ format!("substr_string_view [size={size}, count={count},
strlen={len}]"),
+ |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(),
size))),
+ );
+
+ let args = create_args_with_count::<i32>(size, len, count, false,
true);
+ group.bench_function(
+ format!("substr_string [size={size}, count={count},
strlen={len}]"),
+ |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(),
size))),
+ );
+
+ group.finish();
+
+ // Scalar start and count, long strings
+ let len = 128;
+ let count = 64;
+ let mut group = c.benchmark_group("substr, scalar args, long strings");
+ group.sampling_mode(SamplingMode::Flat);
+ group.sample_size(10);
+
+ let args = create_args_with_count::<i32>(size, len, count, true, true);
+ group.bench_function(
+ format!("substr_string_view [size={size}, count={count},
strlen={len}]"),
+ |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(),
size))),
+ );
+
+ let args = create_args_with_count::<i32>(size, len, count, false,
true);
+ group.bench_function(
+ format!("substr_string [size={size}, count={count},
strlen={len}]"),
+ |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(),
size))),
+ );
+
+ let args = create_args_with_count::<i64>(size, len, count, false,
true);
+ group.bench_function(
+ format!("substr_large_string [size={size}, count={count},
strlen={len}]"),
+ |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(),
size))),
+ );
+
+ group.finish();
}
}
diff --git a/datafusion/functions/src/unicode/substr.rs
b/datafusion/functions/src/unicode/substr.rs
index 737730cf88..53bcdbf2d9 100644
--- a/datafusion/functions/src/unicode/substr.rs
+++ b/datafusion/functions/src/unicode/substr.rs
@@ -21,7 +21,7 @@ use std::sync::Arc;
use crate::strings::make_and_append_view;
use crate::utils::make_scalar_function;
use arrow::array::{
- Array, ArrayIter, ArrayRef, AsArray, Int64Array, NullBufferBuilder,
StringArrayType,
+ Array, ArrayRef, AsArray, Int64Array, NullBufferBuilder, StringArrayType,
StringViewArray, StringViewBuilder,
};
use arrow::buffer::ScalarBuffer;
@@ -53,7 +53,7 @@ use datafusion_macros::user_doc;
standard_argument(name = "str", prefix = "String"),
argument(
name = "start_pos",
- description = "Character position to start the substring at. The first
character in the string has a position of 1."
+ description = "Character position to start the substring at. The first
character in the string has a position of 1. If the start position is less than
1, it is treated as if it is before the start of the string and the (absolute)
number of characters before position 1 is subtracted from `length` (if given).
For example, `substr('abc', -3, 6)` returns `'ab'`."
),
argument(
name = "length",
@@ -134,10 +134,7 @@ impl ScalarUDFImpl for SubstrFunc {
}
}
-/// Extracts the substring of string starting at the start'th character, and
extending for count characters if that is specified. (Same as substring(string
from start for count).)
-/// substr('alphabet', 3) = 'phabet'
-/// substr('alphabet', 3, 2) = 'ph'
-/// The implementation uses UTF-8 code points as characters
+/// Dispatches `substr` to the appropriate string array implementation.
fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::Utf8 => {
@@ -159,70 +156,74 @@ fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}
-// Convert the given `start` and `count` to valid byte indices within `input`
string
-//
-// Input `start` and `count` are equivalent to PostgreSQL's `substr(s, start,
count)`
-// `start` is 1-based, if `count` is not provided count to the end of the
string
-// Input indices are character-based, and return values are byte indices
-// The input bounds can be outside string bounds, this function will return
-// the intersection between input bounds and valid string bounds
-// `input_ascii_only` is used to optimize this function if `input` is
ASCII-only
-//
-// * Example
-// 'Hi🌏' in-mem (`[]` for one char, `x` for one byte): [x][x][xxxx]
-// `get_true_start_end('Hi🌏', 1, None) -> (0, 6)`
-// `get_true_start_end('Hi🌏', 1, 1) -> (0, 1)`
-// `get_true_start_end('Hi🌏', -10, 2) -> (0, 0)`
+/// Convert the given `start` and `count` to valid byte indices within `input`
string.
+///
+/// Input `start` and `count` are equivalent to PostgreSQL's `substr(s, start,
count)`.
+/// `start` is 1-based; if `count` is not provided, returns indices to the end
of the string.
+/// Input indices are character-based, and return values are byte indices.
+/// The input bounds can be outside string bounds; this function will return
+/// the intersection between input bounds and valid string bounds.
+/// `is_input_ascii_only` is used to optimize this function if `input` is
ASCII-only.
+///
+/// # Example
+/// ```text
+/// 'Hi🌏' in-mem (`[]` for one char, `x` for one byte): [x][x][xxxx]
+/// get_true_start_end('Hi🌏', 1, None) -> Ok((0, 6))
+/// get_true_start_end('Hi🌏', 1, Some(1)) -> Ok((0, 1))
+/// get_true_start_end('Hi🌏', -10, Some(2)) -> Ok((0, 0))
+/// ```
pub fn get_true_start_end(
input: &str,
start: i64,
- count: Option<u64>,
+ count: Option<i64>,
is_input_ascii_only: bool,
-) -> (usize, usize) {
- let start = start.checked_sub(1).unwrap_or(start);
+) -> Result<(usize, usize)> {
+ if let Some(count) = count
+ && count < 0
+ {
+ return exec_err!("negative count not allowed: {count}");
+ }
+
+ // The caller-provided `start` is 1-indexed.
+ let Some(start) = start.checked_sub(1) else {
+ return exec_err!("start position overflow: {start}");
+ };
let end = match count {
- Some(count) => {
- let count_i64 = i64::try_from(count).unwrap_or(i64::MAX);
- start.saturating_add(count_i64)
- }
+ Some(count) => start.saturating_add(count),
None => input.len() as i64,
};
- let count_to_end = count.is_some();
let start = start.clamp(0, input.len() as i64) as usize;
let end = end.clamp(0, input.len() as i64) as usize;
- let count = end - start;
- // If input is ASCII-only, byte-based indices equals to char-based indices
+ // If input is ASCII-only, byte-based indices equal char-based indices
if is_input_ascii_only {
- return (start, end);
+ return Ok((start, end));
}
- // Otherwise, calculate byte indices from char indices
- // Note this decoding is relatively expensive for this simple `substr`
function,,
- // so the implementation attempts to decode in one pass (and caused the
complexity)
- let (mut st, mut ed) = (input.len(), input.len());
- let mut start_counting = false;
- let mut cnt = 0;
- for (char_cnt, (byte_cnt, _)) in input.char_indices().enumerate() {
- if char_cnt == start {
- st = byte_cnt;
- if count_to_end {
- start_counting = true;
- } else {
+ // Otherwise, calculate byte indices from char indices. We initialize both
+ // `byte_start` and `byte_end` to the string length to handle cases where
+ // the requested 'start' or 'end' positions are at or beyond the end of the
+ // string (resulting in an empty substring).
+ let mut byte_start = input.len();
+ let mut byte_end = input.len();
+
+ for (char_idx, (byte_idx, _)) in input.char_indices().enumerate() {
+ if char_idx == start {
+ byte_start = byte_idx;
+ // If no length is specified, we only need the start offset.
+ if count.is_none() {
break;
}
}
- if start_counting {
- if cnt == count {
- ed = byte_cnt;
- break;
- }
- cnt += 1;
+ if char_idx == end {
+ byte_end = byte_idx;
+ break;
}
}
- (st, ed)
+
+ Ok((byte_start, byte_end))
}
// String characters are variable length encoded in UTF-8, `substr()`
function's
@@ -272,100 +273,45 @@ pub fn enable_ascii_fast_path<'a, V:
StringArrayType<'a>>(
}
}
-// The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44
-// From<u128> for ByteView
fn string_view_substr(
string_view_array: &StringViewArray,
args: &[ArrayRef],
) -> Result<ArrayRef> {
- let mut views_buf = Vec::with_capacity(string_view_array.len());
- let mut null_builder = NullBufferBuilder::new(string_view_array.len());
-
let start_array = as_int64_array(&args[0])?;
- let count_array_opt = if args.len() == 2 {
- Some(as_int64_array(&args[1])?)
- } else {
- None
- };
+ let count_array_opt = args.get(1).map(|a| as_int64_array(a)).transpose()?;
let enable_ascii_fast_path =
enable_ascii_fast_path(&string_view_array, start_array,
count_array_opt);
- // In either case of `substr(s, i)` or `substr(s, i, cnt)`
- // If any of input argument is `NULL`, the result is `NULL`
- match args.len() {
- 1 => {
- for ((str_opt, raw_view), start_opt) in string_view_array
- .iter()
- .zip(string_view_array.views().iter())
- .zip(start_array.iter())
- {
- if let (Some(str), Some(start)) = (str_opt, start_opt) {
- let (start, end) =
- get_true_start_end(str, start, None,
enable_ascii_fast_path);
- let substr = &str[start..end];
-
- make_and_append_view(
- &mut views_buf,
- &mut null_builder,
- raw_view,
- substr,
- start as u32,
- );
- } else {
- null_builder.append_null();
- views_buf.push(0);
- }
- }
- }
- 2 => {
- let count_array = count_array_opt.unwrap();
- for (((str_opt, raw_view), start_opt), count_opt) in
string_view_array
- .iter()
- .zip(string_view_array.views().iter())
- .zip(start_array.iter())
- .zip(count_array.iter())
- {
- if let (Some(str), Some(start), Some(count)) =
- (str_opt, start_opt, count_opt)
- {
- if count < 0 {
- return exec_err!(
- "negative substring length not allowed:
substr(<str>, {start}, {count})"
- );
- } else {
- if start == i64::MIN {
- return exec_err!(
- "negative overflow when calculating skip value"
- );
- }
- let (start, end) = get_true_start_end(
- str,
- start,
- Some(count as u64),
- enable_ascii_fast_path,
- );
- let substr = &str[start..end];
-
- make_and_append_view(
- &mut views_buf,
- &mut null_builder,
- raw_view,
- substr,
- start as u32,
- );
- }
- } else {
- null_builder.append_null();
- views_buf.push(0);
- }
- }
- }
- other => {
- return exec_err!(
- "substr was called with {other} arguments. It requires 2 or 3."
- );
+ let mut views_buf = Vec::with_capacity(string_view_array.len());
+ let mut null_builder = NullBufferBuilder::new(string_view_array.len());
+
+ for i in 0..string_view_array.len() {
+ if string_view_array.is_null(i)
+ || start_array.is_null(i)
+ || count_array_opt.map(|a| a.is_null(i)).unwrap_or(false)
+ {
+ null_builder.append_null();
+ views_buf.push(0);
+ continue;
}
+
+ let string = string_view_array.value(i);
+ let start = start_array.value(i);
+ let count = count_array_opt.map(|a| a.value(i));
+ let raw_view = string_view_array.views()[i];
+
+ let (start, end) =
+ get_true_start_end(string, start, count, enable_ascii_fast_path)?;
+ let substr = &string[start..end];
+
+ make_and_append_view(
+ &mut views_buf,
+ &mut null_builder,
+ &raw_view,
+ substr,
+ start as u32,
+ );
}
let views_buf = ScalarBuffer::from(views_buf);
@@ -387,82 +333,35 @@ fn string_view_substr(
fn string_substr<'a, V>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef>
where
- V: StringArrayType<'a>,
+ V: StringArrayType<'a> + Copy,
{
let start_array = as_int64_array(&args[0])?;
- let count_array_opt = if args.len() == 2 {
- Some(as_int64_array(&args[1])?)
- } else {
- None
- };
+ let count_array_opt = args.get(1).map(|a| as_int64_array(a)).transpose()?;
let enable_ascii_fast_path =
enable_ascii_fast_path(&string_array, start_array, count_array_opt);
- match args.len() {
- 1 => {
- let iter = ArrayIter::new(string_array);
- let mut result_builder = StringViewBuilder::new();
- for (string, start) in iter.zip(start_array.iter()) {
- match (string, start) {
- (Some(string), Some(start)) => {
- let (start, end) = get_true_start_end(
- string,
- start,
- None,
- enable_ascii_fast_path,
- ); // start, end is byte-based
- let substr = &string[start..end];
- result_builder.append_value(substr);
- }
- _ => {
- result_builder.append_null();
- }
- }
- }
- Ok(Arc::new(result_builder.finish()) as ArrayRef)
- }
- 2 => {
- let iter = ArrayIter::new(string_array);
- let count_array = count_array_opt.unwrap();
- let mut result_builder = StringViewBuilder::new();
-
- for ((string, start), count) in
- iter.zip(start_array.iter()).zip(count_array.iter())
- {
- match (string, start, count) {
- (Some(string), Some(start), Some(count)) => {
- if count < 0 {
- return exec_err!(
- "negative substring length not allowed:
substr(<str>, {start}, {count})"
- );
- } else {
- if start == i64::MIN {
- return exec_err!(
- "negative overflow when calculating skip
value"
- );
- }
- let (start, end) = get_true_start_end(
- string,
- start,
- Some(count as u64),
- enable_ascii_fast_path,
- ); // start, end is byte-based
- let substr = &string[start..end];
- result_builder.append_value(substr);
- }
- }
- _ => {
- result_builder.append_null();
- }
- }
- }
- Ok(Arc::new(result_builder.finish()) as ArrayRef)
- }
- other => {
- exec_err!("substr was called with {other} arguments. It requires 2
or 3.")
+ let mut result_builder = StringViewBuilder::new();
+
+ for i in 0..string_array.len() {
+ if string_array.is_null(i)
+ || start_array.is_null(i)
+ || count_array_opt.map(|a| a.is_null(i)).unwrap_or(false)
+ {
+ result_builder.append_null();
+ continue;
}
+
+ let string = string_array.value(i);
+ let start = start_array.value(i);
+ let count = count_array_opt.map(|a| a.value(i));
+
+ let (start, end) =
+ get_true_start_end(string, start, count, enable_ascii_fast_path)?;
+ result_builder.append_value(&string[start..end]);
}
+
+ Ok(Arc::new(result_builder.finish()) as ArrayRef)
}
#[cfg(test)]
@@ -775,7 +674,7 @@ mod tests {
ColumnarValue::Scalar(ScalarValue::from(1i64)),
ColumnarValue::Scalar(ScalarValue::from(-1i64)),
],
- exec_err!("negative substring length not allowed: substr(<str>, 1,
-1)"),
+ exec_err!("negative count not allowed: -1"),
&str,
Utf8View,
StringViewArray
@@ -812,7 +711,7 @@ mod tests {
ColumnarValue::Scalar(ScalarValue::from("abc")),
ColumnarValue::Scalar(ScalarValue::from(i64::MIN)),
],
- Ok(Some("abc")),
+ exec_err!("start position overflow: -9223372036854775808"),
&str,
Utf8View,
StringViewArray
@@ -824,7 +723,7 @@ mod tests {
ColumnarValue::Scalar(ScalarValue::from(i64::MIN)),
ColumnarValue::Scalar(ScalarValue::from(1i64)),
],
- exec_err!("negative overflow when calculating skip value"),
+ exec_err!("start position overflow: -9223372036854775808"),
&str,
Utf8View,
StringViewArray
diff --git a/datafusion/spark/src/function/string/substring.rs
b/datafusion/spark/src/function/string/substring.rs
index 524262b12f..76abba68cd 100644
--- a/datafusion/spark/src/function/string/substring.rs
+++ b/datafusion/spark/src/function/string/substring.rs
@@ -244,12 +244,8 @@ where
let adjusted_start = spark_start_to_datafusion_start(start,
string_len);
- let (byte_start, byte_end) = get_true_start_end(
- string,
- adjusted_start,
- len_opt.map(|l| l as u64),
- is_ascii,
- );
+ let (byte_start, byte_end) =
+ get_true_start_end(string, adjusted_start, len_opt, is_ascii)?;
let substr = &string[byte_start..byte_end];
builder.append_value(substr);
}
diff --git a/datafusion/sqllogictest/test_files/string/string_literal.slt
b/datafusion/sqllogictest/test_files/string/string_literal.slt
index a07eab3357..569dfe0336 100644
--- a/datafusion/sqllogictest/test_files/string/string_literal.slt
+++ b/datafusion/sqllogictest/test_files/string/string_literal.slt
@@ -138,10 +138,10 @@ SELECT substr(1, 3)
statement error Function 'substr' failed to match any signature
SELECT substr(1, 3, 4)
-statement error Execution error: negative substring length not allowed
+statement error Execution error: negative count not allowed
select substr(arrow_cast('foo', 'Utf8View'), 1, -1);
-statement error Execution error: negative substring length not allowed
+statement error Execution error: negative count not allowed
select substr('', 1, -1);
# StringView scalar to StringView scalar
diff --git a/docs/source/user-guide/sql/scalar_functions.md
b/docs/source/user-guide/sql/scalar_functions.md
index 98a6d63425..918bae0f7d 100644
--- a/docs/source/user-guide/sql/scalar_functions.md
+++ b/docs/source/user-guide/sql/scalar_functions.md
@@ -1974,7 +1974,7 @@ substr(str, start_pos[, length])
#### Arguments
- **str**: String expression to operate on. Can be a constant, column, or
function, and any combination of operators.
-- **start_pos**: Character position to start the substring at. The first
character in the string has a position of 1.
+- **start_pos**: Character position to start the substring at. The first
character in the string has a position of 1. If the start position is less than
1, it is treated as if it is before the start of the string and the (absolute)
number of characters before position 1 is subtracted from `length` (if given).
For example, `substr('abc', -3, 6)` returns `'ab'`.
- **length**: Number of characters to extract. If not specified, returns the
rest of the string after the start position.
#### Example
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]