pitrou commented on code in PR #37797:
URL: https://github.com/apache/arrow/pull/37797#discussion_r1345676663
##########
python/pyarrow/tests/test_cffi.py:
##########
@@ -411,3 +415,120 @@ def test_imported_batch_reader_error():
match="Expected to be able to read 16 bytes "
"for message body, got 8"):
reader_new.read_all()
+
+
[email protected]('obj', [pa.int32(), pa.field('foo', pa.int32()),
+ pa.schema({'foo': pa.int32()})],
+ ids=['type', 'field', 'schema'])
+def test_roundtrip_schema_capsule(obj):
+ gc.collect() # Make sure no Arrow data dangles in a ref cycle
+ old_allocated = pa.total_allocated_bytes()
+
+ capsule = obj.__arrow_c_schema__()
+ assert PyCapsule_IsValid(capsule, b"arrow_schema") == 1
+ obj_out = type(obj)._import_from_c_capsule(capsule)
+ assert obj_out == obj
+
+ assert pa.total_allocated_bytes() == old_allocated
+
+ capsule = obj.__arrow_c_schema__()
+
+ assert pa.total_allocated_bytes() > old_allocated
+ del capsule
+ assert pa.total_allocated_bytes() == old_allocated
+
+
[email protected]('arr,schema_accessor,bad_type,good_type', [
+ (pa.array(['a', 'b', 'c']), lambda x: x.type, pa.int32(), pa.string()),
+ (
+ pa.record_batch([pa.array(['a', 'b', 'c'])], names=['x']),
+ lambda x: x.schema,
+ pa.schema({'x': pa.int32()}),
+ pa.schema({'x': pa.string()})
+ ),
+], ids=['array', 'record_batch'])
+def test_roundtrip_array_capsule(arr, schema_accessor, bad_type, good_type):
+ gc.collect() # Make sure no Arrow data dangles in a ref cycle
+ old_allocated = pa.total_allocated_bytes()
+
+ import_array = type(arr)._import_from_c_capsule
+
+ schema_capsule, capsule = arr.__arrow_c_array__()
+ assert PyCapsule_IsValid(schema_capsule, b"arrow_schema") == 1
+ assert PyCapsule_IsValid(capsule, b"arrow_array") == 1
+ arr_out = import_array(schema_capsule, capsule)
+ assert arr_out.equals(arr)
+
+ assert pa.total_allocated_bytes() > old_allocated
+ del arr_out
+
+ assert pa.total_allocated_bytes() == old_allocated
+
+ capsule = arr.__arrow_c_array__()
+
+ assert pa.total_allocated_bytes() > old_allocated
+ del capsule
+ assert pa.total_allocated_bytes() == old_allocated
+
+ with pytest.raises(ValueError,
+ match=r"Could not cast.* string to requested .* int32"):
+ arr.__arrow_c_array__(bad_type.__arrow_c_schema__())
+
+ schema_capsule, array_capsule = arr.__arrow_c_array__(
+ good_type.__arrow_c_schema__())
+ arr_out = import_array(schema_capsule, array_capsule)
+ assert schema_accessor(arr_out) == good_type
+
+
+# TODO: implement requested_schema for stream
[email protected]('constructor', [
+ pa.RecordBatchReader.from_batches,
+ # Use a lambda because we need to re-order the parameters
+ lambda schema, batches: pa.Table.from_batches(batches, schema),
+], ids=['recordbatchreader', 'table'])
+def test_roundtrip_reader_capsule(constructor):
+ batches = make_batches()
+ schema = batches[0].schema
+
+ gc.collect() # Make sure no Arrow data dangles in a ref cycle
+ old_allocated = pa.total_allocated_bytes()
+
+ obj = constructor(schema, batches)
+
+ capsule = obj.__arrow_c_stream__()
+ assert PyCapsule_IsValid(capsule, b"arrow_array_stream") == 1
+ imported_reader = pa.RecordBatchReader._import_from_c_capsule(capsule)
+ assert imported_reader.schema == schema
+ for batch, expected in zip(imported_reader, batches):
Review Comment:
This won't check that the number of batches is the same on either side,
since `zip` stops when the shortest iterator is exhausted.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]