zeroshade commented on code in PR #40807:
URL: https://github.com/apache/arrow/pull/40807#discussion_r1583821785


##########
cpp/src/arrow/c/bridge.cc:
##########
@@ -2250,101 +2340,126 @@ class ArrayStreamReader {
     return {code, last_error ? std::string(last_error) : ""};
   }
 
+  DeviceAllocationType get_device_type() const {
+    if constexpr (std::is_same_v<ArrayType, struct ArrowDeviceArray>) {
+      return static_cast<DeviceAllocationType>(stream_.device_type);
+    } else {
+      return DeviceAllocationType::kCPU;
+    }
+  }
+
  private:
-  mutable struct ArrowArrayStream stream_;
+  mutable StreamType stream_;
+  const DeviceMemoryMapper mapper_;
 };
 
-class ArrayStreamBatchReader : public RecordBatchReader, public 
ArrayStreamReader {
+template <typename StreamTraits, typename ArrayTraits>
+class ArrayStreamBatchReader : public RecordBatchReader,
+                               public ArrayStreamReader<StreamTraits, 
ArrayTraits> {
+  using StreamType = typename StreamTraits::CType;
+  using ArrayType = typename ArrayTraits::CType;
+
  public:
-  explicit ArrayStreamBatchReader(struct ArrowArrayStream* stream)
-      : ArrayStreamReader(stream) {}
+  explicit ArrayStreamBatchReader(
+      StreamType* stream, const DeviceMemoryMapper& mapper = 
DefaultDeviceMemoryMapper)
+      : ArrayStreamReader<StreamTraits, ArrayTraits>(stream, mapper) {}
 
   Status Init() {
-    ARROW_ASSIGN_OR_RAISE(schema_, ReadSchema());
+    ARROW_ASSIGN_OR_RAISE(schema_, this->ReadSchema());
     return Status::OK();
   }
 
   std::shared_ptr<Schema> schema() const override { return schema_; }
 
   Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
-    ARROW_RETURN_NOT_OK(CheckNotReleased());
+    ARROW_RETURN_NOT_OK(this->CheckNotReleased());
 
-    struct ArrowArray c_array;
-    ARROW_RETURN_NOT_OK(ReadNextArrayInternal(&c_array));
+    ArrayType c_array;
+    ARROW_RETURN_NOT_OK(this->ReadNextArrayInternal(&c_array));
 
-    if (ArrowArrayIsReleased(&c_array)) {
+    if (ArrayTraits::IsReleasedFunc(&c_array)) {
       // End of stream
       batch->reset();
       return Status::OK();
     } else {
-      return ImportRecordBatch(&c_array, schema_).Value(batch);
+      return this->ImportRecordBatchInternal(&c_array, schema_).Value(batch);
     }
   }
 
   Status Close() override {
-    ReleaseStream();
+    this->ReleaseStream();
     return Status::OK();
   }
 
+  DeviceAllocationType device_type() const override { return 
this->get_device_type(); }
+
  private:
   std::shared_ptr<Schema> schema_;
 };
 
-class ArrayStreamArrayReader : public ArrayStreamReader {
+template <typename StreamTraits, typename ArrayTraits>
+class ArrayStreamArrayReader : public ArrayStreamReader<StreamTraits, 
ArrayTraits> {
+  using StreamType = typename StreamTraits::CType;
+  using ArrayType = typename ArrayTraits::CType;
+
  public:
-  explicit ArrayStreamArrayReader(struct ArrowArrayStream* stream)
-      : ArrayStreamReader(stream) {}
+  explicit ArrayStreamArrayReader(
+      StreamType* stream, const DeviceMemoryMapper& mapper = 
DefaultDeviceMemoryMapper)

Review Comment:
   In the vast majority of cases, using the `DefaultDeviceMemoryMapper` is the 
correct value to pass on any callsite. I also don't want to break any existing 
callers / users of this function.
   
   It's only when you start introducing non-CPU data that you'd want to pass 
something else.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to