This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new f1eba3c Add `length` kernel support for List Array (#1488)
f1eba3c is described below
commit f1eba3c588f0ea63cce73c84188729a1d34cdc8d
Author: Remzi Yang <[email protected]>
AuthorDate: Tue Mar 29 04:47:45 2022 +0800
Add `length` kernel support for List Array (#1488)
* add fn for list length
code format
Signed-off-by: remzi <[email protected]>
* add list support into length function
Signed-off-by: remzi <[email protected]>
* add tests
Signed-off-by: remzi <[email protected]>
* update doc
Signed-off-by: remzi <[email protected]>
---
arrow/src/compute/kernels/length.rs | 136 ++++++++++++++++++++++++++++--------
1 file changed, 107 insertions(+), 29 deletions(-)
diff --git a/arrow/src/compute/kernels/length.rs
b/arrow/src/compute/kernels/length.rs
index 19be336..60a48ef 100644
--- a/arrow/src/compute/kernels/length.rs
+++ b/arrow/src/compute/kernels/length.rs
@@ -56,10 +56,23 @@ macro_rules! unary_offsets {
}};
}
-fn octet_length_binary<O: BinaryOffsetSizeTrait, T: ArrowPrimitiveType>(
- array: &dyn Array,
-) -> ArrayRef
+fn length_list<O, T>(array: &dyn Array) -> ArrayRef
where
+ O: OffsetSizeTrait,
+ T: ArrowPrimitiveType,
+ T::Native: OffsetSizeTrait,
+{
+ let array = array
+ .as_any()
+ .downcast_ref::<GenericListArray<O>>()
+ .unwrap();
+ unary_offsets!(array, T::DATA_TYPE, |x| x)
+}
+
+fn length_binary<O, T>(array: &dyn Array) -> ArrayRef
+where
+ O: BinaryOffsetSizeTrait,
+ T: ArrowPrimitiveType,
T::Native: BinaryOffsetSizeTrait,
{
let array = array
@@ -69,10 +82,10 @@ where
unary_offsets!(array, T::DATA_TYPE, |x| x)
}
-fn octet_length<O: StringOffsetSizeTrait, T: ArrowPrimitiveType>(
- array: &dyn Array,
-) -> ArrayRef
+fn length_string<O, T>(array: &dyn Array) -> ArrayRef
where
+ O: StringOffsetSizeTrait,
+ T: ArrowPrimitiveType,
T::Native: StringOffsetSizeTrait,
{
let array = array
@@ -82,10 +95,10 @@ where
unary_offsets!(array, T::DATA_TYPE, |x| x)
}
-fn bit_length_impl_binary<O: BinaryOffsetSizeTrait, T: ArrowPrimitiveType>(
- array: &dyn Array,
-) -> ArrayRef
+fn bit_length_binary<O, T>(array: &dyn Array) -> ArrayRef
where
+ O: BinaryOffsetSizeTrait,
+ T: ArrowPrimitiveType,
T::Native: BinaryOffsetSizeTrait,
{
let array = array
@@ -96,10 +109,10 @@ where
unary_offsets!(array, T::DATA_TYPE, |x| x * bits_in_bytes)
}
-fn bit_length_impl<O: StringOffsetSizeTrait, T: ArrowPrimitiveType>(
- array: &dyn Array,
-) -> ArrayRef
+fn bit_length_string<O, T>(array: &dyn Array) -> ArrayRef
where
+ O: StringOffsetSizeTrait,
+ T: ArrowPrimitiveType,
T::Native: StringOffsetSizeTrait,
{
let array = array
@@ -110,20 +123,23 @@ where
unary_offsets!(array, T::DATA_TYPE, |x| x * bits_in_bytes)
}
-/// Returns an array of Int32/Int64 denoting the number of bytes in each value
in the array.
+/// Returns an array of Int32/Int64 denoting the length of each value in the
array.
+/// For list array, length is the number of elements in each list.
+/// For string array and binary array, length is the number of bytes of each
value.
///
-/// * this only accepts StringArray/Utf8, LargeString/LargeUtf8, BinaryArray
and LargeBinaryArray
+/// * this only accepts ListArray/LargeListArray, StringArray/LargeStringArray
and BinaryArray/LargeBinaryArray
/// * length of null is null.
-/// * length is in number of bytes
pub fn length(array: &dyn Array) -> Result<ArrayRef> {
match array.data_type() {
- DataType::Utf8 => Ok(octet_length::<i32, Int32Type>(array)),
- DataType::LargeUtf8 => Ok(octet_length::<i64, Int64Type>(array)),
- DataType::Binary => Ok(octet_length_binary::<i32, Int32Type>(array)),
- DataType::LargeBinary => Ok(octet_length_binary::<i64,
Int64Type>(array)),
- _ => Err(ArrowError::ComputeError(format!(
+ DataType::List(_) => Ok(length_list::<i32, Int32Type>(array)),
+ DataType::LargeList(_) => Ok(length_list::<i64, Int64Type>(array)),
+ DataType::Utf8 => Ok(length_string::<i32, Int32Type>(array)),
+ DataType::LargeUtf8 => Ok(length_string::<i64, Int64Type>(array)),
+ DataType::Binary => Ok(length_binary::<i32, Int32Type>(array)),
+ DataType::LargeBinary => Ok(length_binary::<i64, Int64Type>(array)),
+ other => Err(ArrowError::ComputeError(format!(
"length not supported for {:?}",
- array.data_type()
+ other
))),
}
}
@@ -135,19 +151,21 @@ pub fn length(array: &dyn Array) -> Result<ArrayRef> {
/// * bit_length is in number of bits
pub fn bit_length(array: &dyn Array) -> Result<ArrayRef> {
match array.data_type() {
- DataType::Utf8 => Ok(bit_length_impl::<i32, Int32Type>(array)),
- DataType::LargeUtf8 => Ok(bit_length_impl::<i64, Int64Type>(array)),
- DataType::Binary => Ok(bit_length_impl_binary::<i32,
Int32Type>(array)),
- DataType::LargeBinary => Ok(bit_length_impl_binary::<i64,
Int64Type>(array)),
- _ => Err(ArrowError::ComputeError(format!(
+ DataType::Utf8 => Ok(bit_length_string::<i32, Int32Type>(array)),
+ DataType::LargeUtf8 => Ok(bit_length_string::<i64, Int64Type>(array)),
+ DataType::Binary => Ok(bit_length_binary::<i32, Int32Type>(array)),
+ DataType::LargeBinary => Ok(bit_length_binary::<i64,
Int64Type>(array)),
+ other => Err(ArrowError::ComputeError(format!(
"bit_length not supported for {:?}",
- array.data_type()
+ other
))),
}
}
#[cfg(test)]
mod tests {
+ use crate::datatypes::{Float32Type, Int8Type};
+
use super::*;
fn double_vec<T: Clone>(v: Vec<T>) -> Vec<T> {
@@ -182,6 +200,20 @@ mod tests {
}};
}
+ macro_rules! length_list_helper {
+ ($offset_ty: ty, $result_ty: ty, $element_ty: ty, $value: expr,
$expected: expr) => {{
+ let array =
+
GenericListArray::<$offset_ty>::from_iter_primitive::<$element_ty, _, _>(
+ $value,
+ );
+ let result = length(&array)?;
+ let result = result.as_any().downcast_ref::<$result_ty>().unwrap();
+ let expected: $result_ty = $expected.into();
+ assert_eq!(expected.data(), result.data());
+ Ok(())
+ }};
+ }
+
#[test]
#[cfg_attr(miri, ignore)] // running forever
fn length_test_string() -> Result<()> {
@@ -230,6 +262,28 @@ mod tests {
length_binary_helper!(i64, Int64Array, length, value, result)
}
+ #[test]
+ fn length_test_list() -> Result<()> {
+ let value = vec![
+ Some(vec![]),
+ Some(vec![Some(1), Some(2), Some(4)]),
+ Some(vec![Some(0)]),
+ ];
+ let result: Vec<i32> = vec![0, 3, 1];
+ length_list_helper!(i32, Int32Array, Int32Type, value, result)
+ }
+
+ #[test]
+ fn length_test_large_list() -> Result<()> {
+ let value = vec![
+ Some(vec![]),
+ Some(vec![Some(1.1), Some(2.2), Some(3.3)]),
+ Some(vec![None]),
+ ];
+ let result: Vec<i64> = vec![0, 3, 1];
+ length_list_helper!(i64, Int64Array, Float32Type, value, result)
+ }
+
type OptionStr = Option<&'static str>;
fn length_null_cases_string() -> Vec<(Vec<OptionStr>, usize,
Vec<Option<i32>>)> {
@@ -293,6 +347,30 @@ mod tests {
length_binary_helper!(i64, Int64Array, length, value, result)
}
+ #[test]
+ fn length_null_list() -> Result<()> {
+ let value = vec![
+ Some(vec![]),
+ None,
+ Some(vec![Some(1), None, Some(2), Some(4)]),
+ Some(vec![Some(0)]),
+ ];
+ let result: Vec<Option<i32>> = vec![Some(0), None, Some(4), Some(1)];
+ length_list_helper!(i32, Int32Array, Int8Type, value, result)
+ }
+
+ #[test]
+ fn length_null_large_list() -> Result<()> {
+ let value = vec![
+ Some(vec![]),
+ None,
+ Some(vec![Some(1.1), None, Some(4.0)]),
+ Some(vec![Some(0.1)]),
+ ];
+ let result: Vec<Option<i64>> = vec![Some(0), None, Some(3), Some(1)];
+ length_list_helper!(i64, Int64Array, Float32Type, value, result)
+ }
+
/// Tests that length is not valid for u64.
#[test]
fn length_wrong_type() {
@@ -303,7 +381,7 @@ mod tests {
/// Tests with an offset
#[test]
- fn length_offsets() -> Result<()> {
+ fn length_offsets_string() -> Result<()> {
let a = StringArray::from(vec![Some("hello"), Some(" "),
Some("world"), None]);
let b = a.slice(1, 3);
let result = length(b.as_ref())?;
@@ -316,7 +394,7 @@ mod tests {
}
#[test]
- fn binary_length_offsets() -> Result<()> {
+ fn length_offsets_binary() -> Result<()> {
let value: Vec<Option<&[u8]>> =
vec![Some(b"hello"), Some(b" "), Some(&[0xff, 0xf8]), None];
let a = BinaryArray::from(value);