gstvg commented on code in PR #8966:
URL: https://github.com/apache/arrow-rs/pull/8966#discussion_r2597133936
##########
arrow-select/src/take.rs:
##########
@@ -164,31 +165,47 @@ pub fn take_arrays(
fn check_bounds<T: ArrowPrimitiveType>(
len: usize,
indices: &PrimitiveArray<T>,
-) -> Result<(), ArrowError> {
+) -> Result<(), ArrowError>
+where
+ T::Native: Display,
+{
+ let len = match T::Native::from_usize(len) {
+ Some(len) => len,
+ None => {
+ if T::DATA_TYPE.is_integer() {
+ // the biggest representable value for T::Native is lower than
len, e.g: u8::MAX < 512, no need to check bounds
+ return Ok(());
+ } else {
+ return Err(ArrowError::ComputeError("Cast to usize
failed".to_string()));
+ }
+ }
+ };
+
if indices.null_count() > 0 {
indices.iter().flatten().try_for_each(|index| {
- let ix = index
- .to_usize()
- .ok_or_else(|| ArrowError::ComputeError("Cast to usize
failed".to_string()))?;
- if ix >= len {
+ if index >= len {
return Err(ArrowError::ComputeError(format!(
- "Array index out of bounds, cannot get item at index {ix}
from {len} entries"
+ "Array index out of bounds, cannot get item at index
{index} from {len} entries"
)));
}
Ok(())
})
} else {
- indices.values().iter().try_for_each(|index| {
- let ix = index
- .to_usize()
- .ok_or_else(|| ArrowError::ComputeError("Cast to usize
failed".to_string()))?;
- if ix >= len {
- return Err(ArrowError::ComputeError(format!(
- "Array index out of bounds, cannot get item at index {ix}
from {len} entries"
- )));
+ let in_bounds = indices.values().iter().fold(true, |in_bounds, &i| {
+ in_bounds & (i >= T::Native::ZERO) & (i < len)
+ });
+
+ if !in_bounds {
+ for &index in indices.values() {
+ if index < T::Native::ZERO || index >= len {
+ return Err(ArrowError::ComputeError(format!(
+ "Array index out of bounds, cannot get item at index
{index} from {len} entries"
+ )));
+ }
}
Review Comment:
This heavily optimizes for the happy path, and make the error-path slow. If
a more balanced approach is desired we can use something like this:
```rust
pub fn check_bounds<T: ArrowPrimitiveType>(
len: usize,
indices: &PrimitiveArray<T>,
) -> Result<(), ArrowError>
where
T::Native: Display,
{
// omitted
if indices.null_count() > 0 {
// omitted
} else {
let chunks = indices.values().chunks_exact(64);
let remainder = chunks.remainder();
for chunk in chunks {
let chunk: &[T::Native; 64] = chunk.try_into().unwrap(); //
unwrap is optimized out
let in_bounds = chunk.iter().fold(true, |in_bounds, &i| {
in_bounds & (i >= T::Native::ZERO) & (i < len)
});
if !in_bounds {
return Err(out_of_bounds_index(chunk, len));
}
}
for &index in remainder {
if index < T::Native::ZERO || index >= len {
return Err(ArrowError::ComputeError(format!(
"Array index out of bounds, cannot get item at index
{index} from {len} entries"
)));
}
}
Ok(())
}
}
#[inline(never)]
fn out_of_bounds_index<T: ArrowNativeType>(indices: &[T], len: T) ->
ArrowError {
todo!()
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]