bkietz commented on code in PR #44495:
URL: https://github.com/apache/arrow/pull/44495#discussion_r1834704974
##########
cpp/src/arrow/c/bridge.cc:
##########
@@ -2511,4 +2516,345 @@ Result<std::shared_ptr<ChunkedArray>>
ImportDeviceChunkedArray(
return ImportChunked</*IsDevice=*/true>(stream, mapper);
}
+namespace {
+
+class AsyncRecordBatchIterator {
+ public:
+ struct State {
+ State(uint64_t queue_size, const DeviceMemoryMapper mapper)
+ : queue_size_{queue_size}, mapper_{std::move(mapper)} {}
+
+ const uint64_t queue_size_;
+ const DeviceMemoryMapper mapper_;
+ ArrowAsyncProducer* producer_;
+ DeviceAllocationType device_type_;
+
+ std::mutex mutex_;
+ std::shared_ptr<Schema> schema_;
+ std::condition_variable cv_;
+ std::queue<std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>>>
batches_;
+ bool end_of_stream_ = false;
+ Status error_{Status::OK()};
+ };
+
+ AsyncRecordBatchIterator(uint64_t queue_size, const DeviceMemoryMapper
mapper)
+ : state_{std::make_shared<State>(queue_size, std::move(mapper))} {}
+
+ explicit AsyncRecordBatchIterator(std::shared_ptr<State> state)
+ : state_{std::move(state)} {}
+
+ const std::shared_ptr<Schema>& schema() const { return state_->schema_; }
+
+ DeviceAllocationType device_type() const { return state_->device_type_; }
+
+ Result<RecordBatchWithMetadata> Next() {
+ std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>> task;
+ {
+ std::unique_lock<std::mutex> lock(state_->mutex_);
Review Comment:
Since this block mostly uses State's members, it should probably be a member
function of State
##########
cpp/src/arrow/c/bridge.cc:
##########
@@ -2511,4 +2516,345 @@ Result<std::shared_ptr<ChunkedArray>>
ImportDeviceChunkedArray(
return ImportChunked</*IsDevice=*/true>(stream, mapper);
}
+namespace {
+
+class AsyncRecordBatchIterator {
+ public:
+ struct State {
+ State(uint64_t queue_size, const DeviceMemoryMapper mapper)
+ : queue_size_{queue_size}, mapper_{std::move(mapper)} {}
+
+ const uint64_t queue_size_;
+ const DeviceMemoryMapper mapper_;
+ ArrowAsyncProducer* producer_;
+ DeviceAllocationType device_type_;
+
+ std::mutex mutex_;
+ std::shared_ptr<Schema> schema_;
+ std::condition_variable cv_;
+ std::queue<std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>>>
batches_;
Review Comment:
Nit: instead of pair, could we use
```c++
struct TaskWithMetadata {
ArrowAsyncTask task;
std::shared_ptr<KeyValueMetadata>> metadata;
};
```
and rename this
```suggestion
std::queue<TaskWithMetadata> tasks_;
```
##########
cpp/src/arrow/c/bridge.cc:
##########
@@ -2511,4 +2516,345 @@ Result<std::shared_ptr<ChunkedArray>>
ImportDeviceChunkedArray(
return ImportChunked</*IsDevice=*/true>(stream, mapper);
}
+namespace {
+
+class AsyncRecordBatchIterator {
+ public:
+ struct State {
+ State(uint64_t queue_size, const DeviceMemoryMapper mapper)
+ : queue_size_{queue_size}, mapper_{std::move(mapper)} {}
+
+ const uint64_t queue_size_;
+ const DeviceMemoryMapper mapper_;
+ ArrowAsyncProducer* producer_;
+ DeviceAllocationType device_type_;
+
+ std::mutex mutex_;
+ std::shared_ptr<Schema> schema_;
+ std::condition_variable cv_;
+ std::queue<std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>>>
batches_;
+ bool end_of_stream_ = false;
+ Status error_{Status::OK()};
+ };
+
+ AsyncRecordBatchIterator(uint64_t queue_size, const DeviceMemoryMapper
mapper)
Review Comment:
```suggestion
AsyncRecordBatchIterator(uint64_t queue_size, DeviceMemoryMapper mapper)
```
##########
cpp/src/arrow/c/bridge.cc:
##########
@@ -2511,4 +2516,345 @@ Result<std::shared_ptr<ChunkedArray>>
ImportDeviceChunkedArray(
return ImportChunked</*IsDevice=*/true>(stream, mapper);
}
+namespace {
+
+class AsyncRecordBatchIterator {
+ public:
+ struct State {
+ State(uint64_t queue_size, const DeviceMemoryMapper mapper)
+ : queue_size_{queue_size}, mapper_{std::move(mapper)} {}
+
+ const uint64_t queue_size_;
+ const DeviceMemoryMapper mapper_;
+ ArrowAsyncProducer* producer_;
+ DeviceAllocationType device_type_;
+
+ std::mutex mutex_;
+ std::shared_ptr<Schema> schema_;
+ std::condition_variable cv_;
+ std::queue<std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>>>
batches_;
+ bool end_of_stream_ = false;
+ Status error_{Status::OK()};
+ };
+
+ AsyncRecordBatchIterator(uint64_t queue_size, const DeviceMemoryMapper
mapper)
+ : state_{std::make_shared<State>(queue_size, std::move(mapper))} {}
+
+ explicit AsyncRecordBatchIterator(std::shared_ptr<State> state)
+ : state_{std::move(state)} {}
+
+ const std::shared_ptr<Schema>& schema() const { return state_->schema_; }
+
+ DeviceAllocationType device_type() const { return state_->device_type_; }
+
+ Result<RecordBatchWithMetadata> Next() {
+ std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>> task;
+ {
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ state_->cv_.wait(lock, [&] {
+ return !state_->error_.ok() || !state_->batches_.empty() ||
+ state_->end_of_stream_;
+ });
+ if (!state_->error_.ok()) {
+ return state_->error_;
+ }
+
+ if (state_->batches_.empty() && state_->end_of_stream_) {
+ return RecordBatchWithMetadata{nullptr, nullptr};
Review Comment:
Nit
```suggestion
return IterationEnd<RecordBatchWithMetadata>();
```
##########
cpp/src/arrow/c/bridge.cc:
##########
@@ -2511,4 +2516,345 @@ Result<std::shared_ptr<ChunkedArray>>
ImportDeviceChunkedArray(
return ImportChunked</*IsDevice=*/true>(stream, mapper);
}
+namespace {
+
+class AsyncRecordBatchIterator {
+ public:
+ struct State {
+ State(uint64_t queue_size, const DeviceMemoryMapper mapper)
Review Comment:
```suggestion
State(uint64_t queue_size, DeviceMemoryMapper mapper)
```
##########
cpp/src/arrow/c/bridge.cc:
##########
@@ -2511,4 +2516,345 @@ Result<std::shared_ptr<ChunkedArray>>
ImportDeviceChunkedArray(
return ImportChunked</*IsDevice=*/true>(stream, mapper);
}
+namespace {
+
+class AsyncRecordBatchIterator {
+ public:
+ struct State {
+ State(uint64_t queue_size, const DeviceMemoryMapper mapper)
+ : queue_size_{queue_size}, mapper_{std::move(mapper)} {}
+
+ const uint64_t queue_size_;
+ const DeviceMemoryMapper mapper_;
+ ArrowAsyncProducer* producer_;
+ DeviceAllocationType device_type_;
+
+ std::mutex mutex_;
+ std::shared_ptr<Schema> schema_;
+ std::condition_variable cv_;
+ std::queue<std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>>>
batches_;
+ bool end_of_stream_ = false;
+ Status error_{Status::OK()};
+ };
+
+ AsyncRecordBatchIterator(uint64_t queue_size, const DeviceMemoryMapper
mapper)
+ : state_{std::make_shared<State>(queue_size, std::move(mapper))} {}
+
+ explicit AsyncRecordBatchIterator(std::shared_ptr<State> state)
+ : state_{std::move(state)} {}
+
+ const std::shared_ptr<Schema>& schema() const { return state_->schema_; }
+
+ DeviceAllocationType device_type() const { return state_->device_type_; }
+
+ Result<RecordBatchWithMetadata> Next() {
+ std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>> task;
+ {
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ state_->cv_.wait(lock, [&] {
+ return !state_->error_.ok() || !state_->batches_.empty() ||
+ state_->end_of_stream_;
+ });
+ if (!state_->error_.ok()) {
+ return state_->error_;
+ }
+
+ if (state_->batches_.empty() && state_->end_of_stream_) {
+ return RecordBatchWithMetadata{nullptr, nullptr};
+ }
+
+ task = state_->batches_.front();
+ state_->batches_.pop();
+ }
+
+ state_->producer_->request(state_->producer_, 1);
+ ArrowDeviceArray out;
+ if (task.first.extract_data(&task.first, &out) != 0) {
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ state_->cv_.wait(lock, [&] { return !state_->error_.ok(); });
+ return state_->error_;
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto batch, ImportDeviceRecordBatch(&out, state_->schema_,
state_->mapper_));
+ return RecordBatchWithMetadata{std::move(batch), std::move(task.second)};
+ }
+
+ static Future<std::shared_ptr<AsyncRecordBatchIterator::State>> Make(
+ AsyncRecordBatchIterator& iterator, struct
ArrowAsyncDeviceStreamHandler* handler) {
+ auto iterator_fut =
Future<std::shared_ptr<AsyncRecordBatchIterator::State>>::Make();
+
+ auto private_data = new PrivateData{iterator.state_};
+ private_data->fut_iterator_ = iterator_fut;
+
+ handler->private_data = private_data;
+ handler->on_schema = on_schema;
+ handler->on_next_task = on_next_task;
+ handler->on_error = on_error;
+ handler->release = release;
+ return iterator_fut;
+ }
+
+ private:
+ struct PrivateData {
+ explicit PrivateData(std::shared_ptr<State> state) :
state_(std::move(state)) {}
+
+ std::shared_ptr<State> state_;
+ Future<std::shared_ptr<AsyncRecordBatchIterator::State>> fut_iterator_;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateData);
+ };
+
+ static int on_schema(struct ArrowAsyncDeviceStreamHandler* self,
+ struct ArrowSchema* stream_schema) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+ if (self->producer != nullptr) {
+ private_data->state_->producer_ = self->producer;
+ private_data->state_->device_type_ =
+ static_cast<DeviceAllocationType>(self->producer->device_type);
+ }
+
+ auto maybe_schema = ImportSchema(stream_schema);
+ if (!maybe_schema.ok()) {
+ private_data->fut_iterator_.MarkFinished(maybe_schema.status());
+ return EINVAL;
+ }
+
+ auto schema = maybe_schema.MoveValueUnsafe();
+ private_data->state_->schema_ = schema;
+ private_data->fut_iterator_.MarkFinished(private_data->state_);
+ self->producer->request(self->producer,
+
static_cast<int64_t>(private_data->state_->queue_size_));
+ return 0;
+ }
+
+ static int on_next_task(ArrowAsyncDeviceStreamHandler* self, ArrowAsyncTask*
task,
+ const char* metadata) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+
+ if (task == nullptr) {
+ std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
+ private_data->state_->end_of_stream_ = true;
+ lock.unlock();
+ private_data->state_->cv_.notify_one();
+ return 0;
+ }
+
+ std::shared_ptr<KeyValueMetadata> kvmetadata;
+ if (metadata != nullptr) {
+ auto maybe_decoded = DecodeMetadata(metadata);
+ if (!maybe_decoded.ok()) {
+ private_data->state_->error_ = maybe_decoded.status();
+ private_data->state_->cv_.notify_one();
+ return EINVAL;
+ }
+
+ kvmetadata = maybe_decoded->metadata;
+ }
+
+ std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
+ private_data->state_->batches_.emplace(*task, std::move(kvmetadata));
+ lock.unlock();
+ private_data->state_->cv_.notify_one();
+ return 0;
+ }
+
+ static void on_error(ArrowAsyncDeviceStreamHandler* self, int code, const
char* message,
+ const char* metadata) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+ std::string message_str, metadata_str;
+ if (message != nullptr) {
+ message_str = message;
+ }
+ if (metadata != nullptr) {
+ metadata_str = metadata;
+ }
+
+ Status error = Status::FromDetailAndArgs(
+ StatusCode::UnknownError,
+ std::make_shared<AsyncErrorDetail>(code, message_str,
std::move(metadata_str)),
+ std::move(message_str));
+
+ if (!private_data->fut_iterator_.is_finished()) {
+ private_data->fut_iterator_.MarkFinished(error);
+ return;
+ }
+
+ std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
+ private_data->state_->error_ = error;
+ lock.unlock();
+ private_data->state_->cv_.notify_one();
+ }
+
+ static void release(ArrowAsyncDeviceStreamHandler* self) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+ delete private_data;
+ }
+
+ std::shared_ptr<State> state_;
+};
+
+struct AsyncProducer {
+ struct State {
+ struct ArrowAsyncProducer producer_;
+
+ std::mutex mutex_;
+ std::condition_variable cv_;
+ uint64_t pending_requests_{0};
+ Status error_{Status::OK()};
+ };
+
+ AsyncProducer(DeviceAllocationType device_type, struct ArrowSchema* schema,
+ struct ArrowAsyncDeviceStreamHandler* handler)
+ : handler_{handler}, state_{std::make_shared<State>()} {
+ state_->producer_.device_type = static_cast<ArrowDeviceType>(device_type);
+ state_->producer_.private_data = reinterpret_cast<void*>(state_.get());
+ state_->producer_.request = AsyncProducer::request;
+ state_->producer_.cancel = AsyncProducer::cancel;
+ handler_->producer = &state_->producer_;
+
+ if (int status = handler_->on_schema(handler_, schema) != 0) {
+ state_->error_ =
+ Status::UnknownError("Received error from handler::on_schema ",
status);
+ }
+ }
+
+ struct PrivateTaskData {
+ PrivateTaskData(std::shared_ptr<State> producer,
std::shared_ptr<RecordBatch> record)
+ : producer_{std::move(producer)}, record_(std::move(record)) {}
+
+ std::shared_ptr<State> producer_;
+ std::shared_ptr<RecordBatch> record_;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateTaskData);
+ };
+
+ Status operator()(const std::shared_ptr<RecordBatch>& record) {
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ if (state_->pending_requests_ == 0) {
+ state_->cv_.wait(lock, [this]() -> bool {
+ return !state_->error_.ok() || state_->pending_requests_ > 0;
+ });
+ }
+
+ if (!state_->error_.ok()) {
+ return state_->error_;
+ }
+
+ if (state_->pending_requests_ > 0) {
+ state_->pending_requests_--;
+ lock.unlock();
+
+ ArrowAsyncTask task;
+ task.private_data = new PrivateTaskData{state_, record};
+ task.extract_data = AsyncProducer::extract_data;
+
+ if (int status = handler_->on_next_task(handler_, &task, nullptr) != 0) {
+ delete reinterpret_cast<PrivateTaskData*>(task.private_data);
+ return Status::UnknownError("Received error from handler::on_next_task
", status);
+ }
+ }
+
+ return Status::OK();
+ }
+
+ static void request(struct ArrowAsyncProducer* producer, int64_t n) {
+ auto* self = reinterpret_cast<State*>(producer->private_data);
+ {
+ std::lock_guard<std::mutex> lock(self->mutex_);
+ if (!self->error_.ok()) {
+ return;
+ }
+ self->pending_requests_ += n;
+ }
+ self->cv_.notify_all();
+ }
+
+ static void cancel(struct ArrowAsyncProducer* producer) {
+ auto* self = reinterpret_cast<State*>(producer->private_data);
+ {
+ std::lock_guard<std::mutex> lock(self->mutex_);
+ if (!self->error_.ok()) {
+ return;
+ }
+ self->error_ = Status::Cancelled("Consumer requested cancellation");
+ }
+ self->cv_.notify_all();
+ }
+
+ static int extract_data(struct ArrowAsyncTask* task, struct
ArrowDeviceArray* out) {
+ auto private_data = reinterpret_cast<PrivateTaskData*>(task->private_data);
+ int ret = 0;
+ if (out != nullptr) {
+ auto status = ExportDeviceRecordBatch(*private_data->record_,
+
private_data->record_->GetSyncEvent(), out);
+ if (!status.ok()) {
+ std::lock_guard<std::mutex> lock(private_data->producer_->mutex_);
+ private_data->producer_->error_ = status;
+ }
+ }
+
+ delete private_data;
+ return ret;
+ }
+
+ struct ArrowAsyncDeviceStreamHandler* handler_;
+ std::shared_ptr<State> state_;
+};
+
+} // namespace
+
+Future<AsyncRecordBatchGenerator> CreateAsyncDeviceStreamHandler(
+ struct ArrowAsyncDeviceStreamHandler* handler, internal::Executor*
executor,
+ uint64_t queue_size, DeviceMemoryMapper mapper) {
+ auto iterator =
+ std::make_shared<AsyncRecordBatchIterator>(queue_size,
std::move(mapper));
+ return AsyncRecordBatchIterator::Make(*iterator, handler)
+ .Then([executor](std::shared_ptr<AsyncRecordBatchIterator::State> state)
+ -> Result<AsyncRecordBatchGenerator> {
+ AsyncRecordBatchGenerator gen{state->schema_, state->device_type_,
nullptr};
+ auto it =
+
Iterator<RecordBatchWithMetadata>(AsyncRecordBatchIterator{std::move(state)});
+ ARROW_ASSIGN_OR_RAISE(gen.generator,
+ MakeBackgroundGenerator(std::move(it),
executor));
+ return gen;
+ });
+}
+
+Future<> ExportAsyncRecordBatchReader(
+ std::shared_ptr<Schema> schema,
+ AsyncGenerator<std::shared_ptr<RecordBatch>> generator,
+ DeviceAllocationType device_type, struct ArrowAsyncDeviceStreamHandler*
handler) {
+ if (!schema) {
+ handler->on_error(handler, EINVAL, "Schema is null", nullptr);
+ handler->release(handler);
+ return Future<>::MakeFinished(Status::Invalid("Schema is null"));
+ }
+
+ struct ArrowSchema c_schema;
+ SchemaExportGuard guard(&c_schema);
+
+ auto status = ExportSchema(*schema, &c_schema);
+ if (!status.ok()) {
+ handler->on_error(handler, EINVAL, status.message().c_str(), nullptr);
+ handler->release(handler);
+ return Future<>::MakeFinished(status);
+ }
+
+ return VisitAsyncGenerator(generator, AsyncProducer{device_type, &c_schema,
handler})
+ .Then(
+ [handler]() -> Future<> {
Review Comment:
I think this should be able to just return Status since we don't have any
true continuations in here
##########
cpp/src/arrow/c/bridge.cc:
##########
@@ -2511,4 +2516,345 @@ Result<std::shared_ptr<ChunkedArray>>
ImportDeviceChunkedArray(
return ImportChunked</*IsDevice=*/true>(stream, mapper);
}
+namespace {
+
+class AsyncRecordBatchIterator {
+ public:
+ struct State {
+ State(uint64_t queue_size, const DeviceMemoryMapper mapper)
+ : queue_size_{queue_size}, mapper_{std::move(mapper)} {}
+
+ const uint64_t queue_size_;
+ const DeviceMemoryMapper mapper_;
+ ArrowAsyncProducer* producer_;
+ DeviceAllocationType device_type_;
+
+ std::mutex mutex_;
+ std::shared_ptr<Schema> schema_;
+ std::condition_variable cv_;
+ std::queue<std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>>>
batches_;
+ bool end_of_stream_ = false;
+ Status error_{Status::OK()};
+ };
+
+ AsyncRecordBatchIterator(uint64_t queue_size, const DeviceMemoryMapper
mapper)
+ : state_{std::make_shared<State>(queue_size, std::move(mapper))} {}
+
+ explicit AsyncRecordBatchIterator(std::shared_ptr<State> state)
+ : state_{std::move(state)} {}
+
+ const std::shared_ptr<Schema>& schema() const { return state_->schema_; }
+
+ DeviceAllocationType device_type() const { return state_->device_type_; }
+
+ Result<RecordBatchWithMetadata> Next() {
+ std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>> task;
+ {
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ state_->cv_.wait(lock, [&] {
+ return !state_->error_.ok() || !state_->batches_.empty() ||
+ state_->end_of_stream_;
+ });
+ if (!state_->error_.ok()) {
+ return state_->error_;
+ }
+
+ if (state_->batches_.empty() && state_->end_of_stream_) {
+ return RecordBatchWithMetadata{nullptr, nullptr};
+ }
+
+ task = state_->batches_.front();
+ state_->batches_.pop();
+ }
+
+ state_->producer_->request(state_->producer_, 1);
+ ArrowDeviceArray out;
+ if (task.first.extract_data(&task.first, &out) != 0) {
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ state_->cv_.wait(lock, [&] { return !state_->error_.ok(); });
+ return state_->error_;
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto batch, ImportDeviceRecordBatch(&out, state_->schema_,
state_->mapper_));
+ return RecordBatchWithMetadata{std::move(batch), std::move(task.second)};
+ }
+
+ static Future<std::shared_ptr<AsyncRecordBatchIterator::State>> Make(
+ AsyncRecordBatchIterator& iterator, struct
ArrowAsyncDeviceStreamHandler* handler) {
+ auto iterator_fut =
Future<std::shared_ptr<AsyncRecordBatchIterator::State>>::Make();
+
+ auto private_data = new PrivateData{iterator.state_};
+ private_data->fut_iterator_ = iterator_fut;
+
+ handler->private_data = private_data;
+ handler->on_schema = on_schema;
+ handler->on_next_task = on_next_task;
+ handler->on_error = on_error;
+ handler->release = release;
+ return iterator_fut;
+ }
+
+ private:
+ struct PrivateData {
+ explicit PrivateData(std::shared_ptr<State> state) :
state_(std::move(state)) {}
+
+ std::shared_ptr<State> state_;
+ Future<std::shared_ptr<AsyncRecordBatchIterator::State>> fut_iterator_;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateData);
+ };
+
+ static int on_schema(struct ArrowAsyncDeviceStreamHandler* self,
+ struct ArrowSchema* stream_schema) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+ if (self->producer != nullptr) {
+ private_data->state_->producer_ = self->producer;
+ private_data->state_->device_type_ =
+ static_cast<DeviceAllocationType>(self->producer->device_type);
+ }
+
+ auto maybe_schema = ImportSchema(stream_schema);
+ if (!maybe_schema.ok()) {
+ private_data->fut_iterator_.MarkFinished(maybe_schema.status());
+ return EINVAL;
+ }
+
+ auto schema = maybe_schema.MoveValueUnsafe();
+ private_data->state_->schema_ = schema;
+ private_data->fut_iterator_.MarkFinished(private_data->state_);
+ self->producer->request(self->producer,
+
static_cast<int64_t>(private_data->state_->queue_size_));
+ return 0;
+ }
+
+ static int on_next_task(ArrowAsyncDeviceStreamHandler* self, ArrowAsyncTask*
task,
+ const char* metadata) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+
+ if (task == nullptr) {
+ std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
+ private_data->state_->end_of_stream_ = true;
+ lock.unlock();
+ private_data->state_->cv_.notify_one();
+ return 0;
+ }
+
+ std::shared_ptr<KeyValueMetadata> kvmetadata;
+ if (metadata != nullptr) {
+ auto maybe_decoded = DecodeMetadata(metadata);
+ if (!maybe_decoded.ok()) {
+ private_data->state_->error_ = maybe_decoded.status();
Review Comment:
```suggestion
private_data->state_->error_ = std::move(maybe_decoded).status();
```
##########
cpp/src/arrow/c/bridge.cc:
##########
@@ -2511,4 +2516,345 @@ Result<std::shared_ptr<ChunkedArray>>
ImportDeviceChunkedArray(
return ImportChunked</*IsDevice=*/true>(stream, mapper);
}
+namespace {
+
+class AsyncRecordBatchIterator {
+ public:
+ struct State {
+ State(uint64_t queue_size, const DeviceMemoryMapper mapper)
+ : queue_size_{queue_size}, mapper_{std::move(mapper)} {}
+
+ const uint64_t queue_size_;
+ const DeviceMemoryMapper mapper_;
+ ArrowAsyncProducer* producer_;
+ DeviceAllocationType device_type_;
+
+ std::mutex mutex_;
+ std::shared_ptr<Schema> schema_;
+ std::condition_variable cv_;
+ std::queue<std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>>>
batches_;
+ bool end_of_stream_ = false;
+ Status error_{Status::OK()};
+ };
+
+ AsyncRecordBatchIterator(uint64_t queue_size, const DeviceMemoryMapper
mapper)
+ : state_{std::make_shared<State>(queue_size, std::move(mapper))} {}
+
+ explicit AsyncRecordBatchIterator(std::shared_ptr<State> state)
+ : state_{std::move(state)} {}
+
+ const std::shared_ptr<Schema>& schema() const { return state_->schema_; }
+
+ DeviceAllocationType device_type() const { return state_->device_type_; }
+
+ Result<RecordBatchWithMetadata> Next() {
+ std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>> task;
+ {
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ state_->cv_.wait(lock, [&] {
+ return !state_->error_.ok() || !state_->batches_.empty() ||
+ state_->end_of_stream_;
+ });
+ if (!state_->error_.ok()) {
+ return state_->error_;
+ }
+
+ if (state_->batches_.empty() && state_->end_of_stream_) {
+ return RecordBatchWithMetadata{nullptr, nullptr};
+ }
+
+ task = state_->batches_.front();
+ state_->batches_.pop();
+ }
+
+ state_->producer_->request(state_->producer_, 1);
+ ArrowDeviceArray out;
+ if (task.first.extract_data(&task.first, &out) != 0) {
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ state_->cv_.wait(lock, [&] { return !state_->error_.ok(); });
+ return state_->error_;
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto batch, ImportDeviceRecordBatch(&out, state_->schema_,
state_->mapper_));
+ return RecordBatchWithMetadata{std::move(batch), std::move(task.second)};
+ }
+
+ static Future<std::shared_ptr<AsyncRecordBatchIterator::State>> Make(
+ AsyncRecordBatchIterator& iterator, struct
ArrowAsyncDeviceStreamHandler* handler) {
+ auto iterator_fut =
Future<std::shared_ptr<AsyncRecordBatchIterator::State>>::Make();
+
+ auto private_data = new PrivateData{iterator.state_};
+ private_data->fut_iterator_ = iterator_fut;
+
+ handler->private_data = private_data;
+ handler->on_schema = on_schema;
+ handler->on_next_task = on_next_task;
+ handler->on_error = on_error;
+ handler->release = release;
+ return iterator_fut;
+ }
+
+ private:
+ struct PrivateData {
+ explicit PrivateData(std::shared_ptr<State> state) :
state_(std::move(state)) {}
+
+ std::shared_ptr<State> state_;
+ Future<std::shared_ptr<AsyncRecordBatchIterator::State>> fut_iterator_;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateData);
+ };
+
+ static int on_schema(struct ArrowAsyncDeviceStreamHandler* self,
+ struct ArrowSchema* stream_schema) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+ if (self->producer != nullptr) {
+ private_data->state_->producer_ = self->producer;
+ private_data->state_->device_type_ =
+ static_cast<DeviceAllocationType>(self->producer->device_type);
+ }
+
+ auto maybe_schema = ImportSchema(stream_schema);
+ if (!maybe_schema.ok()) {
+ private_data->fut_iterator_.MarkFinished(maybe_schema.status());
+ return EINVAL;
+ }
+
+ auto schema = maybe_schema.MoveValueUnsafe();
+ private_data->state_->schema_ = schema;
+ private_data->fut_iterator_.MarkFinished(private_data->state_);
+ self->producer->request(self->producer,
+
static_cast<int64_t>(private_data->state_->queue_size_));
+ return 0;
+ }
+
+ static int on_next_task(ArrowAsyncDeviceStreamHandler* self, ArrowAsyncTask*
task,
+ const char* metadata) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+
+ if (task == nullptr) {
+ std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
+ private_data->state_->end_of_stream_ = true;
+ lock.unlock();
+ private_data->state_->cv_.notify_one();
+ return 0;
+ }
+
+ std::shared_ptr<KeyValueMetadata> kvmetadata;
+ if (metadata != nullptr) {
+ auto maybe_decoded = DecodeMetadata(metadata);
+ if (!maybe_decoded.ok()) {
+ private_data->state_->error_ = maybe_decoded.status();
+ private_data->state_->cv_.notify_one();
+ return EINVAL;
+ }
+
+ kvmetadata = maybe_decoded->metadata;
+ }
+
+ std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
+ private_data->state_->batches_.emplace(*task, std::move(kvmetadata));
+ lock.unlock();
+ private_data->state_->cv_.notify_one();
+ return 0;
+ }
+
+ static void on_error(ArrowAsyncDeviceStreamHandler* self, int code, const
char* message,
+ const char* metadata) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+ std::string message_str, metadata_str;
+ if (message != nullptr) {
+ message_str = message;
+ }
+ if (metadata != nullptr) {
+ metadata_str = metadata;
+ }
+
+ Status error = Status::FromDetailAndArgs(
+ StatusCode::UnknownError,
+ std::make_shared<AsyncErrorDetail>(code, message_str,
std::move(metadata_str)),
+ std::move(message_str));
+
+ if (!private_data->fut_iterator_.is_finished()) {
+ private_data->fut_iterator_.MarkFinished(error);
+ return;
+ }
+
+ std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
+ private_data->state_->error_ = error;
+ lock.unlock();
+ private_data->state_->cv_.notify_one();
+ }
+
+ static void release(ArrowAsyncDeviceStreamHandler* self) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+ delete private_data;
+ }
+
+ std::shared_ptr<State> state_;
+};
+
+struct AsyncProducer {
+ struct State {
+ struct ArrowAsyncProducer producer_;
+
+ std::mutex mutex_;
+ std::condition_variable cv_;
+ uint64_t pending_requests_{0};
+ Status error_{Status::OK()};
+ };
+
+ AsyncProducer(DeviceAllocationType device_type, struct ArrowSchema* schema,
+ struct ArrowAsyncDeviceStreamHandler* handler)
+ : handler_{handler}, state_{std::make_shared<State>()} {
+ state_->producer_.device_type = static_cast<ArrowDeviceType>(device_type);
+ state_->producer_.private_data = reinterpret_cast<void*>(state_.get());
+ state_->producer_.request = AsyncProducer::request;
+ state_->producer_.cancel = AsyncProducer::cancel;
+ handler_->producer = &state_->producer_;
+
+ if (int status = handler_->on_schema(handler_, schema) != 0) {
+ state_->error_ =
+ Status::UnknownError("Received error from handler::on_schema ",
status);
+ }
+ }
+
+ struct PrivateTaskData {
+ PrivateTaskData(std::shared_ptr<State> producer,
std::shared_ptr<RecordBatch> record)
+ : producer_{std::move(producer)}, record_(std::move(record)) {}
+
+ std::shared_ptr<State> producer_;
+ std::shared_ptr<RecordBatch> record_;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateTaskData);
+ };
+
+ Status operator()(const std::shared_ptr<RecordBatch>& record) {
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ if (state_->pending_requests_ == 0) {
+ state_->cv_.wait(lock, [this]() -> bool {
+ return !state_->error_.ok() || state_->pending_requests_ > 0;
+ });
+ }
+
+ if (!state_->error_.ok()) {
+ return state_->error_;
+ }
+
+ if (state_->pending_requests_ > 0) {
+ state_->pending_requests_--;
+ lock.unlock();
+
+ ArrowAsyncTask task;
+ task.private_data = new PrivateTaskData{state_, record};
+ task.extract_data = AsyncProducer::extract_data;
+
+ if (int status = handler_->on_next_task(handler_, &task, nullptr) != 0) {
+ delete reinterpret_cast<PrivateTaskData*>(task.private_data);
+ return Status::UnknownError("Received error from handler::on_next_task
", status);
+ }
+ }
+
+ return Status::OK();
+ }
+
+ static void request(struct ArrowAsyncProducer* producer, int64_t n) {
+ auto* self = reinterpret_cast<State*>(producer->private_data);
+ {
+ std::lock_guard<std::mutex> lock(self->mutex_);
+ if (!self->error_.ok()) {
+ return;
+ }
+ self->pending_requests_ += n;
+ }
+ self->cv_.notify_all();
+ }
+
+ static void cancel(struct ArrowAsyncProducer* producer) {
+ auto* self = reinterpret_cast<State*>(producer->private_data);
+ {
+ std::lock_guard<std::mutex> lock(self->mutex_);
+ if (!self->error_.ok()) {
+ return;
+ }
+ self->error_ = Status::Cancelled("Consumer requested cancellation");
+ }
+ self->cv_.notify_all();
+ }
+
+ static int extract_data(struct ArrowAsyncTask* task, struct
ArrowDeviceArray* out) {
+ auto private_data = reinterpret_cast<PrivateTaskData*>(task->private_data);
Review Comment:
In lieu of `defer` :smile:
```suggestion
std::unique_ptr
private_data{reinterpret_cast<PrivateTaskData*>(task->private_data)};
```
##########
cpp/src/arrow/c/bridge.cc:
##########
@@ -2511,4 +2516,345 @@ Result<std::shared_ptr<ChunkedArray>>
ImportDeviceChunkedArray(
return ImportChunked</*IsDevice=*/true>(stream, mapper);
}
+namespace {
+
+class AsyncRecordBatchIterator {
+ public:
+ struct State {
+ State(uint64_t queue_size, const DeviceMemoryMapper mapper)
+ : queue_size_{queue_size}, mapper_{std::move(mapper)} {}
+
+ const uint64_t queue_size_;
+ const DeviceMemoryMapper mapper_;
+ ArrowAsyncProducer* producer_;
+ DeviceAllocationType device_type_;
+
+ std::mutex mutex_;
+ std::shared_ptr<Schema> schema_;
+ std::condition_variable cv_;
+ std::queue<std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>>>
batches_;
+ bool end_of_stream_ = false;
+ Status error_{Status::OK()};
+ };
+
+ AsyncRecordBatchIterator(uint64_t queue_size, const DeviceMemoryMapper
mapper)
+ : state_{std::make_shared<State>(queue_size, std::move(mapper))} {}
+
+ explicit AsyncRecordBatchIterator(std::shared_ptr<State> state)
+ : state_{std::move(state)} {}
+
+ const std::shared_ptr<Schema>& schema() const { return state_->schema_; }
+
+ DeviceAllocationType device_type() const { return state_->device_type_; }
+
+ Result<RecordBatchWithMetadata> Next() {
+ std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>> task;
+ {
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ state_->cv_.wait(lock, [&] {
+ return !state_->error_.ok() || !state_->batches_.empty() ||
+ state_->end_of_stream_;
+ });
+ if (!state_->error_.ok()) {
+ return state_->error_;
+ }
+
+ if (state_->batches_.empty() && state_->end_of_stream_) {
+ return RecordBatchWithMetadata{nullptr, nullptr};
+ }
+
+ task = state_->batches_.front();
+ state_->batches_.pop();
+ }
+
+ state_->producer_->request(state_->producer_, 1);
+ ArrowDeviceArray out;
+ if (task.first.extract_data(&task.first, &out) != 0) {
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ state_->cv_.wait(lock, [&] { return !state_->error_.ok(); });
+ return state_->error_;
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto batch, ImportDeviceRecordBatch(&out, state_->schema_,
state_->mapper_));
+ return RecordBatchWithMetadata{std::move(batch), std::move(task.second)};
+ }
+
+ static Future<std::shared_ptr<AsyncRecordBatchIterator::State>> Make(
+ AsyncRecordBatchIterator& iterator, struct
ArrowAsyncDeviceStreamHandler* handler) {
+ auto iterator_fut =
Future<std::shared_ptr<AsyncRecordBatchIterator::State>>::Make();
+
+ auto private_data = new PrivateData{iterator.state_};
+ private_data->fut_iterator_ = iterator_fut;
+
+ handler->private_data = private_data;
+ handler->on_schema = on_schema;
+ handler->on_next_task = on_next_task;
+ handler->on_error = on_error;
+ handler->release = release;
+ return iterator_fut;
+ }
+
+ private:
+ struct PrivateData {
+ explicit PrivateData(std::shared_ptr<State> state) :
state_(std::move(state)) {}
+
+ std::shared_ptr<State> state_;
+ Future<std::shared_ptr<AsyncRecordBatchIterator::State>> fut_iterator_;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateData);
+ };
+
+ static int on_schema(struct ArrowAsyncDeviceStreamHandler* self,
+ struct ArrowSchema* stream_schema) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+ if (self->producer != nullptr) {
+ private_data->state_->producer_ = self->producer;
+ private_data->state_->device_type_ =
+ static_cast<DeviceAllocationType>(self->producer->device_type);
+ }
+
+ auto maybe_schema = ImportSchema(stream_schema);
+ if (!maybe_schema.ok()) {
+ private_data->fut_iterator_.MarkFinished(maybe_schema.status());
+ return EINVAL;
+ }
+
+ auto schema = maybe_schema.MoveValueUnsafe();
+ private_data->state_->schema_ = schema;
+ private_data->fut_iterator_.MarkFinished(private_data->state_);
+ self->producer->request(self->producer,
+
static_cast<int64_t>(private_data->state_->queue_size_));
+ return 0;
+ }
+
+ static int on_next_task(ArrowAsyncDeviceStreamHandler* self, ArrowAsyncTask*
task,
+ const char* metadata) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+
+ if (task == nullptr) {
+ std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
+ private_data->state_->end_of_stream_ = true;
+ lock.unlock();
+ private_data->state_->cv_.notify_one();
+ return 0;
+ }
+
+ std::shared_ptr<KeyValueMetadata> kvmetadata;
+ if (metadata != nullptr) {
+ auto maybe_decoded = DecodeMetadata(metadata);
+ if (!maybe_decoded.ok()) {
+ private_data->state_->error_ = maybe_decoded.status();
+ private_data->state_->cv_.notify_one();
+ return EINVAL;
+ }
+
+ kvmetadata = maybe_decoded->metadata;
+ }
+
+ std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
+ private_data->state_->batches_.emplace(*task, std::move(kvmetadata));
+ lock.unlock();
+ private_data->state_->cv_.notify_one();
+ return 0;
+ }
+
+ static void on_error(ArrowAsyncDeviceStreamHandler* self, int code, const
char* message,
+ const char* metadata) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+ std::string message_str, metadata_str;
+ if (message != nullptr) {
+ message_str = message;
+ }
+ if (metadata != nullptr) {
+ metadata_str = metadata;
+ }
+
+ Status error = Status::FromDetailAndArgs(
+ StatusCode::UnknownError,
+ std::make_shared<AsyncErrorDetail>(code, message_str,
std::move(metadata_str)),
+ std::move(message_str));
+
+ if (!private_data->fut_iterator_.is_finished()) {
+ private_data->fut_iterator_.MarkFinished(error);
+ return;
+ }
+
+ std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
+ private_data->state_->error_ = error;
Review Comment:
```suggestion
private_data->state_->error_ = std::move(error);
```
##########
cpp/src/arrow/c/bridge.cc:
##########
@@ -2511,4 +2516,345 @@ Result<std::shared_ptr<ChunkedArray>>
ImportDeviceChunkedArray(
return ImportChunked</*IsDevice=*/true>(stream, mapper);
}
+namespace {
+
+class AsyncRecordBatchIterator {
+ public:
+ struct State {
+ State(uint64_t queue_size, const DeviceMemoryMapper mapper)
+ : queue_size_{queue_size}, mapper_{std::move(mapper)} {}
+
+ const uint64_t queue_size_;
+ const DeviceMemoryMapper mapper_;
+ ArrowAsyncProducer* producer_;
+ DeviceAllocationType device_type_;
+
+ std::mutex mutex_;
+ std::shared_ptr<Schema> schema_;
+ std::condition_variable cv_;
+ std::queue<std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>>>
batches_;
+ bool end_of_stream_ = false;
+ Status error_{Status::OK()};
+ };
+
+ AsyncRecordBatchIterator(uint64_t queue_size, const DeviceMemoryMapper
mapper)
+ : state_{std::make_shared<State>(queue_size, std::move(mapper))} {}
+
+ explicit AsyncRecordBatchIterator(std::shared_ptr<State> state)
+ : state_{std::move(state)} {}
+
+ const std::shared_ptr<Schema>& schema() const { return state_->schema_; }
+
+ DeviceAllocationType device_type() const { return state_->device_type_; }
+
+ Result<RecordBatchWithMetadata> Next() {
+ std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>> task;
+ {
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ state_->cv_.wait(lock, [&] {
+ return !state_->error_.ok() || !state_->batches_.empty() ||
+ state_->end_of_stream_;
+ });
+ if (!state_->error_.ok()) {
+ return state_->error_;
+ }
+
+ if (state_->batches_.empty() && state_->end_of_stream_) {
+ return RecordBatchWithMetadata{nullptr, nullptr};
+ }
+
+ task = state_->batches_.front();
+ state_->batches_.pop();
+ }
+
+ state_->producer_->request(state_->producer_, 1);
+ ArrowDeviceArray out;
+ if (task.first.extract_data(&task.first, &out) != 0) {
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ state_->cv_.wait(lock, [&] { return !state_->error_.ok(); });
+ return state_->error_;
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto batch, ImportDeviceRecordBatch(&out, state_->schema_,
state_->mapper_));
+ return RecordBatchWithMetadata{std::move(batch), std::move(task.second)};
+ }
+
+ static Future<std::shared_ptr<AsyncRecordBatchIterator::State>> Make(
+ AsyncRecordBatchIterator& iterator, struct
ArrowAsyncDeviceStreamHandler* handler) {
+ auto iterator_fut =
Future<std::shared_ptr<AsyncRecordBatchIterator::State>>::Make();
+
+ auto private_data = new PrivateData{iterator.state_};
+ private_data->fut_iterator_ = iterator_fut;
+
+ handler->private_data = private_data;
+ handler->on_schema = on_schema;
+ handler->on_next_task = on_next_task;
+ handler->on_error = on_error;
+ handler->release = release;
+ return iterator_fut;
+ }
+
+ private:
+ struct PrivateData {
+ explicit PrivateData(std::shared_ptr<State> state) :
state_(std::move(state)) {}
+
+ std::shared_ptr<State> state_;
+ Future<std::shared_ptr<AsyncRecordBatchIterator::State>> fut_iterator_;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateData);
+ };
+
+ static int on_schema(struct ArrowAsyncDeviceStreamHandler* self,
+ struct ArrowSchema* stream_schema) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+ if (self->producer != nullptr) {
+ private_data->state_->producer_ = self->producer;
+ private_data->state_->device_type_ =
+ static_cast<DeviceAllocationType>(self->producer->device_type);
+ }
+
+ auto maybe_schema = ImportSchema(stream_schema);
+ if (!maybe_schema.ok()) {
+ private_data->fut_iterator_.MarkFinished(maybe_schema.status());
+ return EINVAL;
+ }
+
+ auto schema = maybe_schema.MoveValueUnsafe();
+ private_data->state_->schema_ = schema;
Review Comment:
```suggestion
private_data->state_->schema_ = maybe_schema.MoveValueUnsafe();
```
##########
cpp/src/arrow/c/bridge.cc:
##########
@@ -2511,4 +2516,345 @@ Result<std::shared_ptr<ChunkedArray>>
ImportDeviceChunkedArray(
return ImportChunked</*IsDevice=*/true>(stream, mapper);
}
+namespace {
+
+class AsyncRecordBatchIterator {
+ public:
+ struct State {
+ State(uint64_t queue_size, const DeviceMemoryMapper mapper)
+ : queue_size_{queue_size}, mapper_{std::move(mapper)} {}
+
+ const uint64_t queue_size_;
+ const DeviceMemoryMapper mapper_;
+ ArrowAsyncProducer* producer_;
+ DeviceAllocationType device_type_;
+
+ std::mutex mutex_;
+ std::shared_ptr<Schema> schema_;
+ std::condition_variable cv_;
+ std::queue<std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>>>
batches_;
+ bool end_of_stream_ = false;
+ Status error_{Status::OK()};
+ };
+
+ AsyncRecordBatchIterator(uint64_t queue_size, const DeviceMemoryMapper
mapper)
+ : state_{std::make_shared<State>(queue_size, std::move(mapper))} {}
+
+ explicit AsyncRecordBatchIterator(std::shared_ptr<State> state)
+ : state_{std::move(state)} {}
+
+ const std::shared_ptr<Schema>& schema() const { return state_->schema_; }
+
+ DeviceAllocationType device_type() const { return state_->device_type_; }
+
+ Result<RecordBatchWithMetadata> Next() {
+ std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>> task;
+ {
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ state_->cv_.wait(lock, [&] {
+ return !state_->error_.ok() || !state_->batches_.empty() ||
+ state_->end_of_stream_;
+ });
+ if (!state_->error_.ok()) {
+ return state_->error_;
+ }
+
+ if (state_->batches_.empty() && state_->end_of_stream_) {
+ return RecordBatchWithMetadata{nullptr, nullptr};
+ }
+
+ task = state_->batches_.front();
+ state_->batches_.pop();
+ }
+
+ state_->producer_->request(state_->producer_, 1);
+ ArrowDeviceArray out;
+ if (task.first.extract_data(&task.first, &out) != 0) {
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ state_->cv_.wait(lock, [&] { return !state_->error_.ok(); });
+ return state_->error_;
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto batch, ImportDeviceRecordBatch(&out, state_->schema_,
state_->mapper_));
+ return RecordBatchWithMetadata{std::move(batch), std::move(task.second)};
+ }
+
+ static Future<std::shared_ptr<AsyncRecordBatchIterator::State>> Make(
+ AsyncRecordBatchIterator& iterator, struct
ArrowAsyncDeviceStreamHandler* handler) {
+ auto iterator_fut =
Future<std::shared_ptr<AsyncRecordBatchIterator::State>>::Make();
+
+ auto private_data = new PrivateData{iterator.state_};
+ private_data->fut_iterator_ = iterator_fut;
+
+ handler->private_data = private_data;
+ handler->on_schema = on_schema;
+ handler->on_next_task = on_next_task;
+ handler->on_error = on_error;
+ handler->release = release;
+ return iterator_fut;
+ }
+
+ private:
+ struct PrivateData {
+ explicit PrivateData(std::shared_ptr<State> state) :
state_(std::move(state)) {}
+
+ std::shared_ptr<State> state_;
+ Future<std::shared_ptr<AsyncRecordBatchIterator::State>> fut_iterator_;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateData);
+ };
+
+ static int on_schema(struct ArrowAsyncDeviceStreamHandler* self,
+ struct ArrowSchema* stream_schema) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+ if (self->producer != nullptr) {
+ private_data->state_->producer_ = self->producer;
+ private_data->state_->device_type_ =
+ static_cast<DeviceAllocationType>(self->producer->device_type);
+ }
+
+ auto maybe_schema = ImportSchema(stream_schema);
+ if (!maybe_schema.ok()) {
+ private_data->fut_iterator_.MarkFinished(maybe_schema.status());
+ return EINVAL;
+ }
+
+ auto schema = maybe_schema.MoveValueUnsafe();
+ private_data->state_->schema_ = schema;
+ private_data->fut_iterator_.MarkFinished(private_data->state_);
+ self->producer->request(self->producer,
+
static_cast<int64_t>(private_data->state_->queue_size_));
+ return 0;
+ }
+
+ static int on_next_task(ArrowAsyncDeviceStreamHandler* self, ArrowAsyncTask*
task,
+ const char* metadata) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+
+ if (task == nullptr) {
+ std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
+ private_data->state_->end_of_stream_ = true;
+ lock.unlock();
+ private_data->state_->cv_.notify_one();
+ return 0;
+ }
+
+ std::shared_ptr<KeyValueMetadata> kvmetadata;
+ if (metadata != nullptr) {
+ auto maybe_decoded = DecodeMetadata(metadata);
+ if (!maybe_decoded.ok()) {
+ private_data->state_->error_ = maybe_decoded.status();
+ private_data->state_->cv_.notify_one();
+ return EINVAL;
+ }
+
+ kvmetadata = maybe_decoded->metadata;
Review Comment:
```suggestion
kvmetadata = std::move(maybe_decoded->metadata);
```
##########
cpp/src/arrow/c/bridge.cc:
##########
@@ -2511,4 +2516,345 @@ Result<std::shared_ptr<ChunkedArray>>
ImportDeviceChunkedArray(
return ImportChunked</*IsDevice=*/true>(stream, mapper);
}
+namespace {
+
+class AsyncRecordBatchIterator {
+ public:
+ struct State {
+ State(uint64_t queue_size, const DeviceMemoryMapper mapper)
+ : queue_size_{queue_size}, mapper_{std::move(mapper)} {}
+
+ const uint64_t queue_size_;
+ const DeviceMemoryMapper mapper_;
+ ArrowAsyncProducer* producer_;
+ DeviceAllocationType device_type_;
+
+ std::mutex mutex_;
+ std::shared_ptr<Schema> schema_;
+ std::condition_variable cv_;
+ std::queue<std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>>>
batches_;
+ bool end_of_stream_ = false;
+ Status error_{Status::OK()};
+ };
+
+ AsyncRecordBatchIterator(uint64_t queue_size, const DeviceMemoryMapper
mapper)
+ : state_{std::make_shared<State>(queue_size, std::move(mapper))} {}
+
+ explicit AsyncRecordBatchIterator(std::shared_ptr<State> state)
+ : state_{std::move(state)} {}
+
+ const std::shared_ptr<Schema>& schema() const { return state_->schema_; }
+
+ DeviceAllocationType device_type() const { return state_->device_type_; }
+
+ Result<RecordBatchWithMetadata> Next() {
+ std::pair<ArrowAsyncTask, std::shared_ptr<KeyValueMetadata>> task;
+ {
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ state_->cv_.wait(lock, [&] {
+ return !state_->error_.ok() || !state_->batches_.empty() ||
+ state_->end_of_stream_;
+ });
+ if (!state_->error_.ok()) {
+ return state_->error_;
+ }
+
+ if (state_->batches_.empty() && state_->end_of_stream_) {
+ return RecordBatchWithMetadata{nullptr, nullptr};
+ }
+
+ task = state_->batches_.front();
+ state_->batches_.pop();
+ }
+
+ state_->producer_->request(state_->producer_, 1);
+ ArrowDeviceArray out;
+ if (task.first.extract_data(&task.first, &out) != 0) {
+ std::unique_lock<std::mutex> lock(state_->mutex_);
+ state_->cv_.wait(lock, [&] { return !state_->error_.ok(); });
+ return state_->error_;
+ }
+
+ ARROW_ASSIGN_OR_RAISE(
+ auto batch, ImportDeviceRecordBatch(&out, state_->schema_,
state_->mapper_));
+ return RecordBatchWithMetadata{std::move(batch), std::move(task.second)};
+ }
+
+ static Future<std::shared_ptr<AsyncRecordBatchIterator::State>> Make(
+ AsyncRecordBatchIterator& iterator, struct
ArrowAsyncDeviceStreamHandler* handler) {
+ auto iterator_fut =
Future<std::shared_ptr<AsyncRecordBatchIterator::State>>::Make();
+
+ auto private_data = new PrivateData{iterator.state_};
+ private_data->fut_iterator_ = iterator_fut;
+
+ handler->private_data = private_data;
+ handler->on_schema = on_schema;
+ handler->on_next_task = on_next_task;
+ handler->on_error = on_error;
+ handler->release = release;
+ return iterator_fut;
+ }
+
+ private:
+ struct PrivateData {
+ explicit PrivateData(std::shared_ptr<State> state) :
state_(std::move(state)) {}
+
+ std::shared_ptr<State> state_;
+ Future<std::shared_ptr<AsyncRecordBatchIterator::State>> fut_iterator_;
+ ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateData);
+ };
+
+ static int on_schema(struct ArrowAsyncDeviceStreamHandler* self,
+ struct ArrowSchema* stream_schema) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+ if (self->producer != nullptr) {
+ private_data->state_->producer_ = self->producer;
+ private_data->state_->device_type_ =
+ static_cast<DeviceAllocationType>(self->producer->device_type);
+ }
+
+ auto maybe_schema = ImportSchema(stream_schema);
+ if (!maybe_schema.ok()) {
+ private_data->fut_iterator_.MarkFinished(maybe_schema.status());
+ return EINVAL;
+ }
+
+ auto schema = maybe_schema.MoveValueUnsafe();
+ private_data->state_->schema_ = schema;
+ private_data->fut_iterator_.MarkFinished(private_data->state_);
+ self->producer->request(self->producer,
+
static_cast<int64_t>(private_data->state_->queue_size_));
+ return 0;
+ }
+
+ static int on_next_task(ArrowAsyncDeviceStreamHandler* self, ArrowAsyncTask*
task,
+ const char* metadata) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+
+ if (task == nullptr) {
+ std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
+ private_data->state_->end_of_stream_ = true;
+ lock.unlock();
+ private_data->state_->cv_.notify_one();
+ return 0;
+ }
+
+ std::shared_ptr<KeyValueMetadata> kvmetadata;
+ if (metadata != nullptr) {
+ auto maybe_decoded = DecodeMetadata(metadata);
+ if (!maybe_decoded.ok()) {
+ private_data->state_->error_ = maybe_decoded.status();
+ private_data->state_->cv_.notify_one();
+ return EINVAL;
+ }
+
+ kvmetadata = maybe_decoded->metadata;
+ }
+
+ std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
+ private_data->state_->batches_.emplace(*task, std::move(kvmetadata));
+ lock.unlock();
+ private_data->state_->cv_.notify_one();
+ return 0;
+ }
+
+ static void on_error(ArrowAsyncDeviceStreamHandler* self, int code, const
char* message,
+ const char* metadata) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+ std::string message_str, metadata_str;
+ if (message != nullptr) {
+ message_str = message;
+ }
+ if (metadata != nullptr) {
+ metadata_str = metadata;
+ }
+
+ Status error = Status::FromDetailAndArgs(
+ StatusCode::UnknownError,
+ std::make_shared<AsyncErrorDetail>(code, message_str,
std::move(metadata_str)),
+ std::move(message_str));
+
+ if (!private_data->fut_iterator_.is_finished()) {
+ private_data->fut_iterator_.MarkFinished(error);
+ return;
+ }
+
+ std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
+ private_data->state_->error_ = error;
+ lock.unlock();
+ private_data->state_->cv_.notify_one();
+ }
+
+ static void release(ArrowAsyncDeviceStreamHandler* self) {
+ auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
+ delete private_data;
Review Comment:
nit
```suggestion
delete reinterpret_cast<PrivateData*>(self->private_data);
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]