This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 8169d6e719 GH-40078: [C++] Import/Export ArrowDeviceArrayStream
(#40807)
8169d6e719 is described below
commit 8169d6e719453acd0e7ca1b6f784d800cca4f113
Author: Matt Topol <[email protected]>
AuthorDate: Tue May 21 15:40:16 2024 -0400
GH-40078: [C++] Import/Export ArrowDeviceArrayStream (#40807)
### Rationale for this change
The original PRs for adding support for importing and exporting the new C
Device interface (#36488 / #36489) only added support for the Arrays
themselves, not for the stream structure. We should support both.
### What changes are included in this PR?
Adding parallel functions for Import/Export of streams that accept
`ArrowDeviceArrayStream`.
### Are these changes tested?
Test writing in progress, wanted to get this up for review while I write
tests.
### Are there any user-facing changes?
No, only new functions have been added.
* GitHub Issue: #40078
Lead-authored-by: Matt Topol <[email protected]>
Co-authored-by: Felipe Oliveira Carvalho <[email protected]>
Co-authored-by: Benjamin Kietzman <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Matt Topol <[email protected]>
---
cpp/src/arrow/array/array_base.h | 8 +
cpp/src/arrow/array/array_test.cc | 5 +
cpp/src/arrow/array/data.cc | 36 +++
cpp/src/arrow/array/data.h | 21 ++
cpp/src/arrow/array/util.cc | 2 +-
cpp/src/arrow/c/bridge.cc | 278 +++++++++++++++-----
cpp/src/arrow/c/bridge.h | 61 +++++
cpp/src/arrow/c/bridge_test.cc | 516 ++++++++++++++++++++++++++++++++++++++
cpp/src/arrow/c/helpers.h | 49 ++++
cpp/src/arrow/c/util_internal.h | 22 ++
cpp/src/arrow/record_batch.cc | 107 ++++++--
cpp/src/arrow/record_batch.h | 43 +++-
python/pyarrow/tests/test_cffi.py | 2 +-
13 files changed, 1051 insertions(+), 99 deletions(-)
diff --git a/cpp/src/arrow/array/array_base.h b/cpp/src/arrow/array/array_base.h
index 6411aebf80..716ae07220 100644
--- a/cpp/src/arrow/array/array_base.h
+++ b/cpp/src/arrow/array/array_base.h
@@ -224,6 +224,14 @@ class ARROW_EXPORT Array {
/// \return Status
Status ValidateFull() const;
+ /// \brief Return the device_type that this array's data is allocated on
+ ///
+ /// This just delegates to calling device_type on the underlying ArrayData
+ /// object which backs this Array.
+ ///
+ /// \return DeviceAllocationType
+ DeviceAllocationType device_type() const { return data_->device_type(); }
+
protected:
Array() = default;
ARROW_DEFAULT_MOVE_AND_ASSIGN(Array);
diff --git a/cpp/src/arrow/array/array_test.cc
b/cpp/src/arrow/array/array_test.cc
index 7e25ad61fa..32806d9d2e 100644
--- a/cpp/src/arrow/array/array_test.cc
+++ b/cpp/src/arrow/array/array_test.cc
@@ -478,6 +478,7 @@ TEST_F(TestArray, TestMakeArrayOfNull) {
ASSERT_EQ(array->type(), type);
ASSERT_OK(array->ValidateFull());
ASSERT_EQ(array->length(), length);
+ ASSERT_EQ(array->device_type(), DeviceAllocationType::kCPU);
if (is_union(type->id())) {
ASSERT_EQ(array->null_count(), 0);
ASSERT_EQ(array->ComputeLogicalNullCount(), length);
@@ -719,6 +720,7 @@ TEST_F(TestArray, TestMakeArrayFromScalar) {
ASSERT_OK(array->ValidateFull());
ASSERT_EQ(array->length(), length);
ASSERT_EQ(array->null_count(), 0);
+ ASSERT_EQ(array->device_type(), DeviceAllocationType::kCPU);
// test case for ARROW-13321
for (int64_t i : {int64_t{0}, length / 2, length - 1}) {
@@ -744,6 +746,7 @@ TEST_F(TestArray, TestMakeArrayFromScalarSliced) {
auto sliced = array->Slice(1, 4);
ASSERT_EQ(sliced->length(), 4);
ASSERT_EQ(sliced->null_count(), 0);
+ ASSERT_EQ(array->device_type(), DeviceAllocationType::kCPU);
ARROW_EXPECT_OK(sliced->ValidateFull());
}
}
@@ -758,6 +761,7 @@ TEST_F(TestArray, TestMakeArrayFromDictionaryScalar) {
ASSERT_OK(array->ValidateFull());
ASSERT_EQ(array->length(), 4);
ASSERT_EQ(array->null_count(), 0);
+ ASSERT_EQ(array->device_type(), DeviceAllocationType::kCPU);
for (int i = 0; i < 4; i++) {
ASSERT_OK_AND_ASSIGN(auto item, array->GetScalar(i));
@@ -797,6 +801,7 @@ TEST_F(TestArray, TestMakeEmptyArray) {
ASSERT_OK_AND_ASSIGN(auto array, MakeEmptyArray(type));
ASSERT_OK(array->ValidateFull());
ASSERT_EQ(array->length(), 0);
+
CheckSpanRoundTrip(*array);
}
}
diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc
index ac828a9c35..76a4352139 100644
--- a/cpp/src/arrow/array/data.cc
+++ b/cpp/src/arrow/array/data.cc
@@ -224,6 +224,42 @@ int64_t ArrayData::ComputeLogicalNullCount() const {
return ArraySpan(*this).ComputeLogicalNullCount();
}
+DeviceAllocationType ArrayData::device_type() const {
+ // we're using 0 as a sentinel value for NOT YET ASSIGNED
+ // there is explicitly no constant DeviceAllocationType to represent
+ // the "UNASSIGNED" case as it is invalid for data to not have an
+ // assigned device type. If it's still 0 at the end, then we return
+ // CPU as the allocation device type
+ int type = 0;
+ for (const auto& buf : buffers) {
+ if (!buf) continue;
+ if (type == 0) {
+ type = static_cast<int>(buf->device_type());
+ } else {
+ DCHECK_EQ(type, static_cast<int>(buf->device_type()));
+ }
+ }
+
+ for (const auto& child : child_data) {
+ if (!child) continue;
+ if (type == 0) {
+ type = static_cast<int>(child->device_type());
+ } else {
+ DCHECK_EQ(type, static_cast<int>(child->device_type()));
+ }
+ }
+
+ if (dictionary) {
+ if (type == 0) {
+ type = static_cast<int>(dictionary->device_type());
+ } else {
+ DCHECK_EQ(type, static_cast<int>(dictionary->device_type()));
+ }
+ }
+
+ return type == 0 ? DeviceAllocationType::kCPU :
static_cast<DeviceAllocationType>(type);
+}
+
// ----------------------------------------------------------------------
// Methods for ArraySpan
diff --git a/cpp/src/arrow/array/data.h b/cpp/src/arrow/array/data.h
index beec29789a..0c49f36229 100644
--- a/cpp/src/arrow/array/data.h
+++ b/cpp/src/arrow/array/data.h
@@ -101,6 +101,11 @@ struct ARROW_EXPORT ArrayData {
int64_t null_count = kUnknownNullCount, int64_t offset = 0)
: ArrayData(std::move(type), length, null_count, offset) {
this->buffers = std::move(buffers);
+#ifndef NDEBUG
+ // in debug mode, call the `device_type` function to trigger
+ // the DCHECKs that validate all the buffers are on the same device
+ ARROW_UNUSED(this->device_type());
+#endif
}
ArrayData(std::shared_ptr<DataType> type, int64_t length,
@@ -110,6 +115,12 @@ struct ARROW_EXPORT ArrayData {
: ArrayData(std::move(type), length, null_count, offset) {
this->buffers = std::move(buffers);
this->child_data = std::move(child_data);
+#ifndef NDEBUG
+ // in debug mode, call the `device_type` function to trigger
+ // the DCHECKs that validate all the buffers (including children)
+ // are on the same device
+ ARROW_UNUSED(this->device_type());
+#endif
}
static std::shared_ptr<ArrayData> Make(std::shared_ptr<DataType> type,
int64_t length,
@@ -358,6 +369,16 @@ struct ARROW_EXPORT ArrayData {
/// \see GetNullCount
int64_t ComputeLogicalNullCount() const;
+ /// \brief Returns the device_type of the underlying buffers and children
+ ///
+ /// If there are no buffers in this ArrayData object, it just returns
+ /// DeviceAllocationType::kCPU as a default. We also assume that all buffers
+ /// should be allocated on the same device type and perform DCHECKs to
confirm
+ /// this in debug mode.
+ ///
+ /// \return DeviceAllocationType
+ DeviceAllocationType device_type() const;
+
std::shared_ptr<DataType> type;
int64_t length = 0;
mutable std::atomic<int64_t> null_count{0};
diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc
index bdba92c9a1..41cd6a1c0b 100644
--- a/cpp/src/arrow/array/util.cc
+++ b/cpp/src/arrow/array/util.cc
@@ -548,7 +548,7 @@ class NullArrayFactory {
}
Status Visit(const StructType& type) {
- for (int i = 0; i < type_->num_fields(); ++i) {
+ for (int i = 0; i < type.num_fields(); ++i) {
ARROW_ASSIGN_OR_RAISE(out_->child_data[i], CreateChild(type, i,
length_));
}
return Status::OK();
diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc
index 8a530b3798..8c5e3637b6 100644
--- a/cpp/src/arrow/c/bridge.cc
+++ b/cpp/src/arrow/c/bridge.cc
@@ -1448,6 +1448,7 @@ namespace {
// The ArrowArray is released on destruction.
struct ImportedArrayData {
struct ArrowArray array_;
+ DeviceAllocationType device_type_;
std::shared_ptr<Device::SyncEvent> device_sync_;
ImportedArrayData() {
@@ -1514,6 +1515,7 @@ struct ArrayImporter {
recursion_level_ = 0;
import_ = std::make_shared<ImportedArrayData>();
c_struct_ = &import_->array_;
+ import_->device_type_ = device_type_;
ArrowArrayMove(src, c_struct_);
return DoImport();
}
@@ -1541,7 +1543,8 @@ struct ArrayImporter {
"cannot be imported as RecordBatch");
}
return RecordBatch::Make(std::move(schema), data_->length,
- std::move(data_->child_data));
+ std::move(data_->child_data),
import_->device_type_,
+ import_->device_sync_);
}
Status ImportChild(const ArrayImporter* parent, struct ArrowArray* src) {
@@ -2041,6 +2044,23 @@ Status ExportStreamNext(const
std::shared_ptr<RecordBatchReader>& src, int64_t i
}
}
+// the int64_t i input here is unused, but exists simply to allow utilizing the
+// overload of this with the version for ChunkedArrays. If we removed the
int64_t
+// from the signature despite it being unused, we wouldn't be able to leverage
the
+// overloading in the templated exporters.
+Status ExportStreamNext(const std::shared_ptr<RecordBatchReader>& src, int64_t
i,
+ struct ArrowDeviceArray* out_array) {
+ std::shared_ptr<RecordBatch> batch;
+ RETURN_NOT_OK(src->ReadNext(&batch));
+ if (batch == nullptr) {
+ // End of stream
+ ArrowArrayMarkReleased(&out_array->array);
+ return Status::OK();
+ } else {
+ return ExportDeviceRecordBatch(*batch, batch->GetSyncEvent(), out_array);
+ }
+}
+
Status ExportStreamNext(const std::shared_ptr<ChunkedArray>& src, int64_t i,
struct ArrowArray* out_array) {
if (i >= src->num_chunks()) {
@@ -2052,8 +2072,27 @@ Status ExportStreamNext(const
std::shared_ptr<ChunkedArray>& src, int64_t i,
}
}
-template <typename T>
+Status ExportStreamNext(const std::shared_ptr<ChunkedArray>& src, int64_t i,
+ struct ArrowDeviceArray* out_array) {
+ if (i >= src->num_chunks()) {
+ // End of stream
+ ArrowArrayMarkReleased(&out_array->array);
+ return Status::OK();
+ } else {
+ return ExportDeviceArray(*src->chunk(static_cast<int>(i)), nullptr,
out_array);
+ }
+}
+
+template <typename T, bool IsDevice>
class ExportedArrayStream {
+ using StreamTraits =
+ std::conditional_t<IsDevice, internal::ArrayDeviceStreamExportTraits,
+ internal::ArrayStreamExportTraits>;
+ using StreamType = typename StreamTraits::CType;
+ using ArrayTraits = std::conditional_t<IsDevice,
internal::ArrayDeviceExportTraits,
+ internal::ArrayExportTraits>;
+ using ArrayType = typename ArrayTraits::CType;
+
public:
struct PrivateData {
explicit PrivateData(std::shared_ptr<T> reader)
@@ -2067,13 +2106,13 @@ class ExportedArrayStream {
ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateData);
};
- explicit ExportedArrayStream(struct ArrowArrayStream* stream) :
stream_(stream) {}
+ explicit ExportedArrayStream(StreamType* stream) : stream_(stream) {}
Status GetSchema(struct ArrowSchema* out_schema) {
return ExportStreamSchema(reader(), out_schema);
}
- Status GetNext(struct ArrowArray* out_array) {
+ Status GetNext(ArrayType* out_array) {
return ExportStreamNext(reader(), next_batch_num(), out_array);
}
@@ -2083,38 +2122,35 @@ class ExportedArrayStream {
}
void Release() {
- if (ArrowArrayStreamIsReleased(stream_)) {
+ if (StreamTraits::IsReleasedFunc(stream_)) {
return;
}
+
DCHECK_NE(private_data(), nullptr);
delete private_data();
- ArrowArrayStreamMarkReleased(stream_);
+ StreamTraits::MarkReleased(stream_);
}
// C-compatible callbacks
- static int StaticGetSchema(struct ArrowArrayStream* stream,
- struct ArrowSchema* out_schema) {
+ static int StaticGetSchema(StreamType* stream, struct ArrowSchema*
out_schema) {
ExportedArrayStream self{stream};
return self.ToCError(self.GetSchema(out_schema));
}
- static int StaticGetNext(struct ArrowArrayStream* stream,
- struct ArrowArray* out_array) {
+ static int StaticGetNext(StreamType* stream, ArrayType* out_array) {
ExportedArrayStream self{stream};
return self.ToCError(self.GetNext(out_array));
}
- static void StaticRelease(struct ArrowArrayStream* stream) {
- ExportedArrayStream{stream}.Release();
- }
+ static void StaticRelease(StreamType* stream) {
ExportedArrayStream{stream}.Release(); }
- static const char* StaticGetLastError(struct ArrowArrayStream* stream) {
+ static const char* StaticGetLastError(StreamType* stream) {
return ExportedArrayStream{stream}.GetLastError();
}
- static Status Make(std::shared_ptr<T> reader, struct ArrowArrayStream* out) {
+ static Status Make(std::shared_ptr<T> reader, StreamType* out) {
out->get_schema = ExportedArrayStream::StaticGetSchema;
out->get_next = ExportedArrayStream::StaticGetNext;
out->get_last_error = ExportedArrayStream::StaticGetLastError;
@@ -2150,19 +2186,36 @@ class ExportedArrayStream {
int64_t next_batch_num() { return private_data()->batch_num_++; }
- struct ArrowArrayStream* stream_;
+ StreamType* stream_;
};
} // namespace
Status ExportRecordBatchReader(std::shared_ptr<RecordBatchReader> reader,
struct ArrowArrayStream* out) {
- return ExportedArrayStream<RecordBatchReader>::Make(std::move(reader), out);
+ memset(out, 0, sizeof(struct ArrowArrayStream));
+ return ExportedArrayStream<RecordBatchReader,
false>::Make(std::move(reader), out);
}
Status ExportChunkedArray(std::shared_ptr<ChunkedArray> chunked_array,
struct ArrowArrayStream* out) {
- return ExportedArrayStream<ChunkedArray>::Make(std::move(chunked_array),
out);
+ memset(out, 0, sizeof(struct ArrowArrayStream));
+ return ExportedArrayStream<ChunkedArray,
false>::Make(std::move(chunked_array), out);
+}
+
+Status ExportDeviceRecordBatchReader(std::shared_ptr<RecordBatchReader> reader,
+ struct ArrowDeviceArrayStream* out) {
+ memset(out, 0, sizeof(struct ArrowDeviceArrayStream));
+ out->device_type = static_cast<ArrowDeviceType>(reader->device_type());
+ return ExportedArrayStream<RecordBatchReader, true>::Make(std::move(reader),
out);
+}
+
+Status ExportDeviceChunkedArray(std::shared_ptr<ChunkedArray> chunked_array,
+ DeviceAllocationType device_type,
+ struct ArrowDeviceArrayStream* out) {
+ memset(out, 0, sizeof(struct ArrowDeviceArrayStream));
+ out->device_type = static_cast<ArrowDeviceType>(device_type);
+ return ExportedArrayStream<ChunkedArray,
true>::Make(std::move(chunked_array), out);
}
//////////////////////////////////////////////////////////////////////////
@@ -2170,33 +2223,65 @@ Status ExportChunkedArray(std::shared_ptr<ChunkedArray>
chunked_array,
namespace {
+template <bool IsDevice>
class ArrayStreamReader {
+ protected:
+ using StreamTraits =
+ std::conditional_t<IsDevice, internal::ArrayDeviceStreamExportTraits,
+ internal::ArrayStreamExportTraits>;
+ using StreamType = typename StreamTraits::CType;
+ using ArrayTraits = std::conditional_t<IsDevice,
internal::ArrayDeviceExportTraits,
+ internal::ArrayExportTraits>;
+ using ArrayType = typename ArrayTraits::CType;
+
public:
- explicit ArrayStreamReader(struct ArrowArrayStream* stream) {
- ArrowArrayStreamMove(stream, &stream_);
- DCHECK(!ArrowArrayStreamIsReleased(&stream_));
+ explicit ArrayStreamReader(StreamType* stream,
+ const DeviceMemoryMapper mapper =
DefaultDeviceMemoryMapper)
+ : mapper_{std::move(mapper)} {
+ StreamTraits::MoveFunc(stream, &stream_);
+ DCHECK(!StreamTraits::IsReleasedFunc(&stream_));
}
~ArrayStreamReader() { ReleaseStream(); }
void ReleaseStream() {
- if (!ArrowArrayStreamIsReleased(&stream_)) {
- ArrowArrayStreamRelease(&stream_);
- }
- DCHECK(ArrowArrayStreamIsReleased(&stream_));
+ // all our trait release funcs check IsReleased so we don't
+ // need to repeat it here
+ StreamTraits::ReleaseFunc(&stream_);
+ DCHECK(StreamTraits::IsReleasedFunc(&stream_));
}
protected:
- Status ReadNextArrayInternal(struct ArrowArray* array) {
- ArrowArrayMarkReleased(array);
+ Status ReadNextArrayInternal(ArrayType* array) {
+ ArrayTraits::MarkReleased(array);
Status status = StatusFromCError(stream_.get_next(&stream_, array));
- if (!status.ok() && !ArrowArrayIsReleased(array)) {
- ArrowArrayRelease(array);
+ if (!status.ok()) {
+ ArrayTraits::ReleaseFunc(array);
}
return status;
}
+ Result<std::shared_ptr<RecordBatch>> ImportRecordBatchInternal(
+ struct ArrowArray* array, std::shared_ptr<Schema> schema) {
+ return ImportRecordBatch(array, schema);
+ }
+
+ Result<std::shared_ptr<RecordBatch>> ImportRecordBatchInternal(
+ struct ArrowDeviceArray* array, std::shared_ptr<Schema> schema) {
+ return ImportDeviceRecordBatch(array, schema, mapper_);
+ }
+
+ Result<std::shared_ptr<Array>> ImportArrayInternal(
+ struct ArrowArray* array, std::shared_ptr<arrow::DataType> type) {
+ return ImportArray(array, type);
+ }
+
+ Result<std::shared_ptr<Array>> ImportArrayInternal(
+ struct ArrowDeviceArray* array, std::shared_ptr<arrow::DataType> type) {
+ return ImportDeviceArray(array, type, mapper_);
+ }
+
Result<std::shared_ptr<Schema>> ReadSchema() {
struct ArrowSchema c_schema = {};
ARROW_RETURN_NOT_OK(
@@ -2214,19 +2299,19 @@ class ArrayStreamReader {
}
Status CheckNotReleased() {
- if (ArrowArrayStreamIsReleased(&stream_)) {
+ if (StreamTraits::IsReleasedFunc(&stream_)) {
return Status::Invalid(
"Attempt to read from a stream that has already been closed");
- } else {
- return Status::OK();
}
+
+ return Status::OK();
}
Status StatusFromCError(int errno_like) const {
return StatusFromCError(&stream_, errno_like);
}
- static Status StatusFromCError(struct ArrowArrayStream* stream, int
errno_like) {
+ static Status StatusFromCError(StreamType* stream, int errno_like) {
if (ARROW_PREDICT_TRUE(errno_like == 0)) {
return Status::OK();
}
@@ -2250,70 +2335,102 @@ class ArrayStreamReader {
return {code, last_error ? std::string(last_error) : ""};
}
+ DeviceAllocationType get_device_type() const {
+ if constexpr (IsDevice) {
+ return static_cast<DeviceAllocationType>(stream_.device_type);
+ } else {
+ return DeviceAllocationType::kCPU;
+ }
+ }
+
private:
- mutable struct ArrowArrayStream stream_;
+ mutable StreamType stream_;
+ const DeviceMemoryMapper mapper_;
};
-class ArrayStreamBatchReader : public RecordBatchReader, public
ArrayStreamReader {
+template <bool IsDevice>
+class ArrayStreamBatchReader : public RecordBatchReader,
+ public ArrayStreamReader<IsDevice> {
+ using StreamTraits =
+ std::conditional_t<IsDevice, internal::ArrayDeviceStreamExportTraits,
+ internal::ArrayStreamExportTraits>;
+ using StreamType = typename StreamTraits::CType;
+ using ArrayTraits = std::conditional_t<IsDevice,
internal::ArrayDeviceExportTraits,
+ internal::ArrayExportTraits>;
+ using ArrayType = typename ArrayTraits::CType;
+
public:
- explicit ArrayStreamBatchReader(struct ArrowArrayStream* stream)
- : ArrayStreamReader(stream) {}
+ explicit ArrayStreamBatchReader(
+ StreamType* stream, const DeviceMemoryMapper& mapper =
DefaultDeviceMemoryMapper)
+ : ArrayStreamReader<IsDevice>(stream, mapper) {}
Status Init() {
- ARROW_ASSIGN_OR_RAISE(schema_, ReadSchema());
+ ARROW_ASSIGN_OR_RAISE(schema_, this->ReadSchema());
return Status::OK();
}
std::shared_ptr<Schema> schema() const override { return schema_; }
Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
- ARROW_RETURN_NOT_OK(CheckNotReleased());
+ ARROW_RETURN_NOT_OK(this->CheckNotReleased());
- struct ArrowArray c_array;
- ARROW_RETURN_NOT_OK(ReadNextArrayInternal(&c_array));
+ ArrayType c_array;
+ ARROW_RETURN_NOT_OK(this->ReadNextArrayInternal(&c_array));
- if (ArrowArrayIsReleased(&c_array)) {
+ if (ArrayTraits::IsReleasedFunc(&c_array)) {
// End of stream
batch->reset();
return Status::OK();
} else {
- return ImportRecordBatch(&c_array, schema_).Value(batch);
+ return this->ImportRecordBatchInternal(&c_array, schema_).Value(batch);
}
}
Status Close() override {
- ReleaseStream();
+ this->ReleaseStream();
return Status::OK();
}
+ DeviceAllocationType device_type() const override { return
this->get_device_type(); }
+
private:
std::shared_ptr<Schema> schema_;
};
-class ArrayStreamArrayReader : public ArrayStreamReader {
+template <bool IsDevice>
+class ArrayStreamArrayReader : public ArrayStreamReader<IsDevice> {
+ using StreamTraits =
+ std::conditional_t<IsDevice, internal::ArrayDeviceStreamExportTraits,
+ internal::ArrayStreamExportTraits>;
+ using StreamType = typename StreamTraits::CType;
+ using ArrayTraits = std::conditional_t<IsDevice,
internal::ArrayDeviceExportTraits,
+ internal::ArrayExportTraits>;
+ using ArrayType = typename ArrayTraits::CType;
+
public:
- explicit ArrayStreamArrayReader(struct ArrowArrayStream* stream)
- : ArrayStreamReader(stream) {}
+ explicit ArrayStreamArrayReader(
+ StreamType* stream, const DeviceMemoryMapper& mapper =
DefaultDeviceMemoryMapper)
+ : ArrayStreamReader<IsDevice>(stream, mapper) {}
Status Init() {
- ARROW_ASSIGN_OR_RAISE(field_, ReadField());
+ ARROW_ASSIGN_OR_RAISE(field_, this->ReadField());
return Status::OK();
}
std::shared_ptr<DataType> data_type() const { return field_->type(); }
Status ReadNext(std::shared_ptr<Array>* array) {
- ARROW_RETURN_NOT_OK(CheckNotReleased());
+ ARROW_RETURN_NOT_OK(this->CheckNotReleased());
- struct ArrowArray c_array;
- ARROW_RETURN_NOT_OK(ReadNextArrayInternal(&c_array));
+ ArrayType c_array;
+ ARROW_RETURN_NOT_OK(this->ReadNextArrayInternal(&c_array));
- if (ArrowArrayIsReleased(&c_array)) {
+ if (ArrayTraits::IsReleasedFunc(&c_array)) {
// End of stream
array->reset();
return Status::OK();
} else {
- return ImportArray(&c_array, field_->type()).Value(array);
+ return this->ImportArrayInternal(&c_array, field_->type()).Value(array);
}
}
@@ -2321,30 +2438,35 @@ class ArrayStreamArrayReader : public ArrayStreamReader
{
std::shared_ptr<Field> field_;
};
-} // namespace
-
-Result<std::shared_ptr<RecordBatchReader>> ImportRecordBatchReader(
- struct ArrowArrayStream* stream) {
- if (ArrowArrayStreamIsReleased(stream)) {
- return Status::Invalid("Cannot import released ArrowArrayStream");
+template <bool IsDevice, typename StreamTraits = std::conditional_t<
+ IsDevice, internal::ArrayDeviceStreamExportTraits,
+ internal::ArrayStreamExportTraits>>
+Result<std::shared_ptr<RecordBatchReader>> ImportReader(
+ typename StreamTraits::CType* stream,
+ const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper) {
+ if (StreamTraits::IsReleasedFunc(stream)) {
+ return Status::Invalid("Cannot import released Arrow Stream");
}
- auto reader = std::make_shared<ArrayStreamBatchReader>(stream);
+ auto reader = std::make_shared<ArrayStreamBatchReader<IsDevice>>(stream,
mapper);
ARROW_RETURN_NOT_OK(reader->Init());
return reader;
}
-Result<std::shared_ptr<ChunkedArray>> ImportChunkedArray(
- struct ArrowArrayStream* stream) {
- if (ArrowArrayStreamIsReleased(stream)) {
- return Status::Invalid("Cannot import released ArrowArrayStream");
+template <bool IsDevice, typename StreamTraits = std::conditional_t<
+ IsDevice, internal::ArrayDeviceStreamExportTraits,
+ internal::ArrayStreamExportTraits>>
+Result<std::shared_ptr<ChunkedArray>> ImportChunked(
+ typename StreamTraits::CType* stream,
+ const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper) {
+ if (StreamTraits::IsReleasedFunc(stream)) {
+ return Status::Invalid("Cannot import released Arrow Stream");
}
- auto reader = std::make_shared<ArrayStreamArrayReader>(stream);
+ auto reader = std::make_shared<ArrayStreamArrayReader<IsDevice>>(stream,
mapper);
ARROW_RETURN_NOT_OK(reader->Init());
- std::shared_ptr<DataType> data_type = reader->data_type();
-
+ auto data_type = reader->data_type();
ArrayVector chunks;
std::shared_ptr<Array> chunk;
while (true) {
@@ -2360,4 +2482,26 @@ Result<std::shared_ptr<ChunkedArray>> ImportChunkedArray(
return ChunkedArray::Make(std::move(chunks), std::move(data_type));
}
+} // namespace
+
+Result<std::shared_ptr<RecordBatchReader>> ImportRecordBatchReader(
+ struct ArrowArrayStream* stream) {
+ return ImportReader</*IsDevice=*/false>(stream);
+}
+
+Result<std::shared_ptr<RecordBatchReader>> ImportDeviceRecordBatchReader(
+ struct ArrowDeviceArrayStream* stream, const DeviceMemoryMapper& mapper) {
+ return ImportReader</*IsDevice=*/true>(stream, mapper);
+}
+
+Result<std::shared_ptr<ChunkedArray>> ImportChunkedArray(
+ struct ArrowArrayStream* stream) {
+ return ImportChunked</*IsDevice=*/false>(stream);
+}
+
+Result<std::shared_ptr<ChunkedArray>> ImportDeviceChunkedArray(
+ struct ArrowDeviceArrayStream* stream, const DeviceMemoryMapper& mapper) {
+ return ImportChunked</*IsDevice=*/true>(stream, mapper);
+}
+
} // namespace arrow
diff --git a/cpp/src/arrow/c/bridge.h b/cpp/src/arrow/c/bridge.h
index 74a302be4c..45367e4f93 100644
--- a/cpp/src/arrow/c/bridge.h
+++ b/cpp/src/arrow/c/bridge.h
@@ -321,6 +321,31 @@ ARROW_EXPORT
Status ExportChunkedArray(std::shared_ptr<ChunkedArray> chunked_array,
struct ArrowArrayStream* out);
+/// \brief Export C++ RecordBatchReader using the C device stream interface
+///
+/// The resulting ArrowDeviceArrayStream struct keeps the record batch reader
+/// alive until its release callback is called by the consumer. The device
+/// type is determined by calling device_type() on the RecordBatchReader.
+///
+/// \param[in] reader RecordBatchReader object to export
+/// \param[out] out C struct to export the stream to
+ARROW_EXPORT
+Status ExportDeviceRecordBatchReader(std::shared_ptr<RecordBatchReader> reader,
+ struct ArrowDeviceArrayStream* out);
+
+/// \brief Export C++ ChunkedArray using the C device data interface format.
+///
+/// The resulting ArrowDeviceArrayStream keeps the chunked array data and
buffers
+/// alive until its release callback is called by the consumer.
+///
+/// \param[in] chunked_array ChunkedArray object to export
+/// \param[in] device_type the device type the data is located on
+/// \param[out] out C struct to export the stream to
+ARROW_EXPORT
+Status ExportDeviceChunkedArray(std::shared_ptr<ChunkedArray> chunked_array,
+ DeviceAllocationType device_type,
+ struct ArrowDeviceArrayStream* out);
+
/// \brief Import C++ RecordBatchReader from the C stream interface.
///
/// The ArrowArrayStream struct has its contents moved to a private object
@@ -343,6 +368,42 @@ Result<std::shared_ptr<RecordBatchReader>>
ImportRecordBatchReader(
ARROW_EXPORT
Result<std::shared_ptr<ChunkedArray>> ImportChunkedArray(struct
ArrowArrayStream* stream);
+/// \brief Import C++ RecordBatchReader from the C device stream interface
+///
+/// The ArrowDeviceArrayStream struct has its contents moved to a private
object
+/// held alive by the resulting record batch reader.
+///
+/// \note If there was a required sync event, sync events are accessible by
individual
+/// buffers of columns. We are not yet bubbling the sync events from the
buffers up to
+/// the `GetSyncEvent` method of an imported RecordBatch. This will be added
in a future
+/// update.
+///
+/// \param[in,out] stream C device stream interface struct
+/// \param[in] mapper mapping from device type and ID to memory manager
+/// \return Imported RecordBatchReader object
+ARROW_EXPORT
+Result<std::shared_ptr<RecordBatchReader>> ImportDeviceRecordBatchReader(
+ struct ArrowDeviceArrayStream* stream,
+ const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper);
+
+/// \brief Import C++ ChunkedArray from the C device stream interface
+///
+/// The ArrowDeviceArrayStream struct has its contents moved to a private
object,
+/// is consumed in its entirety, and released before returning all chunks as a
+/// ChunkedArray.
+///
+/// \note Any chunks that require synchronization for their device memory will
have
+/// the SyncEvent objects available by checking the individual buffers of each
chunk.
+/// These SyncEvents should be checked before accessing the data in those
buffers.
+///
+/// \param[in,out] stream C device stream interface struct
+/// \param[in] mapper mapping from device type and ID to memory manager
+/// \return Imported ChunkedArray object
+ARROW_EXPORT
+Result<std::shared_ptr<ChunkedArray>> ImportDeviceChunkedArray(
+ struct ArrowDeviceArrayStream* stream,
+ const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper);
+
/// @}
} // namespace arrow
diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc
index d64fe67acc..0ecfb5a957 100644
--- a/cpp/src/arrow/c/bridge_test.cc
+++ b/cpp/src/arrow/c/bridge_test.cc
@@ -53,11 +53,15 @@
namespace arrow {
+using internal::ArrayDeviceExportTraits;
+using internal::ArrayDeviceStreamExportTraits;
using internal::ArrayExportGuard;
using internal::ArrayExportTraits;
using internal::ArrayStreamExportGuard;
using internal::ArrayStreamExportTraits;
using internal::checked_cast;
+using internal::DeviceArrayExportGuard;
+using internal::DeviceArrayStreamExportGuard;
using internal::SchemaExportGuard;
using internal::SchemaExportTraits;
using internal::Zip;
@@ -4746,4 +4750,516 @@ TEST_F(TestArrayStreamRoundtrip,
ChunkedArrayRoundtripEmpty) {
});
}
+////////////////////////////////////////////////////////////////////////////
+// Array device stream export tests
+
+class TestArrayDeviceStreamExport : public BaseArrayStreamTest {
+ public:
+ void AssertStreamSchema(struct ArrowDeviceArrayStream* c_stream,
+ const Schema& expected) {
+ struct ArrowSchema c_schema;
+ ASSERT_EQ(0, c_stream->get_schema(c_stream, &c_schema));
+
+ SchemaExportGuard schema_guard(&c_schema);
+ ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema));
+ ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
+ AssertSchemaEqual(expected, *schema, /*check_metadata=*/true);
+ }
+
+ void AssertStreamEnd(struct ArrowDeviceArrayStream* c_stream) {
+ struct ArrowDeviceArray c_array;
+ ASSERT_EQ(0, c_stream->get_next(c_stream, &c_array));
+
+ DeviceArrayExportGuard guard(&c_array);
+ ASSERT_TRUE(ArrowDeviceArrayIsReleased(&c_array));
+ }
+
+ void AssertStreamNext(struct ArrowDeviceArrayStream* c_stream,
+ const RecordBatch& expected) {
+ struct ArrowDeviceArray c_array;
+ ASSERT_EQ(0, c_stream->get_next(c_stream, &c_array));
+
+ DeviceArrayExportGuard guard(&c_array);
+ ASSERT_FALSE(ArrowDeviceArrayIsReleased(&c_array));
+
+ ASSERT_OK_AND_ASSIGN(auto batch,
+ ImportDeviceRecordBatch(&c_array, expected.schema(),
+
TestDeviceArrayRoundtrip::DeviceMapper));
+ AssertBatchesEqual(expected, *batch);
+ }
+
+ void AssertStreamNext(struct ArrowDeviceArrayStream* c_stream, const Array&
expected) {
+ struct ArrowDeviceArray c_array;
+ ASSERT_EQ(0, c_stream->get_next(c_stream, &c_array));
+
+ DeviceArrayExportGuard guard(&c_array);
+ ASSERT_FALSE(ArrowDeviceArrayIsReleased(&c_array));
+
+ ASSERT_OK_AND_ASSIGN(auto array,
+ ImportDeviceArray(&c_array, expected.type(),
+
TestDeviceArrayRoundtrip::DeviceMapper));
+ AssertArraysEqual(expected, *array);
+ }
+
+ static Result<std::shared_ptr<ArrayData>> ToDeviceData(
+ const std::shared_ptr<MemoryManager>& mm, const ArrayData& data) {
+ arrow::BufferVector buffers;
+ for (const auto& buf : data.buffers) {
+ if (buf) {
+ ARROW_ASSIGN_OR_RAISE(auto dest, mm->CopyBuffer(buf, mm));
+ buffers.push_back(dest);
+ } else {
+ buffers.push_back(nullptr);
+ }
+ }
+
+ arrow::ArrayDataVector children;
+ for (const auto& child : data.child_data) {
+ ARROW_ASSIGN_OR_RAISE(auto dest, ToDeviceData(mm, *child));
+ children.push_back(dest);
+ }
+
+ return ArrayData::Make(data.type, data.length, buffers, children,
data.null_count,
+ data.offset);
+ }
+
+ static Result<std::shared_ptr<Array>> ToDevice(const
std::shared_ptr<MemoryManager>& mm,
+ const ArrayData& data) {
+ ARROW_ASSIGN_OR_RAISE(auto result, ToDeviceData(mm, data));
+ return MakeArray(result);
+ }
+};
+
+TEST_F(TestArrayDeviceStreamExport, Empty) {
+ auto schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(schema, {});
+ ASSERT_OK_AND_ASSIGN(
+ auto reader,
+ RecordBatchReader::Make(batches, schema,
+
static_cast<DeviceAllocationType>(kMyDeviceType)));
+
+ struct ArrowDeviceArrayStream c_stream;
+ ASSERT_OK(ExportDeviceRecordBatchReader(reader, &c_stream));
+ DeviceArrayStreamExportGuard guard(&c_stream);
+
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+ ASSERT_EQ(kMyDeviceType, c_stream.device_type);
+ AssertStreamSchema(&c_stream, *schema);
+ AssertStreamEnd(&c_stream);
+ AssertStreamEnd(&c_stream);
+}
+
+TEST_F(TestArrayDeviceStreamExport, Simple) {
+ std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+ auto mm = device->default_memory_manager();
+
+ ASSERT_OK_AND_ASSIGN(auto arr1,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[1,
2]")->data()));
+ ASSERT_EQ(device->device_type(), arr1->device_type());
+ ASSERT_OK_AND_ASSIGN(auto arr2,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5,
null]")->data()));
+ ASSERT_EQ(device->device_type(), arr2->device_type());
+ auto schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(schema, {arr1, arr2});
+ ASSERT_OK_AND_ASSIGN(auto reader,
+ RecordBatchReader::Make(batches, schema,
device->device_type()));
+
+ struct ArrowDeviceArrayStream c_stream;
+
+ ASSERT_OK(ExportDeviceRecordBatchReader(reader, &c_stream));
+ DeviceArrayStreamExportGuard guard(&c_stream);
+
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+ AssertStreamSchema(&c_stream, *schema);
+ ASSERT_EQ(kMyDeviceType, c_stream.device_type);
+ AssertStreamNext(&c_stream, *batches[0]);
+ AssertStreamNext(&c_stream, *batches[1]);
+ AssertStreamEnd(&c_stream);
+ AssertStreamEnd(&c_stream);
+}
+
+TEST_F(TestArrayDeviceStreamExport, ArrayLifetime) {
+ std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+ auto mm = device->default_memory_manager();
+
+ ASSERT_OK_AND_ASSIGN(auto arr1,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[1,
2]")->data()));
+ ASSERT_EQ(device->device_type(), arr1->device_type());
+ ASSERT_OK_AND_ASSIGN(auto arr2,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5,
null]")->data()));
+ ASSERT_EQ(device->device_type(), arr2->device_type());
+ auto schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(schema, {arr1, arr2});
+ ASSERT_OK_AND_ASSIGN(auto reader,
+ RecordBatchReader::Make(batches, schema,
device->device_type()));
+
+ struct ArrowDeviceArrayStream c_stream;
+ struct ArrowSchema c_schema;
+ struct ArrowDeviceArray c_array0, c_array1;
+
+ ASSERT_OK(ExportDeviceRecordBatchReader(reader, &c_stream));
+ {
+ DeviceArrayStreamExportGuard guard(&c_stream);
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+ ASSERT_EQ(kMyDeviceType, c_stream.device_type);
+
+ ASSERT_EQ(0, c_stream.get_schema(&c_stream, &c_schema));
+ ASSERT_EQ(0, c_stream.get_next(&c_stream, &c_array0));
+ ASSERT_EQ(0, c_stream.get_next(&c_stream, &c_array1));
+ AssertStreamEnd(&c_stream);
+ }
+
+ DeviceArrayExportGuard guard0(&c_array0), guard1(&c_array1);
+
+ {
+ SchemaExportGuard schema_guard(&c_schema);
+ ASSERT_OK_AND_ASSIGN(auto got_schema, ImportSchema(&c_schema));
+ AssertSchemaEqual(*schema, *got_schema, /*check_metadata=*/true);
+ }
+
+ ASSERT_EQ(kMyDeviceType, c_array0.device_type);
+ ASSERT_EQ(kMyDeviceType, c_array1.device_type);
+
+ ASSERT_GT(pool_->bytes_allocated(), orig_allocated_);
+ ASSERT_OK_AND_ASSIGN(
+ auto batch,
+ ImportDeviceRecordBatch(&c_array1, schema,
TestDeviceArrayRoundtrip::DeviceMapper));
+ AssertBatchesEqual(*batches[1], *batch);
+ ASSERT_EQ(device->device_type(), batch->device_type());
+ ASSERT_OK_AND_ASSIGN(
+ batch,
+ ImportDeviceRecordBatch(&c_array0, schema,
TestDeviceArrayRoundtrip::DeviceMapper));
+ AssertBatchesEqual(*batches[0], *batch);
+ ASSERT_EQ(device->device_type(), batch->device_type());
+}
+
+TEST_F(TestArrayDeviceStreamExport, Errors) {
+ auto reader =
+ std::make_shared<FailingRecordBatchReader>(Status::Invalid("some example
error"));
+
+ struct ArrowDeviceArrayStream c_stream;
+
+ ASSERT_OK(ExportDeviceRecordBatchReader(reader, &c_stream));
+ DeviceArrayStreamExportGuard guard(&c_stream);
+
+ struct ArrowSchema c_schema;
+ ASSERT_EQ(0, c_stream.get_schema(&c_stream, &c_schema));
+ ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema));
+ {
+ SchemaExportGuard schema_guard(&c_schema);
+ ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
+ AssertSchemaEqual(schema, arrow::schema({}), /*check_metadata=*/true);
+ }
+
+ struct ArrowDeviceArray c_array;
+ ASSERT_EQ(EINVAL, c_stream.get_next(&c_stream, &c_array));
+}
+
+TEST_F(TestArrayDeviceStreamExport, ChunkedArrayExportEmpty) {
+ ASSERT_OK_AND_ASSIGN(auto chunked_array, ChunkedArray::Make({}, int32()));
+
+ struct ArrowDeviceArrayStream c_stream;
+ struct ArrowSchema c_schema;
+
+ ASSERT_OK(ExportDeviceChunkedArray(
+ chunked_array, static_cast<DeviceAllocationType>(kMyDeviceType),
&c_stream));
+ DeviceArrayStreamExportGuard guard(&c_stream);
+
+ {
+ DeviceArrayStreamExportGuard guard(&c_stream);
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+
+ ASSERT_EQ(kMyDeviceType, c_stream.device_type);
+ ASSERT_EQ(0, c_stream.get_schema(&c_stream, &c_schema));
+ AssertStreamEnd(&c_stream);
+ }
+
+ {
+ SchemaExportGuard schema_guard(&c_schema);
+ ASSERT_OK_AND_ASSIGN(auto got_type, ImportType(&c_schema));
+ AssertTypeEqual(*chunked_array->type(), *got_type);
+ }
+}
+
+TEST_F(TestArrayDeviceStreamExport, ChunkedArrayExport) {
+ std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+ auto mm = device->default_memory_manager();
+
+ ASSERT_OK_AND_ASSIGN(auto arr1,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[1,
2]")->data()));
+ ASSERT_EQ(device->device_type(), arr1->device_type());
+ ASSERT_OK_AND_ASSIGN(auto arr2,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5,
null]")->data()));
+ ASSERT_EQ(device->device_type(), arr2->device_type());
+
+ ASSERT_OK_AND_ASSIGN(auto chunked_array, ChunkedArray::Make({arr1, arr2}));
+
+ struct ArrowDeviceArrayStream c_stream;
+ struct ArrowSchema c_schema;
+ struct ArrowDeviceArray c_array0, c_array1;
+
+ ASSERT_OK(ExportDeviceChunkedArray(chunked_array, device->device_type(),
&c_stream));
+ DeviceArrayStreamExportGuard guard(&c_stream);
+
+ {
+ DeviceArrayStreamExportGuard guard(&c_stream);
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+ ASSERT_EQ(kMyDeviceType, c_stream.device_type);
+
+ ASSERT_EQ(0, c_stream.get_schema(&c_stream, &c_schema));
+ ASSERT_EQ(0, c_stream.get_next(&c_stream, &c_array0));
+ ASSERT_EQ(0, c_stream.get_next(&c_stream, &c_array1));
+ AssertStreamEnd(&c_stream);
+ }
+
+ DeviceArrayExportGuard guard0(&c_array0), guard1(&c_array1);
+
+ {
+ SchemaExportGuard schema_guard(&c_schema);
+ ASSERT_OK_AND_ASSIGN(auto got_type, ImportType(&c_schema));
+ AssertTypeEqual(*chunked_array->type(), *got_type);
+ }
+
+ ASSERT_EQ(kMyDeviceType, c_array0.device_type);
+ ASSERT_EQ(kMyDeviceType, c_array1.device_type);
+
+ ASSERT_GT(pool_->bytes_allocated(), orig_allocated_);
+ ASSERT_OK_AND_ASSIGN(auto array,
+ ImportDeviceArray(&c_array0, chunked_array->type(),
+
TestDeviceArrayRoundtrip::DeviceMapper));
+ ASSERT_EQ(device->device_type(), array->device_type());
+ AssertArraysEqual(*chunked_array->chunk(0), *array);
+ ASSERT_OK_AND_ASSIGN(array, ImportDeviceArray(&c_array1,
chunked_array->type(),
+
TestDeviceArrayRoundtrip::DeviceMapper));
+ ASSERT_EQ(device->device_type(), array->device_type());
+ AssertArraysEqual(*chunked_array->chunk(1), *array);
+}
+
+////////////////////////////////////////////////////////////////////////////
+// Array device stream roundtrip tests
+
+class TestArrayDeviceStreamRoundtrip : public BaseArrayStreamTest {
+ public:
+ static Result<std::shared_ptr<ArrayData>> ToDeviceData(
+ const std::shared_ptr<MemoryManager>& mm, const ArrayData& data) {
+ arrow::BufferVector buffers;
+ for (const auto& buf : data.buffers) {
+ if (buf) {
+ ARROW_ASSIGN_OR_RAISE(auto dest, mm->CopyBuffer(buf, mm));
+ buffers.push_back(dest);
+ } else {
+ buffers.push_back(nullptr);
+ }
+ }
+
+ arrow::ArrayDataVector children;
+ for (const auto& child : data.child_data) {
+ ARROW_ASSIGN_OR_RAISE(auto dest, ToDeviceData(mm, *child));
+ children.push_back(dest);
+ }
+
+ return ArrayData::Make(data.type, data.length, buffers, children,
data.null_count,
+ data.offset);
+ }
+
+ static Result<std::shared_ptr<Array>> ToDevice(const
std::shared_ptr<MemoryManager>& mm,
+ const ArrayData& data) {
+ ARROW_ASSIGN_OR_RAISE(auto result, ToDeviceData(mm, data));
+ return MakeArray(result);
+ }
+
+ void Roundtrip(std::shared_ptr<RecordBatchReader>* reader,
+ struct ArrowDeviceArrayStream* c_stream) {
+ ASSERT_OK(ExportDeviceRecordBatchReader(*reader, c_stream));
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(c_stream));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto got_reader,
+ ImportDeviceRecordBatchReader(c_stream,
TestDeviceArrayRoundtrip::DeviceMapper));
+ *reader = std::move(got_reader);
+ }
+
+ void Roundtrip(
+ std::shared_ptr<RecordBatchReader> reader,
+ std::function<void(const std::shared_ptr<RecordBatchReader>&)>
check_func) {
+ ArrowDeviceArrayStream c_stream;
+
+ // NOTE: ReleaseCallback<> is not immediately usable with
ArrowDeviceArayStream
+ // because get_next and get_schema need the original private_data.
+ std::weak_ptr<RecordBatchReader> weak_reader(reader);
+ ASSERT_EQ(weak_reader.use_count(), 1); // Expiration check will fail
otherwise
+
+ ASSERT_OK(ExportDeviceRecordBatchReader(std::move(reader), &c_stream));
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+
+ {
+ ASSERT_OK_AND_ASSIGN(auto new_reader,
+ ImportDeviceRecordBatchReader(
+ &c_stream,
TestDeviceArrayRoundtrip::DeviceMapper));
+ // stream was moved
+ ASSERT_TRUE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+ ASSERT_FALSE(weak_reader.expired());
+
+ check_func(new_reader);
+ }
+ // Stream was released when `new_reader` was destroyed
+ ASSERT_TRUE(weak_reader.expired());
+ }
+
+ void Roundtrip(std::shared_ptr<ChunkedArray> src,
+ std::function<void(const std::shared_ptr<ChunkedArray>&)>
check_func) {
+ ArrowDeviceArrayStream c_stream;
+
+ // One original copy to compare the result, one copy held by the stream
+ std::weak_ptr<ChunkedArray> weak_src(src);
+ int64_t initial_use_count = weak_src.use_count();
+
+ ASSERT_OK(ExportDeviceChunkedArray(
+ std::move(src), static_cast<DeviceAllocationType>(kMyDeviceType),
&c_stream));
+ ASSERT_FALSE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+ ASSERT_EQ(kMyDeviceType, c_stream.device_type);
+
+ {
+ ASSERT_OK_AND_ASSIGN(
+ auto dst,
+ ImportDeviceChunkedArray(&c_stream,
TestDeviceArrayRoundtrip::DeviceMapper));
+ // Stream was moved, consumed, and released
+ ASSERT_TRUE(ArrowDeviceArrayStreamIsReleased(&c_stream));
+
+ // Stream was released by ImportDeviceChunkedArray but original copy
remains
+ ASSERT_EQ(weak_src.use_count(), initial_use_count - 1);
+
+ check_func(dst);
+ }
+ }
+
+ void AssertReaderNext(const std::shared_ptr<RecordBatchReader>& reader,
+ const RecordBatch& expected) {
+ ASSERT_OK_AND_ASSIGN(auto batch, reader->Next());
+ ASSERT_NE(batch, nullptr);
+ ASSERT_EQ(static_cast<DeviceAllocationType>(kMyDeviceType),
batch->device_type());
+ AssertBatchesEqual(expected, *batch);
+ }
+
+ void AssertReaderEnd(const std::shared_ptr<RecordBatchReader>& reader) {
+ ASSERT_OK_AND_ASSIGN(auto batch, reader->Next());
+ ASSERT_EQ(batch, nullptr);
+ }
+
+ void AssertReaderClosed(const std::shared_ptr<RecordBatchReader>& reader) {
+ ASSERT_THAT(reader->Next(),
+ Raises(StatusCode::Invalid, ::testing::HasSubstr("already been
closed")));
+ }
+
+ void AssertReaderClose(const std::shared_ptr<RecordBatchReader>& reader) {
+ ASSERT_OK(reader->Close());
+ AssertReaderClosed(reader);
+ }
+};
+
+TEST_F(TestArrayDeviceStreamRoundtrip, Simple) {
+ std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+ auto mm = device->default_memory_manager();
+
+ ASSERT_OK_AND_ASSIGN(auto arr1,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[1,
2]")->data()));
+ ASSERT_EQ(device->device_type(), arr1->device_type());
+ ASSERT_OK_AND_ASSIGN(auto arr2,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5,
null]")->data()));
+ ASSERT_EQ(device->device_type(), arr2->device_type());
+ auto orig_schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(orig_schema, {arr1, arr2});
+ ASSERT_OK_AND_ASSIGN(
+ auto reader, RecordBatchReader::Make(batches, orig_schema,
device->device_type()));
+
+ Roundtrip(std::move(reader), [&](const std::shared_ptr<RecordBatchReader>&
reader) {
+ AssertSchemaEqual(*orig_schema, *reader->schema(),
/*check_metadata=*/true);
+ AssertReaderNext(reader, *batches[0]);
+ AssertReaderNext(reader, *batches[1]);
+ AssertReaderEnd(reader);
+ AssertReaderEnd(reader);
+ AssertReaderClose(reader);
+ });
+}
+
+TEST_F(TestArrayDeviceStreamRoundtrip, CloseEarly) {
+ std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+ auto mm = device->default_memory_manager();
+
+ ASSERT_OK_AND_ASSIGN(auto arr1,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[1,
2]")->data()));
+ ASSERT_EQ(device->device_type(), arr1->device_type());
+ ASSERT_OK_AND_ASSIGN(auto arr2,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5,
null]")->data()));
+ ASSERT_EQ(device->device_type(), arr2->device_type());
+ auto orig_schema = arrow::schema({field("ints", int32())});
+ auto batches = MakeBatches(orig_schema, {arr1, arr2});
+ ASSERT_OK_AND_ASSIGN(
+ auto reader, RecordBatchReader::Make(batches, orig_schema,
device->device_type()));
+
+ Roundtrip(std::move(reader), [&](const std::shared_ptr<RecordBatchReader>&
reader) {
+ AssertReaderNext(reader, *batches[0]);
+ AssertReaderClose(reader);
+ });
+}
+
+TEST_F(TestArrayDeviceStreamRoundtrip, Errors) {
+ auto reader = std::make_shared<FailingRecordBatchReader>(
+ Status::Invalid("roundtrip error example"));
+
+ Roundtrip(std::move(reader), [&](const std::shared_ptr<RecordBatchReader>&
reader) {
+ EXPECT_THAT(reader->Next(), Raises(StatusCode::Invalid,
+ ::testing::HasSubstr("roundtrip error
example")));
+ });
+}
+
+TEST_F(TestArrayDeviceStreamRoundtrip, SchemaError) {
+ struct ArrowDeviceArrayStream stream = {};
+ stream.get_last_error = [](struct ArrowDeviceArrayStream* stream) {
+ return "Expected error";
+ };
+ stream.get_schema = [](struct ArrowDeviceArrayStream* stream,
+ struct ArrowSchema* schema) { return EIO; };
+ stream.get_next = [](struct ArrowDeviceArrayStream* stream,
+ struct ArrowDeviceArray* array) { return EINVAL; };
+ stream.release = [](struct ArrowDeviceArrayStream* stream) {
+ *static_cast<bool*>(stream->private_data) = true;
+ std::memset(stream, 0, sizeof(*stream));
+ };
+ bool released = false;
+ stream.private_data = &released;
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(IOError, ::testing::HasSubstr("Expected
error"),
+ ImportDeviceRecordBatchReader(&stream));
+ ASSERT_TRUE(released);
+}
+
+TEST_F(TestArrayDeviceStreamRoundtrip, ChunkedArrayRoundtrip) {
+ std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+ auto mm = device->default_memory_manager();
+
+ ASSERT_OK_AND_ASSIGN(auto arr1,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[1,
2]")->data()));
+ ASSERT_EQ(device->device_type(), arr1->device_type());
+ ASSERT_OK_AND_ASSIGN(auto arr2,
+ ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5,
null]")->data()));
+ ASSERT_EQ(device->device_type(), arr2->device_type());
+
+ ASSERT_OK_AND_ASSIGN(auto src, ChunkedArray::Make({arr1, arr2}));
+
+ Roundtrip(src, [&](const std::shared_ptr<ChunkedArray>& dst) {
+ AssertTypeEqual(*dst->type(), *src->type());
+ AssertChunkedEqual(*dst, *src);
+ });
+}
+
+TEST_F(TestArrayDeviceStreamRoundtrip, ChunkedArrayRoundtripEmpty) {
+ ASSERT_OK_AND_ASSIGN(auto src, ChunkedArray::Make({}, int32()));
+
+ Roundtrip(src, [&](const std::shared_ptr<ChunkedArray>& dst) {
+ AssertTypeEqual(*dst->type(), *src->type());
+ AssertChunkedEqual(*dst, *src);
+ });
+}
+
} // namespace arrow
diff --git a/cpp/src/arrow/c/helpers.h b/cpp/src/arrow/c/helpers.h
index a24f272fea..6e4df17f43 100644
--- a/cpp/src/arrow/c/helpers.h
+++ b/cpp/src/arrow/c/helpers.h
@@ -17,6 +17,7 @@
#pragma once
+#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
@@ -70,9 +71,17 @@ inline int ArrowArrayIsReleased(const struct ArrowArray*
array) {
return array->release == NULL;
}
+inline int ArrowDeviceArrayIsReleased(const struct ArrowDeviceArray* array) {
+ return ArrowArrayIsReleased(&array->array);
+}
+
/// Mark the C array released (for use in release callbacks)
inline void ArrowArrayMarkReleased(struct ArrowArray* array) { array->release
= NULL; }
+inline void ArrowDeviceArrayMarkReleased(struct ArrowDeviceArray* array) {
+ ArrowArrayMarkReleased(&array->array);
+}
+
/// Move the C array from `src` to `dest`
///
/// Note `dest` must *not* point to a valid array already, otherwise there
@@ -84,6 +93,14 @@ inline void ArrowArrayMove(struct ArrowArray* src, struct
ArrowArray* dest) {
ArrowArrayMarkReleased(src);
}
+inline void ArrowDeviceArrayMove(struct ArrowDeviceArray* src,
+ struct ArrowDeviceArray* dest) {
+ assert(dest != src);
+ assert(!ArrowDeviceArrayIsReleased(src));
+ memcpy(dest, src, sizeof(struct ArrowDeviceArray));
+ ArrowDeviceArrayMarkReleased(src);
+}
+
/// Release the C array, if necessary, by calling its release callback
inline void ArrowArrayRelease(struct ArrowArray* array) {
if (!ArrowArrayIsReleased(array)) {
@@ -93,16 +110,32 @@ inline void ArrowArrayRelease(struct ArrowArray* array) {
}
}
+inline void ArrowDeviceArrayRelease(struct ArrowDeviceArray* array) {
+ if (!ArrowDeviceArrayIsReleased(array)) {
+ array->array.release(&array->array);
+ ARROW_C_ASSERT(ArrowDeviceArrayIsReleased(array),
+ "ArrowDeviceArrayRelease did not cleanup release callback");
+ }
+}
+
/// Query whether the C array stream is released
inline int ArrowArrayStreamIsReleased(const struct ArrowArrayStream* stream) {
return stream->release == NULL;
}
+inline int ArrowDeviceArrayStreamIsReleased(const struct
ArrowDeviceArrayStream* stream) {
+ return stream->release == NULL;
+}
+
/// Mark the C array stream released (for use in release callbacks)
inline void ArrowArrayStreamMarkReleased(struct ArrowArrayStream* stream) {
stream->release = NULL;
}
+inline void ArrowDeviceArrayStreamMarkReleased(struct ArrowDeviceArrayStream*
stream) {
+ stream->release = NULL;
+}
+
/// Move the C array stream from `src` to `dest`
///
/// Note `dest` must *not* point to a valid stream already, otherwise there
@@ -115,6 +148,14 @@ inline void ArrowArrayStreamMove(struct ArrowArrayStream*
src,
ArrowArrayStreamMarkReleased(src);
}
+inline void ArrowDeviceArrayStreamMove(struct ArrowDeviceArrayStream* src,
+ struct ArrowDeviceArrayStream* dest) {
+ assert(dest != src);
+ assert(!ArrowDeviceArrayStreamIsReleased(src));
+ memcpy(dest, src, sizeof(struct ArrowDeviceArrayStream));
+ ArrowDeviceArrayStreamMarkReleased(src);
+}
+
/// Release the C array stream, if necessary, by calling its release callback
inline void ArrowArrayStreamRelease(struct ArrowArrayStream* stream) {
if (!ArrowArrayStreamIsReleased(stream)) {
@@ -124,6 +165,14 @@ inline void ArrowArrayStreamRelease(struct
ArrowArrayStream* stream) {
}
}
+inline void ArrowDeviceArrayStreamRelease(struct ArrowDeviceArrayStream*
stream) {
+ if (!ArrowDeviceArrayStreamIsReleased(stream)) {
+ stream->release(stream);
+ ARROW_C_ASSERT(ArrowDeviceArrayStreamIsReleased(stream),
+ "ArrowDeviceArrayStreamRelease did not cleanup release
callback");
+ }
+}
+
#ifdef __cplusplus
}
#endif
diff --git a/cpp/src/arrow/c/util_internal.h b/cpp/src/arrow/c/util_internal.h
index 6a33be9b0d..dc0e25710e 100644
--- a/cpp/src/arrow/c/util_internal.h
+++ b/cpp/src/arrow/c/util_internal.h
@@ -32,12 +32,32 @@ struct ArrayExportTraits {
typedef struct ArrowArray CType;
static constexpr auto IsReleasedFunc = &ArrowArrayIsReleased;
static constexpr auto ReleaseFunc = &ArrowArrayRelease;
+ static constexpr auto MoveFunc = &ArrowArrayMove;
+ static constexpr auto MarkReleased = &ArrowArrayMarkReleased;
+};
+
+struct ArrayDeviceExportTraits {
+ typedef struct ArrowDeviceArray CType;
+ static constexpr auto IsReleasedFunc = &ArrowDeviceArrayIsReleased;
+ static constexpr auto ReleaseFunc = &ArrowDeviceArrayRelease;
+ static constexpr auto MoveFunc = &ArrowDeviceArrayMove;
+ static constexpr auto MarkReleased = &ArrowDeviceArrayMarkReleased;
};
struct ArrayStreamExportTraits {
typedef struct ArrowArrayStream CType;
static constexpr auto IsReleasedFunc = &ArrowArrayStreamIsReleased;
static constexpr auto ReleaseFunc = &ArrowArrayStreamRelease;
+ static constexpr auto MoveFunc = &ArrowArrayStreamMove;
+ static constexpr auto MarkReleased = &ArrowArrayStreamMarkReleased;
+};
+
+struct ArrayDeviceStreamExportTraits {
+ typedef struct ArrowDeviceArrayStream CType;
+ static constexpr auto IsReleasedFunc = &ArrowDeviceArrayStreamIsReleased;
+ static constexpr auto ReleaseFunc = &ArrowDeviceArrayStreamRelease;
+ static constexpr auto MoveFunc = &ArrowDeviceArrayStreamMove;
+ static constexpr auto MarkReleased = &ArrowDeviceArrayStreamMarkReleased;
};
// A RAII-style object to release a C Array / Schema struct at block scope
exit.
@@ -79,7 +99,9 @@ class ExportGuard {
using SchemaExportGuard = ExportGuard<SchemaExportTraits>;
using ArrayExportGuard = ExportGuard<ArrayExportTraits>;
+using DeviceArrayExportGuard = ExportGuard<ArrayDeviceExportTraits>;
using ArrayStreamExportGuard = ExportGuard<ArrayStreamExportTraits>;
+using DeviceArrayStreamExportGuard =
ExportGuard<ArrayDeviceStreamExportTraits>;
} // namespace internal
} // namespace arrow
diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc
index 8521d500f5..351f72f523 100644
--- a/cpp/src/arrow/record_batch.cc
+++ b/cpp/src/arrow/record_batch.cc
@@ -59,17 +59,31 @@ int RecordBatch::num_columns() const { return
schema_->num_fields(); }
class SimpleRecordBatch : public RecordBatch {
public:
SimpleRecordBatch(std::shared_ptr<Schema> schema, int64_t num_rows,
- std::vector<std::shared_ptr<Array>> columns)
- : RecordBatch(std::move(schema), num_rows),
boxed_columns_(std::move(columns)) {
+ std::vector<std::shared_ptr<Array>> columns,
+ std::shared_ptr<Device::SyncEvent> sync_event = nullptr)
+ : RecordBatch(std::move(schema), num_rows),
+ boxed_columns_(std::move(columns)),
+ device_type_(DeviceAllocationType::kCPU),
+ sync_event_(std::move(sync_event)) {
+ if (boxed_columns_.size() > 0) {
+ device_type_ = boxed_columns_[0]->device_type();
+ }
+
columns_.resize(boxed_columns_.size());
for (size_t i = 0; i < columns_.size(); ++i) {
columns_[i] = boxed_columns_[i]->data();
+ DCHECK_EQ(device_type_, columns_[i]->device_type());
}
}
SimpleRecordBatch(const std::shared_ptr<Schema>& schema, int64_t num_rows,
- std::vector<std::shared_ptr<ArrayData>> columns)
- : RecordBatch(std::move(schema), num_rows), columns_(std::move(columns))
{
+ std::vector<std::shared_ptr<ArrayData>> columns,
+ DeviceAllocationType device_type =
DeviceAllocationType::kCPU,
+ std::shared_ptr<Device::SyncEvent> sync_event = nullptr)
+ : RecordBatch(std::move(schema), num_rows),
+ columns_(std::move(columns)),
+ device_type_(device_type),
+ sync_event_(std::move(sync_event)) {
boxed_columns_.resize(schema_->num_fields());
}
@@ -99,6 +113,7 @@ class SimpleRecordBatch : public RecordBatch {
const std::shared_ptr<Array>& column) const override {
ARROW_CHECK(field != nullptr);
ARROW_CHECK(column != nullptr);
+ ARROW_CHECK(column->device_type() == device_type_);
if (!field->type()->Equals(column->type())) {
return Status::TypeError("Column data type ", field->type()->name(),
@@ -113,7 +128,8 @@ class SimpleRecordBatch : public RecordBatch {
ARROW_ASSIGN_OR_RAISE(auto new_schema, schema_->AddField(i, field));
return RecordBatch::Make(std::move(new_schema), num_rows_,
- internal::AddVectorElement(columns_, i,
column->data()));
+ internal::AddVectorElement(columns_, i,
column->data()),
+ device_type_, sync_event_);
}
Result<std::shared_ptr<RecordBatch>> SetColumn(
@@ -121,6 +137,7 @@ class SimpleRecordBatch : public RecordBatch {
const std::shared_ptr<Array>& column) const override {
ARROW_CHECK(field != nullptr);
ARROW_CHECK(column != nullptr);
+ ARROW_CHECK(column->device_type() == device_type_);
if (!field->type()->Equals(column->type())) {
return Status::TypeError("Column data type ", field->type()->name(),
@@ -135,19 +152,22 @@ class SimpleRecordBatch : public RecordBatch {
ARROW_ASSIGN_OR_RAISE(auto new_schema, schema_->SetField(i, field));
return RecordBatch::Make(std::move(new_schema), num_rows_,
- internal::ReplaceVectorElement(columns_, i,
column->data()));
+ internal::ReplaceVectorElement(columns_, i,
column->data()),
+ device_type_, sync_event_);
}
Result<std::shared_ptr<RecordBatch>> RemoveColumn(int i) const override {
ARROW_ASSIGN_OR_RAISE(auto new_schema, schema_->RemoveField(i));
return RecordBatch::Make(std::move(new_schema), num_rows_,
- internal::DeleteVectorElement(columns_, i));
+ internal::DeleteVectorElement(columns_, i),
device_type_,
+ sync_event_);
}
std::shared_ptr<RecordBatch> ReplaceSchemaMetadata(
const std::shared_ptr<const KeyValueMetadata>& metadata) const override {
auto new_schema = schema_->WithMetadata(metadata);
- return RecordBatch::Make(std::move(new_schema), num_rows_, columns_);
+ return RecordBatch::Make(std::move(new_schema), num_rows_, columns_,
device_type_,
+ sync_event_);
}
std::shared_ptr<RecordBatch> Slice(int64_t offset, int64_t length) const
override {
@@ -157,7 +177,8 @@ class SimpleRecordBatch : public RecordBatch {
arrays.emplace_back(field->Slice(offset, length));
}
int64_t num_rows = std::min(num_rows_ - offset, length);
- return std::make_shared<SimpleRecordBatch>(schema_, num_rows,
std::move(arrays));
+ return std::make_shared<SimpleRecordBatch>(schema_, num_rows,
std::move(arrays),
+ device_type_, sync_event_);
}
Status Validate() const override {
@@ -167,11 +188,22 @@ class SimpleRecordBatch : public RecordBatch {
return RecordBatch::Validate();
}
+ const std::shared_ptr<Device::SyncEvent>& GetSyncEvent() const override {
+ return sync_event_;
+ }
+
+ DeviceAllocationType device_type() const override { return device_type_; }
+
private:
std::vector<std::shared_ptr<ArrayData>> columns_;
// Caching boxed array data
mutable std::vector<std::shared_ptr<Array>> boxed_columns_;
+
+ // the type of device that the buffers for columns are allocated on.
+ // all columns should be on the same type of device.
+ DeviceAllocationType device_type_;
+ std::shared_ptr<Device::SyncEvent> sync_event_;
};
RecordBatch::RecordBatch(const std::shared_ptr<Schema>& schema, int64_t
num_rows)
@@ -179,18 +211,21 @@ RecordBatch::RecordBatch(const std::shared_ptr<Schema>&
schema, int64_t num_rows
std::shared_ptr<RecordBatch> RecordBatch::Make(
std::shared_ptr<Schema> schema, int64_t num_rows,
- std::vector<std::shared_ptr<Array>> columns) {
+ std::vector<std::shared_ptr<Array>> columns,
+ std::shared_ptr<Device::SyncEvent> sync_event) {
DCHECK_EQ(schema->num_fields(), static_cast<int>(columns.size()));
return std::make_shared<SimpleRecordBatch>(std::move(schema), num_rows,
- std::move(columns));
+ std::move(columns),
std::move(sync_event));
}
std::shared_ptr<RecordBatch> RecordBatch::Make(
std::shared_ptr<Schema> schema, int64_t num_rows,
- std::vector<std::shared_ptr<ArrayData>> columns) {
+ std::vector<std::shared_ptr<ArrayData>> columns, DeviceAllocationType
device_type,
+ std::shared_ptr<Device::SyncEvent> sync_event) {
DCHECK_EQ(schema->num_fields(), static_cast<int>(columns.size()));
return std::make_shared<SimpleRecordBatch>(std::move(schema), num_rows,
- std::move(columns));
+ std::move(columns), device_type,
+ std::move(sync_event));
}
Result<std::shared_ptr<RecordBatch>> RecordBatch::MakeEmpty(
@@ -466,6 +501,10 @@ bool RecordBatch::Equals(const RecordBatch& other, bool
check_metadata,
return false;
}
+ if (device_type() != other.device_type()) {
+ return false;
+ }
+
for (int i = 0; i < num_columns(); ++i) {
if (!column(i)->Equals(other.column(i), opts)) {
return false;
@@ -480,6 +519,10 @@ bool RecordBatch::ApproxEquals(const RecordBatch& other,
const EqualOptions& opt
return false;
}
+ if (device_type() != other.device_type()) {
+ return false;
+ }
+
for (int i = 0; i < num_columns(); ++i) {
if (!column(i)->ApproxEquals(other.column(i), opts)) {
return false;
@@ -505,7 +548,7 @@ Result<std::shared_ptr<RecordBatch>>
RecordBatch::ReplaceSchema(
", did not match new schema field type: ", replace_type->ToString());
}
}
- return RecordBatch::Make(std::move(schema), num_rows(), columns());
+ return RecordBatch::Make(std::move(schema), num_rows(), columns(),
GetSyncEvent());
}
std::vector<std::string> RecordBatch::ColumnNames() const {
@@ -534,7 +577,7 @@ Result<std::shared_ptr<RecordBatch>>
RecordBatch::RenameColumns(
}
return RecordBatch::Make(::arrow::schema(std::move(fields)), num_rows(),
- std::move(columns));
+ std::move(columns), GetSyncEvent());
}
Result<std::shared_ptr<RecordBatch>> RecordBatch::SelectColumns(
@@ -555,7 +598,8 @@ Result<std::shared_ptr<RecordBatch>>
RecordBatch::SelectColumns(
auto new_schema =
std::make_shared<arrow::Schema>(std::move(fields), schema()->metadata());
- return RecordBatch::Make(std::move(new_schema), num_rows(),
std::move(columns));
+ return RecordBatch::Make(std::move(new_schema), num_rows(),
std::move(columns),
+ GetSyncEvent());
}
std::shared_ptr<RecordBatch> RecordBatch::Slice(int64_t offset) const {
@@ -647,12 +691,16 @@ Result<std::shared_ptr<Table>>
RecordBatchReader::ToTable() {
class SimpleRecordBatchReader : public RecordBatchReader {
public:
SimpleRecordBatchReader(Iterator<std::shared_ptr<RecordBatch>> it,
- std::shared_ptr<Schema> schema)
- : schema_(std::move(schema)), it_(std::move(it)) {}
+ std::shared_ptr<Schema> schema,
+ DeviceAllocationType device_type =
DeviceAllocationType::kCPU)
+ : schema_(std::move(schema)), it_(std::move(it)),
device_type_(device_type) {}
SimpleRecordBatchReader(std::vector<std::shared_ptr<RecordBatch>> batches,
- std::shared_ptr<Schema> schema)
- : schema_(std::move(schema)),
it_(MakeVectorIterator(std::move(batches))) {}
+ std::shared_ptr<Schema> schema,
+ DeviceAllocationType device_type =
DeviceAllocationType::kCPU)
+ : schema_(std::move(schema)),
+ it_(MakeVectorIterator(std::move(batches))),
+ device_type_(device_type) {}
Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
return it_.Next().Value(batch);
@@ -660,13 +708,17 @@ class SimpleRecordBatchReader : public RecordBatchReader {
std::shared_ptr<Schema> schema() const override { return schema_; }
+ DeviceAllocationType device_type() const override { return device_type_; }
+
protected:
std::shared_ptr<Schema> schema_;
Iterator<std::shared_ptr<RecordBatch>> it_;
+ DeviceAllocationType device_type_;
};
Result<std::shared_ptr<RecordBatchReader>> RecordBatchReader::Make(
- std::vector<std::shared_ptr<RecordBatch>> batches, std::shared_ptr<Schema>
schema) {
+ std::vector<std::shared_ptr<RecordBatch>> batches, std::shared_ptr<Schema>
schema,
+ DeviceAllocationType device_type) {
if (schema == nullptr) {
if (batches.size() == 0 || batches[0] == nullptr) {
return Status::Invalid("Cannot infer schema from empty vector or
nullptr");
@@ -675,16 +727,19 @@ Result<std::shared_ptr<RecordBatchReader>>
RecordBatchReader::Make(
schema = batches[0]->schema();
}
- return std::make_shared<SimpleRecordBatchReader>(std::move(batches),
std::move(schema));
+ return std::make_shared<SimpleRecordBatchReader>(std::move(batches),
std::move(schema),
+ device_type);
}
Result<std::shared_ptr<RecordBatchReader>> RecordBatchReader::MakeFromIterator(
- Iterator<std::shared_ptr<RecordBatch>> batches, std::shared_ptr<Schema>
schema) {
+ Iterator<std::shared_ptr<RecordBatch>> batches, std::shared_ptr<Schema>
schema,
+ DeviceAllocationType device_type) {
if (schema == nullptr) {
return Status::Invalid("Schema cannot be nullptr");
}
- return std::make_shared<SimpleRecordBatchReader>(std::move(batches),
std::move(schema));
+ return std::make_shared<SimpleRecordBatchReader>(std::move(batches),
std::move(schema),
+ device_type);
}
RecordBatchReader::~RecordBatchReader() {
@@ -701,6 +756,10 @@ Result<std::shared_ptr<RecordBatch>>
ConcatenateRecordBatches(
int cols = batches[0]->num_columns();
auto schema = batches[0]->schema();
for (size_t i = 0; i < batches.size(); ++i) {
+ if (auto sync = batches[i]->GetSyncEvent()) {
+ ARROW_RETURN_NOT_OK(sync->Wait());
+ }
+
length += batches[i]->num_rows();
if (!schema->Equals(batches[i]->schema())) {
return Status::Invalid(
diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h
index cd647a88ab..b03cbf2251 100644
--- a/cpp/src/arrow/record_batch.h
+++ b/cpp/src/arrow/record_batch.h
@@ -23,6 +23,7 @@
#include <vector>
#include "arrow/compare.h"
+#include "arrow/device.h"
#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/type_fwd.h"
@@ -45,9 +46,12 @@ class ARROW_EXPORT RecordBatch {
/// \param[in] num_rows length of fields in the record batch. Each array
/// should have the same length as num_rows
/// \param[in] columns the record batch fields as vector of arrays
- static std::shared_ptr<RecordBatch> Make(std::shared_ptr<Schema> schema,
- int64_t num_rows,
- std::vector<std::shared_ptr<Array>>
columns);
+ /// \param[in] sync_event optional synchronization event for non-CPU device
+ /// memory used by buffers
+ static std::shared_ptr<RecordBatch> Make(
+ std::shared_ptr<Schema> schema, int64_t num_rows,
+ std::vector<std::shared_ptr<Array>> columns,
+ std::shared_ptr<Device::SyncEvent> sync_event = NULLPTR);
/// \brief Construct record batch from vector of internal data structures
/// \since 0.5.0
@@ -58,9 +62,15 @@ class ARROW_EXPORT RecordBatch {
/// \param num_rows the number of semantic rows in the record batch. This
/// should be equal to the length of each field
/// \param columns the data for the batch's columns
+ /// \param device_type the type of the device that the Arrow columns are
+ /// allocated on
+ /// \param sync_event optional synchronization event for non-CPU device
+ /// memory used by buffers
static std::shared_ptr<RecordBatch> Make(
std::shared_ptr<Schema> schema, int64_t num_rows,
- std::vector<std::shared_ptr<ArrayData>> columns);
+ std::vector<std::shared_ptr<ArrayData>> columns,
+ DeviceAllocationType device_type = DeviceAllocationType::kCPU,
+ std::shared_ptr<Device::SyncEvent> sync_event = NULLPTR);
/// \brief Create an empty RecordBatch of a given schema
///
@@ -260,6 +270,18 @@ class ARROW_EXPORT RecordBatch {
/// \return Status
virtual Status ValidateFull() const;
+ /// \brief EXPERIMENTAL: Return a top-level sync event object for this
record batch
+ ///
+ /// If all of the data for this record batch is in CPU memory, then this
+ /// will return null. If the data for this batch is
+ /// on a device, then if synchronization is needed before accessing the
+ /// data the returned sync event will allow for it.
+ ///
+ /// \return null or a Device::SyncEvent
+ virtual const std::shared_ptr<Device::SyncEvent>& GetSyncEvent() const = 0;
+
+ virtual DeviceAllocationType device_type() const = 0;
+
protected:
RecordBatch(const std::shared_ptr<Schema>& schema, int64_t num_rows);
@@ -306,6 +328,11 @@ class ARROW_EXPORT RecordBatchReader {
/// \brief finalize reader
virtual Status Close() { return Status::OK(); }
+ /// \brief EXPERIMENTAL: Get the device type for record batches this reader
produces
+ ///
+ /// default implementation is to return DeviceAllocationType::kCPU
+ virtual DeviceAllocationType device_type() const { return
DeviceAllocationType::kCPU; }
+
class RecordBatchReaderIterator {
public:
using iterator_category = std::input_iterator_tag;
@@ -379,15 +406,19 @@ class ARROW_EXPORT RecordBatchReader {
/// \param[in] batches the vector of RecordBatch to read from
/// \param[in] schema schema to conform to. Will be inferred from the first
/// element if not provided.
+ /// \param[in] device_type the type of device that the batches are allocated
on
static Result<std::shared_ptr<RecordBatchReader>> Make(
- RecordBatchVector batches, std::shared_ptr<Schema> schema = NULLPTR);
+ RecordBatchVector batches, std::shared_ptr<Schema> schema = NULLPTR,
+ DeviceAllocationType device_type = DeviceAllocationType::kCPU);
/// \brief Create a RecordBatchReader from an Iterator of RecordBatch.
///
/// \param[in] batches an iterator of RecordBatch to read from.
/// \param[in] schema schema that each record batch in iterator will conform
to.
+ /// \param[in] device_type the type of device that the batches are allocated
on
static Result<std::shared_ptr<RecordBatchReader>> MakeFromIterator(
- Iterator<std::shared_ptr<RecordBatch>> batches, std::shared_ptr<Schema>
schema);
+ Iterator<std::shared_ptr<RecordBatch>> batches, std::shared_ptr<Schema>
schema,
+ DeviceAllocationType device_type = DeviceAllocationType::kCPU);
};
/// \brief Concatenate record batches
diff --git a/python/pyarrow/tests/test_cffi.py
b/python/pyarrow/tests/test_cffi.py
index 5bf41c3c14..45a3db9b66 100644
--- a/python/pyarrow/tests/test_cffi.py
+++ b/python/pyarrow/tests/test_cffi.py
@@ -45,7 +45,7 @@ assert_array_released = pytest.raises(
ValueError, match="Cannot import released ArrowArray")
assert_stream_released = pytest.raises(
- ValueError, match="Cannot import released ArrowArrayStream")
+ ValueError, match="Cannot import released Arrow Stream")
def PyCapsule_IsValid(capsule, name):