felipecrv commented on code in PR #40807:
URL: https://github.com/apache/arrow/pull/40807#discussion_r1583750847


##########
cpp/src/arrow/array/data.cc:
##########
@@ -224,6 +224,41 @@ int64_t ArrayData::ComputeLogicalNullCount() const {
   return ArraySpan(*this).ComputeLogicalNullCount();
 }
 
+DeviceAllocationType ArrayData::device_type() const {
+  // we're using 0 as a sentinel value for NOT YET ASSIGNED
+  // there is explicitly no constant DeviceAllocationType to represent
+  // the "UNASSIGNED" case as it is invalid for data to not have an
+  // assigned device type. If it's still 0 at the end, then we return
+  // CPU as the allocation device type
+  int type = 0;
+  for (const auto& buf : buffers) {
+    if (!buf) continue;
+    if (type == 0) {
+      type = static_cast<int>(buf->device_type());

Review Comment:
   Was there a backlash about adding `kUNDEFINED` to `DeviceAllocationType`?



##########
cpp/src/arrow/c/bridge.cc:
##########
@@ -2250,101 +2340,126 @@ class ArrayStreamReader {
     return {code, last_error ? std::string(last_error) : ""};
   }
 
+  DeviceAllocationType get_device_type() const {
+    if constexpr (std::is_same_v<ArrayType, struct ArrowDeviceArray>) {
+      return static_cast<DeviceAllocationType>(stream_.device_type);
+    } else {
+      return DeviceAllocationType::kCPU;
+    }
+  }
+
  private:
-  mutable struct ArrowArrayStream stream_;
+  mutable StreamType stream_;
+  const DeviceMemoryMapper mapper_;
 };
 
-class ArrayStreamBatchReader : public RecordBatchReader, public 
ArrayStreamReader {
+template <typename StreamTraits, typename ArrayTraits>
+class ArrayStreamBatchReader : public RecordBatchReader,
+                               public ArrayStreamReader<StreamTraits, 
ArrayTraits> {
+  using StreamType = typename StreamTraits::CType;
+  using ArrayType = typename ArrayTraits::CType;
+
  public:
-  explicit ArrayStreamBatchReader(struct ArrowArrayStream* stream)
-      : ArrayStreamReader(stream) {}
+  explicit ArrayStreamBatchReader(
+      StreamType* stream, const DeviceMemoryMapper& mapper = 
DefaultDeviceMemoryMapper)
+      : ArrayStreamReader<StreamTraits, ArrayTraits>(stream, mapper) {}
 
   Status Init() {
-    ARROW_ASSIGN_OR_RAISE(schema_, ReadSchema());
+    ARROW_ASSIGN_OR_RAISE(schema_, this->ReadSchema());
     return Status::OK();
   }
 
   std::shared_ptr<Schema> schema() const override { return schema_; }
 
   Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
-    ARROW_RETURN_NOT_OK(CheckNotReleased());
+    ARROW_RETURN_NOT_OK(this->CheckNotReleased());
 
-    struct ArrowArray c_array;
-    ARROW_RETURN_NOT_OK(ReadNextArrayInternal(&c_array));
+    ArrayType c_array;
+    ARROW_RETURN_NOT_OK(this->ReadNextArrayInternal(&c_array));
 
-    if (ArrowArrayIsReleased(&c_array)) {
+    if (ArrayTraits::IsReleasedFunc(&c_array)) {
       // End of stream
       batch->reset();
       return Status::OK();
     } else {
-      return ImportRecordBatch(&c_array, schema_).Value(batch);
+      return this->ImportRecordBatchInternal(&c_array, schema_).Value(batch);
     }
   }
 
   Status Close() override {
-    ReleaseStream();
+    this->ReleaseStream();
     return Status::OK();
   }
 
+  DeviceAllocationType device_type() const override { return 
this->get_device_type(); }
+
  private:
   std::shared_ptr<Schema> schema_;
 };
 
-class ArrayStreamArrayReader : public ArrayStreamReader {
+template <typename StreamTraits, typename ArrayTraits>
+class ArrayStreamArrayReader : public ArrayStreamReader<StreamTraits, 
ArrayTraits> {
+  using StreamType = typename StreamTraits::CType;
+  using ArrayType = typename ArrayTraits::CType;
+
  public:
-  explicit ArrayStreamArrayReader(struct ArrowArrayStream* stream)
-      : ArrayStreamReader(stream) {}
+  explicit ArrayStreamArrayReader(
+      StreamType* stream, const DeviceMemoryMapper& mapper = 
DefaultDeviceMemoryMapper)

Review Comment:
   I would recommend avoiding default param values on internal abstractions. 
Not having them forces us to think about the correct value to pass on every 
callsite.



##########
cpp/src/arrow/record_batch.cc:
##########
@@ -167,30 +188,41 @@ class SimpleRecordBatch : public RecordBatch {
     return RecordBatch::Validate();
   }
 
+  std::shared_ptr<Device::SyncEvent> GetSyncEvent() const override { return 
sync_event_; }

Review Comment:
   Are there any override that would have to create one `Device::SyncEvent` per 
call? If not, this should return a `const ...&` to avoid the non-trivial cost 
of incrementing a ref-count on every call.



##########
cpp/src/arrow/c/bridge.h:
##########
@@ -321,6 +321,31 @@ ARROW_EXPORT
 Status ExportChunkedArray(std::shared_ptr<ChunkedArray> chunked_array,
                           struct ArrowArrayStream* out);
 
+/// \brief Export C++ RecordBatchReader using the C device stream interface
+///
+/// The resulting ArrowDeviceArrayStream struct keeps the record batch reader
+/// alive until its release callback is called by the consumer. The device
+/// type is determined by calling device_type() on the RecordBatchReader.
+///
+/// \param[in] reader RecordBatchReader object to export
+/// \param[out] out C struct to export the stream to
+ARROW_EXPORT
+Status ExportDeviceRecordBatchReader(std::shared_ptr<RecordBatchReader> reader,
+                                     struct ArrowDeviceArrayStream* out);
+
+/// \brief Export C++ ChunkedArray using the c device data interface format.

Review Comment:
   ```suggestion
   /// \brief Export C++ ChunkedArray using the C device data interface format.
   ```



##########
cpp/src/arrow/record_batch.cc:
##########
@@ -167,30 +188,41 @@ class SimpleRecordBatch : public RecordBatch {
     return RecordBatch::Validate();
   }
 
+  std::shared_ptr<Device::SyncEvent> GetSyncEvent() const override { return 
sync_event_; }
+
+  DeviceAllocationType device_type() const override { return device_type_; }
+
  private:
   std::vector<std::shared_ptr<ArrayData>> columns_;
 
   // Caching boxed array data
   mutable std::vector<std::shared_ptr<Array>> boxed_columns_;
+
+  // the type of device that the buffers for columns are allocated on.
+  // all columns should be on the same type of device.
+  DeviceAllocationType device_type_;
+  std::shared_ptr<Device::SyncEvent> sync_event_;
 };
 
 RecordBatch::RecordBatch(const std::shared_ptr<Schema>& schema, int64_t 
num_rows)
     : schema_(schema), num_rows_(num_rows) {}
 
 std::shared_ptr<RecordBatch> RecordBatch::Make(
     std::shared_ptr<Schema> schema, int64_t num_rows,
-    std::vector<std::shared_ptr<Array>> columns) {
+    std::vector<std::shared_ptr<Array>> columns,
+    std::shared_ptr<Device::SyncEvent> sync_event) {
   DCHECK_EQ(schema->num_fields(), static_cast<int>(columns.size()));
   return std::make_shared<SimpleRecordBatch>(std::move(schema), num_rows,
-                                             std::move(columns));
+                                             std::move(columns), sync_event);
 }
 
 std::shared_ptr<RecordBatch> RecordBatch::Make(
     std::shared_ptr<Schema> schema, int64_t num_rows,
-    std::vector<std::shared_ptr<ArrayData>> columns) {
+    std::vector<std::shared_ptr<ArrayData>> columns, DeviceAllocationType 
device_type,
+    std::shared_ptr<Device::SyncEvent> sync_event) {
   DCHECK_EQ(schema->num_fields(), static_cast<int>(columns.size()));
   return std::make_shared<SimpleRecordBatch>(std::move(schema), num_rows,
-                                             std::move(columns));
+                                             std::move(columns), device_type, 
sync_event);

Review Comment:
   `std::move(sync_event)`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to