This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 75ef138c2 Support UnionArray in ffi (#3305)
75ef138c2 is described below
commit 75ef138c2397a221311c626eada51fd96a7515c9
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Dec 12 08:21:48 2022 -0800
Support UnionArray in ffi (#3305)
* Make ffi support UnionArray
* Move union to supported types
---
.../tests/test_sql.py | 14 +-
arrow/src/datatypes/ffi.rs | 56 +++++++
arrow/src/ffi.rs | 166 ++++++++++++++++++++-
3 files changed, 226 insertions(+), 10 deletions(-)
diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py
b/arrow-pyarrow-integration-testing/tests/test_sql.py
index 5a8bec792..196dc7990 100644
--- a/arrow-pyarrow-integration-testing/tests/test_sql.py
+++ b/arrow-pyarrow-integration-testing/tests/test_sql.py
@@ -87,10 +87,6 @@ _supported_pyarrow_types = [
),
pa.dictionary(pa.int8(), pa.string()),
pa.map_(pa.string(), pa.int32()),
-]
-
-_unsupported_pyarrow_types = [
- pa.decimal256(76, 38),
pa.union(
[pa.field("a", pa.binary(10)), pa.field("b", pa.string())],
mode=pa.lib.UnionMode_DENSE,
@@ -113,6 +109,10 @@ _unsupported_pyarrow_types = [
),
]
+_unsupported_pyarrow_types = [
+ pa.decimal256(76, 38),
+]
+
@pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str)
def test_type_roundtrip(pyarrow_type):
@@ -202,6 +202,9 @@ def test_empty_array_python(datatype):
if datatype == pa.float16():
pytest.skip("Float 16 is not implemented in Rust")
+ if type(datatype) is pa.lib.DenseUnionType or type(datatype) is
pa.lib.SparseUnionType:
+ pytest.skip("Union is not implemented in Python")
+
a = pa.array([], datatype)
b = rust.round_trip_array(a)
b.validate(full=True)
@@ -216,6 +219,9 @@ def test_empty_array_rust(datatype):
"""
Rust -> Python
"""
+ if type(datatype) is pa.lib.DenseUnionType or type(datatype) is
pa.lib.SparseUnionType:
+ pytest.skip("Union is not implemented in Python")
+
a = pa.array([], type=datatype)
b = rust.make_empty_array(datatype)
b.validate(full=True)
diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs
index 41addf24f..58fc8858a 100644
--- a/arrow/src/datatypes/ffi.rs
+++ b/arrow/src/datatypes/ffi.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+use arrow_schema::UnionMode;
use std::convert::TryFrom;
use crate::datatypes::DataType::Map;
@@ -134,6 +135,50 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {
}
}
}
+ // DenseUnion
+ ["+ud", extra] => {
+ let type_ids = extra.split(',').map(|t|
t.parse::<i8>().map_err(|_| {
+ ArrowError::CDataInterface(
+ "The Union type requires an integer type
id".to_string(),
+ )
+ })).collect::<Result<Vec<_>>>()?;
+ let mut fields = Vec::with_capacity(type_ids.len());
+ for idx in 0..c_schema.n_children {
+ let c_child = c_schema.child(idx as usize);
+ let field = Field::try_from(c_child)?;
+ fields.push(field);
+ }
+
+ if fields.len() != type_ids.len() {
+ return Err(ArrowError::CDataInterface(
+ "The Union type requires same number of fields
and type ids".to_string(),
+ ));
+ }
+
+ DataType::Union(fields, type_ids, UnionMode::Dense)
+ }
+ // SparseUnion
+ ["+us", extra] => {
+ let type_ids = extra.split(',').map(|t|
t.parse::<i8>().map_err(|_| {
+ ArrowError::CDataInterface(
+ "The Union type requires an integer type
id".to_string(),
+ )
+ })).collect::<Result<Vec<_>>>()?;
+ let mut fields = Vec::with_capacity(type_ids.len());
+ for idx in 0..c_schema.n_children {
+ let c_child = c_schema.child(idx as usize);
+ let field = Field::try_from(c_child)?;
+ fields.push(field);
+ }
+
+ if fields.len() != type_ids.len() {
+ return Err(ArrowError::CDataInterface(
+ "The Union type requires same number of fields
and type ids".to_string(),
+ ));
+ }
+
+ DataType::Union(fields, type_ids, UnionMode::Sparse)
+ }
// Timestamps in format "tts:" and "tts:America/New_York"
for no timezones and timezones resp.
["tss", ""] => DataType::Timestamp(TimeUnit::Second, None),
@@ -211,6 +256,10 @@ impl TryFrom<&DataType> for FFI_ArrowSchema {
| DataType::Map(child, _) => {
vec![FFI_ArrowSchema::try_from(child.as_ref())?]
}
+ DataType::Union(fields, _, _) => fields
+ .iter()
+ .map(FFI_ArrowSchema::try_from)
+ .collect::<Result<Vec<_>>>()?,
DataType::Struct(fields) => fields
.iter()
.map(FFI_ArrowSchema::try_from)
@@ -279,6 +328,13 @@ fn get_format_string(dtype: &DataType) -> Result<String> {
DataType::Struct(_) => Ok("+s".to_string()),
DataType::Map(_, _) => Ok("+m".to_string()),
DataType::Dictionary(key_data_type, _) =>
get_format_string(key_data_type),
+ DataType::Union(_, type_ids, mode) => {
+ let formats = type_ids.iter().map(|t|
t.to_string()).collect::<Vec<_>>();
+ match mode {
+ UnionMode::Dense => Ok(format!("{}:{}", "+ud",
formats.join(","))),
+ UnionMode::Sparse => Ok(format!("{}:{}", "+us",
formats.join(","))),
+ }
+ }
other => Err(ArrowError::CDataInterface(format!(
"The datatype \"{:?}\" is still not supported in Rust
implementation",
other
diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs
index abb53dff6..5e9b01b5c 100644
--- a/arrow/src/ffi.rs
+++ b/arrow/src/ffi.rs
@@ -120,6 +120,7 @@ use std::{
sync::Arc,
};
+use arrow_schema::UnionMode;
use bitflags::bitflags;
use crate::array::{layout, ArrayData};
@@ -310,8 +311,6 @@ impl Drop for FFI_ArrowSchema {
#[allow(clippy::manual_bits)]
fn bit_width(data_type: &DataType, i: usize) -> Result<usize> {
Ok(match (data_type, i) {
- // the null buffer is bit sized
- (_, 0) => 1,
// primitive types first buffer's size is given by the native types
(DataType::Boolean, 1) => 1,
(DataType::UInt8, 1) => size_of::<u8>() * 8,
@@ -385,6 +384,30 @@ fn bit_width(data_type: &DataType, i: usize) ->
Result<usize> {
data_type, i
)))
}
+ // type ids. UnionArray doesn't have null bitmap so buffer index
begins with 0.
+ (DataType::Union(_, _, _), 0) => size_of::<i8>() * 8,
+ // Only DenseUnion has 2nd buffer
+ (DataType::Union(_, _, UnionMode::Dense), 1) => size_of::<i32>() * 8,
+ (DataType::Union(_, _, UnionMode::Sparse), _) => {
+ return Err(ArrowError::CDataInterface(format!(
+ "The datatype \"{:?}\" expects 1 buffer, but requested {}.
Please verify that the C data interface is correctly implemented.",
+ data_type, i
+ )))
+ }
+ (DataType::Union(_, _, UnionMode::Dense), _) => {
+ return Err(ArrowError::CDataInterface(format!(
+ "The datatype \"{:?}\" expects 2 buffer, but requested {}.
Please verify that the C data interface is correctly implemented.",
+ data_type, i
+ )))
+ }
+ (_, 0) => {
+ // We don't call this `bit_width` to compute buffer length for
null buffer. If any types that don't have null buffer like
+ // UnionArray, they should be handled above.
+ return Err(ArrowError::CDataInterface(format!(
+ "The datatype \"{:?}\" doesn't expect buffer at index 0.
Please verify that the C data interface is correctly implemented.",
+ data_type
+ )))
+ }
_ => {
return Err(ArrowError::CDataInterface(format!(
"The datatype \"{:?}\" is still not supported in Rust
implementation",
@@ -661,7 +684,8 @@ pub trait ArrowArrayRef {
})
}
- /// returns all buffers, as organized by Rust (i.e. null buffer is skipped)
+ /// returns all buffers, as organized by Rust (i.e. null buffer is skipped
if it's present
+ /// in the spec of the type)
fn buffers(&self, can_contain_null_mask: bool) -> Result<Vec<Buffer>> {
// + 1: skip null buffer
let buffer_begin = can_contain_null_mask as i64;
@@ -690,9 +714,9 @@ pub trait ArrowArrayRef {
}
/// Returns the length, in bytes, of the buffer `i` (indexed according to
the C data interface)
- // Rust implementation uses fixed-sized buffers, which require knowledge
of their `len`.
- // for variable-sized buffers, such as the second buffer of a stringArray,
we need
- // to fetch offset buffer's len to build the second buffer.
+ /// Rust implementation uses fixed-sized buffers, which require knowledge
of their `len`.
+ /// for variable-sized buffers, such as the second buffer of a
stringArray, we need
+ /// to fetch offset buffer's len to build the second buffer.
fn buffer_len(&self, i: usize) -> Result<usize> {
// Special handling for dictionary type as we only care about the key
type in the case.
let t = self.data_type()?;
@@ -937,6 +961,9 @@ mod tests {
};
use crate::compute::kernels;
use crate::datatypes::{Field, Int8Type};
+ use arrow_array::builder::UnionBuilder;
+ use arrow_array::types::{Float64Type, Int32Type};
+ use arrow_array::{Float64Array, UnionArray};
use std::convert::TryFrom;
#[test]
@@ -1500,4 +1527,131 @@ mod tests {
Ok(())
}
+
+ #[test]
+ fn test_union_sparse_array() -> Result<()> {
+ let mut builder = UnionBuilder::new_sparse();
+ builder.append::<Int32Type>("a", 1).unwrap();
+ builder.append_null::<Int32Type>("a").unwrap();
+ builder.append::<Float64Type>("c", 3.0).unwrap();
+ builder.append::<Int32Type>("a", 4).unwrap();
+ let union = builder.build().unwrap();
+
+ // export it
+ let array = ArrowArray::try_from(union.data().clone())?;
+
+ // (simulate consumer) import it
+ let data = ArrayData::try_from(array)?;
+ let array = make_array(data);
+
+ let array = array.as_any().downcast_ref::<UnionArray>().unwrap();
+
+ let expected_type_ids = vec![0_i8, 0, 1, 0];
+
+ // Check type ids
+ assert_eq!(
+ Buffer::from_slice_ref(&expected_type_ids),
+ array.data().buffers()[0]
+ );
+ for (i, id) in expected_type_ids.iter().enumerate() {
+ assert_eq!(id, &array.type_id(i));
+ }
+
+ // Check offsets, sparse union should only have a single buffer, i.e.
no offsets
+ assert_eq!(array.data().buffers().len(), 1);
+
+ for i in 0..array.len() {
+ let slot = array.value(i);
+ match i {
+ 0 => {
+ let slot =
slot.as_any().downcast_ref::<Int32Array>().unwrap();
+ assert!(!slot.is_null(0));
+ assert_eq!(slot.len(), 1);
+ let value = slot.value(0);
+ assert_eq!(1_i32, value);
+ }
+ 1 => assert!(slot.is_null(0)),
+ 2 => {
+ let slot =
slot.as_any().downcast_ref::<Float64Array>().unwrap();
+ assert!(!slot.is_null(0));
+ assert_eq!(slot.len(), 1);
+ let value = slot.value(0);
+ assert_eq!(value, 3_f64);
+ }
+ 3 => {
+ let slot =
slot.as_any().downcast_ref::<Int32Array>().unwrap();
+ assert!(!slot.is_null(0));
+ assert_eq!(slot.len(), 1);
+ let value = slot.value(0);
+ assert_eq!(4_i32, value);
+ }
+ _ => unreachable!(),
+ }
+ }
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_union_dense_array() -> Result<()> {
+ let mut builder = UnionBuilder::new_dense();
+ builder.append::<Int32Type>("a", 1).unwrap();
+ builder.append_null::<Int32Type>("a").unwrap();
+ builder.append::<Float64Type>("c", 3.0).unwrap();
+ builder.append::<Int32Type>("a", 4).unwrap();
+ let union = builder.build().unwrap();
+
+ // export it
+ let array = ArrowArray::try_from(union.data().clone())?;
+
+ // (simulate consumer) import it
+ let data = ArrayData::try_from(array)?;
+ let array = make_array(data);
+
+ let array = array.as_any().downcast_ref::<UnionArray>().unwrap();
+
+ let expected_type_ids = vec![0_i8, 0, 1, 0];
+
+ // Check type ids
+ assert_eq!(
+ Buffer::from_slice_ref(&expected_type_ids),
+ array.data().buffers()[0]
+ );
+ for (i, id) in expected_type_ids.iter().enumerate() {
+ assert_eq!(id, &array.type_id(i));
+ }
+
+ assert_eq!(array.data().buffers().len(), 2);
+
+ for i in 0..array.len() {
+ let slot = array.value(i);
+ match i {
+ 0 => {
+ let slot =
slot.as_any().downcast_ref::<Int32Array>().unwrap();
+ assert!(!slot.is_null(0));
+ assert_eq!(slot.len(), 1);
+ let value = slot.value(0);
+ assert_eq!(1_i32, value);
+ }
+ 1 => assert!(slot.is_null(0)),
+ 2 => {
+ let slot =
slot.as_any().downcast_ref::<Float64Array>().unwrap();
+ assert!(!slot.is_null(0));
+ assert_eq!(slot.len(), 1);
+ let value = slot.value(0);
+ assert_eq!(value, 3_f64);
+ }
+ 3 => {
+ let slot =
slot.as_any().downcast_ref::<Int32Array>().unwrap();
+ assert!(!slot.is_null(0));
+ assert_eq!(slot.len(), 1);
+ let value = slot.value(0);
+ assert_eq!(4_i32, value);
+ }
+ _ => unreachable!(),
+ }
+ }
+
+ Ok(())
+ }
}