This is an automated email from the ASF dual-hosted git repository. zeroshade pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push: new aab7d81aee GH-43631: [C++] Add C++ implementation of Async C Data Interface (#44495) aab7d81aee is described below commit aab7d81aeec1e8f106bdda953cdeb00f7f78b355 Author: Matt Topol <zotthewiz...@gmail.com> AuthorDate: Mon Nov 11 23:20:13 2024 +0100 GH-43631: [C++] Add C++ implementation of Async C Data Interface (#44495) ### Rationale for this change Building on #43632 which created the Async C Data Structures, this adds functions to `bridge.h`/`bridge.cc` to implement helpers for managing the Async C Data interfaces ### What changes are included in this PR? Two functions added to bridge.h: 1. `CreateAsyncDeviceStreamHandler` populates a `ArrowAsyncDeviceStreamHandler` and an `Executor` to provide a future that resolves to an `AsyncRecordBatchGenerator` to produce record batches as they are pushed asynchronously. The `ArrowAsyncDeviceStreamHandler` can then be passed to any asynchronous producer. 2. `ExportAsyncRecordBatchReader` takes a record batch generator and a schema, along with an `ArrowAsyncDeviceStreamHandler` to use for calling the callbacks to push data as it is available from the generator. ### Are these changes tested? Unit tests are added (currently only one test, more tests to be added) ### Are there any user-facing changes? No * GitHub Issue: #43631 Lead-authored-by: Matt Topol <zotthewiz...@gmail.com> Co-authored-by: David Li <li.david...@gmail.com> Co-authored-by: Benjamin Kietzman <bengil...@gmail.com> Signed-off-by: Matt Topol <zotthewiz...@gmail.com> --- cpp/src/arrow/c/abi.h | 5 +- cpp/src/arrow/c/bridge.cc | 348 ++++++++++++++++++++++++++++ cpp/src/arrow/c/bridge.h | 80 +++++++ cpp/src/arrow/c/bridge_test.cc | 110 +++++++++ cpp/src/arrow/record_batch.h | 6 + cpp/src/arrow/type_fwd.h | 2 + docs/source/format/CDeviceDataInterface.rst | 8 + 7 files changed, 558 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/c/abi.h b/cpp/src/arrow/c/abi.h index e44933a6af..ae632f2dbd 100644 --- a/cpp/src/arrow/c/abi.h +++ b/cpp/src/arrow/c/abi.h @@ -287,7 +287,7 @@ struct ArrowAsyncTask { // calling this, and so it must be released separately. // // It is only valid to call this method exactly once. - int (*extract_data)(struct ArrowArrayTask* self, struct ArrowDeviceArray* out); + int (*extract_data)(struct ArrowAsyncTask* self, struct ArrowDeviceArray* out); // opaque task-specific data void* private_data; @@ -298,6 +298,9 @@ struct ArrowAsyncTask { // control on the asynchronous stream processing. This object must be owned by the // producer who creates it, and thus is responsible for cleaning it up. struct ArrowAsyncProducer { + // The device type that this stream produces data on. + ArrowDeviceType device_type; + // A consumer must call this function to start receiving on_next_task calls. // // It *must* be valid to call this synchronously from within `on_next_task` or diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index 0976a5cb61..f848b34115 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -19,8 +19,11 @@ #include <algorithm> #include <cerrno> +#include <condition_variable> #include <cstring> #include <memory> +#include <mutex> +#include <queue> #include <string> #include <string_view> #include <utility> @@ -37,8 +40,10 @@ #include "arrow/result.h" #include "arrow/stl_allocator.h" #include "arrow/type_traits.h" +#include "arrow/util/async_generator.h" #include "arrow/util/bit_util.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/macros.h" @@ -2511,4 +2516,347 @@ Result<std::shared_ptr<ChunkedArray>> ImportDeviceChunkedArray( return ImportChunked</*IsDevice=*/true>(stream, mapper); } +namespace { + +class AsyncRecordBatchIterator { + public: + struct TaskWithMetadata { + ArrowAsyncTask task_; + std::shared_ptr<KeyValueMetadata> metadata_; + }; + + struct State { + State(uint64_t queue_size, DeviceMemoryMapper mapper) + : queue_size_{queue_size}, mapper_{std::move(mapper)} {} + + Result<RecordBatchWithMetadata> next() { + TaskWithMetadata task; + { + std::unique_lock<std::mutex> lock(mutex_); + cv_.wait(lock, + [&] { return !error_.ok() || !batches_.empty() || end_of_stream_; }); + if (!error_.ok()) { + return error_; + } + + if (batches_.empty() && end_of_stream_) { + return IterationEnd<RecordBatchWithMetadata>(); + } + + task = std::move(batches_.front()); + batches_.pop(); + } + + producer_->request(producer_, 1); + ArrowDeviceArray out; + if (task.task_.extract_data(&task.task_, &out) != 0) { + std::unique_lock<std::mutex> lock(mutex_); + cv_.wait(lock, [&] { return !error_.ok(); }); + return error_; + } + + ARROW_ASSIGN_OR_RAISE(auto batch, ImportDeviceRecordBatch(&out, schema_, mapper_)); + return RecordBatchWithMetadata{std::move(batch), std::move(task.metadata_)}; + } + + const uint64_t queue_size_; + const DeviceMemoryMapper mapper_; + ArrowAsyncProducer* producer_; + DeviceAllocationType device_type_; + + std::mutex mutex_; + std::shared_ptr<Schema> schema_; + std::condition_variable cv_; + std::queue<TaskWithMetadata> batches_; + bool end_of_stream_ = false; + Status error_{Status::OK()}; + }; + + AsyncRecordBatchIterator(uint64_t queue_size, DeviceMemoryMapper mapper) + : state_{std::make_shared<State>(queue_size, std::move(mapper))} {} + + explicit AsyncRecordBatchIterator(std::shared_ptr<State> state) + : state_{std::move(state)} {} + + const std::shared_ptr<Schema>& schema() const { return state_->schema_; } + + DeviceAllocationType device_type() const { return state_->device_type_; } + + Result<RecordBatchWithMetadata> Next() { return state_->next(); } + + static Future<std::shared_ptr<AsyncRecordBatchIterator::State>> Make( + AsyncRecordBatchIterator& iterator, struct ArrowAsyncDeviceStreamHandler* handler) { + auto iterator_fut = Future<std::shared_ptr<AsyncRecordBatchIterator::State>>::Make(); + + auto private_data = new PrivateData{iterator.state_}; + private_data->fut_iterator_ = iterator_fut; + + handler->private_data = private_data; + handler->on_schema = on_schema; + handler->on_next_task = on_next_task; + handler->on_error = on_error; + handler->release = release; + return iterator_fut; + } + + private: + struct PrivateData { + explicit PrivateData(std::shared_ptr<State> state) : state_(std::move(state)) {} + + std::shared_ptr<State> state_; + Future<std::shared_ptr<AsyncRecordBatchIterator::State>> fut_iterator_; + ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateData); + }; + + static int on_schema(struct ArrowAsyncDeviceStreamHandler* self, + struct ArrowSchema* stream_schema) { + auto* private_data = reinterpret_cast<PrivateData*>(self->private_data); + if (self->producer != nullptr) { + private_data->state_->producer_ = self->producer; + private_data->state_->device_type_ = + static_cast<DeviceAllocationType>(self->producer->device_type); + } + + auto maybe_schema = ImportSchema(stream_schema); + if (!maybe_schema.ok()) { + private_data->fut_iterator_.MarkFinished(maybe_schema.status()); + return EINVAL; + } + + private_data->state_->schema_ = maybe_schema.MoveValueUnsafe(); + private_data->fut_iterator_.MarkFinished(private_data->state_); + self->producer->request(self->producer, + static_cast<int64_t>(private_data->state_->queue_size_)); + return 0; + } + + static int on_next_task(ArrowAsyncDeviceStreamHandler* self, ArrowAsyncTask* task, + const char* metadata) { + auto* private_data = reinterpret_cast<PrivateData*>(self->private_data); + + if (task == nullptr) { + std::unique_lock<std::mutex> lock(private_data->state_->mutex_); + private_data->state_->end_of_stream_ = true; + lock.unlock(); + private_data->state_->cv_.notify_one(); + return 0; + } + + std::shared_ptr<KeyValueMetadata> kvmetadata; + if (metadata != nullptr) { + auto maybe_decoded = DecodeMetadata(metadata); + if (!maybe_decoded.ok()) { + private_data->state_->error_ = std::move(maybe_decoded).status(); + private_data->state_->cv_.notify_one(); + return EINVAL; + } + + kvmetadata = std::move(maybe_decoded->metadata); + } + + std::unique_lock<std::mutex> lock(private_data->state_->mutex_); + private_data->state_->batches_.push({*task, std::move(kvmetadata)}); + lock.unlock(); + private_data->state_->cv_.notify_one(); + return 0; + } + + static void on_error(ArrowAsyncDeviceStreamHandler* self, int code, const char* message, + const char* metadata) { + auto* private_data = reinterpret_cast<PrivateData*>(self->private_data); + std::string message_str, metadata_str; + if (message != nullptr) { + message_str = message; + } + if (metadata != nullptr) { + metadata_str = metadata; + } + + Status error = Status::FromDetailAndArgs( + StatusCode::UnknownError, + std::make_shared<AsyncErrorDetail>(code, message_str, std::move(metadata_str)), + std::move(message_str)); + + if (!private_data->fut_iterator_.is_finished()) { + private_data->fut_iterator_.MarkFinished(error); + return; + } + + std::unique_lock<std::mutex> lock(private_data->state_->mutex_); + private_data->state_->error_ = std::move(error); + lock.unlock(); + private_data->state_->cv_.notify_one(); + } + + static void release(ArrowAsyncDeviceStreamHandler* self) { + delete reinterpret_cast<PrivateData*>(self->private_data); + } + + std::shared_ptr<State> state_; +}; + +struct AsyncProducer { + struct State { + struct ArrowAsyncProducer producer_; + + std::mutex mutex_; + std::condition_variable cv_; + uint64_t pending_requests_{0}; + Status error_{Status::OK()}; + }; + + AsyncProducer(DeviceAllocationType device_type, struct ArrowSchema* schema, + struct ArrowAsyncDeviceStreamHandler* handler) + : handler_{handler}, state_{std::make_shared<State>()} { + state_->producer_.device_type = static_cast<ArrowDeviceType>(device_type); + state_->producer_.private_data = reinterpret_cast<void*>(state_.get()); + state_->producer_.request = AsyncProducer::request; + state_->producer_.cancel = AsyncProducer::cancel; + handler_->producer = &state_->producer_; + + if (int status = handler_->on_schema(handler_, schema) != 0) { + state_->error_ = + Status::UnknownError("Received error from handler::on_schema ", status); + } + } + + struct PrivateTaskData { + PrivateTaskData(std::shared_ptr<State> producer, std::shared_ptr<RecordBatch> record) + : producer_{std::move(producer)}, record_(std::move(record)) {} + + std::shared_ptr<State> producer_; + std::shared_ptr<RecordBatch> record_; + ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateTaskData); + }; + + Status operator()(const std::shared_ptr<RecordBatch>& record) { + std::unique_lock<std::mutex> lock(state_->mutex_); + if (state_->pending_requests_ == 0) { + state_->cv_.wait(lock, [this]() -> bool { + return !state_->error_.ok() || state_->pending_requests_ > 0; + }); + } + + if (!state_->error_.ok()) { + return state_->error_; + } + + if (state_->pending_requests_ > 0) { + state_->pending_requests_--; + lock.unlock(); + + ArrowAsyncTask task; + task.private_data = new PrivateTaskData{state_, record}; + task.extract_data = AsyncProducer::extract_data; + + if (int status = handler_->on_next_task(handler_, &task, nullptr) != 0) { + delete reinterpret_cast<PrivateTaskData*>(task.private_data); + return Status::UnknownError("Received error from handler::on_next_task ", status); + } + } + + return Status::OK(); + } + + static void request(struct ArrowAsyncProducer* producer, int64_t n) { + auto* self = reinterpret_cast<State*>(producer->private_data); + { + std::lock_guard<std::mutex> lock(self->mutex_); + if (!self->error_.ok()) { + return; + } + self->pending_requests_ += n; + } + self->cv_.notify_all(); + } + + static void cancel(struct ArrowAsyncProducer* producer) { + auto* self = reinterpret_cast<State*>(producer->private_data); + { + std::lock_guard<std::mutex> lock(self->mutex_); + if (!self->error_.ok()) { + return; + } + self->error_ = Status::Cancelled("Consumer requested cancellation"); + } + self->cv_.notify_all(); + } + + static int extract_data(struct ArrowAsyncTask* task, struct ArrowDeviceArray* out) { + std::unique_ptr<PrivateTaskData> private_data{ + reinterpret_cast<PrivateTaskData*>(task->private_data)}; + int ret = 0; + if (out != nullptr) { + auto status = ExportDeviceRecordBatch(*private_data->record_, + private_data->record_->GetSyncEvent(), out); + if (!status.ok()) { + std::lock_guard<std::mutex> lock(private_data->producer_->mutex_); + private_data->producer_->error_ = status; + } + } + + return ret; + } + + struct ArrowAsyncDeviceStreamHandler* handler_; + std::shared_ptr<State> state_; +}; + +} // namespace + +Future<AsyncRecordBatchGenerator> CreateAsyncDeviceStreamHandler( + struct ArrowAsyncDeviceStreamHandler* handler, internal::Executor* executor, + uint64_t queue_size, DeviceMemoryMapper mapper) { + auto iterator = + std::make_shared<AsyncRecordBatchIterator>(queue_size, std::move(mapper)); + return AsyncRecordBatchIterator::Make(*iterator, handler) + .Then([executor](std::shared_ptr<AsyncRecordBatchIterator::State> state) + -> Result<AsyncRecordBatchGenerator> { + AsyncRecordBatchGenerator gen{state->schema_, state->device_type_, nullptr}; + auto it = + Iterator<RecordBatchWithMetadata>(AsyncRecordBatchIterator{std::move(state)}); + ARROW_ASSIGN_OR_RAISE(gen.generator, + MakeBackgroundGenerator(std::move(it), executor)); + return gen; + }); +} + +Future<> ExportAsyncRecordBatchReader( + std::shared_ptr<Schema> schema, + AsyncGenerator<std::shared_ptr<RecordBatch>> generator, + DeviceAllocationType device_type, struct ArrowAsyncDeviceStreamHandler* handler) { + if (!schema) { + handler->on_error(handler, EINVAL, "Schema is null", nullptr); + handler->release(handler); + return Future<>::MakeFinished(Status::Invalid("Schema is null")); + } + + struct ArrowSchema c_schema; + SchemaExportGuard guard(&c_schema); + + auto status = ExportSchema(*schema, &c_schema); + if (!status.ok()) { + handler->on_error(handler, EINVAL, status.message().c_str(), nullptr); + handler->release(handler); + return Future<>::MakeFinished(status); + } + + return VisitAsyncGenerator(generator, AsyncProducer{device_type, &c_schema, handler}) + .Then( + [handler]() -> Status { + int status = handler->on_next_task(handler, nullptr, nullptr); + handler->release(handler); + if (status != 0) { + return Status::UnknownError("Received error from handler::on_next_task ", + status); + } + return Status::OK(); + }, + [handler](const Status status) -> Status { + handler->on_error(handler, EINVAL, status.message().c_str(), nullptr); + handler->release(handler); + return status; + }); +} + } // namespace arrow diff --git a/cpp/src/arrow/c/bridge.h b/cpp/src/arrow/c/bridge.h index 45367e4f93..78860e0650 100644 --- a/cpp/src/arrow/c/bridge.h +++ b/cpp/src/arrow/c/bridge.h @@ -26,6 +26,7 @@ #include "arrow/result.h" #include "arrow/status.h" #include "arrow/type_fwd.h" +#include "arrow/util/async_generator_fwd.h" #include "arrow/util/macros.h" #include "arrow/util/visibility.h" @@ -406,4 +407,83 @@ Result<std::shared_ptr<ChunkedArray>> ImportDeviceChunkedArray( /// @} +/// \defgroup c-async-stream-interface Functions for working with the async C data +/// interface. +/// +/// @{ + +/// \brief EXPERIMENTAL: AsyncErrorDetail is a StatusDetail that contains an error code +/// and message from an asynchronous operation. +class AsyncErrorDetail : public StatusDetail { + public: + AsyncErrorDetail(int code, std::string message, std::string metadata) + : code_(code), message_(std::move(message)), metadata_(std::move(metadata)) {} + const char* type_id() const override { return "AsyncErrorDetail"; } + // ToString just returns the error message that was returned with the error + std::string ToString() const override { return message_; } + // code is an errno-compatible error code + int code() const { return code_; } + // returns any metadata that was returned with the error, likely in a + // key-value format similar to ArrowSchema metadata + const std::string& ErrorMetadataString() const { return metadata_; } + std::shared_ptr<KeyValueMetadata> ErrorMetadata() const; + + private: + int code_{0}; + std::string message_; + std::string metadata_; +}; + +struct AsyncRecordBatchGenerator { + std::shared_ptr<Schema> schema; + DeviceAllocationType device_type; + AsyncGenerator<RecordBatchWithMetadata> generator; +}; + +namespace internal { +class Executor; +} + +/// \brief EXPERIMENTAL: Create an AsyncRecordBatchReader and populate a corresponding +/// handler to pass to a producer +/// +/// The ArrowAsyncDeviceStreamHandler struct is intended to have its callbacks populated +/// and then be passed to a producer to call the appropriate callbacks when data is ready. +/// This inverts the traditional flow of control, and so we construct a corresponding +/// AsyncRecordBatchGenerator to provide an interface for the consumer to retrieve data as +/// it is pushed to the handler. +/// +/// \param[in,out] handler C struct to be populated +/// \param[in] executor the executor to use for waiting and populating record batches +/// \param[in] queue_size initial number of record batches to request for queueing +/// \param[in] mapper mapping from device type and ID to memory manager +/// \return Future that resolves to either an error or AsyncRecordBatchGenerator once a +/// schema is available or an error is received. +ARROW_EXPORT +Future<AsyncRecordBatchGenerator> CreateAsyncDeviceStreamHandler( + struct ArrowAsyncDeviceStreamHandler* handler, internal::Executor* executor, + uint64_t queue_size = 5, DeviceMemoryMapper mapper = DefaultDeviceMemoryMapper); + +/// \brief EXPERIMENTAL: Export an AsyncGenerator of record batches using a provided +/// handler +/// +/// This function calls the callbacks on the consumer-provided async handler as record +/// batches become available from the AsyncGenerator which is provided. It will first call +/// on_schema using the provided schema, and then serially visit each record batch from +/// the generator, calling the on_next_task callback. If an error occurs, on_error will be +/// called appropriately. +/// +/// \param[in] schema the schema of the stream being exported +/// \param[in] generator a generator that asynchronously produces record batches +/// \param[in] device_type the device type that the record batches will be located on +/// \param[in] handler the handler whose callbacks to utilize as data is available +/// \return Future that will resolve once the generator is exhausted or an error occurs +ARROW_EXPORT +Future<> ExportAsyncRecordBatchReader( + std::shared_ptr<Schema> schema, + AsyncGenerator<std::shared_ptr<RecordBatch>> generator, + DeviceAllocationType device_type, struct ArrowAsyncDeviceStreamHandler* handler); + +/// @} + } // namespace arrow diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc index fdcb53ddbc..bc60b587cf 100644 --- a/cpp/src/arrow/c/bridge_test.cc +++ b/cpp/src/arrow/c/bridge_test.cc @@ -38,6 +38,7 @@ #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" #include "arrow/testing/util.h" +#include "arrow/util/async_generator.h" #include "arrow/util/binary_view_util.h" #include "arrow/util/checked_cast.h" #include "arrow/util/endian.h" @@ -45,6 +46,7 @@ #include "arrow/util/logging.h" #include "arrow/util/macros.h" #include "arrow/util/range.h" +#include "arrow/util/thread_pool.h" // TODO(GH-37221): Remove these ifdef checks when compute dependency is removed #ifdef ARROW_COMPUTE @@ -5311,4 +5313,112 @@ TEST_F(TestArrayDeviceStreamRoundtrip, ChunkedArrayRoundtripEmpty) { }); } +class TestAsyncDeviceArrayStreamRoundTrip : public BaseArrayStreamTest { + public: + static Result<std::shared_ptr<ArrayData>> ToDeviceData( + const std::shared_ptr<MemoryManager>& mm, const ArrayData& data) { + arrow::BufferVector buffers; + for (const auto& buf : data.buffers) { + if (buf) { + ARROW_ASSIGN_OR_RAISE(auto dest, mm->CopyBuffer(buf, mm)); + buffers.push_back(dest); + } else { + buffers.push_back(nullptr); + } + } + + arrow::ArrayDataVector children; + for (const auto& child : data.child_data) { + ARROW_ASSIGN_OR_RAISE(auto dest, ToDeviceData(mm, *child)); + children.push_back(dest); + } + + return ArrayData::Make(data.type, data.length, buffers, children, data.null_count, + data.offset); + } + + static Result<std::shared_ptr<Array>> ToDevice(const std::shared_ptr<MemoryManager>& mm, + const ArrayData& data) { + ARROW_ASSIGN_OR_RAISE(auto result, ToDeviceData(mm, data)); + return MakeArray(result); + } +}; + +TEST_F(TestAsyncDeviceArrayStreamRoundTrip, Simple) { + std::shared_ptr<Device> device = std::make_shared<MyDevice>(1); + auto mm = device->default_memory_manager(); + + ASSERT_OK_AND_ASSIGN(auto arr1, + ToDevice(mm, *ArrayFromJSON(int32(), "[1, 2]")->data())); + ASSERT_EQ(device->device_type(), arr1->device_type()); + ASSERT_OK_AND_ASSIGN(auto arr2, + ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5, null]")->data())); + ASSERT_EQ(device->device_type(), arr2->device_type()); + auto orig_schema = arrow::schema({field("ints", int32())}); + auto batches = MakeBatches(orig_schema, {arr1, arr2}); + + struct ArrowAsyncDeviceStreamHandler handler; + auto fut_gen = CreateAsyncDeviceStreamHandler(&handler, internal::GetCpuThreadPool(), 1, + TestDeviceArrayRoundtrip::DeviceMapper); + ASSERT_FALSE(fut_gen.is_finished()); + + ASSERT_OK_AND_ASSIGN(auto fut, internal::GetCpuThreadPool()->Submit([&]() { + return ExportAsyncRecordBatchReader(orig_schema, MakeVectorGenerator(batches), + device->device_type(), &handler); + })); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto generator, fut_gen); + ASSERT_NO_FATAL_FAILURE(AssertSchemaEqual(*orig_schema, *generator.schema)); + + auto collect_fut = CollectAsyncGenerator(generator.generator); + ASSERT_FINISHES_OK_AND_ASSIGN(auto results, collect_fut); + ASSERT_FINISHES_OK(fut); + ASSERT_FINISHES_OK(fut_gen); + + ASSERT_EQ(results.size(), 2); + AssertBatchesEqual(*results[0].batch, *batches[0]); + AssertBatchesEqual(*results[1].batch, *batches[1]); + + internal::GetCpuThreadPool()->WaitForIdle(); +} + +TEST_F(TestAsyncDeviceArrayStreamRoundTrip, NullSchema) { + struct ArrowAsyncDeviceStreamHandler handler; + auto fut_gen = CreateAsyncDeviceStreamHandler(&handler, internal::GetCpuThreadPool(), 1, + TestDeviceArrayRoundtrip::DeviceMapper); + ASSERT_FALSE(fut_gen.is_finished()); + + auto fut = ExportAsyncRecordBatchReader(nullptr, nullptr, DeviceAllocationType::kCPU, + &handler); + ASSERT_FINISHES_AND_RAISES(Invalid, fut); + ASSERT_FINISHES_AND_RAISES(UnknownError, fut_gen); +} + +TEST_F(TestAsyncDeviceArrayStreamRoundTrip, PropagateError) { + std::shared_ptr<Device> device = std::make_shared<MyDevice>(1); + auto orig_schema = arrow::schema({field("ints", int32())}); + + struct ArrowAsyncDeviceStreamHandler handler; + auto fut_gen = CreateAsyncDeviceStreamHandler(&handler, internal::GetCpuThreadPool(), 1, + TestDeviceArrayRoundtrip::DeviceMapper); + ASSERT_FALSE(fut_gen.is_finished()); + + ASSERT_OK_AND_ASSIGN(auto fut, internal::GetCpuThreadPool()->Submit([&]() { + return ExportAsyncRecordBatchReader( + orig_schema, + MakeFailingGenerator<std::shared_ptr<RecordBatch>>( + Status::UnknownError("expected failure")), + device->device_type(), &handler); + })); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto generator, fut_gen); + ASSERT_NO_FATAL_FAILURE(AssertSchemaEqual(*orig_schema, *generator.schema)); + + auto collect_fut = CollectAsyncGenerator(generator.generator); + ASSERT_FINISHES_AND_RAISES(UnknownError, collect_fut); + ASSERT_FINISHES_AND_RAISES(UnknownError, fut); + + internal::GetCpuThreadPool()->WaitForIdle(); +} + } // namespace arrow diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index edbefc1c77..06cb621e98 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -309,6 +309,12 @@ struct ARROW_EXPORT RecordBatchWithMetadata { std::shared_ptr<KeyValueMetadata> custom_metadata; }; +template <> +struct IterationTraits<RecordBatchWithMetadata> { + static RecordBatchWithMetadata End() { return {NULLPTR, NULLPTR}; } + static bool IsEnd(const RecordBatchWithMetadata& val) { return val.batch == NULLPTR; } +}; + /// \brief Abstract interface for reading stream of record batches class ARROW_EXPORT RecordBatchReader { public: diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index 69029b67ab..5a2fbde023 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -79,7 +79,9 @@ using ScalarVector = std::vector<std::shared_ptr<Scalar>>; class ChunkedArray; class RecordBatch; +struct RecordBatchWithMetadata; class RecordBatchReader; +class AsyncRecordBatchReader; class Table; struct Datum; diff --git a/docs/source/format/CDeviceDataInterface.rst b/docs/source/format/CDeviceDataInterface.rst index fbb2012c30..19412c605c 100644 --- a/docs/source/format/CDeviceDataInterface.rst +++ b/docs/source/format/CDeviceDataInterface.rst @@ -695,6 +695,8 @@ The C device async stream interface consists of three ``struct`` definitions: }; struct ArrowAsyncProducer { + ArrowDeviceType device_type; + void (*request)(struct ArrowAsyncProducer* self, int64_t n); void (*cancel)(struct ArrowAsyncProducer* self); @@ -869,6 +871,12 @@ The ArrowAsyncProducer structure This producer-provided and managed object has the following fields: +.. c:member:: ArrowDeviceType ArrowAsyncProducer.device_type + + *Mandatory.* The device type that this producer will provide data on. All + ``ArrowDeviceArray`` structs that are produced by this producer should have the + same device type as is set here. + .. c:member:: void (*ArrowAsyncProducer.request)(struct ArrowAsyncProducer*, uint64_t) *Mandatory.* This function must be called by a consumer to start receiving calls to