This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new aab7d81aee GH-43631: [C++] Add C++ implementation of Async C Data 
Interface (#44495)
aab7d81aee is described below

commit aab7d81aeec1e8f106bdda953cdeb00f7f78b355
Author: Matt Topol <zotthewiz...@gmail.com>
AuthorDate: Mon Nov 11 23:20:13 2024 +0100

    GH-43631: [C++] Add C++ implementation of Async C Data Interface (#44495)
    
    
    
    ### Rationale for this change
    Building on #43632 which created the Async C Data Structures, this adds 
functions to `bridge.h`/`bridge.cc` to implement helpers for managing the Async 
C Data interfaces
    
    ### What changes are included in this PR?
    Two functions added to bridge.h:
    
    1. `CreateAsyncDeviceStreamHandler` populates a 
`ArrowAsyncDeviceStreamHandler` and an `Executor` to provide a future that 
resolves to an `AsyncRecordBatchGenerator` to produce record batches as they 
are pushed asynchronously. The `ArrowAsyncDeviceStreamHandler` can then be 
passed to any asynchronous producer.
    2. `ExportAsyncRecordBatchReader` takes a record batch generator and a 
schema, along with an `ArrowAsyncDeviceStreamHandler` to use for calling the 
callbacks to push data as it is available from the generator.
    
    ### Are these changes tested?
    Unit tests are added (currently only one test, more tests to be added)
    
    ### Are there any user-facing changes?
    No
    
    * GitHub Issue: #43631
    
    Lead-authored-by: Matt Topol <zotthewiz...@gmail.com>
    Co-authored-by: David Li <li.david...@gmail.com>
    Co-authored-by: Benjamin Kietzman <bengil...@gmail.com>
    Signed-off-by: Matt Topol <zotthewiz...@gmail.com>
---
 cpp/src/arrow/c/abi.h                       |   5 +-
 cpp/src/arrow/c/bridge.cc                   | 348 ++++++++++++++++++++++++++++
 cpp/src/arrow/c/bridge.h                    |  80 +++++++
 cpp/src/arrow/c/bridge_test.cc              | 110 +++++++++
 cpp/src/arrow/record_batch.h                |   6 +
 cpp/src/arrow/type_fwd.h                    |   2 +
 docs/source/format/CDeviceDataInterface.rst |   8 +
 7 files changed, 558 insertions(+), 1 deletion(-)

diff --git a/cpp/src/arrow/c/abi.h b/cpp/src/arrow/c/abi.h
index e44933a6af..ae632f2dbd 100644
--- a/cpp/src/arrow/c/abi.h
+++ b/cpp/src/arrow/c/abi.h
@@ -287,7 +287,7 @@ struct ArrowAsyncTask {
   // calling this, and so it must be released separately.
   //
   // It is only valid to call this method exactly once.
-  int (*extract_data)(struct ArrowArrayTask* self, struct ArrowDeviceArray* 
out);
+  int (*extract_data)(struct ArrowAsyncTask* self, struct ArrowDeviceArray* 
out);
 
   // opaque task-specific data
   void* private_data;
@@ -298,6 +298,9 @@ struct ArrowAsyncTask {
 // control on the asynchronous stream processing. This object must be owned by 
the
 // producer who creates it, and thus is responsible for cleaning it up.
 struct ArrowAsyncProducer {
+  // The device type that this stream produces data on.
+  ArrowDeviceType device_type;
+
   // A consumer must call this function to start receiving on_next_task calls.
   //
   // It *must* be valid to call this synchronously from within `on_next_task` 
or
diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc
index 0976a5cb61..f848b34115 100644
--- a/cpp/src/arrow/c/bridge.cc
+++ b/cpp/src/arrow/c/bridge.cc
@@ -19,8 +19,11 @@
 
 #include <algorithm>
 #include <cerrno>
+#include <condition_variable>
 #include <cstring>
 #include <memory>
+#include <mutex>
+#include <queue>
 #include <string>
 #include <string_view>
 #include <utility>
@@ -37,8 +40,10 @@
 #include "arrow/result.h"
 #include "arrow/stl_allocator.h"
 #include "arrow/type_traits.h"
+#include "arrow/util/async_generator.h"
 #include "arrow/util/bit_util.h"
 #include "arrow/util/checked_cast.h"
+#include "arrow/util/future.h"
 #include "arrow/util/key_value_metadata.h"
 #include "arrow/util/logging.h"
 #include "arrow/util/macros.h"
@@ -2511,4 +2516,347 @@ Result<std::shared_ptr<ChunkedArray>> 
ImportDeviceChunkedArray(
   return ImportChunked</*IsDevice=*/true>(stream, mapper);
 }
 
+namespace {
+
+class AsyncRecordBatchIterator {
+ public:
+  struct TaskWithMetadata {
+    ArrowAsyncTask task_;
+    std::shared_ptr<KeyValueMetadata> metadata_;
+  };
+
+  struct State {
+    State(uint64_t queue_size, DeviceMemoryMapper mapper)
+        : queue_size_{queue_size}, mapper_{std::move(mapper)} {}
+
+    Result<RecordBatchWithMetadata> next() {
+      TaskWithMetadata task;
+      {
+        std::unique_lock<std::mutex> lock(mutex_);
+        cv_.wait(lock,
+                 [&] { return !error_.ok() || !batches_.empty() || 
end_of_stream_; });
+        if (!error_.ok()) {
+          return error_;
+        }
+
+        if (batches_.empty() && end_of_stream_) {
+          return IterationEnd<RecordBatchWithMetadata>();
+        }
+
+        task = std::move(batches_.front());
+        batches_.pop();
+      }
+
+      producer_->request(producer_, 1);
+      ArrowDeviceArray out;
+      if (task.task_.extract_data(&task.task_, &out) != 0) {
+        std::unique_lock<std::mutex> lock(mutex_);
+        cv_.wait(lock, [&] { return !error_.ok(); });
+        return error_;
+      }
+
+      ARROW_ASSIGN_OR_RAISE(auto batch, ImportDeviceRecordBatch(&out, schema_, 
mapper_));
+      return RecordBatchWithMetadata{std::move(batch), 
std::move(task.metadata_)};
+    }
+
+    const uint64_t queue_size_;
+    const DeviceMemoryMapper mapper_;
+    ArrowAsyncProducer* producer_;
+    DeviceAllocationType device_type_;
+
+    std::mutex mutex_;
+    std::shared_ptr<Schema> schema_;
+    std::condition_variable cv_;
+    std::queue<TaskWithMetadata> batches_;
+    bool end_of_stream_ = false;
+    Status error_{Status::OK()};
+  };
+
+  AsyncRecordBatchIterator(uint64_t queue_size, DeviceMemoryMapper mapper)
+      : state_{std::make_shared<State>(queue_size, std::move(mapper))} {}
+
+  explicit AsyncRecordBatchIterator(std::shared_ptr<State> state)
+      : state_{std::move(state)} {}
+
+  const std::shared_ptr<Schema>& schema() const { return state_->schema_; }
+
+  DeviceAllocationType device_type() const { return state_->device_type_; }
+
+  Result<RecordBatchWithMetadata> Next() { return state_->next(); }
+
+  static Future<std::shared_ptr<AsyncRecordBatchIterator::State>> Make(
+      AsyncRecordBatchIterator& iterator, struct 
ArrowAsyncDeviceStreamHandler* handler) {
+    auto iterator_fut = 
Future<std::shared_ptr<AsyncRecordBatchIterator::State>>::Make();
+
+    auto private_data = new PrivateData{iterator.state_};
+    private_data->fut_iterator_ = iterator_fut;
+
+    handler->private_data = private_data;
+    handler->on_schema = on_schema;
+    handler->on_next_task = on_next_task;
+    handler->on_error = on_error;
+    handler->release = release;
+    return iterator_fut;
+  }
+
+ private:
+  struct PrivateData {
+    explicit PrivateData(std::shared_ptr<State> state) : 
state_(std::move(state)) {}
+
+    std::shared_ptr<State> state_;
+    Future<std::shared_ptr<AsyncRecordBatchIterator::State>> fut_iterator_;
+    ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateData);
+  };
+
+  static int on_schema(struct ArrowAsyncDeviceStreamHandler* self,
+                       struct ArrowSchema* stream_schema) {
+    auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+    if (self->producer != nullptr) {
+      private_data->state_->producer_ = self->producer;
+      private_data->state_->device_type_ =
+          static_cast<DeviceAllocationType>(self->producer->device_type);
+    }
+
+    auto maybe_schema = ImportSchema(stream_schema);
+    if (!maybe_schema.ok()) {
+      private_data->fut_iterator_.MarkFinished(maybe_schema.status());
+      return EINVAL;
+    }
+
+    private_data->state_->schema_ = maybe_schema.MoveValueUnsafe();
+    private_data->fut_iterator_.MarkFinished(private_data->state_);
+    self->producer->request(self->producer,
+                            
static_cast<int64_t>(private_data->state_->queue_size_));
+    return 0;
+  }
+
+  static int on_next_task(ArrowAsyncDeviceStreamHandler* self, ArrowAsyncTask* 
task,
+                          const char* metadata) {
+    auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+
+    if (task == nullptr) {
+      std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
+      private_data->state_->end_of_stream_ = true;
+      lock.unlock();
+      private_data->state_->cv_.notify_one();
+      return 0;
+    }
+
+    std::shared_ptr<KeyValueMetadata> kvmetadata;
+    if (metadata != nullptr) {
+      auto maybe_decoded = DecodeMetadata(metadata);
+      if (!maybe_decoded.ok()) {
+        private_data->state_->error_ = std::move(maybe_decoded).status();
+        private_data->state_->cv_.notify_one();
+        return EINVAL;
+      }
+
+      kvmetadata = std::move(maybe_decoded->metadata);
+    }
+
+    std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
+    private_data->state_->batches_.push({*task, std::move(kvmetadata)});
+    lock.unlock();
+    private_data->state_->cv_.notify_one();
+    return 0;
+  }
+
+  static void on_error(ArrowAsyncDeviceStreamHandler* self, int code, const 
char* message,
+                       const char* metadata) {
+    auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+    std::string message_str, metadata_str;
+    if (message != nullptr) {
+      message_str = message;
+    }
+    if (metadata != nullptr) {
+      metadata_str = metadata;
+    }
+
+    Status error = Status::FromDetailAndArgs(
+        StatusCode::UnknownError,
+        std::make_shared<AsyncErrorDetail>(code, message_str, 
std::move(metadata_str)),
+        std::move(message_str));
+
+    if (!private_data->fut_iterator_.is_finished()) {
+      private_data->fut_iterator_.MarkFinished(error);
+      return;
+    }
+
+    std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
+    private_data->state_->error_ = std::move(error);
+    lock.unlock();
+    private_data->state_->cv_.notify_one();
+  }
+
+  static void release(ArrowAsyncDeviceStreamHandler* self) {
+    delete reinterpret_cast<PrivateData*>(self->private_data);
+  }
+
+  std::shared_ptr<State> state_;
+};
+
+struct AsyncProducer {
+  struct State {
+    struct ArrowAsyncProducer producer_;
+
+    std::mutex mutex_;
+    std::condition_variable cv_;
+    uint64_t pending_requests_{0};
+    Status error_{Status::OK()};
+  };
+
+  AsyncProducer(DeviceAllocationType device_type, struct ArrowSchema* schema,
+                struct ArrowAsyncDeviceStreamHandler* handler)
+      : handler_{handler}, state_{std::make_shared<State>()} {
+    state_->producer_.device_type = static_cast<ArrowDeviceType>(device_type);
+    state_->producer_.private_data = reinterpret_cast<void*>(state_.get());
+    state_->producer_.request = AsyncProducer::request;
+    state_->producer_.cancel = AsyncProducer::cancel;
+    handler_->producer = &state_->producer_;
+
+    if (int status = handler_->on_schema(handler_, schema) != 0) {
+      state_->error_ =
+          Status::UnknownError("Received error from handler::on_schema ", 
status);
+    }
+  }
+
+  struct PrivateTaskData {
+    PrivateTaskData(std::shared_ptr<State> producer, 
std::shared_ptr<RecordBatch> record)
+        : producer_{std::move(producer)}, record_(std::move(record)) {}
+
+    std::shared_ptr<State> producer_;
+    std::shared_ptr<RecordBatch> record_;
+    ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateTaskData);
+  };
+
+  Status operator()(const std::shared_ptr<RecordBatch>& record) {
+    std::unique_lock<std::mutex> lock(state_->mutex_);
+    if (state_->pending_requests_ == 0) {
+      state_->cv_.wait(lock, [this]() -> bool {
+        return !state_->error_.ok() || state_->pending_requests_ > 0;
+      });
+    }
+
+    if (!state_->error_.ok()) {
+      return state_->error_;
+    }
+
+    if (state_->pending_requests_ > 0) {
+      state_->pending_requests_--;
+      lock.unlock();
+
+      ArrowAsyncTask task;
+      task.private_data = new PrivateTaskData{state_, record};
+      task.extract_data = AsyncProducer::extract_data;
+
+      if (int status = handler_->on_next_task(handler_, &task, nullptr) != 0) {
+        delete reinterpret_cast<PrivateTaskData*>(task.private_data);
+        return Status::UnknownError("Received error from handler::on_next_task 
", status);
+      }
+    }
+
+    return Status::OK();
+  }
+
+  static void request(struct ArrowAsyncProducer* producer, int64_t n) {
+    auto* self = reinterpret_cast<State*>(producer->private_data);
+    {
+      std::lock_guard<std::mutex> lock(self->mutex_);
+      if (!self->error_.ok()) {
+        return;
+      }
+      self->pending_requests_ += n;
+    }
+    self->cv_.notify_all();
+  }
+
+  static void cancel(struct ArrowAsyncProducer* producer) {
+    auto* self = reinterpret_cast<State*>(producer->private_data);
+    {
+      std::lock_guard<std::mutex> lock(self->mutex_);
+      if (!self->error_.ok()) {
+        return;
+      }
+      self->error_ = Status::Cancelled("Consumer requested cancellation");
+    }
+    self->cv_.notify_all();
+  }
+
+  static int extract_data(struct ArrowAsyncTask* task, struct 
ArrowDeviceArray* out) {
+    std::unique_ptr<PrivateTaskData> private_data{
+        reinterpret_cast<PrivateTaskData*>(task->private_data)};
+    int ret = 0;
+    if (out != nullptr) {
+      auto status = ExportDeviceRecordBatch(*private_data->record_,
+                                            
private_data->record_->GetSyncEvent(), out);
+      if (!status.ok()) {
+        std::lock_guard<std::mutex> lock(private_data->producer_->mutex_);
+        private_data->producer_->error_ = status;
+      }
+    }
+
+    return ret;
+  }
+
+  struct ArrowAsyncDeviceStreamHandler* handler_;
+  std::shared_ptr<State> state_;
+};
+
+}  // namespace
+
+Future<AsyncRecordBatchGenerator> CreateAsyncDeviceStreamHandler(
+    struct ArrowAsyncDeviceStreamHandler* handler, internal::Executor* 
executor,
+    uint64_t queue_size, DeviceMemoryMapper mapper) {
+  auto iterator =
+      std::make_shared<AsyncRecordBatchIterator>(queue_size, 
std::move(mapper));
+  return AsyncRecordBatchIterator::Make(*iterator, handler)
+      .Then([executor](std::shared_ptr<AsyncRecordBatchIterator::State> state)
+                -> Result<AsyncRecordBatchGenerator> {
+        AsyncRecordBatchGenerator gen{state->schema_, state->device_type_, 
nullptr};
+        auto it =
+            
Iterator<RecordBatchWithMetadata>(AsyncRecordBatchIterator{std::move(state)});
+        ARROW_ASSIGN_OR_RAISE(gen.generator,
+                              MakeBackgroundGenerator(std::move(it), 
executor));
+        return gen;
+      });
+}
+
+Future<> ExportAsyncRecordBatchReader(
+    std::shared_ptr<Schema> schema,
+    AsyncGenerator<std::shared_ptr<RecordBatch>> generator,
+    DeviceAllocationType device_type, struct ArrowAsyncDeviceStreamHandler* 
handler) {
+  if (!schema) {
+    handler->on_error(handler, EINVAL, "Schema is null", nullptr);
+    handler->release(handler);
+    return Future<>::MakeFinished(Status::Invalid("Schema is null"));
+  }
+
+  struct ArrowSchema c_schema;
+  SchemaExportGuard guard(&c_schema);
+
+  auto status = ExportSchema(*schema, &c_schema);
+  if (!status.ok()) {
+    handler->on_error(handler, EINVAL, status.message().c_str(), nullptr);
+    handler->release(handler);
+    return Future<>::MakeFinished(status);
+  }
+
+  return VisitAsyncGenerator(generator, AsyncProducer{device_type, &c_schema, 
handler})
+      .Then(
+          [handler]() -> Status {
+            int status = handler->on_next_task(handler, nullptr, nullptr);
+            handler->release(handler);
+            if (status != 0) {
+              return Status::UnknownError("Received error from 
handler::on_next_task ",
+                                          status);
+            }
+            return Status::OK();
+          },
+          [handler](const Status status) -> Status {
+            handler->on_error(handler, EINVAL, status.message().c_str(), 
nullptr);
+            handler->release(handler);
+            return status;
+          });
+}
+
 }  // namespace arrow
diff --git a/cpp/src/arrow/c/bridge.h b/cpp/src/arrow/c/bridge.h
index 45367e4f93..78860e0650 100644
--- a/cpp/src/arrow/c/bridge.h
+++ b/cpp/src/arrow/c/bridge.h
@@ -26,6 +26,7 @@
 #include "arrow/result.h"
 #include "arrow/status.h"
 #include "arrow/type_fwd.h"
+#include "arrow/util/async_generator_fwd.h"
 #include "arrow/util/macros.h"
 #include "arrow/util/visibility.h"
 
@@ -406,4 +407,83 @@ Result<std::shared_ptr<ChunkedArray>> 
ImportDeviceChunkedArray(
 
 /// @}
 
+/// \defgroup c-async-stream-interface Functions for working with the async C 
data
+/// interface.
+///
+/// @{
+
+/// \brief EXPERIMENTAL: AsyncErrorDetail is a StatusDetail that contains an 
error code
+/// and message from an asynchronous operation.
+class AsyncErrorDetail : public StatusDetail {
+ public:
+  AsyncErrorDetail(int code, std::string message, std::string metadata)
+      : code_(code), message_(std::move(message)), 
metadata_(std::move(metadata)) {}
+  const char* type_id() const override { return "AsyncErrorDetail"; }
+  // ToString just returns the error message that was returned with the error
+  std::string ToString() const override { return message_; }
+  // code is an errno-compatible error code
+  int code() const { return code_; }
+  // returns any metadata that was returned with the error, likely in a
+  // key-value format similar to ArrowSchema metadata
+  const std::string& ErrorMetadataString() const { return metadata_; }
+  std::shared_ptr<KeyValueMetadata> ErrorMetadata() const;
+
+ private:
+  int code_{0};
+  std::string message_;
+  std::string metadata_;
+};
+
+struct AsyncRecordBatchGenerator {
+  std::shared_ptr<Schema> schema;
+  DeviceAllocationType device_type;
+  AsyncGenerator<RecordBatchWithMetadata> generator;
+};
+
+namespace internal {
+class Executor;
+}
+
+/// \brief EXPERIMENTAL: Create an AsyncRecordBatchReader and populate a 
corresponding
+/// handler to pass to a producer
+///
+/// The ArrowAsyncDeviceStreamHandler struct is intended to have its callbacks 
populated
+/// and then be passed to a producer to call the appropriate callbacks when 
data is ready.
+/// This inverts the traditional flow of control, and so we construct a 
corresponding
+/// AsyncRecordBatchGenerator to provide an interface for the consumer to 
retrieve data as
+/// it is pushed to the handler.
+///
+/// \param[in,out] handler C struct to be populated
+/// \param[in] executor the executor to use for waiting and populating record 
batches
+/// \param[in] queue_size initial number of record batches to request for 
queueing
+/// \param[in] mapper mapping from device type and ID to memory manager
+/// \return Future that resolves to either an error or 
AsyncRecordBatchGenerator once a
+/// schema is available or an error is received.
+ARROW_EXPORT
+Future<AsyncRecordBatchGenerator> CreateAsyncDeviceStreamHandler(
+    struct ArrowAsyncDeviceStreamHandler* handler, internal::Executor* 
executor,
+    uint64_t queue_size = 5, DeviceMemoryMapper mapper = 
DefaultDeviceMemoryMapper);
+
+/// \brief EXPERIMENTAL: Export an AsyncGenerator of record batches using a 
provided
+/// handler
+///
+/// This function calls the callbacks on the consumer-provided async handler 
as record
+/// batches become available from the AsyncGenerator which is provided. It 
will first call
+/// on_schema using the provided schema, and then serially visit each record 
batch from
+/// the generator, calling the on_next_task callback. If an error occurs, 
on_error will be
+/// called appropriately.
+///
+/// \param[in] schema the schema of the stream being exported
+/// \param[in] generator a generator that asynchronously produces record 
batches
+/// \param[in] device_type the device type that the record batches will be 
located on
+/// \param[in] handler the handler whose callbacks to utilize as data is 
available
+/// \return Future that will resolve once the generator is exhausted or an 
error occurs
+ARROW_EXPORT
+Future<> ExportAsyncRecordBatchReader(
+    std::shared_ptr<Schema> schema,
+    AsyncGenerator<std::shared_ptr<RecordBatch>> generator,
+    DeviceAllocationType device_type, struct ArrowAsyncDeviceStreamHandler* 
handler);
+
+/// @}
+
 }  // namespace arrow
diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc
index fdcb53ddbc..bc60b587cf 100644
--- a/cpp/src/arrow/c/bridge_test.cc
+++ b/cpp/src/arrow/c/bridge_test.cc
@@ -38,6 +38,7 @@
 #include "arrow/testing/gtest_util.h"
 #include "arrow/testing/matchers.h"
 #include "arrow/testing/util.h"
+#include "arrow/util/async_generator.h"
 #include "arrow/util/binary_view_util.h"
 #include "arrow/util/checked_cast.h"
 #include "arrow/util/endian.h"
@@ -45,6 +46,7 @@
 #include "arrow/util/logging.h"
 #include "arrow/util/macros.h"
 #include "arrow/util/range.h"
+#include "arrow/util/thread_pool.h"
 
 // TODO(GH-37221): Remove these ifdef checks when compute dependency is removed
 #ifdef ARROW_COMPUTE
@@ -5311,4 +5313,112 @@ TEST_F(TestArrayDeviceStreamRoundtrip, 
ChunkedArrayRoundtripEmpty) {
   });
 }
 
+class TestAsyncDeviceArrayStreamRoundTrip : public BaseArrayStreamTest {
+ public:
+  static Result<std::shared_ptr<ArrayData>> ToDeviceData(
+      const std::shared_ptr<MemoryManager>& mm, const ArrayData& data) {
+    arrow::BufferVector buffers;
+    for (const auto& buf : data.buffers) {
+      if (buf) {
+        ARROW_ASSIGN_OR_RAISE(auto dest, mm->CopyBuffer(buf, mm));
+        buffers.push_back(dest);
+      } else {
+        buffers.push_back(nullptr);
+      }
+    }
+
+    arrow::ArrayDataVector children;
+    for (const auto& child : data.child_data) {
+      ARROW_ASSIGN_OR_RAISE(auto dest, ToDeviceData(mm, *child));
+      children.push_back(dest);
+    }
+
+    return ArrayData::Make(data.type, data.length, buffers, children, 
data.null_count,
+                           data.offset);
+  }
+
+  static Result<std::shared_ptr<Array>> ToDevice(const 
std::shared_ptr<MemoryManager>& mm,
+                                                 const ArrayData& data) {
+    ARROW_ASSIGN_OR_RAISE(auto result, ToDeviceData(mm, data));
+    return MakeArray(result);
+  }
+};
+
+TEST_F(TestAsyncDeviceArrayStreamRoundTrip, Simple) {
+  std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+  auto mm = device->default_memory_manager();
+
+  ASSERT_OK_AND_ASSIGN(auto arr1,
+                       ToDevice(mm, *ArrayFromJSON(int32(), "[1, 
2]")->data()));
+  ASSERT_EQ(device->device_type(), arr1->device_type());
+  ASSERT_OK_AND_ASSIGN(auto arr2,
+                       ToDevice(mm, *ArrayFromJSON(int32(), "[4, 5, 
null]")->data()));
+  ASSERT_EQ(device->device_type(), arr2->device_type());
+  auto orig_schema = arrow::schema({field("ints", int32())});
+  auto batches = MakeBatches(orig_schema, {arr1, arr2});
+
+  struct ArrowAsyncDeviceStreamHandler handler;
+  auto fut_gen = CreateAsyncDeviceStreamHandler(&handler, 
internal::GetCpuThreadPool(), 1,
+                                                
TestDeviceArrayRoundtrip::DeviceMapper);
+  ASSERT_FALSE(fut_gen.is_finished());
+
+  ASSERT_OK_AND_ASSIGN(auto fut, internal::GetCpuThreadPool()->Submit([&]() {
+    return ExportAsyncRecordBatchReader(orig_schema, 
MakeVectorGenerator(batches),
+                                        device->device_type(), &handler);
+  }));
+
+  ASSERT_FINISHES_OK_AND_ASSIGN(auto generator, fut_gen);
+  ASSERT_NO_FATAL_FAILURE(AssertSchemaEqual(*orig_schema, *generator.schema));
+
+  auto collect_fut = CollectAsyncGenerator(generator.generator);
+  ASSERT_FINISHES_OK_AND_ASSIGN(auto results, collect_fut);
+  ASSERT_FINISHES_OK(fut);
+  ASSERT_FINISHES_OK(fut_gen);
+
+  ASSERT_EQ(results.size(), 2);
+  AssertBatchesEqual(*results[0].batch, *batches[0]);
+  AssertBatchesEqual(*results[1].batch, *batches[1]);
+
+  internal::GetCpuThreadPool()->WaitForIdle();
+}
+
+TEST_F(TestAsyncDeviceArrayStreamRoundTrip, NullSchema) {
+  struct ArrowAsyncDeviceStreamHandler handler;
+  auto fut_gen = CreateAsyncDeviceStreamHandler(&handler, 
internal::GetCpuThreadPool(), 1,
+                                                
TestDeviceArrayRoundtrip::DeviceMapper);
+  ASSERT_FALSE(fut_gen.is_finished());
+
+  auto fut = ExportAsyncRecordBatchReader(nullptr, nullptr, 
DeviceAllocationType::kCPU,
+                                          &handler);
+  ASSERT_FINISHES_AND_RAISES(Invalid, fut);
+  ASSERT_FINISHES_AND_RAISES(UnknownError, fut_gen);
+}
+
+TEST_F(TestAsyncDeviceArrayStreamRoundTrip, PropagateError) {
+  std::shared_ptr<Device> device = std::make_shared<MyDevice>(1);
+  auto orig_schema = arrow::schema({field("ints", int32())});
+
+  struct ArrowAsyncDeviceStreamHandler handler;
+  auto fut_gen = CreateAsyncDeviceStreamHandler(&handler, 
internal::GetCpuThreadPool(), 1,
+                                                
TestDeviceArrayRoundtrip::DeviceMapper);
+  ASSERT_FALSE(fut_gen.is_finished());
+
+  ASSERT_OK_AND_ASSIGN(auto fut, internal::GetCpuThreadPool()->Submit([&]() {
+    return ExportAsyncRecordBatchReader(
+        orig_schema,
+        MakeFailingGenerator<std::shared_ptr<RecordBatch>>(
+            Status::UnknownError("expected failure")),
+        device->device_type(), &handler);
+  }));
+
+  ASSERT_FINISHES_OK_AND_ASSIGN(auto generator, fut_gen);
+  ASSERT_NO_FATAL_FAILURE(AssertSchemaEqual(*orig_schema, *generator.schema));
+
+  auto collect_fut = CollectAsyncGenerator(generator.generator);
+  ASSERT_FINISHES_AND_RAISES(UnknownError, collect_fut);
+  ASSERT_FINISHES_AND_RAISES(UnknownError, fut);
+
+  internal::GetCpuThreadPool()->WaitForIdle();
+}
+
 }  // namespace arrow
diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h
index edbefc1c77..06cb621e98 100644
--- a/cpp/src/arrow/record_batch.h
+++ b/cpp/src/arrow/record_batch.h
@@ -309,6 +309,12 @@ struct ARROW_EXPORT RecordBatchWithMetadata {
   std::shared_ptr<KeyValueMetadata> custom_metadata;
 };
 
+template <>
+struct IterationTraits<RecordBatchWithMetadata> {
+  static RecordBatchWithMetadata End() { return {NULLPTR, NULLPTR}; }
+  static bool IsEnd(const RecordBatchWithMetadata& val) { return val.batch == 
NULLPTR; }
+};
+
 /// \brief Abstract interface for reading stream of record batches
 class ARROW_EXPORT RecordBatchReader {
  public:
diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h
index 69029b67ab..5a2fbde023 100644
--- a/cpp/src/arrow/type_fwd.h
+++ b/cpp/src/arrow/type_fwd.h
@@ -79,7 +79,9 @@ using ScalarVector = std::vector<std::shared_ptr<Scalar>>;
 
 class ChunkedArray;
 class RecordBatch;
+struct RecordBatchWithMetadata;
 class RecordBatchReader;
+class AsyncRecordBatchReader;
 class Table;
 
 struct Datum;
diff --git a/docs/source/format/CDeviceDataInterface.rst 
b/docs/source/format/CDeviceDataInterface.rst
index fbb2012c30..19412c605c 100644
--- a/docs/source/format/CDeviceDataInterface.rst
+++ b/docs/source/format/CDeviceDataInterface.rst
@@ -695,6 +695,8 @@ The C device async stream interface consists of three 
``struct`` definitions:
     };
 
     struct ArrowAsyncProducer {
+      ArrowDeviceType device_type;
+
       void (*request)(struct ArrowAsyncProducer* self, int64_t n);
       void (*cancel)(struct ArrowAsyncProducer* self);
 
@@ -869,6 +871,12 @@ The ArrowAsyncProducer structure
 
 This producer-provided and managed object has the following fields:
 
+.. c:member:: ArrowDeviceType ArrowAsyncProducer.device_type
+
+  *Mandatory.* The device type that this producer will provide data on. All
+  ``ArrowDeviceArray`` structs that are produced by this producer should have 
the
+  same device type as is set here.
+
 .. c:member:: void (*ArrowAsyncProducer.request)(struct ArrowAsyncProducer*, 
uint64_t)
 
   *Mandatory.* This function must be called by a consumer to start receiving 
calls to

Reply via email to