This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new dff6402457 Let `ArrowArrayStreamReader` handle schema with attached
metadata + do schema checking (#8944)
dff6402457 is described below
commit dff6402457910389af785d768e0f9e274d9526a7
Author: Jonas Dedden <[email protected]>
AuthorDate: Tue Dec 9 23:37:54 2025 +0100
Let `ArrowArrayStreamReader` handle schema with attached metadata + do
schema checking (#8944)
# Which issue does this PR close? / Rationale for this change
Solves an issue discovered during
https://github.com/apache/arrow-rs/pull/8790, namely that
`ArrowArrayStreamReader` does not correctly expose schema-level metadata
and does not check whether the `StructArray` constructed from the FFI
stream actually in general corresponds to the expected schema.
# What changes are included in this PR?
- Change how `RecordBatch` is construted inside `ArrowArrayStreamReader`
such that it holds metadata and schema validity checks are done.
- Augment FFI tests with schema- and column-level metadata.
# Are these changes tested?
Yes, both `_test_round_trip_export` and `_test_round_trip_import` now
test for metadata on schema- and column-level.
# Are there any user-facing changes?
Yes, `ArrowArrayStreamReader` now is able to export `RecordBatch` with
schema-level metadata, and the interface has increased security since it
actually checks for schema validity.
---
arrow-array/src/ffi_stream.rs | 45 ++++++++++++++++------
.../tests/test_sql.py | 28 ++++++++------
2 files changed, 49 insertions(+), 24 deletions(-)
diff --git a/arrow-array/src/ffi_stream.rs b/arrow-array/src/ffi_stream.rs
index 27c020e5c0..c469436829 100644
--- a/arrow-array/src/ffi_stream.rs
+++ b/arrow-array/src/ffi_stream.rs
@@ -364,7 +364,9 @@ impl Iterator for ArrowArrayStreamReader {
let result = unsafe {
from_ffi_and_data_type(array,
DataType::Struct(self.schema().fields().clone()))
};
- Some(result.map(|data| RecordBatch::from(StructArray::from(data))))
+ Some(result.and_then(|data| {
+ RecordBatch::try_new(self.schema.clone(),
StructArray::from(data).into_parts().1)
+ }))
} else {
let last_error = self.get_stream_last_error();
let err = ArrowError::CDataInterface(last_error.unwrap());
@@ -382,6 +384,7 @@ impl RecordBatchReader for ArrowArrayStreamReader {
#[cfg(test)]
mod tests {
use super::*;
+ use std::collections::HashMap;
use arrow_schema::Field;
@@ -417,11 +420,18 @@ mod tests {
}
fn _test_round_trip_export(arrays: Vec<Arc<dyn Array>>) -> Result<()> {
- let schema = Arc::new(Schema::new(vec![
- Field::new("a", arrays[0].data_type().clone(), true),
- Field::new("b", arrays[1].data_type().clone(), true),
- Field::new("c", arrays[2].data_type().clone(), true),
- ]));
+ let metadata = HashMap::from([("foo".to_owned(), "bar".to_owned())]);
+ let schema = Arc::new(Schema::new_with_metadata(
+ vec![
+ Field::new("a", arrays[0].data_type().clone(), true)
+ .with_metadata(metadata.clone()),
+ Field::new("b", arrays[1].data_type().clone(), true)
+ .with_metadata(metadata.clone()),
+ Field::new("c", arrays[2].data_type().clone(), true)
+ .with_metadata(metadata.clone()),
+ ],
+ metadata,
+ ));
let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
let iter = Box::new(vec![batch.clone(),
batch.clone()].into_iter().map(Ok)) as _;
@@ -452,7 +462,11 @@ mod tests {
let array = unsafe { from_ffi(ffi_array, &ffi_schema) }.unwrap();
- let record_batch = RecordBatch::from(StructArray::from(array));
+ let record_batch = RecordBatch::try_new(
+ SchemaRef::from(exported_schema.clone()),
+ StructArray::from(array).into_parts().1,
+ )
+ .unwrap();
produced_batches.push(record_batch);
}
@@ -462,11 +476,18 @@ mod tests {
}
fn _test_round_trip_import(arrays: Vec<Arc<dyn Array>>) -> Result<()> {
- let schema = Arc::new(Schema::new(vec![
- Field::new("a", arrays[0].data_type().clone(), true),
- Field::new("b", arrays[1].data_type().clone(), true),
- Field::new("c", arrays[2].data_type().clone(), true),
- ]));
+ let metadata = HashMap::from([("foo".to_owned(), "bar".to_owned())]);
+ let schema = Arc::new(Schema::new_with_metadata(
+ vec![
+ Field::new("a", arrays[0].data_type().clone(), true)
+ .with_metadata(metadata.clone()),
+ Field::new("b", arrays[1].data_type().clone(), true)
+ .with_metadata(metadata.clone()),
+ Field::new("c", arrays[2].data_type().clone(), true)
+ .with_metadata(metadata.clone()),
+ ],
+ metadata,
+ ));
let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
let iter = Box::new(vec![batch.clone(),
batch.clone()].into_iter().map(Ok)) as _;
diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py
b/arrow-pyarrow-integration-testing/tests/test_sql.py
index 79220fb6a6..b9b04ddee5 100644
--- a/arrow-pyarrow-integration-testing/tests/test_sql.py
+++ b/arrow-pyarrow-integration-testing/tests/test_sql.py
@@ -527,7 +527,7 @@ def test_empty_recordbatch_with_row_count():
"""
# Create an empty schema with no fields
- batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3, 4]}).select([])
+ batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3, 4]}, metadata={b'key1':
b'value1'}).select([])
num_rows = 4
assert batch.num_rows == num_rows
assert batch.num_columns == 0
@@ -545,7 +545,7 @@ def test_record_batch_reader():
"""
Python -> Rust -> Python
"""
- schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1':
b'value1'})
+ schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()),
metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
@@ -571,7 +571,7 @@ def test_record_batch_reader_pycapsule():
"""
Python -> Rust -> Python
"""
- schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1':
b'value1'})
+ schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()),
metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
@@ -621,7 +621,7 @@ def test_record_batch_pycapsule():
"""
Python -> Rust -> Python
"""
- schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1':
b'value1'})
+ schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()),
metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
batch = pa.record_batch([[[1], [2, 42]]], schema)
wrapped = StreamWrapper(batch)
b = rust.round_trip_record_batch_reader(wrapped)
@@ -640,7 +640,7 @@ def test_table_pycapsule():
"""
Python -> Rust -> Python
"""
- schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1':
b'value1'})
+ schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()),
metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
@@ -650,8 +650,9 @@ def test_table_pycapsule():
b = rust.round_trip_record_batch_reader(wrapped)
new_table = b.read_all()
- assert table.schema == new_table.schema
assert table == new_table
+ assert table.schema == new_table.schema
+ assert table.schema.metadata == new_table.schema.metadata
assert len(table.to_batches()) == len(new_table.to_batches())
@@ -659,12 +660,13 @@ def test_table_empty():
"""
Python -> Rust -> Python
"""
- schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1':
b'value1'})
+ schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()),
metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
table = pa.Table.from_batches([], schema=schema)
new_table = rust.build_table([], schema=schema)
- assert table.schema == new_table.schema
assert table == new_table
+ assert table.schema == new_table.schema
+ assert table.schema.metadata == new_table.schema.metadata
assert len(table.to_batches()) == len(new_table.to_batches())
@@ -672,7 +674,7 @@ def test_table_roundtrip():
"""
Python -> Rust -> Python
"""
- schema = pa.schema([('ints', pa.list_(pa.int32()))])
+ schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()),
metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
@@ -680,8 +682,9 @@ def test_table_roundtrip():
table = pa.Table.from_batches(batches, schema=schema)
new_table = rust.round_trip_table(table)
- assert table.schema == new_table.schema
assert table == new_table
+ assert table.schema == new_table.schema
+ assert table.schema.metadata == new_table.schema.metadata
assert len(table.to_batches()) == len(new_table.to_batches())
@@ -689,7 +692,7 @@ def test_table_from_batches():
"""
Python -> Rust -> Python
"""
- schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1':
b'value1'})
+ schema = pa.schema([pa.field(name='ints', type=pa.list_(pa.int32()),
metadata={b'key1': b'value1'})], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
@@ -697,8 +700,9 @@ def test_table_from_batches():
table = pa.Table.from_batches(batches)
new_table = rust.build_table(batches, schema)
- assert table.schema == new_table.schema
assert table == new_table
+ assert table.schema == new_table.schema
+ assert table.schema.metadata == new_table.schema.metadata
assert len(table.to_batches()) == len(new_table.to_batches())