This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new b7248497a4 Improve C Data Interface and Add Integration Testing 
Entrypoints (#5080)
b7248497a4 is described below

commit b7248497a43992a6f8da41b25829766b0867891c
Author: Antoine Pitrou <[email protected]>
AuthorDate: Mon Nov 20 17:20:10 2023 +0100

    Improve C Data Interface and Add Integration Testing Entrypoints (#5080)
    
    * Add C Data Interface integration testing entrypoints
    
    * Allow importing FFI_ArrowArray with existing datatype
    
    * Clippy
    
    * Use ptr::write
    
    * Fix null_count for Null type
    
    * Use new from_raw() APIs
    
    * Address some review comments.
    
    * Add unsafe markers
    
    * Try to fix CI
    
    * Revamp ArrowFile
---
 arrow-data/src/ffi.rs                              |   8 +-
 arrow-integration-testing/Cargo.toml               |   5 +-
 arrow-integration-testing/README.md                |   2 +-
 .../src/bin/arrow-json-integration-test.rs         |  49 +----
 .../flight_client_scenarios/integration_test.rs    |  17 +-
 arrow-integration-testing/src/lib.rs               | 228 +++++++++++++++++++--
 arrow-schema/src/error.rs                          |   6 +
 arrow-schema/src/ffi.rs                            |  10 +-
 arrow/src/array/ffi.rs                             |   2 +-
 arrow/src/ffi.rs                                   | 161 +++++++++------
 arrow/src/ffi_stream.rs                            |   6 +-
 arrow/src/pyarrow.rs                               |   6 +-
 12 files changed, 363 insertions(+), 137 deletions(-)

diff --git a/arrow-data/src/ffi.rs b/arrow-data/src/ffi.rs
index 2b4d526012..589f7dac6d 100644
--- a/arrow-data/src/ffi.rs
+++ b/arrow-data/src/ffi.rs
@@ -168,6 +168,12 @@ impl FFI_ArrowArray {
             .collect::<Box<_>>();
         let n_children = children.len() as i64;
 
+        // As in the IPC format, emit null_count = length for Null type
+        let null_count = match data.data_type() {
+            DataType::Null => data.len(),
+            _ => data.null_count(),
+        };
+
         // create the private data owning everything.
         // any other data must be added here, e.g. via a struct, to track 
lifetime.
         let mut private_data = Box::new(ArrayPrivateData {
@@ -179,7 +185,7 @@ impl FFI_ArrowArray {
 
         Self {
             length: data.len() as i64,
-            null_count: data.null_count() as i64,
+            null_count: null_count as i64,
             offset: data.offset() as i64,
             n_buffers,
             n_children,
diff --git a/arrow-integration-testing/Cargo.toml 
b/arrow-integration-testing/Cargo.toml
index 86c2cb27d2..c29860f09d 100644
--- a/arrow-integration-testing/Cargo.toml
+++ b/arrow-integration-testing/Cargo.toml
@@ -27,11 +27,14 @@ edition = { workspace = true }
 publish = false
 rust-version = { workspace = true }
 
+[lib]
+crate-type = ["lib", "cdylib"]
+
 [features]
 logging = ["tracing-subscriber"]
 
 [dependencies]
-arrow = { path = "../arrow", default-features = false, features = 
["test_utils", "ipc", "ipc_compression", "json"] }
+arrow = { path = "../arrow", default-features = false, features = 
["test_utils", "ipc", "ipc_compression", "json", "ffi"] }
 arrow-flight = { path = "../arrow-flight", default-features = false }
 arrow-buffer = { path = "../arrow-buffer", default-features = false }
 arrow-integration-test = { path = "../arrow-integration-test", 
default-features = false }
diff --git a/arrow-integration-testing/README.md 
b/arrow-integration-testing/README.md
index e82591e6b1..dcf39c27fb 100644
--- a/arrow-integration-testing/README.md
+++ b/arrow-integration-testing/README.md
@@ -48,7 +48,7 @@ ln -s <path_to_arrow_rs> arrow/rust
 
 ```shell
 cd arrow
-pip install -e dev/archery[docker]
+pip install -e dev/archery[integration]
 ```
 
 ### Build the C++ binaries:
diff --git a/arrow-integration-testing/src/bin/arrow-json-integration-test.rs 
b/arrow-integration-testing/src/bin/arrow-json-integration-test.rs
index 187d987a5a..9f1abb16a6 100644
--- a/arrow-integration-testing/src/bin/arrow-json-integration-test.rs
+++ b/arrow-integration-testing/src/bin/arrow-json-integration-test.rs
@@ -15,16 +15,13 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use arrow::datatypes::{DataType, Field};
-use arrow::datatypes::{Fields, Schema};
 use arrow::error::{ArrowError, Result};
 use arrow::ipc::reader::FileReader;
 use arrow::ipc::writer::FileWriter;
 use arrow_integration_test::*;
-use arrow_integration_testing::read_json_file;
+use arrow_integration_testing::{canonicalize_schema, open_json_file};
 use clap::Parser;
 use std::fs::File;
-use std::sync::Arc;
 
 #[derive(clap::ValueEnum, Debug, Clone)]
 #[clap(rename_all = "SCREAMING_SNAKE_CASE")]
@@ -66,12 +63,12 @@ fn json_to_arrow(json_name: &str, arrow_name: &str, 
verbose: bool) -> Result<()>
         eprintln!("Converting {json_name} to {arrow_name}");
     }
 
-    let json_file = read_json_file(json_name)?;
+    let json_file = open_json_file(json_name)?;
 
     let arrow_file = File::create(arrow_name)?;
     let mut writer = FileWriter::try_new(arrow_file, &json_file.schema)?;
 
-    for b in json_file.batches {
+    for b in json_file.read_batches()? {
         writer.write(&b)?;
     }
 
@@ -113,49 +110,13 @@ fn arrow_to_json(arrow_name: &str, json_name: &str, 
verbose: bool) -> Result<()>
     Ok(())
 }
 
-fn canonicalize_schema(schema: &Schema) -> Schema {
-    let fields = schema
-        .fields()
-        .iter()
-        .map(|field| match field.data_type() {
-            DataType::Map(child_field, sorted) => match 
child_field.data_type() {
-                DataType::Struct(fields) if fields.len() == 2 => {
-                    let first_field = fields.get(0).unwrap();
-                    let key_field =
-                        Arc::new(Field::new("key", 
first_field.data_type().clone(), false));
-                    let second_field = fields.get(1).unwrap();
-                    let value_field = Arc::new(Field::new(
-                        "value",
-                        second_field.data_type().clone(),
-                        second_field.is_nullable(),
-                    ));
-
-                    let fields = Fields::from([key_field, value_field]);
-                    let struct_type = DataType::Struct(fields);
-                    let child_field = Field::new("entries", struct_type, 
false);
-
-                    Arc::new(Field::new(
-                        field.name().as_str(),
-                        DataType::Map(Arc::new(child_field), *sorted),
-                        field.is_nullable(),
-                    ))
-                }
-                _ => panic!("The child field of Map type should be Struct type 
with 2 fields."),
-            },
-            _ => field.clone(),
-        })
-        .collect::<Fields>();
-
-    Schema::new(fields).with_metadata(schema.metadata().clone())
-}
-
 fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> {
     if verbose {
         eprintln!("Validating {arrow_name} and {json_name}");
     }
 
     // open JSON file
-    let json_file = read_json_file(json_name)?;
+    let json_file = open_json_file(json_name)?;
 
     // open Arrow file
     let arrow_file = File::open(arrow_name)?;
@@ -170,7 +131,7 @@ fn validate(arrow_name: &str, json_name: &str, verbose: 
bool) -> Result<()> {
         )));
     }
 
-    let json_batches = &json_file.batches;
+    let json_batches = json_file.read_batches()?;
 
     // compare number of batches
     assert!(
diff --git 
a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs 
b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs
index 81cc4bbe8e..c6b5a72ca6 100644
--- a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs
+++ b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::{read_json_file, ArrowFile};
+use crate::open_json_file;
 use std::collections::HashMap;
 
 use arrow::{
@@ -45,23 +45,16 @@ pub async fn run_scenario(host: &str, port: u16, path: 
&str) -> Result {
 
     let client = FlightServiceClient::connect(url).await?;
 
-    let ArrowFile {
-        schema, batches, ..
-    } = read_json_file(path)?;
+    let json_file = open_json_file(path)?;
 
-    let schema = Arc::new(schema);
+    let batches = json_file.read_batches()?;
+    let schema = Arc::new(json_file.schema);
 
     let mut descriptor = FlightDescriptor::default();
     descriptor.set_type(DescriptorType::Path);
     descriptor.path = vec![path.to_string()];
 
-    upload_data(
-        client.clone(),
-        schema.clone(),
-        descriptor.clone(),
-        batches.clone(),
-    )
-    .await?;
+    upload_data(client.clone(), schema, descriptor.clone(), 
batches.clone()).await?;
     verify_data(client, descriptor, &batches).await?;
 
     Ok(())
diff --git a/arrow-integration-testing/src/lib.rs 
b/arrow-integration-testing/src/lib.rs
index 2d76be3495..553e69b0a1 100644
--- a/arrow-integration-testing/src/lib.rs
+++ b/arrow-integration-testing/src/lib.rs
@@ -19,14 +19,20 @@
 
 use serde_json::Value;
 
-use arrow::datatypes::Schema;
-use arrow::error::Result;
+use arrow::array::{Array, StructArray};
+use arrow::datatypes::{DataType, Field, Fields, Schema};
+use arrow::error::{ArrowError, Result};
+use arrow::ffi::{from_ffi_and_data_type, FFI_ArrowArray, FFI_ArrowSchema};
 use arrow::record_batch::RecordBatch;
 use arrow::util::test_util::arrow_test_data;
 use arrow_integration_test::*;
 use std::collections::HashMap;
+use std::ffi::{c_int, CStr, CString};
 use std::fs::File;
 use std::io::BufReader;
+use std::iter::zip;
+use std::ptr;
+use std::sync::Arc;
 
 /// The expected username for the basic auth integration test.
 pub const AUTH_USERNAME: &str = "arrow";
@@ -40,11 +46,68 @@ pub struct ArrowFile {
     pub schema: Schema,
     // we can evolve this into a concrete Arrow type
     // this is temporarily not being read from
-    pub _dictionaries: HashMap<i64, ArrowJsonDictionaryBatch>,
-    pub batches: Vec<RecordBatch>,
+    dictionaries: HashMap<i64, ArrowJsonDictionaryBatch>,
+    arrow_json: Value,
 }
 
-pub fn read_json_file(json_name: &str) -> Result<ArrowFile> {
+impl ArrowFile {
+    pub fn read_batch(&self, batch_num: usize) -> Result<RecordBatch> {
+        let b = self.arrow_json["batches"].get(batch_num).unwrap();
+        let json_batch: ArrowJsonBatch = 
serde_json::from_value(b.clone()).unwrap();
+        record_batch_from_json(&self.schema, json_batch, 
Some(&self.dictionaries))
+    }
+
+    pub fn read_batches(&self) -> Result<Vec<RecordBatch>> {
+        self.arrow_json["batches"]
+            .as_array()
+            .unwrap()
+            .iter()
+            .map(|b| {
+                let json_batch: ArrowJsonBatch = 
serde_json::from_value(b.clone()).unwrap();
+                record_batch_from_json(&self.schema, json_batch, 
Some(&self.dictionaries))
+            })
+            .collect()
+    }
+}
+
+// Canonicalize the names of map fields in a schema
+pub fn canonicalize_schema(schema: &Schema) -> Schema {
+    let fields = schema
+        .fields()
+        .iter()
+        .map(|field| match field.data_type() {
+            DataType::Map(child_field, sorted) => match 
child_field.data_type() {
+                DataType::Struct(fields) if fields.len() == 2 => {
+                    let first_field = fields.get(0).unwrap();
+                    let key_field =
+                        Arc::new(Field::new("key", 
first_field.data_type().clone(), false));
+                    let second_field = fields.get(1).unwrap();
+                    let value_field = Arc::new(Field::new(
+                        "value",
+                        second_field.data_type().clone(),
+                        second_field.is_nullable(),
+                    ));
+
+                    let fields = Fields::from([key_field, value_field]);
+                    let struct_type = DataType::Struct(fields);
+                    let child_field = Field::new("entries", struct_type, 
false);
+
+                    Arc::new(Field::new(
+                        field.name().as_str(),
+                        DataType::Map(Arc::new(child_field), *sorted),
+                        field.is_nullable(),
+                    ))
+                }
+                _ => panic!("The child field of Map type should be Struct type 
with 2 fields."),
+            },
+            _ => field.clone(),
+        })
+        .collect::<Fields>();
+
+    Schema::new(fields).with_metadata(schema.metadata().clone())
+}
+
+pub fn open_json_file(json_name: &str) -> Result<ArrowFile> {
     let json_file = File::open(json_name)?;
     let reader = BufReader::new(json_file);
     let arrow_json: Value = serde_json::from_reader(reader).unwrap();
@@ -62,17 +125,10 @@ pub fn read_json_file(json_name: &str) -> 
Result<ArrowFile> {
             dictionaries.insert(json_dict.id, json_dict);
         }
     }
-
-    let mut batches = vec![];
-    for b in arrow_json["batches"].as_array().unwrap() {
-        let json_batch: ArrowJsonBatch = 
serde_json::from_value(b.clone()).unwrap();
-        let batch = record_batch_from_json(&schema, json_batch, 
Some(&dictionaries))?;
-        batches.push(batch);
-    }
     Ok(ArrowFile {
         schema,
-        _dictionaries: dictionaries,
-        batches,
+        dictionaries,
+        arrow_json,
     })
 }
 
@@ -100,3 +156,147 @@ pub fn read_gzip_json(version: &str, path: &str) -> 
ArrowJson {
     let arrow_json: ArrowJson = serde_json::from_str(&s).unwrap();
     arrow_json
 }
+
+//
+// C Data Integration entrypoints
+//
+
+fn cdata_integration_export_schema_from_json(
+    c_json_name: *const i8,
+    out: *mut FFI_ArrowSchema,
+) -> Result<()> {
+    let json_name = unsafe { CStr::from_ptr(c_json_name) };
+    let f = open_json_file(json_name.to_str()?)?;
+    let c_schema = FFI_ArrowSchema::try_from(&f.schema)?;
+    // Move exported schema into output struct
+    unsafe { ptr::write(out, c_schema) };
+    Ok(())
+}
+
+fn cdata_integration_export_batch_from_json(
+    c_json_name: *const i8,
+    batch_num: c_int,
+    out: *mut FFI_ArrowArray,
+) -> Result<()> {
+    let json_name = unsafe { CStr::from_ptr(c_json_name) };
+    let b = 
open_json_file(json_name.to_str()?)?.read_batch(batch_num.try_into().unwrap())?;
+    let a = StructArray::from(b).into_data();
+    let c_array = FFI_ArrowArray::new(&a);
+    // Move exported array into output struct
+    unsafe { ptr::write(out, c_array) };
+    Ok(())
+}
+
+fn cdata_integration_import_schema_and_compare_to_json(
+    c_json_name: *const i8,
+    c_schema: *mut FFI_ArrowSchema,
+) -> Result<()> {
+    let json_name = unsafe { CStr::from_ptr(c_json_name) };
+    let json_schema = open_json_file(json_name.to_str()?)?.schema;
+
+    // The source ArrowSchema will be released when this is dropped
+    let imported_schema = unsafe { FFI_ArrowSchema::from_raw(c_schema) };
+    let imported_schema = Schema::try_from(&imported_schema)?;
+
+    // compare schemas
+    if canonicalize_schema(&json_schema) != 
canonicalize_schema(&imported_schema) {
+        return Err(ArrowError::ComputeError(format!(
+            "Schemas do not match.\n- JSON: {:?}\n- Imported: {:?}",
+            json_schema, imported_schema
+        )));
+    }
+    Ok(())
+}
+
+fn compare_batches(a: &RecordBatch, b: &RecordBatch) -> Result<()> {
+    if a.num_columns() != b.num_columns() {
+        return Err(ArrowError::InvalidArgumentError(
+            "batches do not have the same number of columns".to_string(),
+        ));
+    }
+    for (a_column, b_column) in zip(a.columns(), b.columns()) {
+        if a_column != b_column {
+            return Err(ArrowError::InvalidArgumentError(
+                "batch columns are not the same".to_string(),
+            ));
+        }
+    }
+    Ok(())
+}
+
+fn cdata_integration_import_batch_and_compare_to_json(
+    c_json_name: *const i8,
+    batch_num: c_int,
+    c_array: *mut FFI_ArrowArray,
+) -> Result<()> {
+    let json_name = unsafe { CStr::from_ptr(c_json_name) };
+    let json_batch =
+        
open_json_file(json_name.to_str()?)?.read_batch(batch_num.try_into().unwrap())?;
+    let schema = json_batch.schema();
+
+    let data_type_for_import = DataType::Struct(schema.fields.clone());
+    let imported_array = unsafe { FFI_ArrowArray::from_raw(c_array) };
+    let imported_array = unsafe { from_ffi_and_data_type(imported_array, 
data_type_for_import) }?;
+    imported_array.validate_full()?;
+    let imported_batch = RecordBatch::from(StructArray::from(imported_array));
+
+    compare_batches(&json_batch, &imported_batch)
+}
+
+// If Result is an error, then export a const char* to its string display, 
otherwise NULL
+fn result_to_c_error<T, E: std::fmt::Display>(result: &std::result::Result<T, 
E>) -> *mut i8 {
+    match result {
+        Ok(_) => ptr::null_mut(),
+        Err(e) => CString::new(format!("{}", e)).unwrap().into_raw(),
+    }
+}
+
+/// Release a const char* exported by result_to_c_error()
+///
+/// # Safety
+///
+/// The pointer is assumed to have been obtained using CString::into_raw.
+#[no_mangle]
+pub unsafe extern "C" fn arrow_rs_free_error(c_error: *mut i8) {
+    if !c_error.is_null() {
+        drop(unsafe { CString::from_raw(c_error) });
+    }
+}
+
+#[no_mangle]
+pub extern "C" fn arrow_rs_cdata_integration_export_schema_from_json(
+    c_json_name: *const i8,
+    out: *mut FFI_ArrowSchema,
+) -> *mut i8 {
+    let r = cdata_integration_export_schema_from_json(c_json_name, out);
+    result_to_c_error(&r)
+}
+
+#[no_mangle]
+pub extern "C" fn arrow_rs_cdata_integration_import_schema_and_compare_to_json(
+    c_json_name: *const i8,
+    c_schema: *mut FFI_ArrowSchema,
+) -> *mut i8 {
+    let r = cdata_integration_import_schema_and_compare_to_json(c_json_name, 
c_schema);
+    result_to_c_error(&r)
+}
+
+#[no_mangle]
+pub extern "C" fn arrow_rs_cdata_integration_export_batch_from_json(
+    c_json_name: *const i8,
+    batch_num: c_int,
+    out: *mut FFI_ArrowArray,
+) -> *mut i8 {
+    let r = cdata_integration_export_batch_from_json(c_json_name, batch_num, 
out);
+    result_to_c_error(&r)
+}
+
+#[no_mangle]
+pub extern "C" fn arrow_rs_cdata_integration_import_batch_and_compare_to_json(
+    c_json_name: *const i8,
+    batch_num: c_int,
+    c_array: *mut FFI_ArrowArray,
+) -> *mut i8 {
+    let r = cdata_integration_import_batch_and_compare_to_json(c_json_name, 
batch_num, c_array);
+    result_to_c_error(&r)
+}
diff --git a/arrow-schema/src/error.rs b/arrow-schema/src/error.rs
index 8ea533db89..b7bf8d6e12 100644
--- a/arrow-schema/src/error.rs
+++ b/arrow-schema/src/error.rs
@@ -58,6 +58,12 @@ impl From<std::io::Error> for ArrowError {
     }
 }
 
+impl From<std::str::Utf8Error> for ArrowError {
+    fn from(error: std::str::Utf8Error) -> Self {
+        ArrowError::ParseError(error.to_string())
+    }
+}
+
 impl From<std::string::FromUtf8Error> for ArrowError {
     fn from(error: std::string::FromUtf8Error) -> Self {
         ArrowError::ParseError(error.to_string())
diff --git a/arrow-schema/src/ffi.rs b/arrow-schema/src/ffi.rs
index b4d10b814a..8a18c77ea2 100644
--- a/arrow-schema/src/ffi.rs
+++ b/arrow-schema/src/ffi.rs
@@ -34,7 +34,9 @@
 //! assert_eq!(schema, back);
 //! ```
 
-use crate::{ArrowError, DataType, Field, FieldRef, Schema, TimeUnit, 
UnionFields, UnionMode};
+use crate::{
+    ArrowError, DataType, Field, FieldRef, IntervalUnit, Schema, TimeUnit, 
UnionFields, UnionMode,
+};
 use std::sync::Arc;
 use std::{
     collections::HashMap,
@@ -402,6 +404,9 @@ impl TryFrom<&FFI_ArrowSchema> for DataType {
             "tDm" => DataType::Duration(TimeUnit::Millisecond),
             "tDu" => DataType::Duration(TimeUnit::Microsecond),
             "tDn" => DataType::Duration(TimeUnit::Nanosecond),
+            "tiM" => DataType::Interval(IntervalUnit::YearMonth),
+            "tiD" => DataType::Interval(IntervalUnit::DayTime),
+            "tin" => DataType::Interval(IntervalUnit::MonthDayNano),
             "+l" => {
                 let c_child = c_schema.child(0);
                 DataType::List(Arc::new(Field::try_from(c_child)?))
@@ -669,6 +674,9 @@ fn get_format_string(dtype: &DataType) -> Result<String, 
ArrowError> {
         DataType::Duration(TimeUnit::Millisecond) => Ok("tDm".to_string()),
         DataType::Duration(TimeUnit::Microsecond) => Ok("tDu".to_string()),
         DataType::Duration(TimeUnit::Nanosecond) => Ok("tDn".to_string()),
+        DataType::Interval(IntervalUnit::YearMonth) => Ok("tiM".to_string()),
+        DataType::Interval(IntervalUnit::DayTime) => Ok("tiD".to_string()),
+        DataType::Interval(IntervalUnit::MonthDayNano) => 
Ok("tin".to_string()),
         DataType::List(_) => Ok("+l".to_string()),
         DataType::LargeList(_) => Ok("+L".to_string()),
         DataType::Struct(_) => Ok("+s".to_string()),
diff --git a/arrow/src/array/ffi.rs b/arrow/src/array/ffi.rs
index e05c256d01..d4d95a6e17 100644
--- a/arrow/src/array/ffi.rs
+++ b/arrow/src/array/ffi.rs
@@ -70,7 +70,7 @@ mod tests {
         let schema = FFI_ArrowSchema::try_from(expected.data_type())?;
 
         // simulate an external consumer by being the consumer
-        let result = &from_ffi(array, &schema)?;
+        let result = &unsafe { from_ffi(array, &schema) }?;
 
         assert_eq!(result, expected);
         Ok(())
diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs
index c13d4c6e5d..31388bf993 100644
--- a/arrow/src/ffi.rs
+++ b/arrow/src/ffi.rs
@@ -43,7 +43,7 @@
 //! let (out_array, out_schema) = to_ffi(&data)?;
 //!
 //! // import it
-//! let data = from_ffi(out_array, &out_schema)?;
+//! let data = unsafe { from_ffi(out_array, &out_schema) }?;
 //! let array = Int32Array::from(data);
 //!
 //! // perform some operation
@@ -80,7 +80,7 @@
 //!     let mut schema = FFI_ArrowSchema::empty();
 //!     let mut array = FFI_ArrowArray::empty();
 //!     foreign.export_to_c(addr_of_mut!(array), addr_of_mut!(schema));
-//!     Ok(make_array(from_ffi(array, &schema)?))
+//!     Ok(make_array(unsafe { from_ffi(array, &schema) }?))
 //! }
 //! ```
 
@@ -108,6 +108,7 @@ use std::{mem::size_of, ptr::NonNull, sync::Arc};
 
 pub use arrow_data::ffi::FFI_ArrowArray;
 pub use arrow_schema::ffi::{FFI_ArrowSchema, Flags};
+
 use arrow_schema::UnionMode;
 
 use crate::array::{layout, ArrayData};
@@ -233,32 +234,53 @@ pub fn to_ffi(data: &ArrayData) -> 
Result<(FFI_ArrowArray, FFI_ArrowSchema)> {
 /// # Safety
 ///
 /// This struct assumes that the incoming data agrees with the C data 
interface.
-pub fn from_ffi(array: FFI_ArrowArray, schema: &FFI_ArrowSchema) -> 
Result<ArrayData> {
+pub unsafe fn from_ffi(array: FFI_ArrowArray, schema: &FFI_ArrowSchema) -> 
Result<ArrayData> {
+    let dt = DataType::try_from(schema)?;
     let array = Arc::new(array);
-    let tmp = ArrowArray {
+    let tmp = ImportedArrowArray {
         array: &array,
-        schema,
+        data_type: dt,
+        owner: &array,
+    };
+    tmp.consume()
+}
+
+/// Import [ArrayData] from the C Data Interface
+///
+/// # Safety
+///
+/// This struct assumes that the incoming data agrees with the C data 
interface.
+pub unsafe fn from_ffi_and_data_type(
+    array: FFI_ArrowArray,
+    data_type: DataType,
+) -> Result<ArrayData> {
+    let array = Arc::new(array);
+    let tmp = ImportedArrowArray {
+        array: &array,
+        data_type,
         owner: &array,
     };
     tmp.consume()
 }
 
 #[derive(Debug)]
-struct ArrowArray<'a> {
+struct ImportedArrowArray<'a> {
     array: &'a FFI_ArrowArray,
-    schema: &'a FFI_ArrowSchema,
+    data_type: DataType,
     owner: &'a Arc<FFI_ArrowArray>,
 }
 
-impl<'a> ArrowArray<'a> {
+impl<'a> ImportedArrowArray<'a> {
     fn consume(self) -> Result<ArrayData> {
-        let dt = DataType::try_from(self.schema)?;
         let len = self.array.len();
         let offset = self.array.offset();
-        let null_count = self.array.null_count();
+        let null_count = match &self.data_type {
+            DataType::Null => 0,
+            _ => self.array.null_count(),
+        };
 
-        let data_layout = layout(&dt);
-        let buffers = self.buffers(data_layout.can_contain_null_mask, &dt)?;
+        let data_layout = layout(&self.data_type);
+        let buffers = self.buffers(data_layout.can_contain_null_mask)?;
 
         let null_bit_buffer = if data_layout.can_contain_null_mask {
             self.null_bit_buffer()
@@ -266,14 +288,9 @@ impl<'a> ArrowArray<'a> {
             None
         };
 
-        let mut child_data = (0..self.array.num_children())
-            .map(|i| {
-                let child = self.child(i);
-                child.consume()
-            })
-            .collect::<Result<Vec<_>>>()?;
+        let mut child_data = self.consume_children()?;
 
-        if let Some(d) = self.dictionary() {
+        if let Some(d) = self.dictionary()? {
             // For dictionary type there should only be a single child, so we 
don't need to worry if
             // there are other children added above.
             assert!(child_data.is_empty());
@@ -283,7 +300,7 @@ impl<'a> ArrowArray<'a> {
         // Should FFI be checking validity?
         Ok(unsafe {
             ArrayData::new_unchecked(
-                dt,
+                self.data_type,
                 len,
                 Some(null_count),
                 null_bit_buffer,
@@ -294,14 +311,49 @@ impl<'a> ArrowArray<'a> {
         })
     }
 
+    fn consume_children(&self) -> Result<Vec<ArrayData>> {
+        match &self.data_type {
+            DataType::List(field)
+            | DataType::FixedSizeList(field, _)
+            | DataType::LargeList(field)
+            | DataType::Map(field, _) => Ok([self.consume_child(0, 
field.data_type())?].to_vec()),
+            DataType::Struct(fields) => {
+                assert!(fields.len() == self.array.num_children());
+                fields
+                    .iter()
+                    .enumerate()
+                    .map(|(i, field)| self.consume_child(i, field.data_type()))
+                    .collect::<Result<Vec<_>>>()
+            }
+            DataType::Union(union_fields, _) => {
+                assert!(union_fields.len() == self.array.num_children());
+                union_fields
+                    .iter()
+                    .enumerate()
+                    .map(|(i, (_, field))| self.consume_child(i, 
field.data_type()))
+                    .collect::<Result<Vec<_>>>()
+            }
+            _ => Ok(Vec::new()),
+        }
+    }
+
+    fn consume_child(&self, index: usize, child_type: &DataType) -> 
Result<ArrayData> {
+        ImportedArrowArray {
+            array: self.array.child(index),
+            data_type: child_type.clone(),
+            owner: self.owner,
+        }
+        .consume()
+    }
+
     /// returns all buffers, as organized by Rust (i.e. null buffer is skipped 
if it's present
     /// in the spec of the type)
-    fn buffers(&self, can_contain_null_mask: bool, dt: &DataType) -> 
Result<Vec<Buffer>> {
+    fn buffers(&self, can_contain_null_mask: bool) -> Result<Vec<Buffer>> {
         // + 1: skip null buffer
         let buffer_begin = can_contain_null_mask as usize;
         (buffer_begin..self.array.num_buffers())
             .map(|index| {
-                let len = self.buffer_len(index, dt)?;
+                let len = self.buffer_len(index, &self.data_type)?;
 
                 match unsafe { create_buffer(self.owner.clone(), self.array, 
index, len) } {
                     Some(buf) => Ok(buf),
@@ -388,25 +440,20 @@ impl<'a> ArrowArray<'a> {
         unsafe { create_buffer(self.owner.clone(), self.array, 0, buffer_len) }
     }
 
-    fn child(&self, index: usize) -> ArrowArray {
-        ArrowArray {
-            array: self.array.child(index),
-            schema: self.schema.child(index),
-            owner: self.owner,
-        }
-    }
-
-    fn dictionary(&self) -> Option<ArrowArray> {
-        match (self.array.dictionary(), self.schema.dictionary()) {
-            (Some(array), Some(schema)) => Some(ArrowArray {
+    fn dictionary(&self) -> Result<Option<ImportedArrowArray>> {
+        match (self.array.dictionary(), &self.data_type) {
+            (Some(array), DataType::Dictionary(_, value_type)) => 
Ok(Some(ImportedArrowArray {
                 array,
-                schema,
+                data_type: value_type.as_ref().clone(),
                 owner: self.owner,
-            }),
-            (None, None) => None,
-            _ => panic!(
-                "Dictionary should both be set or not set in FFI_ArrowArray 
and FFI_ArrowSchema"
-            ),
+            })),
+            (Some(_), _) => Err(ArrowError::CDataInterface(
+                "Got dictionary in FFI_ArrowArray for non-dictionary data 
type".to_string(),
+            )),
+            (None, DataType::Dictionary(_, _)) => 
Err(ArrowError::CDataInterface(
+                "Missing dictionary in FFI_ArrowArray for dictionary data 
type".to_string(),
+            )),
+            (_, _) => Ok(None),
         }
     }
 }
@@ -443,7 +490,7 @@ mod tests {
         let (array, schema) = to_ffi(&array.into_data()).unwrap();
 
         // (simulate consumer) import it
-        let array = Int32Array::from(from_ffi(array, &schema).unwrap());
+        let array = Int32Array::from(unsafe { from_ffi(array, &schema) 
}.unwrap());
         let array = kernels::numeric::add(&array, &array).unwrap();
 
         // verify
@@ -487,7 +534,7 @@ mod tests {
         let (array, schema) = to_ffi(&array.to_data())?;
 
         // (simulate consumer) import it
-        let data = from_ffi(array, &schema)?;
+        let data = unsafe { from_ffi(array, &schema) }?;
         let array = make_array(data);
 
         // perform some operation
@@ -517,7 +564,7 @@ mod tests {
         let (array, schema) = to_ffi(&original_array.to_data())?;
 
         // (simulate consumer) import it
-        let data = from_ffi(array, &schema)?;
+        let data = unsafe { from_ffi(array, &schema) }?;
         let array = make_array(data);
 
         // perform some operation
@@ -539,7 +586,7 @@ mod tests {
         let (array, schema) = to_ffi(&array.to_data())?;
 
         // (simulate consumer) import it
-        let data = from_ffi(array, &schema)?;
+        let data = unsafe { from_ffi(array, &schema) }?;
         let array = make_array(data);
 
         // perform some operation
@@ -608,7 +655,7 @@ mod tests {
         let (array, schema) = to_ffi(&array.to_data())?;
 
         // (simulate consumer) import it
-        let data = from_ffi(array, &schema)?;
+        let data = unsafe { from_ffi(array, &schema) }?;
         let array = make_array(data);
 
         // downcast
@@ -648,7 +695,7 @@ mod tests {
         let (array, schema) = to_ffi(&array.to_data())?;
 
         // (simulate consumer) import it
-        let data = from_ffi(array, &schema)?;
+        let data = unsafe { from_ffi(array, &schema) }?;
         let array = make_array(data);
 
         // perform some operation
@@ -693,7 +740,7 @@ mod tests {
         let (array, schema) = to_ffi(&array.to_data())?;
 
         // (simulate consumer) import it
-        let data = from_ffi(array, &schema)?;
+        let data = unsafe { from_ffi(array, &schema) }?;
         let array = make_array(data);
 
         // perform some operation
@@ -719,7 +766,7 @@ mod tests {
         let (array, schema) = to_ffi(&array.to_data())?;
 
         // (simulate consumer) import it
-        let data = from_ffi(array, &schema)?;
+        let data = unsafe { from_ffi(array, &schema) }?;
         let array = make_array(data);
 
         // perform some operation
@@ -748,7 +795,7 @@ mod tests {
         let (array, schema) = to_ffi(&array.to_data())?;
 
         // (simulate consumer) import it
-        let data = from_ffi(array, &schema)?;
+        let data = unsafe { from_ffi(array, &schema) }?;
         let array = make_array(data);
 
         // perform some operation
@@ -784,7 +831,7 @@ mod tests {
         let (array, schema) = to_ffi(&array.to_data())?;
 
         // (simulate consumer) import it
-        let data = from_ffi(array, &schema)?;
+        let data = unsafe { from_ffi(array, &schema) }?;
         let array = make_array(data);
 
         // perform some operation
@@ -845,7 +892,7 @@ mod tests {
         let (array, schema) = to_ffi(&list_data)?;
 
         // (simulate consumer) import it
-        let data = from_ffi(array, &schema)?;
+        let data = unsafe { from_ffi(array, &schema) }?;
         let array = make_array(data);
 
         // perform some operation
@@ -890,7 +937,7 @@ mod tests {
         let (array, schema) = to_ffi(&dict_array.to_data())?;
 
         // (simulate consumer) import it
-        let data = from_ffi(array, &schema)?;
+        let data = unsafe { from_ffi(array, &schema) }?;
         let array = make_array(data);
 
         // perform some operation
@@ -928,7 +975,7 @@ mod tests {
         }
 
         // (simulate consumer) import it
-        let data = from_ffi(out_array, &out_schema)?;
+        let data = unsafe { from_ffi(out_array, &out_schema) }?;
         let array = make_array(data);
 
         // perform some operation
@@ -949,7 +996,7 @@ mod tests {
         let (array, schema) = to_ffi(&array.to_data())?;
 
         // (simulate consumer) import it
-        let data = from_ffi(array, &schema)?;
+        let data = unsafe { from_ffi(array, &schema) }?;
         let array = make_array(data);
 
         // perform some operation
@@ -986,7 +1033,7 @@ mod tests {
         let (array, schema) = to_ffi(&map_array.to_data())?;
 
         // (simulate consumer) import it
-        let data = from_ffi(array, &schema)?;
+        let data = unsafe { from_ffi(array, &schema) }?;
         let array = make_array(data);
 
         // perform some operation
@@ -1009,7 +1056,7 @@ mod tests {
         let (array, schema) = to_ffi(&struct_array.to_data())?;
 
         // (simulate consumer) import it
-        let data = from_ffi(array, &schema)?;
+        let data = unsafe { from_ffi(array, &schema) }?;
         let array = make_array(data);
 
         // perform some operation
@@ -1033,7 +1080,7 @@ mod tests {
         let (array, schema) = to_ffi(&union.to_data())?;
 
         // (simulate consumer) import it
-        let data = from_ffi(array, &schema)?;
+        let data = unsafe { from_ffi(array, &schema) }?;
         let array = make_array(data);
 
         let array = array.as_any().downcast_ref::<UnionArray>().unwrap();
@@ -1094,7 +1141,7 @@ mod tests {
         let (array, schema) = to_ffi(&union.to_data())?;
 
         // (simulate consumer) import it
-        let data = from_ffi(array, &schema)?;
+        let data = unsafe { from_ffi(array, &schema) }?;
         let array = UnionArray::from(data);
 
         let expected_type_ids = vec![0_i8, 0, 1, 0];
diff --git a/arrow/src/ffi_stream.rs b/arrow/src/ffi_stream.rs
index 123669aa61..bbec71e883 100644
--- a/arrow/src/ffi_stream.rs
+++ b/arrow/src/ffi_stream.rs
@@ -357,9 +357,11 @@ impl Iterator for ArrowArrayStreamReader {
             }
 
             let schema_ref = self.schema();
+            // NOTE: this parses the FFI_ArrowSchema again on each iterator 
call;
+            // should probably use from_ffi_and_data_type() instead.
             let schema = FFI_ArrowSchema::try_from(schema_ref.as_ref()).ok()?;
 
-            let data = from_ffi(array, &schema).ok()?;
+            let data = unsafe { from_ffi(array, &schema) }.ok()?;
 
             let record_batch = RecordBatch::from(StructArray::from(data));
 
@@ -464,7 +466,7 @@ mod tests {
                 break;
             }
 
-            let array = from_ffi(ffi_array, &ffi_schema).unwrap();
+            let array = unsafe { from_ffi(ffi_array, &ffi_schema) }.unwrap();
 
             let record_batch = RecordBatch::from(StructArray::from(array));
             produced_batches.push(record_batch);
diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs
index 2ac550ad04..8302f8741b 100644
--- a/arrow/src/pyarrow.rs
+++ b/arrow/src/pyarrow.rs
@@ -267,7 +267,7 @@ impl FromPyArrow for ArrayData {
 
             let schema_ptr = unsafe { 
schema_capsule.reference::<FFI_ArrowSchema>() };
             let array = unsafe { 
FFI_ArrowArray::from_raw(array_capsule.pointer() as _) };
-            return ffi::from_ffi(array, schema_ptr).map_err(to_py_err);
+            return unsafe { ffi::from_ffi(array, schema_ptr) 
}.map_err(to_py_err);
         }
 
         validate_class("Array", value)?;
@@ -287,7 +287,7 @@ impl FromPyArrow for ArrayData {
             ),
         )?;
 
-        ffi::from_ffi(array, &schema).map_err(to_py_err)
+        unsafe { ffi::from_ffi(array, &schema) }.map_err(to_py_err)
     }
 }
 
@@ -348,7 +348,7 @@ impl FromPyArrow for RecordBatch {
 
             let schema_ptr = unsafe { 
schema_capsule.reference::<FFI_ArrowSchema>() };
             let ffi_array = unsafe { 
FFI_ArrowArray::from_raw(array_capsule.pointer() as _) };
-            let array_data = ffi::from_ffi(ffi_array, 
schema_ptr).map_err(to_py_err)?;
+            let array_data = unsafe { ffi::from_ffi(ffi_array, schema_ptr) 
}.map_err(to_py_err)?;
             if !matches!(array_data.data_type(), DataType::Struct(_)) {
                 return Err(PyTypeError::new_err(
                     "Expected Struct type from __arrow_c_array.",


Reply via email to