This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new f1eba3c  Add `length` kernel support for List Array (#1488)
f1eba3c is described below

commit f1eba3c588f0ea63cce73c84188729a1d34cdc8d
Author: Remzi Yang <[email protected]>
AuthorDate: Tue Mar 29 04:47:45 2022 +0800

    Add `length` kernel support for List Array (#1488)
    
    * add fn for list length
    code format
    
    Signed-off-by: remzi <[email protected]>
    
    * add list support into length function
    
    Signed-off-by: remzi <[email protected]>
    
    * add tests
    
    Signed-off-by: remzi <[email protected]>
    
    * update doc
    
    Signed-off-by: remzi <[email protected]>
---
 arrow/src/compute/kernels/length.rs | 136 ++++++++++++++++++++++++++++--------
 1 file changed, 107 insertions(+), 29 deletions(-)

diff --git a/arrow/src/compute/kernels/length.rs 
b/arrow/src/compute/kernels/length.rs
index 19be336..60a48ef 100644
--- a/arrow/src/compute/kernels/length.rs
+++ b/arrow/src/compute/kernels/length.rs
@@ -56,10 +56,23 @@ macro_rules! unary_offsets {
     }};
 }
 
-fn octet_length_binary<O: BinaryOffsetSizeTrait, T: ArrowPrimitiveType>(
-    array: &dyn Array,
-) -> ArrayRef
+fn length_list<O, T>(array: &dyn Array) -> ArrayRef
 where
+    O: OffsetSizeTrait,
+    T: ArrowPrimitiveType,
+    T::Native: OffsetSizeTrait,
+{
+    let array = array
+        .as_any()
+        .downcast_ref::<GenericListArray<O>>()
+        .unwrap();
+    unary_offsets!(array, T::DATA_TYPE, |x| x)
+}
+
+fn length_binary<O, T>(array: &dyn Array) -> ArrayRef
+where
+    O: BinaryOffsetSizeTrait,
+    T: ArrowPrimitiveType,
     T::Native: BinaryOffsetSizeTrait,
 {
     let array = array
@@ -69,10 +82,10 @@ where
     unary_offsets!(array, T::DATA_TYPE, |x| x)
 }
 
-fn octet_length<O: StringOffsetSizeTrait, T: ArrowPrimitiveType>(
-    array: &dyn Array,
-) -> ArrayRef
+fn length_string<O, T>(array: &dyn Array) -> ArrayRef
 where
+    O: StringOffsetSizeTrait,
+    T: ArrowPrimitiveType,
     T::Native: StringOffsetSizeTrait,
 {
     let array = array
@@ -82,10 +95,10 @@ where
     unary_offsets!(array, T::DATA_TYPE, |x| x)
 }
 
-fn bit_length_impl_binary<O: BinaryOffsetSizeTrait, T: ArrowPrimitiveType>(
-    array: &dyn Array,
-) -> ArrayRef
+fn bit_length_binary<O, T>(array: &dyn Array) -> ArrayRef
 where
+    O: BinaryOffsetSizeTrait,
+    T: ArrowPrimitiveType,
     T::Native: BinaryOffsetSizeTrait,
 {
     let array = array
@@ -96,10 +109,10 @@ where
     unary_offsets!(array, T::DATA_TYPE, |x| x * bits_in_bytes)
 }
 
-fn bit_length_impl<O: StringOffsetSizeTrait, T: ArrowPrimitiveType>(
-    array: &dyn Array,
-) -> ArrayRef
+fn bit_length_string<O, T>(array: &dyn Array) -> ArrayRef
 where
+    O: StringOffsetSizeTrait,
+    T: ArrowPrimitiveType,
     T::Native: StringOffsetSizeTrait,
 {
     let array = array
@@ -110,20 +123,23 @@ where
     unary_offsets!(array, T::DATA_TYPE, |x| x * bits_in_bytes)
 }
 
-/// Returns an array of Int32/Int64 denoting the number of bytes in each value 
in the array.
+/// Returns an array of Int32/Int64 denoting the length of each value in the 
array.
+/// For list array, length is the number of elements in each list.
+/// For string array and binary array, length is the number of bytes of each 
value.
 ///
-/// * this only accepts StringArray/Utf8, LargeString/LargeUtf8, BinaryArray 
and LargeBinaryArray
+/// * this only accepts ListArray/LargeListArray, StringArray/LargeStringArray 
and BinaryArray/LargeBinaryArray
 /// * length of null is null.
-/// * length is in number of bytes
 pub fn length(array: &dyn Array) -> Result<ArrayRef> {
     match array.data_type() {
-        DataType::Utf8 => Ok(octet_length::<i32, Int32Type>(array)),
-        DataType::LargeUtf8 => Ok(octet_length::<i64, Int64Type>(array)),
-        DataType::Binary => Ok(octet_length_binary::<i32, Int32Type>(array)),
-        DataType::LargeBinary => Ok(octet_length_binary::<i64, 
Int64Type>(array)),
-        _ => Err(ArrowError::ComputeError(format!(
+        DataType::List(_) => Ok(length_list::<i32, Int32Type>(array)),
+        DataType::LargeList(_) => Ok(length_list::<i64, Int64Type>(array)),
+        DataType::Utf8 => Ok(length_string::<i32, Int32Type>(array)),
+        DataType::LargeUtf8 => Ok(length_string::<i64, Int64Type>(array)),
+        DataType::Binary => Ok(length_binary::<i32, Int32Type>(array)),
+        DataType::LargeBinary => Ok(length_binary::<i64, Int64Type>(array)),
+        other => Err(ArrowError::ComputeError(format!(
             "length not supported for {:?}",
-            array.data_type()
+            other
         ))),
     }
 }
@@ -135,19 +151,21 @@ pub fn length(array: &dyn Array) -> Result<ArrayRef> {
 /// * bit_length is in number of bits
 pub fn bit_length(array: &dyn Array) -> Result<ArrayRef> {
     match array.data_type() {
-        DataType::Utf8 => Ok(bit_length_impl::<i32, Int32Type>(array)),
-        DataType::LargeUtf8 => Ok(bit_length_impl::<i64, Int64Type>(array)),
-        DataType::Binary => Ok(bit_length_impl_binary::<i32, 
Int32Type>(array)),
-        DataType::LargeBinary => Ok(bit_length_impl_binary::<i64, 
Int64Type>(array)),
-        _ => Err(ArrowError::ComputeError(format!(
+        DataType::Utf8 => Ok(bit_length_string::<i32, Int32Type>(array)),
+        DataType::LargeUtf8 => Ok(bit_length_string::<i64, Int64Type>(array)),
+        DataType::Binary => Ok(bit_length_binary::<i32, Int32Type>(array)),
+        DataType::LargeBinary => Ok(bit_length_binary::<i64, 
Int64Type>(array)),
+        other => Err(ArrowError::ComputeError(format!(
             "bit_length not supported for {:?}",
-            array.data_type()
+            other
         ))),
     }
 }
 
 #[cfg(test)]
 mod tests {
+    use crate::datatypes::{Float32Type, Int8Type};
+
     use super::*;
 
     fn double_vec<T: Clone>(v: Vec<T>) -> Vec<T> {
@@ -182,6 +200,20 @@ mod tests {
         }};
     }
 
+    macro_rules! length_list_helper {
+        ($offset_ty: ty, $result_ty: ty, $element_ty: ty, $value: expr, 
$expected: expr) => {{
+            let array =
+                
GenericListArray::<$offset_ty>::from_iter_primitive::<$element_ty, _, _>(
+                    $value,
+                );
+            let result = length(&array)?;
+            let result = result.as_any().downcast_ref::<$result_ty>().unwrap();
+            let expected: $result_ty = $expected.into();
+            assert_eq!(expected.data(), result.data());
+            Ok(())
+        }};
+    }
+
     #[test]
     #[cfg_attr(miri, ignore)] // running forever
     fn length_test_string() -> Result<()> {
@@ -230,6 +262,28 @@ mod tests {
         length_binary_helper!(i64, Int64Array, length, value, result)
     }
 
+    #[test]
+    fn length_test_list() -> Result<()> {
+        let value = vec![
+            Some(vec![]),
+            Some(vec![Some(1), Some(2), Some(4)]),
+            Some(vec![Some(0)]),
+        ];
+        let result: Vec<i32> = vec![0, 3, 1];
+        length_list_helper!(i32, Int32Array, Int32Type, value, result)
+    }
+
+    #[test]
+    fn length_test_large_list() -> Result<()> {
+        let value = vec![
+            Some(vec![]),
+            Some(vec![Some(1.1), Some(2.2), Some(3.3)]),
+            Some(vec![None]),
+        ];
+        let result: Vec<i64> = vec![0, 3, 1];
+        length_list_helper!(i64, Int64Array, Float32Type, value, result)
+    }
+
     type OptionStr = Option<&'static str>;
 
     fn length_null_cases_string() -> Vec<(Vec<OptionStr>, usize, 
Vec<Option<i32>>)> {
@@ -293,6 +347,30 @@ mod tests {
         length_binary_helper!(i64, Int64Array, length, value, result)
     }
 
+    #[test]
+    fn length_null_list() -> Result<()> {
+        let value = vec![
+            Some(vec![]),
+            None,
+            Some(vec![Some(1), None, Some(2), Some(4)]),
+            Some(vec![Some(0)]),
+        ];
+        let result: Vec<Option<i32>> = vec![Some(0), None, Some(4), Some(1)];
+        length_list_helper!(i32, Int32Array, Int8Type, value, result)
+    }
+
+    #[test]
+    fn length_null_large_list() -> Result<()> {
+        let value = vec![
+            Some(vec![]),
+            None,
+            Some(vec![Some(1.1), None, Some(4.0)]),
+            Some(vec![Some(0.1)]),
+        ];
+        let result: Vec<Option<i64>> = vec![Some(0), None, Some(3), Some(1)];
+        length_list_helper!(i64, Int64Array, Float32Type, value, result)
+    }
+
     /// Tests that length is not valid for u64.
     #[test]
     fn length_wrong_type() {
@@ -303,7 +381,7 @@ mod tests {
 
     /// Tests with an offset
     #[test]
-    fn length_offsets() -> Result<()> {
+    fn length_offsets_string() -> Result<()> {
         let a = StringArray::from(vec![Some("hello"), Some(" "), 
Some("world"), None]);
         let b = a.slice(1, 3);
         let result = length(b.as_ref())?;
@@ -316,7 +394,7 @@ mod tests {
     }
 
     #[test]
-    fn binary_length_offsets() -> Result<()> {
+    fn length_offsets_binary() -> Result<()> {
         let value: Vec<Option<&[u8]>> =
             vec![Some(b"hello"), Some(b" "), Some(&[0xff, 0xf8]), None];
         let a = BinaryArray::from(value);

Reply via email to