This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 75ef138c2 Support UnionArray in ffi (#3305)
75ef138c2 is described below

commit 75ef138c2397a221311c626eada51fd96a7515c9
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Dec 12 08:21:48 2022 -0800

    Support UnionArray in ffi (#3305)
    
    * Make ffi support UnionArray
    
    * Move union to supported types
---
 .../tests/test_sql.py                              |  14 +-
 arrow/src/datatypes/ffi.rs                         |  56 +++++++
 arrow/src/ffi.rs                                   | 166 ++++++++++++++++++++-
 3 files changed, 226 insertions(+), 10 deletions(-)

diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py 
b/arrow-pyarrow-integration-testing/tests/test_sql.py
index 5a8bec792..196dc7990 100644
--- a/arrow-pyarrow-integration-testing/tests/test_sql.py
+++ b/arrow-pyarrow-integration-testing/tests/test_sql.py
@@ -87,10 +87,6 @@ _supported_pyarrow_types = [
     ),
     pa.dictionary(pa.int8(), pa.string()),
     pa.map_(pa.string(), pa.int32()),
-]
-
-_unsupported_pyarrow_types = [
-    pa.decimal256(76, 38),
     pa.union(
         [pa.field("a", pa.binary(10)), pa.field("b", pa.string())],
         mode=pa.lib.UnionMode_DENSE,
@@ -113,6 +109,10 @@ _unsupported_pyarrow_types = [
     ),
 ]
 
+_unsupported_pyarrow_types = [
+    pa.decimal256(76, 38),
+]
+
 
 @pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str)
 def test_type_roundtrip(pyarrow_type):
@@ -202,6 +202,9 @@ def test_empty_array_python(datatype):
     if datatype == pa.float16():
         pytest.skip("Float 16 is not implemented in Rust")
 
+    if type(datatype) is pa.lib.DenseUnionType or type(datatype) is 
pa.lib.SparseUnionType:
+        pytest.skip("Union is not implemented in Python")
+
     a = pa.array([], datatype)
     b = rust.round_trip_array(a)
     b.validate(full=True)
@@ -216,6 +219,9 @@ def test_empty_array_rust(datatype):
     """
     Rust -> Python
     """
+    if type(datatype) is pa.lib.DenseUnionType or type(datatype) is 
pa.lib.SparseUnionType:
+        pytest.skip("Union is not implemented in Python")
+
     a = pa.array([], type=datatype)
     b = rust.make_empty_array(datatype)
     b.validate(full=True)
diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs
index 41addf24f..58fc8858a 100644
--- a/arrow/src/datatypes/ffi.rs
+++ b/arrow/src/datatypes/ffi.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use arrow_schema::UnionMode;
 use std::convert::TryFrom;
 
 use crate::datatypes::DataType::Map;
@@ -134,6 +135,50 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {
                             }
                         }
                     }
+                    // DenseUnion
+                    ["+ud", extra] => {
+                        let type_ids = extra.split(',').map(|t| 
t.parse::<i8>().map_err(|_| {
+                            ArrowError::CDataInterface(
+                                "The Union type requires an integer type 
id".to_string(),
+                            )
+                        })).collect::<Result<Vec<_>>>()?;
+                        let mut fields = Vec::with_capacity(type_ids.len());
+                        for idx in 0..c_schema.n_children {
+                            let c_child = c_schema.child(idx as usize);
+                            let field = Field::try_from(c_child)?;
+                            fields.push(field);
+                        }
+
+                        if fields.len() != type_ids.len() {
+                            return Err(ArrowError::CDataInterface(
+                                "The Union type requires same number of fields 
and type ids".to_string(),
+                            ));
+                        }
+
+                        DataType::Union(fields, type_ids, UnionMode::Dense)
+                    }
+                    // SparseUnion
+                    ["+us", extra] => {
+                        let type_ids = extra.split(',').map(|t| 
t.parse::<i8>().map_err(|_| {
+                            ArrowError::CDataInterface(
+                                "The Union type requires an integer type 
id".to_string(),
+                            )
+                        })).collect::<Result<Vec<_>>>()?;
+                        let mut fields = Vec::with_capacity(type_ids.len());
+                        for idx in 0..c_schema.n_children {
+                            let c_child = c_schema.child(idx as usize);
+                            let field = Field::try_from(c_child)?;
+                            fields.push(field);
+                        }
+
+                        if fields.len() != type_ids.len() {
+                            return Err(ArrowError::CDataInterface(
+                                "The Union type requires same number of fields 
and type ids".to_string(),
+                            ));
+                        }
+
+                        DataType::Union(fields, type_ids, UnionMode::Sparse)
+                    }
 
                     // Timestamps in format "tts:" and "tts:America/New_York" 
for no timezones and timezones resp.
                     ["tss", ""] => DataType::Timestamp(TimeUnit::Second, None),
@@ -211,6 +256,10 @@ impl TryFrom<&DataType> for FFI_ArrowSchema {
             | DataType::Map(child, _) => {
                 vec![FFI_ArrowSchema::try_from(child.as_ref())?]
             }
+            DataType::Union(fields, _, _) => fields
+                .iter()
+                .map(FFI_ArrowSchema::try_from)
+                .collect::<Result<Vec<_>>>()?,
             DataType::Struct(fields) => fields
                 .iter()
                 .map(FFI_ArrowSchema::try_from)
@@ -279,6 +328,13 @@ fn get_format_string(dtype: &DataType) -> Result<String> {
         DataType::Struct(_) => Ok("+s".to_string()),
         DataType::Map(_, _) => Ok("+m".to_string()),
         DataType::Dictionary(key_data_type, _) => 
get_format_string(key_data_type),
+        DataType::Union(_, type_ids, mode) => {
+            let formats = type_ids.iter().map(|t| 
t.to_string()).collect::<Vec<_>>();
+            match mode {
+                UnionMode::Dense => Ok(format!("{}:{}", "+ud", 
formats.join(","))),
+                UnionMode::Sparse => Ok(format!("{}:{}", "+us", 
formats.join(","))),
+            }
+        }
         other => Err(ArrowError::CDataInterface(format!(
             "The datatype \"{:?}\" is still not supported in Rust 
implementation",
             other
diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs
index abb53dff6..5e9b01b5c 100644
--- a/arrow/src/ffi.rs
+++ b/arrow/src/ffi.rs
@@ -120,6 +120,7 @@ use std::{
     sync::Arc,
 };
 
+use arrow_schema::UnionMode;
 use bitflags::bitflags;
 
 use crate::array::{layout, ArrayData};
@@ -310,8 +311,6 @@ impl Drop for FFI_ArrowSchema {
 #[allow(clippy::manual_bits)]
 fn bit_width(data_type: &DataType, i: usize) -> Result<usize> {
     Ok(match (data_type, i) {
-        // the null buffer is bit sized
-        (_, 0) => 1,
         // primitive types first buffer's size is given by the native types
         (DataType::Boolean, 1) => 1,
         (DataType::UInt8, 1) => size_of::<u8>() * 8,
@@ -385,6 +384,30 @@ fn bit_width(data_type: &DataType, i: usize) -> 
Result<usize> {
                 data_type, i
             )))
         }
+        // type ids. UnionArray doesn't have null bitmap so buffer index 
begins with 0.
+        (DataType::Union(_, _, _), 0) => size_of::<i8>() * 8,
+        // Only DenseUnion has 2nd buffer
+        (DataType::Union(_, _, UnionMode::Dense), 1) => size_of::<i32>() * 8,
+        (DataType::Union(_, _, UnionMode::Sparse), _) => {
+            return Err(ArrowError::CDataInterface(format!(
+                "The datatype \"{:?}\" expects 1 buffer, but requested {}. 
Please verify that the C data interface is correctly implemented.",
+                data_type, i
+            )))
+        }
+        (DataType::Union(_, _, UnionMode::Dense), _) => {
+            return Err(ArrowError::CDataInterface(format!(
+                "The datatype \"{:?}\" expects 2 buffer, but requested {}. 
Please verify that the C data interface is correctly implemented.",
+                data_type, i
+            )))
+        }
+        (_, 0) => {
+            // We don't call this `bit_width` to compute buffer length for 
null buffer. If any types that don't have null buffer like
+            // UnionArray, they should be handled above.
+            return Err(ArrowError::CDataInterface(format!(
+                "The datatype \"{:?}\" doesn't expect buffer at index 0. 
Please verify that the C data interface is correctly implemented.",
+                data_type
+            )))
+        }
         _ => {
             return Err(ArrowError::CDataInterface(format!(
                 "The datatype \"{:?}\" is still not supported in Rust 
implementation",
@@ -661,7 +684,8 @@ pub trait ArrowArrayRef {
         })
     }
 
-    /// returns all buffers, as organized by Rust (i.e. null buffer is skipped)
+    /// returns all buffers, as organized by Rust (i.e. null buffer is skipped 
if it's present
+    /// in the spec of the type)
     fn buffers(&self, can_contain_null_mask: bool) -> Result<Vec<Buffer>> {
         // + 1: skip null buffer
         let buffer_begin = can_contain_null_mask as i64;
@@ -690,9 +714,9 @@ pub trait ArrowArrayRef {
     }
 
     /// Returns the length, in bytes, of the buffer `i` (indexed according to 
the C data interface)
-    // Rust implementation uses fixed-sized buffers, which require knowledge 
of their `len`.
-    // for variable-sized buffers, such as the second buffer of a stringArray, 
we need
-    // to fetch offset buffer's len to build the second buffer.
+    /// Rust implementation uses fixed-sized buffers, which require knowledge 
of their `len`.
+    /// for variable-sized buffers, such as the second buffer of a 
stringArray, we need
+    /// to fetch offset buffer's len to build the second buffer.
     fn buffer_len(&self, i: usize) -> Result<usize> {
         // Special handling for dictionary type as we only care about the key 
type in the case.
         let t = self.data_type()?;
@@ -937,6 +961,9 @@ mod tests {
     };
     use crate::compute::kernels;
     use crate::datatypes::{Field, Int8Type};
+    use arrow_array::builder::UnionBuilder;
+    use arrow_array::types::{Float64Type, Int32Type};
+    use arrow_array::{Float64Array, UnionArray};
     use std::convert::TryFrom;
 
     #[test]
@@ -1500,4 +1527,131 @@ mod tests {
 
         Ok(())
     }
+
+    #[test]
+    fn test_union_sparse_array() -> Result<()> {
+        let mut builder = UnionBuilder::new_sparse();
+        builder.append::<Int32Type>("a", 1).unwrap();
+        builder.append_null::<Int32Type>("a").unwrap();
+        builder.append::<Float64Type>("c", 3.0).unwrap();
+        builder.append::<Int32Type>("a", 4).unwrap();
+        let union = builder.build().unwrap();
+
+        // export it
+        let array = ArrowArray::try_from(union.data().clone())?;
+
+        // (simulate consumer) import it
+        let data = ArrayData::try_from(array)?;
+        let array = make_array(data);
+
+        let array = array.as_any().downcast_ref::<UnionArray>().unwrap();
+
+        let expected_type_ids = vec![0_i8, 0, 1, 0];
+
+        // Check type ids
+        assert_eq!(
+            Buffer::from_slice_ref(&expected_type_ids),
+            array.data().buffers()[0]
+        );
+        for (i, id) in expected_type_ids.iter().enumerate() {
+            assert_eq!(id, &array.type_id(i));
+        }
+
+        // Check offsets, sparse union should only have a single buffer, i.e. 
no offsets
+        assert_eq!(array.data().buffers().len(), 1);
+
+        for i in 0..array.len() {
+            let slot = array.value(i);
+            match i {
+                0 => {
+                    let slot = 
slot.as_any().downcast_ref::<Int32Array>().unwrap();
+                    assert!(!slot.is_null(0));
+                    assert_eq!(slot.len(), 1);
+                    let value = slot.value(0);
+                    assert_eq!(1_i32, value);
+                }
+                1 => assert!(slot.is_null(0)),
+                2 => {
+                    let slot = 
slot.as_any().downcast_ref::<Float64Array>().unwrap();
+                    assert!(!slot.is_null(0));
+                    assert_eq!(slot.len(), 1);
+                    let value = slot.value(0);
+                    assert_eq!(value, 3_f64);
+                }
+                3 => {
+                    let slot = 
slot.as_any().downcast_ref::<Int32Array>().unwrap();
+                    assert!(!slot.is_null(0));
+                    assert_eq!(slot.len(), 1);
+                    let value = slot.value(0);
+                    assert_eq!(4_i32, value);
+                }
+                _ => unreachable!(),
+            }
+        }
+
+        Ok(())
+    }
+
+    #[test]
+    fn test_union_dense_array() -> Result<()> {
+        let mut builder = UnionBuilder::new_dense();
+        builder.append::<Int32Type>("a", 1).unwrap();
+        builder.append_null::<Int32Type>("a").unwrap();
+        builder.append::<Float64Type>("c", 3.0).unwrap();
+        builder.append::<Int32Type>("a", 4).unwrap();
+        let union = builder.build().unwrap();
+
+        // export it
+        let array = ArrowArray::try_from(union.data().clone())?;
+
+        // (simulate consumer) import it
+        let data = ArrayData::try_from(array)?;
+        let array = make_array(data);
+
+        let array = array.as_any().downcast_ref::<UnionArray>().unwrap();
+
+        let expected_type_ids = vec![0_i8, 0, 1, 0];
+
+        // Check type ids
+        assert_eq!(
+            Buffer::from_slice_ref(&expected_type_ids),
+            array.data().buffers()[0]
+        );
+        for (i, id) in expected_type_ids.iter().enumerate() {
+            assert_eq!(id, &array.type_id(i));
+        }
+
+        assert_eq!(array.data().buffers().len(), 2);
+
+        for i in 0..array.len() {
+            let slot = array.value(i);
+            match i {
+                0 => {
+                    let slot = 
slot.as_any().downcast_ref::<Int32Array>().unwrap();
+                    assert!(!slot.is_null(0));
+                    assert_eq!(slot.len(), 1);
+                    let value = slot.value(0);
+                    assert_eq!(1_i32, value);
+                }
+                1 => assert!(slot.is_null(0)),
+                2 => {
+                    let slot = 
slot.as_any().downcast_ref::<Float64Array>().unwrap();
+                    assert!(!slot.is_null(0));
+                    assert_eq!(slot.len(), 1);
+                    let value = slot.value(0);
+                    assert_eq!(value, 3_f64);
+                }
+                3 => {
+                    let slot = 
slot.as_any().downcast_ref::<Int32Array>().unwrap();
+                    assert!(!slot.is_null(0));
+                    assert_eq!(slot.len(), 1);
+                    let value = slot.value(0);
+                    assert_eq!(4_i32, value);
+                }
+                _ => unreachable!(),
+            }
+        }
+
+        Ok(())
+    }
 }

Reply via email to