bkietz commented on code in PR #40807:
URL: https://github.com/apache/arrow/pull/40807#discussion_r1586297491
##########
python/pyarrow/tests/test_cffi.py:
##########
@@ -45,7 +45,7 @@
ValueError, match="Cannot import released ArrowArray")
assert_stream_released = pytest.raises(
- ValueError, match="Cannot import released ArrowArrayStream")
+ ValueError, match="Cannot import released Arrow Stream")
Review Comment:
nit: I suppose this is to be generic since it might be an
ArrowDeviceArrayStream instead, but it's inconsistent with the above asserters
##########
cpp/src/arrow/record_batch.h:
##########
@@ -260,6 +270,18 @@ class ARROW_EXPORT RecordBatch {
/// \return Status
virtual Status ValidateFull() const;
+ /// \brief EXPERIMENTAL: Return a top-level sync event object for this
record batch
+ ///
+ /// If all of the data for this record batch is in CPU memory, then this
+ /// will return null. If the data for this batch is
+ /// on a device, then if synchronization is needed before accessing the
+ /// data the returned sync event will allow for it.
+ ///
+ /// \return null or a Device::SyncEvent
+ virtual const std::shared_ptr<Device::SyncEvent>& GetSyncEvent() const;
Review Comment:
Is it desirable to give this a default implementation at all? Seems it could
instead be pure virtual
##########
cpp/src/arrow/c/bridge_test.cc:
##########
@@ -4746,4 +4750,531 @@ TEST_F(TestArrayStreamRoundtrip,
ChunkedArrayRoundtripEmpty) {
});
}
+////////////////////////////////////////////////////////////////////////////
+// Array device stream export tests
+
+class TestArrayDeviceStreamExport : public BaseArrayStreamTest {
+ public:
+ void AssertStreamSchema(struct ArrowDeviceArrayStream* c_stream,
+ const Schema& expected) {
+ struct ArrowSchema c_schema;
+ ASSERT_EQ(0, c_stream->get_schema(c_stream, &c_schema));
+
+ SchemaExportGuard schema_guard(&c_schema);
+ ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema));
+ ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
+ AssertSchemaEqual(expected, *schema, /*check_metadata=*/true);
+ }
+
+ void AssertStreamEnd(struct ArrowDeviceArrayStream* c_stream) {
+ struct ArrowDeviceArray c_array;
+ ASSERT_EQ(0, c_stream->get_next(c_stream, &c_array));
+
+ DeviceArrayExportGuard guard(&c_array);
+ ASSERT_TRUE(ArrowDeviceArrayIsReleased(&c_array));
+ }
+
+ void AssertStreamNext(struct ArrowDeviceArrayStream* c_stream,
+ const RecordBatch& expected) {
+ struct ArrowDeviceArray c_array;
+ ASSERT_EQ(0, c_stream->get_next(c_stream, &c_array));
+
+ DeviceArrayExportGuard guard(&c_array);
+ ASSERT_FALSE(ArrowDeviceArrayIsReleased(&c_array));
+
+ ASSERT_OK_AND_ASSIGN(auto batch,
+ ImportDeviceRecordBatch(&c_array, expected.schema(),
+
TestDeviceArrayRoundtrip::DeviceMapper));
+ AssertBatchesEqual(expected, *batch);
+ }
+
+ void AssertStreamNext(struct ArrowDeviceArrayStream* c_stream, const Array&
expected) {
+ struct ArrowDeviceArray c_array;
+ ASSERT_EQ(0, c_stream->get_next(c_stream, &c_array));
+
+ DeviceArrayExportGuard guard(&c_array);
+ ASSERT_FALSE(ArrowDeviceArrayIsReleased(&c_array));
+
+ ASSERT_OK_AND_ASSIGN(auto array,
+ ImportDeviceArray(&c_array, expected.type(),
+
TestDeviceArrayRoundtrip::DeviceMapper));
+ AssertArraysEqual(expected, *array);
+ }
+
+ static Result<std::shared_ptr<ArrayData>> ToDeviceData(
+ const std::shared_ptr<MemoryManager>& mm, const ArrayData& data) {
+ arrow::BufferVector buffers;
+ for (const auto& buf : data.buffers) {
+ if (buf) {
+ ARROW_ASSIGN_OR_RAISE(auto dest, mm->CopyBuffer(buf, mm));
+ buffers.push_back(dest);
+ } else {
+ buffers.push_back(nullptr);
+ }
+ }
+
+ arrow::ArrayDataVector children;
+ for (const auto& child : data.child_data) {
+ ARROW_ASSIGN_OR_RAISE(auto dest, ToDeviceData(mm, *child));
+ children.push_back(dest);
+ }
+
+ return ArrayData::Make(data.type, data.length, buffers, children,
data.null_count,
+ data.offset);
+ }
+
+ static Result<std::shared_ptr<Array>> ToDevice(const
std::shared_ptr<MemoryManager>& mm,
+ const ArrayData& data) {
+ ARROW_ASSIGN_OR_RAISE(auto result, ToDeviceData(mm, data));
+ return MakeArray(result);
+ }
+};
+
+TEST_F(TestArrayDeviceStreamExport, Empty) {
+ auto schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(schema, {});
+ ASSERT_OK_AND_ASSIGN(
+ auto reader,
+ RecordBatchReader::Make(batches, schema,
+
static_cast<DeviceAllocationType>(kMyDeviceType)));
+
+ struct ArrowDeviceArrayStream c_stream;
+ ASSERT_OK(ExportDeviceRecordBatchReader(reader, &c_stream));
+ DeviceArrayStreamExportGuard guard(&c_stream);
+
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+ ASSERT_EQ(kMyDeviceType, c_stream.device_type);
+ AssertStreamSchema(&c_stream, *schema);
+ AssertStreamEnd(&c_stream);
+ AssertStreamEnd(&c_stream);
+}
+
+TEST_F(TestArrayDeviceStreamExport, Simple) {
+ std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+ auto mm = device->default_memory_manager();
+
+ ASSERT_OK_AND_ASSIGN(auto arr1,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[1,
2]")->data()));
+ ASSERT_EQ(device->device_type(), arr1->device_type());
+ ASSERT_OK_AND_ASSIGN(auto arr2,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5,
null]")->data()));
+ ASSERT_EQ(device->device_type(), arr2->device_type());
+ auto schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(schema, {arr1, arr2});
+ ASSERT_OK_AND_ASSIGN(auto reader,
+ RecordBatchReader::Make(batches, schema,
device->device_type()));
+
+ struct ArrowDeviceArrayStream c_stream;
+
+ ASSERT_OK(ExportDeviceRecordBatchReader(reader, &c_stream));
+ DeviceArrayStreamExportGuard guard(&c_stream);
+
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+ AssertStreamSchema(&c_stream, *schema);
+ ASSERT_EQ(kMyDeviceType, c_stream.device_type);
+ AssertStreamNext(&c_stream, *batches[0]);
+ AssertStreamNext(&c_stream, *batches[1]);
+ AssertStreamEnd(&c_stream);
+ AssertStreamEnd(&c_stream);
+}
+
+TEST_F(TestArrayDeviceStreamExport, ArrayLifetime) {
+ std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+ auto mm = device->default_memory_manager();
+
+ ASSERT_OK_AND_ASSIGN(auto arr1,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[1,
2]")->data()));
+ ASSERT_EQ(device->device_type(), arr1->device_type());
+ ASSERT_OK_AND_ASSIGN(auto arr2,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5,
null]")->data()));
+ ASSERT_EQ(device->device_type(), arr2->device_type());
+ auto schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(schema, {arr1, arr2});
+ ASSERT_OK_AND_ASSIGN(auto reader,
+ RecordBatchReader::Make(batches, schema,
device->device_type()));
+
+ struct ArrowDeviceArrayStream c_stream;
+ struct ArrowSchema c_schema;
+ struct ArrowDeviceArray c_array0, c_array1;
+
+ ASSERT_OK(ExportDeviceRecordBatchReader(reader, &c_stream));
+ {
+ DeviceArrayStreamExportGuard guard(&c_stream);
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+ ASSERT_EQ(kMyDeviceType, c_stream.device_type);
+
+ ASSERT_EQ(0, c_stream.get_schema(&c_stream, &c_schema));
+ ASSERT_EQ(0, c_stream.get_next(&c_stream, &c_array0));
+ ASSERT_EQ(0, c_stream.get_next(&c_stream, &c_array1));
+ AssertStreamEnd(&c_stream);
+ }
+
+ DeviceArrayExportGuard guard0(&c_array0), guard1(&c_array1);
+
+ {
+ SchemaExportGuard schema_guard(&c_schema);
+ ASSERT_OK_AND_ASSIGN(auto got_schema, ImportSchema(&c_schema));
+ AssertSchemaEqual(*schema, *got_schema, /*check_metadata=*/true);
+ }
+
+ ASSERT_EQ(kMyDeviceType, c_array0.device_type);
+ ASSERT_EQ(kMyDeviceType, c_array1.device_type);
+
+ ASSERT_GT(pool_->bytes_allocated(), orig_allocated_);
+ ASSERT_OK_AND_ASSIGN(
+ auto batch,
+ ImportDeviceRecordBatch(&c_array1, schema,
TestDeviceArrayRoundtrip::DeviceMapper));
+ AssertBatchesEqual(*batches[1], *batch);
+ ASSERT_EQ(device->device_type(), batch->device_type());
+ ASSERT_OK_AND_ASSIGN(
+ batch,
+ ImportDeviceRecordBatch(&c_array0, schema,
TestDeviceArrayRoundtrip::DeviceMapper));
+ AssertBatchesEqual(*batches[0], *batch);
+ ASSERT_EQ(device->device_type(), batch->device_type());
+}
+
+TEST_F(TestArrayDeviceStreamExport, Errors) {
+ auto reader =
+ std::make_shared<FailingRecordBatchReader>(Status::Invalid("some example
error"));
+
+ struct ArrowDeviceArrayStream c_stream;
+
+ ASSERT_OK(ExportDeviceRecordBatchReader(reader, &c_stream));
+ DeviceArrayStreamExportGuard guard(&c_stream);
+
+ struct ArrowSchema c_schema;
+ ASSERT_EQ(0, c_stream.get_schema(&c_stream, &c_schema));
+ ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema));
+ {
+ SchemaExportGuard schema_guard(&c_schema);
+ ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
+ AssertSchemaEqual(schema, arrow::schema({}), /*check_metadata=*/true);
+ }
+
+ struct ArrowDeviceArray c_array;
+ ASSERT_EQ(EINVAL, c_stream.get_next(&c_stream, &c_array));
+}
+
+TEST_F(TestArrayDeviceStreamExport, ChunkedArrayExportEmpty) {
+ ASSERT_OK_AND_ASSIGN(auto chunked_array, ChunkedArray::Make({}, int32()));
+
+ struct ArrowDeviceArrayStream c_stream;
+ struct ArrowSchema c_schema;
+
+ ASSERT_OK(ExportDeviceChunkedArray(
+ chunked_array, static_cast<DeviceAllocationType>(kMyDeviceType),
&c_stream));
+ DeviceArrayStreamExportGuard guard(&c_stream);
+
+ {
+ DeviceArrayStreamExportGuard guard(&c_stream);
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+
+ ASSERT_EQ(kMyDeviceType, c_stream.device_type);
+ ASSERT_EQ(0, c_stream.get_schema(&c_stream, &c_schema));
+ AssertStreamEnd(&c_stream);
+ }
+
+ {
+ SchemaExportGuard schema_guard(&c_schema);
+ ASSERT_OK_AND_ASSIGN(auto got_type, ImportType(&c_schema));
+ AssertTypeEqual(*chunked_array->type(), *got_type);
+ }
+}
+
+TEST_F(TestArrayDeviceStreamExport, ChunkedArrayExport) {
+ std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+ auto mm = device->default_memory_manager();
+
+ ASSERT_OK_AND_ASSIGN(auto arr1,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[1,
2]")->data()));
+ ASSERT_EQ(device->device_type(), arr1->device_type());
+ ASSERT_OK_AND_ASSIGN(auto arr2,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5,
null]")->data()));
+ ASSERT_EQ(device->device_type(), arr2->device_type());
+
+ ASSERT_OK_AND_ASSIGN(auto chunked_array, ChunkedArray::Make({arr1, arr2}));
+
+ struct ArrowDeviceArrayStream c_stream;
+ struct ArrowSchema c_schema;
+ struct ArrowDeviceArray c_array0, c_array1;
+
+ ASSERT_OK(ExportDeviceChunkedArray(chunked_array, device->device_type(),
&c_stream));
+ DeviceArrayStreamExportGuard guard(&c_stream);
+
+ {
+ DeviceArrayStreamExportGuard guard(&c_stream);
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+ ASSERT_EQ(kMyDeviceType, c_stream.device_type);
+
+ ASSERT_EQ(0, c_stream.get_schema(&c_stream, &c_schema));
+ ASSERT_EQ(0, c_stream.get_next(&c_stream, &c_array0));
+ ASSERT_EQ(0, c_stream.get_next(&c_stream, &c_array1));
+ AssertStreamEnd(&c_stream);
+ }
+
+ DeviceArrayExportGuard guard0(&c_array0), guard1(&c_array1);
+
+ {
+ SchemaExportGuard schema_guard(&c_schema);
+ ASSERT_OK_AND_ASSIGN(auto got_type, ImportType(&c_schema));
+ AssertTypeEqual(*chunked_array->type(), *got_type);
+ }
+
+ ASSERT_EQ(kMyDeviceType, c_array0.device_type);
+ ASSERT_EQ(kMyDeviceType, c_array1.device_type);
+
+ ASSERT_GT(pool_->bytes_allocated(), orig_allocated_);
+ ASSERT_OK_AND_ASSIGN(auto array,
+ ImportDeviceArray(&c_array0, chunked_array->type(),
+
TestDeviceArrayRoundtrip::DeviceMapper));
+ ASSERT_EQ(device->device_type(), array->device_type());
+ AssertArraysEqual(*chunked_array->chunk(0), *array);
+ ASSERT_OK_AND_ASSIGN(array, ImportDeviceArray(&c_array1,
chunked_array->type(),
+
TestDeviceArrayRoundtrip::DeviceMapper));
+ ASSERT_EQ(device->device_type(), array->device_type());
+ AssertArraysEqual(*chunked_array->chunk(1), *array);
+}
+
+////////////////////////////////////////////////////////////////////////////
+// Array device stream roundtrip tests
+
+class TestArrayDeviceStreamRoundtrip : public BaseArrayStreamTest {
+ public:
+ static Result<std::shared_ptr<ArrayData>> ToDeviceData(
+ const std::shared_ptr<MemoryManager>& mm, const ArrayData& data) {
+ arrow::BufferVector buffers;
+ for (const auto& buf : data.buffers) {
+ if (buf) {
+ ARROW_ASSIGN_OR_RAISE(auto dest, mm->CopyBuffer(buf, mm));
+ buffers.push_back(dest);
+ } else {
+ buffers.push_back(nullptr);
+ }
+ }
+
+ arrow::ArrayDataVector children;
+ for (const auto& child : data.child_data) {
+ ARROW_ASSIGN_OR_RAISE(auto dest, ToDeviceData(mm, *child));
+ children.push_back(dest);
+ }
+
+ return ArrayData::Make(data.type, data.length, buffers, children,
data.null_count,
+ data.offset);
+ }
+
+ static Result<std::shared_ptr<Array>> ToDevice(const
std::shared_ptr<MemoryManager>& mm,
+ const ArrayData& data) {
+ ARROW_ASSIGN_OR_RAISE(auto result, ToDeviceData(mm, data));
+ return MakeArray(result);
+ }
+
+ void Roundtrip(std::shared_ptr<RecordBatchReader>* reader,
+ struct ArrowDeviceArrayStream* c_stream) {
+ ASSERT_OK(ExportDeviceRecordBatchReader(*reader, c_stream));
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(c_stream));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto got_reader,
+ ImportDeviceRecordBatchReader(c_stream,
TestDeviceArrayRoundtrip::DeviceMapper));
+ *reader = std::move(got_reader);
+ }
+
+ void Roundtrip(
+ std::shared_ptr<RecordBatchReader> reader,
+ std::function<void(const std::shared_ptr<RecordBatchReader>&)>
check_func) {
+ ArrowDeviceArrayStream c_stream;
+
+ // NOTE: ReleaseCallback<> is not immediately usable with
ArrowDeviceArayStream
+ // because get_next and get_schema need the original private_data.
+ std::weak_ptr<RecordBatchReader> weak_reader(reader);
+ ASSERT_EQ(weak_reader.use_count(), 1); // Expiration check will fail
otherwise
+
+ ASSERT_OK(ExportDeviceRecordBatchReader(std::move(reader), &c_stream));
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+
+ {
+ ASSERT_OK_AND_ASSIGN(auto new_reader,
+ ImportDeviceRecordBatchReader(
+ &c_stream,
TestDeviceArrayRoundtrip::DeviceMapper));
+ // stream was moved
+ ASSERT_TRUE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+ ASSERT_FALSE(weak_reader.expired());
+
+ check_func(new_reader);
+ }
+ // Stream was released when `new_reader` was destroyed
+ ASSERT_TRUE(weak_reader.expired());
+ }
+
+ void Roundtrip(std::shared_ptr<ChunkedArray> src,
+ std::function<void(const std::shared_ptr<ChunkedArray>&)>
check_func) {
+ ArrowDeviceArrayStream c_stream;
+
+ // One original copy to compare the result, one copy held by the stream
+ std::weak_ptr<ChunkedArray> weak_src(src);
+ int64_t initial_use_count = weak_src.use_count();
+
+ ASSERT_OK(ExportDeviceChunkedArray(
+ std::move(src), static_cast<DeviceAllocationType>(kMyDeviceType),
&c_stream));
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+ ASSERT_EQ(kMyDeviceType, c_stream.device_type);
+
+ {
+ ASSERT_OK_AND_ASSIGN(
+ auto dst,
+ ImportDeviceChunkedArray(&c_stream,
TestDeviceArrayRoundtrip::DeviceMapper));
+ // Stream was moved, consumed, and released
+ ASSERT_TRUE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+
+ // Stream was released by ImportDeviceChunkedArray but original copy
remains
+ ASSERT_EQ(weak_src.use_count(), initial_use_count - 1);
+
+ check_func(dst);
+ }
+ }
+
+ void AssertReaderNext(const std::shared_ptr<RecordBatchReader>& reader,
+ const RecordBatch& expected) {
+ ASSERT_OK_AND_ASSIGN(auto batch, reader->Next());
+ ASSERT_NE(batch, nullptr);
+ ASSERT_EQ(static_cast<DeviceAllocationType>(kMyDeviceType),
batch->device_type());
+ AssertBatchesEqual(expected, *batch);
+ }
+
+ void AssertReaderEnd(const std::shared_ptr<RecordBatchReader>& reader) {
+ ASSERT_OK_AND_ASSIGN(auto batch, reader->Next());
+ ASSERT_EQ(batch, nullptr);
+ }
+
+ void AssertReaderClosed(const std::shared_ptr<RecordBatchReader>& reader) {
+ ASSERT_THAT(reader->Next(),
+ Raises(StatusCode::Invalid, ::testing::HasSubstr("already been
closed")));
+ }
+
+ void AssertReaderClose(const std::shared_ptr<RecordBatchReader>& reader) {
+ ASSERT_OK(reader->Close());
+ AssertReaderClosed(reader);
+ }
+};
+
+TEST_F(TestArrayDeviceStreamRoundtrip, Simple) {
+ std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+ auto mm = device->default_memory_manager();
+
+ ASSERT_OK_AND_ASSIGN(auto arr1,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[1,
2]")->data()));
+ ASSERT_EQ(device->device_type(), arr1->device_type());
+ ASSERT_OK_AND_ASSIGN(auto arr2,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5,
null]")->data()));
+ ASSERT_EQ(device->device_type(), arr2->device_type());
+ auto orig_schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(orig_schema, {arr1, arr2});
+ ASSERT_OK_AND_ASSIGN(
+ auto reader, RecordBatchReader::Make(batches, orig_schema,
device->device_type()));
+
+ Roundtrip(std::move(reader), [&](const std::shared_ptr<RecordBatchReader>&
reader) {
+ AssertSchemaEqual(*orig_schema, *reader->schema(),
/*check_metadata=*/true);
+ AssertReaderNext(reader, *batches[0]);
+ AssertReaderNext(reader, *batches[1]);
+ AssertReaderEnd(reader);
+ AssertReaderEnd(reader);
+ AssertReaderClose(reader);
+ });
+}
+
+TEST_F(TestArrayDeviceStreamRoundtrip, CloseEarly) {
+ std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+ auto mm = device->default_memory_manager();
+
+ ASSERT_OK_AND_ASSIGN(auto arr1,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[1,
2]")->data()));
+ ASSERT_EQ(device->device_type(), arr1->device_type());
+ ASSERT_OK_AND_ASSIGN(auto arr2,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5,
null]")->data()));
+ ASSERT_EQ(device->device_type(), arr2->device_type());
+ auto orig_schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(orig_schema, {arr1, arr2});
+ ASSERT_OK_AND_ASSIGN(
+ auto reader, RecordBatchReader::Make(batches, orig_schema,
device->device_type()));
+
+ Roundtrip(std::move(reader), [&](const std::shared_ptr<RecordBatchReader>&
reader) {
+ AssertReaderNext(reader, *batches[0]);
+ AssertReaderClose(reader);
+ });
+}
+
+TEST_F(TestArrayDeviceStreamRoundtrip, Errors) {
+ auto reader = std::make_shared<FailingRecordBatchReader>(
+ Status::Invalid("roundtrip error example"));
+
+ Roundtrip(std::move(reader), [&](const std::shared_ptr<RecordBatchReader>&
reader) {
+ auto status = reader->Next().status();
+ ASSERT_RAISES(Invalid, status);
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("roundtrip error
example"));
+ });
+}
+
+TEST_F(TestArrayDeviceStreamRoundtrip, SchemaError) {
+ struct StreamState {
+ bool released = false;
+
+ static const char* GetLastError(struct ArrowDeviceArrayStream* stream) {
+ return "Expected error";
+ }
+
+ static int GetSchema(struct ArrowDeviceArrayStream* stream,
+ struct ArrowSchema* schema) {
+ return EIO;
+ }
+
+ static int GetNext(struct ArrowDeviceArrayStream* stream,
+ struct ArrowDeviceArray* array) {
+ return EINVAL;
+ }
+
+ static void Release(struct ArrowDeviceArrayStream* stream) {
+ reinterpret_cast<StreamState*>(stream->private_data)->released = true;
+ std::memset(stream, 0, sizeof(*stream));
+ }
+ } state;
+ struct ArrowDeviceArrayStream stream = {};
+ stream.get_last_error = &StreamState::GetLastError;
+ stream.get_schema = &StreamState::GetSchema;
+ stream.get_next = &StreamState::GetNext;
+ stream.release = &StreamState::Release;
+ stream.private_data = &state;
Review Comment:
Empty lambdas are convertible to function pointers:
```suggestion
struct ArrowDeviceArrayStream stream = {};
stream.get_last_error = [](struct ArrowDeviceArrayStream* stream) {
return "Expected error";
};
stream.get_schema = [](struct ArrowDeviceArrayStream* stream,
struct ArrowSchema* schema) {
return EIO;
};
stream.get_next = [](struct ArrowDeviceArrayStream* stream,
struct ArrowDeviceArray* array) {
return EINVAL;
};
stream.release = [](struct ArrowDeviceArrayStream* stream) {
*static_cast<bool*>(stream->private_data) = true;
std::memset(stream, 0, sizeof(*stream));
};
bool released = false;
stream.private_data = &released;
```
##########
cpp/src/arrow/c/bridge.cc:
##########
@@ -2150,53 +2186,102 @@ class ExportedArrayStream {
int64_t next_batch_num() { return private_data()->batch_num_++; }
- struct ArrowArrayStream* stream_;
+ StreamType* stream_;
};
} // namespace
Status ExportRecordBatchReader(std::shared_ptr<RecordBatchReader> reader,
struct ArrowArrayStream* out) {
- return ExportedArrayStream<RecordBatchReader>::Make(std::move(reader), out);
+ memset(out, 0, sizeof(struct ArrowArrayStream));
+ return ExportedArrayStream<RecordBatchReader,
false>::Make(std::move(reader), out);
}
Status ExportChunkedArray(std::shared_ptr<ChunkedArray> chunked_array,
struct ArrowArrayStream* out) {
- return ExportedArrayStream<ChunkedArray>::Make(std::move(chunked_array),
out);
+ memset(out, 0, sizeof(struct ArrowArrayStream));
+ return ExportedArrayStream<ChunkedArray,
false>::Make(std::move(chunked_array), out);
+}
+
+Status ExportDeviceRecordBatchReader(std::shared_ptr<RecordBatchReader> reader,
+ struct ArrowDeviceArrayStream* out) {
+ memset(out, 0, sizeof(struct ArrowDeviceArrayStream));
+ out->device_type = static_cast<ArrowDeviceType>(reader->device_type());
+ return ExportedArrayStream<RecordBatchReader, true>::Make(std::move(reader),
out);
+}
+
+Status ExportDeviceChunkedArray(std::shared_ptr<ChunkedArray> chunked_array,
+ DeviceAllocationType device_type,
+ struct ArrowDeviceArrayStream* out) {
+ memset(out, 0, sizeof(struct ArrowDeviceArrayStream));
+ out->device_type = static_cast<ArrowDeviceType>(device_type);
+ return ExportedArrayStream<ChunkedArray,
true>::Make(std::move(chunked_array), out);
}
//////////////////////////////////////////////////////////////////////////
// C stream import
namespace {
+template <bool IsDevice>
class ArrayStreamReader {
+ protected:
+ using StreamTraits =
+ std::conditional_t<IsDevice, internal::ArrayDeviceStreamExportTraits,
+ internal::ArrayStreamExportTraits>;
+ using StreamType = typename StreamTraits::CType;
+ using ArrayTraits = std::conditional_t<IsDevice,
internal::ArrayDeviceExportTraits,
+ internal::ArrayExportTraits>;
+ using ArrayType = typename ArrayTraits::CType;
+
public:
- explicit ArrayStreamReader(struct ArrowArrayStream* stream) {
- ArrowArrayStreamMove(stream, &stream_);
- DCHECK(!ArrowArrayStreamIsReleased(&stream_));
+ explicit ArrayStreamReader(StreamType* stream,
+ const DeviceMemoryMapper mapper =
DefaultDeviceMemoryMapper)
+ : mapper_{std::move(mapper)} {
+ StreamTraits::MoveFunc(stream, &stream_);
+ DCHECK(!StreamTraits::IsReleasedFunc(&stream_));
}
~ArrayStreamReader() { ReleaseStream(); }
void ReleaseStream() {
- if (!ArrowArrayStreamIsReleased(&stream_)) {
- ArrowArrayStreamRelease(&stream_);
- }
- DCHECK(ArrowArrayStreamIsReleased(&stream_));
+ // all our trait release funcs check IsReleased so we don't
+ // need to repeat it here
+ StreamTraits::ReleaseFunc(&stream_);
+ DCHECK(StreamTraits::IsReleasedFunc(&stream_));
}
protected:
- Status ReadNextArrayInternal(struct ArrowArray* array) {
- ArrowArrayMarkReleased(array);
+ Status ReadNextArrayInternal(ArrayType* array) {
+ ArrayTraits::MarkReleased(array);
Status status = StatusFromCError(stream_.get_next(&stream_, array));
- if (!status.ok() && !ArrowArrayIsReleased(array)) {
- ArrowArrayRelease(array);
+ if (!status.ok() && !ArrayTraits::IsReleasedFunc(array)) {
+ ArrayTraits::ReleaseFunc(array);
Review Comment:
The comment above suggests we don't need to check whether the array is
released either
```suggestion
if (!status.ok()) {
ArrayTraits::ReleaseFunc(array);
```
##########
cpp/src/arrow/c/bridge.cc:
##########
@@ -2250,101 +2335,138 @@ class ArrayStreamReader {
return {code, last_error ? std::string(last_error) : ""};
}
+ DeviceAllocationType get_device_type() const {
+ if constexpr (std::is_same_v<ArrayType, struct ArrowDeviceArray>) {
Review Comment:
```suggestion
if constexpr (IsDevice) {
```
##########
cpp/src/arrow/record_batch.h:
##########
@@ -260,6 +270,18 @@ class ARROW_EXPORT RecordBatch {
/// \return Status
virtual Status ValidateFull() const;
+ /// \brief EXPERIMENTAL: Return a top-level sync event object for this
record batch
+ ///
+ /// If all of the data for this record batch is in CPU memory, then this
+ /// will return null. If the data for this batch is
+ /// on a device, then if synchronization is needed before accessing the
+ /// data the returned sync event will allow for it.
+ ///
+ /// \return null or a Device::SyncEvent
+ virtual const std::shared_ptr<Device::SyncEvent>& GetSyncEvent() const;
+
+ virtual DeviceAllocationType device_type() const;
Review Comment:
Same here, now that we always specify the device type in the constructor
##########
cpp/src/arrow/c/bridge_test.cc:
##########
@@ -4746,4 +4750,531 @@ TEST_F(TestArrayStreamRoundtrip,
ChunkedArrayRoundtripEmpty) {
});
}
+////////////////////////////////////////////////////////////////////////////
+// Array device stream export tests
+
+class TestArrayDeviceStreamExport : public BaseArrayStreamTest {
+ public:
+ void AssertStreamSchema(struct ArrowDeviceArrayStream* c_stream,
+ const Schema& expected) {
+ struct ArrowSchema c_schema;
+ ASSERT_EQ(0, c_stream->get_schema(c_stream, &c_schema));
+
+ SchemaExportGuard schema_guard(&c_schema);
+ ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema));
+ ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
+ AssertSchemaEqual(expected, *schema, /*check_metadata=*/true);
+ }
+
+ void AssertStreamEnd(struct ArrowDeviceArrayStream* c_stream) {
+ struct ArrowDeviceArray c_array;
+ ASSERT_EQ(0, c_stream->get_next(c_stream, &c_array));
+
+ DeviceArrayExportGuard guard(&c_array);
+ ASSERT_TRUE(ArrowDeviceArrayIsReleased(&c_array));
+ }
+
+ void AssertStreamNext(struct ArrowDeviceArrayStream* c_stream,
+ const RecordBatch& expected) {
+ struct ArrowDeviceArray c_array;
+ ASSERT_EQ(0, c_stream->get_next(c_stream, &c_array));
+
+ DeviceArrayExportGuard guard(&c_array);
+ ASSERT_FALSE(ArrowDeviceArrayIsReleased(&c_array));
+
+ ASSERT_OK_AND_ASSIGN(auto batch,
+ ImportDeviceRecordBatch(&c_array, expected.schema(),
+
TestDeviceArrayRoundtrip::DeviceMapper));
+ AssertBatchesEqual(expected, *batch);
+ }
+
+ void AssertStreamNext(struct ArrowDeviceArrayStream* c_stream, const Array&
expected) {
+ struct ArrowDeviceArray c_array;
+ ASSERT_EQ(0, c_stream->get_next(c_stream, &c_array));
+
+ DeviceArrayExportGuard guard(&c_array);
+ ASSERT_FALSE(ArrowDeviceArrayIsReleased(&c_array));
+
+ ASSERT_OK_AND_ASSIGN(auto array,
+ ImportDeviceArray(&c_array, expected.type(),
+
TestDeviceArrayRoundtrip::DeviceMapper));
+ AssertArraysEqual(expected, *array);
+ }
+
+ static Result<std::shared_ptr<ArrayData>> ToDeviceData(
+ const std::shared_ptr<MemoryManager>& mm, const ArrayData& data) {
+ arrow::BufferVector buffers;
+ for (const auto& buf : data.buffers) {
+ if (buf) {
+ ARROW_ASSIGN_OR_RAISE(auto dest, mm->CopyBuffer(buf, mm));
+ buffers.push_back(dest);
+ } else {
+ buffers.push_back(nullptr);
+ }
+ }
+
+ arrow::ArrayDataVector children;
+ for (const auto& child : data.child_data) {
+ ARROW_ASSIGN_OR_RAISE(auto dest, ToDeviceData(mm, *child));
+ children.push_back(dest);
+ }
+
+ return ArrayData::Make(data.type, data.length, buffers, children,
data.null_count,
+ data.offset);
+ }
+
+ static Result<std::shared_ptr<Array>> ToDevice(const
std::shared_ptr<MemoryManager>& mm,
+ const ArrayData& data) {
+ ARROW_ASSIGN_OR_RAISE(auto result, ToDeviceData(mm, data));
+ return MakeArray(result);
+ }
+};
+
+TEST_F(TestArrayDeviceStreamExport, Empty) {
+ auto schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(schema, {});
+ ASSERT_OK_AND_ASSIGN(
+ auto reader,
+ RecordBatchReader::Make(batches, schema,
+
static_cast<DeviceAllocationType>(kMyDeviceType)));
+
+ struct ArrowDeviceArrayStream c_stream;
+ ASSERT_OK(ExportDeviceRecordBatchReader(reader, &c_stream));
+ DeviceArrayStreamExportGuard guard(&c_stream);
+
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+ ASSERT_EQ(kMyDeviceType, c_stream.device_type);
+ AssertStreamSchema(&c_stream, *schema);
+ AssertStreamEnd(&c_stream);
+ AssertStreamEnd(&c_stream);
+}
+
+TEST_F(TestArrayDeviceStreamExport, Simple) {
+ std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+ auto mm = device->default_memory_manager();
+
+ ASSERT_OK_AND_ASSIGN(auto arr1,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[1,
2]")->data()));
+ ASSERT_EQ(device->device_type(), arr1->device_type());
+ ASSERT_OK_AND_ASSIGN(auto arr2,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5,
null]")->data()));
+ ASSERT_EQ(device->device_type(), arr2->device_type());
+ auto schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(schema, {arr1, arr2});
+ ASSERT_OK_AND_ASSIGN(auto reader,
+ RecordBatchReader::Make(batches, schema,
device->device_type()));
+
+ struct ArrowDeviceArrayStream c_stream;
+
+ ASSERT_OK(ExportDeviceRecordBatchReader(reader, &c_stream));
+ DeviceArrayStreamExportGuard guard(&c_stream);
+
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+ AssertStreamSchema(&c_stream, *schema);
+ ASSERT_EQ(kMyDeviceType, c_stream.device_type);
+ AssertStreamNext(&c_stream, *batches[0]);
+ AssertStreamNext(&c_stream, *batches[1]);
+ AssertStreamEnd(&c_stream);
+ AssertStreamEnd(&c_stream);
+}
+
+TEST_F(TestArrayDeviceStreamExport, ArrayLifetime) {
+ std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+ auto mm = device->default_memory_manager();
+
+ ASSERT_OK_AND_ASSIGN(auto arr1,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[1,
2]")->data()));
+ ASSERT_EQ(device->device_type(), arr1->device_type());
+ ASSERT_OK_AND_ASSIGN(auto arr2,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5,
null]")->data()));
+ ASSERT_EQ(device->device_type(), arr2->device_type());
+ auto schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(schema, {arr1, arr2});
+ ASSERT_OK_AND_ASSIGN(auto reader,
+ RecordBatchReader::Make(batches, schema,
device->device_type()));
+
+ struct ArrowDeviceArrayStream c_stream;
+ struct ArrowSchema c_schema;
+ struct ArrowDeviceArray c_array0, c_array1;
+
+ ASSERT_OK(ExportDeviceRecordBatchReader(reader, &c_stream));
+ {
+ DeviceArrayStreamExportGuard guard(&c_stream);
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+ ASSERT_EQ(kMyDeviceType, c_stream.device_type);
+
+ ASSERT_EQ(0, c_stream.get_schema(&c_stream, &c_schema));
+ ASSERT_EQ(0, c_stream.get_next(&c_stream, &c_array0));
+ ASSERT_EQ(0, c_stream.get_next(&c_stream, &c_array1));
+ AssertStreamEnd(&c_stream);
+ }
+
+ DeviceArrayExportGuard guard0(&c_array0), guard1(&c_array1);
+
+ {
+ SchemaExportGuard schema_guard(&c_schema);
+ ASSERT_OK_AND_ASSIGN(auto got_schema, ImportSchema(&c_schema));
+ AssertSchemaEqual(*schema, *got_schema, /*check_metadata=*/true);
+ }
+
+ ASSERT_EQ(kMyDeviceType, c_array0.device_type);
+ ASSERT_EQ(kMyDeviceType, c_array1.device_type);
+
+ ASSERT_GT(pool_->bytes_allocated(), orig_allocated_);
+ ASSERT_OK_AND_ASSIGN(
+ auto batch,
+ ImportDeviceRecordBatch(&c_array1, schema,
TestDeviceArrayRoundtrip::DeviceMapper));
+ AssertBatchesEqual(*batches[1], *batch);
+ ASSERT_EQ(device->device_type(), batch->device_type());
+ ASSERT_OK_AND_ASSIGN(
+ batch,
+ ImportDeviceRecordBatch(&c_array0, schema,
TestDeviceArrayRoundtrip::DeviceMapper));
+ AssertBatchesEqual(*batches[0], *batch);
+ ASSERT_EQ(device->device_type(), batch->device_type());
+}
+
+TEST_F(TestArrayDeviceStreamExport, Errors) {
+ auto reader =
+ std::make_shared<FailingRecordBatchReader>(Status::Invalid("some example
error"));
+
+ struct ArrowDeviceArrayStream c_stream;
+
+ ASSERT_OK(ExportDeviceRecordBatchReader(reader, &c_stream));
+ DeviceArrayStreamExportGuard guard(&c_stream);
+
+ struct ArrowSchema c_schema;
+ ASSERT_EQ(0, c_stream.get_schema(&c_stream, &c_schema));
+ ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema));
+ {
+ SchemaExportGuard schema_guard(&c_schema);
+ ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
+ AssertSchemaEqual(schema, arrow::schema({}), /*check_metadata=*/true);
+ }
+
+ struct ArrowDeviceArray c_array;
+ ASSERT_EQ(EINVAL, c_stream.get_next(&c_stream, &c_array));
+}
+
+TEST_F(TestArrayDeviceStreamExport, ChunkedArrayExportEmpty) {
+ ASSERT_OK_AND_ASSIGN(auto chunked_array, ChunkedArray::Make({}, int32()));
+
+ struct ArrowDeviceArrayStream c_stream;
+ struct ArrowSchema c_schema;
+
+ ASSERT_OK(ExportDeviceChunkedArray(
+ chunked_array, static_cast<DeviceAllocationType>(kMyDeviceType),
&c_stream));
+ DeviceArrayStreamExportGuard guard(&c_stream);
+
+ {
+ DeviceArrayStreamExportGuard guard(&c_stream);
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+
+ ASSERT_EQ(kMyDeviceType, c_stream.device_type);
+ ASSERT_EQ(0, c_stream.get_schema(&c_stream, &c_schema));
+ AssertStreamEnd(&c_stream);
+ }
+
+ {
+ SchemaExportGuard schema_guard(&c_schema);
+ ASSERT_OK_AND_ASSIGN(auto got_type, ImportType(&c_schema));
+ AssertTypeEqual(*chunked_array->type(), *got_type);
+ }
+}
+
+TEST_F(TestArrayDeviceStreamExport, ChunkedArrayExport) {
+ std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+ auto mm = device->default_memory_manager();
+
+ ASSERT_OK_AND_ASSIGN(auto arr1,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[1,
2]")->data()));
+ ASSERT_EQ(device->device_type(), arr1->device_type());
+ ASSERT_OK_AND_ASSIGN(auto arr2,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5,
null]")->data()));
+ ASSERT_EQ(device->device_type(), arr2->device_type());
+
+ ASSERT_OK_AND_ASSIGN(auto chunked_array, ChunkedArray::Make({arr1, arr2}));
+
+ struct ArrowDeviceArrayStream c_stream;
+ struct ArrowSchema c_schema;
+ struct ArrowDeviceArray c_array0, c_array1;
+
+ ASSERT_OK(ExportDeviceChunkedArray(chunked_array, device->device_type(),
&c_stream));
+ DeviceArrayStreamExportGuard guard(&c_stream);
+
+ {
+ DeviceArrayStreamExportGuard guard(&c_stream);
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+ ASSERT_EQ(kMyDeviceType, c_stream.device_type);
+
+ ASSERT_EQ(0, c_stream.get_schema(&c_stream, &c_schema));
+ ASSERT_EQ(0, c_stream.get_next(&c_stream, &c_array0));
+ ASSERT_EQ(0, c_stream.get_next(&c_stream, &c_array1));
+ AssertStreamEnd(&c_stream);
+ }
+
+ DeviceArrayExportGuard guard0(&c_array0), guard1(&c_array1);
+
+ {
+ SchemaExportGuard schema_guard(&c_schema);
+ ASSERT_OK_AND_ASSIGN(auto got_type, ImportType(&c_schema));
+ AssertTypeEqual(*chunked_array->type(), *got_type);
+ }
+
+ ASSERT_EQ(kMyDeviceType, c_array0.device_type);
+ ASSERT_EQ(kMyDeviceType, c_array1.device_type);
+
+ ASSERT_GT(pool_->bytes_allocated(), orig_allocated_);
+ ASSERT_OK_AND_ASSIGN(auto array,
+ ImportDeviceArray(&c_array0, chunked_array->type(),
+
TestDeviceArrayRoundtrip::DeviceMapper));
+ ASSERT_EQ(device->device_type(), array->device_type());
+ AssertArraysEqual(*chunked_array->chunk(0), *array);
+ ASSERT_OK_AND_ASSIGN(array, ImportDeviceArray(&c_array1,
chunked_array->type(),
+
TestDeviceArrayRoundtrip::DeviceMapper));
+ ASSERT_EQ(device->device_type(), array->device_type());
+ AssertArraysEqual(*chunked_array->chunk(1), *array);
+}
+
+////////////////////////////////////////////////////////////////////////////
+// Array device stream roundtrip tests
+
+class TestArrayDeviceStreamRoundtrip : public BaseArrayStreamTest {
+ public:
+ static Result<std::shared_ptr<ArrayData>> ToDeviceData(
+ const std::shared_ptr<MemoryManager>& mm, const ArrayData& data) {
+ arrow::BufferVector buffers;
+ for (const auto& buf : data.buffers) {
+ if (buf) {
+ ARROW_ASSIGN_OR_RAISE(auto dest, mm->CopyBuffer(buf, mm));
+ buffers.push_back(dest);
+ } else {
+ buffers.push_back(nullptr);
+ }
+ }
+
+ arrow::ArrayDataVector children;
+ for (const auto& child : data.child_data) {
+ ARROW_ASSIGN_OR_RAISE(auto dest, ToDeviceData(mm, *child));
+ children.push_back(dest);
+ }
+
+ return ArrayData::Make(data.type, data.length, buffers, children,
data.null_count,
+ data.offset);
+ }
+
+ static Result<std::shared_ptr<Array>> ToDevice(const
std::shared_ptr<MemoryManager>& mm,
+ const ArrayData& data) {
+ ARROW_ASSIGN_OR_RAISE(auto result, ToDeviceData(mm, data));
+ return MakeArray(result);
+ }
+
+ void Roundtrip(std::shared_ptr<RecordBatchReader>* reader,
+ struct ArrowDeviceArrayStream* c_stream) {
+ ASSERT_OK(ExportDeviceRecordBatchReader(*reader, c_stream));
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(c_stream));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto got_reader,
+ ImportDeviceRecordBatchReader(c_stream,
TestDeviceArrayRoundtrip::DeviceMapper));
+ *reader = std::move(got_reader);
+ }
+
+ void Roundtrip(
+ std::shared_ptr<RecordBatchReader> reader,
+ std::function<void(const std::shared_ptr<RecordBatchReader>&)>
check_func) {
+ ArrowDeviceArrayStream c_stream;
+
+ // NOTE: ReleaseCallback<> is not immediately usable with
ArrowDeviceArayStream
+ // because get_next and get_schema need the original private_data.
+ std::weak_ptr<RecordBatchReader> weak_reader(reader);
+ ASSERT_EQ(weak_reader.use_count(), 1); // Expiration check will fail
otherwise
+
+ ASSERT_OK(ExportDeviceRecordBatchReader(std::move(reader), &c_stream));
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+
+ {
+ ASSERT_OK_AND_ASSIGN(auto new_reader,
+ ImportDeviceRecordBatchReader(
+ &c_stream,
TestDeviceArrayRoundtrip::DeviceMapper));
+ // stream was moved
+ ASSERT_TRUE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+ ASSERT_FALSE(weak_reader.expired());
+
+ check_func(new_reader);
+ }
+ // Stream was released when `new_reader` was destroyed
+ ASSERT_TRUE(weak_reader.expired());
+ }
+
+ void Roundtrip(std::shared_ptr<ChunkedArray> src,
+ std::function<void(const std::shared_ptr<ChunkedArray>&)>
check_func) {
+ ArrowDeviceArrayStream c_stream;
+
+ // One original copy to compare the result, one copy held by the stream
+ std::weak_ptr<ChunkedArray> weak_src(src);
+ int64_t initial_use_count = weak_src.use_count();
+
+ ASSERT_OK(ExportDeviceChunkedArray(
+ std::move(src), static_cast<DeviceAllocationType>(kMyDeviceType),
&c_stream));
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+ ASSERT_EQ(kMyDeviceType, c_stream.device_type);
+
+ {
+ ASSERT_OK_AND_ASSIGN(
+ auto dst,
+ ImportDeviceChunkedArray(&c_stream,
TestDeviceArrayRoundtrip::DeviceMapper));
+ // Stream was moved, consumed, and released
+ ASSERT_TRUE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+
+ // Stream was released by ImportDeviceChunkedArray but original copy
remains
+ ASSERT_EQ(weak_src.use_count(), initial_use_count - 1);
+
+ check_func(dst);
+ }
+ }
+
+ void AssertReaderNext(const std::shared_ptr<RecordBatchReader>& reader,
+ const RecordBatch& expected) {
+ ASSERT_OK_AND_ASSIGN(auto batch, reader->Next());
+ ASSERT_NE(batch, nullptr);
+ ASSERT_EQ(static_cast<DeviceAllocationType>(kMyDeviceType),
batch->device_type());
+ AssertBatchesEqual(expected, *batch);
+ }
+
+ void AssertReaderEnd(const std::shared_ptr<RecordBatchReader>& reader) {
+ ASSERT_OK_AND_ASSIGN(auto batch, reader->Next());
+ ASSERT_EQ(batch, nullptr);
+ }
+
+ void AssertReaderClosed(const std::shared_ptr<RecordBatchReader>& reader) {
+ ASSERT_THAT(reader->Next(),
+ Raises(StatusCode::Invalid, ::testing::HasSubstr("already been
closed")));
+ }
+
+ void AssertReaderClose(const std::shared_ptr<RecordBatchReader>& reader) {
+ ASSERT_OK(reader->Close());
+ AssertReaderClosed(reader);
+ }
+};
+
+TEST_F(TestArrayDeviceStreamRoundtrip, Simple) {
+ std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+ auto mm = device->default_memory_manager();
+
+ ASSERT_OK_AND_ASSIGN(auto arr1,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[1,
2]")->data()));
+ ASSERT_EQ(device->device_type(), arr1->device_type());
+ ASSERT_OK_AND_ASSIGN(auto arr2,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5,
null]")->data()));
+ ASSERT_EQ(device->device_type(), arr2->device_type());
+ auto orig_schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(orig_schema, {arr1, arr2});
+ ASSERT_OK_AND_ASSIGN(
+ auto reader, RecordBatchReader::Make(batches, orig_schema,
device->device_type()));
+
+ Roundtrip(std::move(reader), [&](const std::shared_ptr<RecordBatchReader>&
reader) {
+ AssertSchemaEqual(*orig_schema, *reader->schema(),
/*check_metadata=*/true);
+ AssertReaderNext(reader, *batches[0]);
+ AssertReaderNext(reader, *batches[1]);
+ AssertReaderEnd(reader);
+ AssertReaderEnd(reader);
+ AssertReaderClose(reader);
+ });
+}
+
+TEST_F(TestArrayDeviceStreamRoundtrip, CloseEarly) {
+ std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+ auto mm = device->default_memory_manager();
+
+ ASSERT_OK_AND_ASSIGN(auto arr1,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[1,
2]")->data()));
+ ASSERT_EQ(device->device_type(), arr1->device_type());
+ ASSERT_OK_AND_ASSIGN(auto arr2,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5,
null]")->data()));
+ ASSERT_EQ(device->device_type(), arr2->device_type());
+ auto orig_schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(orig_schema, {arr1, arr2});
+ ASSERT_OK_AND_ASSIGN(
+ auto reader, RecordBatchReader::Make(batches, orig_schema,
device->device_type()));
+
+ Roundtrip(std::move(reader), [&](const std::shared_ptr<RecordBatchReader>&
reader) {
+ AssertReaderNext(reader, *batches[0]);
+ AssertReaderClose(reader);
+ });
+}
+
+TEST_F(TestArrayDeviceStreamRoundtrip, Errors) {
+ auto reader = std::make_shared<FailingRecordBatchReader>(
+ Status::Invalid("roundtrip error example"));
+
+ Roundtrip(std::move(reader), [&](const std::shared_ptr<RecordBatchReader>&
reader) {
+ auto status = reader->Next().status();
+ ASSERT_RAISES(Invalid, status);
+ ASSERT_THAT(status.message(), ::testing::HasSubstr("roundtrip error
example"));
Review Comment:
```suggestion
EXPECT_THAT(reader->Next(), Raises(StatusCode::Invalid,
::testing::HasSubstr("roundtrip error example")));
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]